Source code for pywick.datasets.tnt.splitdataset

from .dataset import Dataset
import numpy as np


[docs]class SplitDataset(Dataset): """ Dataset to partition a given dataset. Partition a given `dataset`, according to the specified `partitions`. Use the method `select()` to select the current partition in use. The `partitions` is a dictionary where a key is a user-chosen string naming the partition, and value is a number representing the weight (as a number between 0 and 1) or the size (in number of samples) of the corresponding partition. Partioning is achieved linearly (no shuffling). See `ShuffleDataset` if you want to shuffle the dataset before partitioning. Args: dataset (Dataset): Dataset to be split. partitions (dict): Dictionary where key is a user-chosen string naming the partition, and value is a number representing the weight (as a number between 0 and 1) or the size (in number of samples) of the corresponding partition. initial_partition (str, optional): Initial parition to be selected. """ def __init__(self, dataset, partitions, initial_partition=None): super(SplitDataset, self).__init__() self.dataset = dataset self.partitions = partitions # A few assertions if not isinstance(partitions, dict): raise AssertionError('partitions must be a dict') if len(partitions) < 2: raise AssertionError('SplitDataset should have at least two partitions') if min(partitions.values()) < 0: raise AssertionError('partition sizes cannot be negative') if max(partitions.values()) <= 0: raise AssertionError('all partitions cannot be empty') self.partition_names = sorted(list(self.partitions.keys())) self.partition_index = {partition: i for i, partition in enumerate(self.partition_names)} self.partition_sizes = [self.partitions[parition] for parition in self.partition_names] # if partition sizes are fractions, convert to sizes: if sum(self.partition_sizes) <= 1: self.partition_sizes = [round(x * len(dataset)) for x in self.partition_sizes] else: for x in self.partition_sizes: if x != int(x): raise AssertionError('partition sizes should be integer' ' numbers, or sum up to <= 1 ') self.partition_cum_sizes = np.cumsum(self.partition_sizes) if initial_partition is not None: self.select(initial_partition)
[docs] def select(self, partition): """ Select the parition. Args: partition (str): Partition to be selected. """ self.current_partition_idx = self.partition_index[partition]
def __len__(self): try: return self.partition_sizes[self.current_partition_idx] except AttributeError: raise ValueError("Select a partition before accessing data.") def __getitem__(self, idx): super(SplitDataset, self).__getitem__(idx) try: if self.current_partition_idx == 0: return self.dataset[idx] else: offset = self.partition_cum_sizes[self.current_partition_idx - 1] return self.dataset[int(offset) + idx] except AttributeError: raise ValueError("Select a partition before accessing data.")