Source code for pywick.samplers

"""
Samplers are used during the training phase and are especially useful when your training data is not uniformly distributed among
all of your classes.
"""

import math
from .utils import th_random_choice
import torch.utils.data
import torchvision

from .datasets import FolderDataset, MultiFolderDataset, PredictFolderDataset, ClonedFolderDataset
from torch.utils.data.sampler import Sampler


[docs]class StratifiedSampler(Sampler): """Stratified Sampling Provides equal representation of target classes in each batch :param class_vector: (torch tensor): a vector of class labels :param batch_size: (int): size of the batch """ def __init__(self, class_vector, batch_size): self.n_splits = int(class_vector.size(0) / batch_size) self.class_vector = class_vector
[docs] def gen_sample_array(self): try: from sklearn.model_selection import StratifiedShuffleSplit except: print('Need scikit-learn for this functionality') import numpy as np s = StratifiedShuffleSplit(n_splits=self.n_splits, test_size=0.5) X = torch.randn(self.class_vector.size(0),2).numpy() y = self.class_vector.numpy() s.get_n_splits(X, y) train_index, test_index = next(s.split(X, y)) return np.hstack([train_index, test_index])
def __iter__(self): return iter(self.gen_sample_array()) def __len__(self): return len(self.class_vector)
[docs]class MultiSampler(Sampler): """Samples elements more than once in a single pass through the data. This allows the number of samples per epoch to be larger than the number of samples itself, which can be useful when training on 2D slices taken from 3D images, for instance. """ def __init__(self, nb_samples, desired_samples, shuffle=False): """Initialize MultiSampler :param data_source: (dataset): the dataset to sample from :param desired_samples: (int): number of samples per batch you want. Whatever the difference is between an even division will be randomly selected from the samples. e.g. if len(data_source) = 3 and desired_samples = 4, then all 3 samples will be included and the last sample will be randomly chosen from the 3 original samples. :param shuffle: (bool): whether to shuffle the indices or not Example: >>> m = MultiSampler(2, 6) >>> x = m.gen_sample_array() >>> print(x) # [0,1,0,1,0,1] """ self.data_samples = nb_samples self.desired_samples = desired_samples self.shuffle = shuffle
[docs] def gen_sample_array(self): n_repeats = self.desired_samples / self.data_samples cat_list = [] for i in range(math.floor(n_repeats)): cat_list.append(torch.arange(0,self.data_samples)) # add the left over samples left_over = self.desired_samples % self.data_samples if left_over > 0: cat_list.append(th_random_choice(self.data_samples, left_over)) self.sample_idx_array = torch.cat(cat_list).long() return self.sample_idx_array
def __iter__(self): return iter(self.gen_sample_array()) def __len__(self): return self.desired_samples
# Source: https://github.com/ufoym/imbalanced-dataset-sampler (MIT license)
[docs]class ImbalancedDatasetSampler(Sampler): """Samples elements randomly from a given list of indices for imbalanced dataset :param indices: (list, optional): a list of indices :param num_samples: (int, optional): number of samples to draw """ def __init__(self, dataset, indices=None, num_samples=None): # if indices is not provided, all elements in the dataset will be considered self.indices = list(range(len(dataset))) \ if indices is None else indices # if num_samples is not provided, draw `len(indices)` samples in each iteration self.num_samples = len(self.indices) \ if num_samples is None else num_samples # distribution of classes in the dataset label_to_count = {} for idx in self.indices: label = self._get_label(dataset, idx) if label in label_to_count: label_to_count[label] += 1 else: label_to_count[label] = 1 # weight for each sample weights = [1.0 / label_to_count[self._get_label(dataset, idx)] for idx in self.indices] self.weights = torch.DoubleTensor(weights) @staticmethod def _get_label(dataset, idx): dataset_type = type(dataset) if dataset_type is torchvision.datasets.MNIST: return dataset.train_labels[idx].item() elif dataset_type is torchvision.datasets.ImageFolder: return dataset.imgs[idx][1] elif dataset_type is FolderDataset or dataset_type is MultiFolderDataset or dataset_type is PredictFolderDataset or dataset_type is ClonedFolderDataset: return dataset.getdata()[idx][1] else: raise NotImplementedError def __iter__(self): return (self.indices[i] for i in torch.multinomial(self.weights, self.num_samples, replacement=True)) def __len__(self): return self.num_samples