Source code for pywick.datasets.tnt.transformdataset
from .dataset import Dataset
[docs]class TransformDataset(Dataset):
"""
Dataset which transforms a given dataset with a given function.
Given a function `transform`, and a `dataset`, `TransformDataset` applies
the function in an on-the-fly manner when querying a sample with
`__getitem__(idx)` and therefore returning `transform[dataset[idx]]`.
`transform` can also be a dict with functions as values. In this case, it
is assumed that `dataset[idx]` is a dict which has all the keys in
`transform`. Then, `transform[key]` is applied to dataset[idx][key] for
each key in `transform`
The size of the new dataset is equal to the size of the underlying
`dataset`.
Purpose: when performing pre-processing operations, it is convenient to be
able to perform on-the-fly transformations to a dataset.
Args:
dataset (Dataset): Dataset which has to be transformed.
transforms (function/dict): Function or dict with function as values.
These functions will be applied to data.
"""
def __init__(self, dataset, transforms):
super(TransformDataset, self).__init__()
if not (isinstance(transforms, dict) or callable(transforms)):
raise AssertionError('expected a dict of transforms or a function')
if isinstance(transforms, dict):
for k, v in transforms.items():
if not callable(v):
raise AssertionError(str(k) + ' is not a function')
self.dataset = dataset
self.transforms = transforms
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
super(TransformDataset, self).__getitem__(idx)
z = self.dataset[idx]
if isinstance(self.transforms, dict):
for k, transform in self.transforms.items():
z[k] = transform(z[k])
else:
z = self.transforms(z)
return z