Source code for pywick.datasets.tnt.shuffledataset

from .resampledataset import ResampleDataset
import torch


[docs]class ShuffleDataset(ResampleDataset): """ Dataset which shuffles a given dataset. `ShuffleDataset` is a sub-class of `ResampleDataset` provided for convenience. It samples uniformly from the given `dataset` with, or without `replacement`. The chosen partition can be redrawn by calling `resample()` If `replacement` is `true`, then the specified `size` may be larger than the underlying `dataset`. If `size` is not provided, then the new dataset size will be equal to the underlying `dataset` size. Purpose: the easiest way to shuffle a dataset! Args: dataset (Dataset): Dataset to be shuffled. size (int, optional): Desired size of the shuffled dataset. If `replacement` is `true`, then can be larger than the `len(dataset)`. By default, the new dataset will have the same size as `dataset`. replacement (bool, optional): True if uniform sampling is to be done with replacement. False otherwise. Defaults to false. Raises: ValueError: If `size` is larger than the size of the underlying dataset and `replacement` is False. """ def __init__(self, dataset, size=None, replacement=False): if size and not replacement and size > len(dataset): raise ValueError('size cannot be larger than underlying dataset \ size when sampling without replacement') super(ShuffleDataset, self).__init__(dataset, lambda dataset, idx: self.perm[idx], size) self.replacement = replacement self.resample()
[docs] def resample(self, seed=None): """Resample the dataset. Args: seed (int, optional): Seed for resampling. By default no seed is used. """ if seed is not None: gen = torch.manual_seed(seed) else: gen = torch.default_generator if self.replacement: self.perm = torch.LongTensor(len(self)).random_(len(self.dataset), generator=gen) else: self.perm = torch.randperm(len(self.dataset), generator=gen).narrow(0, 0, len(self))