import numpy as np
import torch as th
from torchvision import transforms
from .data_utils import is_tuple_or_list
[docs]class BaseDataset:
"""An abstract class representing a Dataset.
All other datasets should subclass it. All subclasses should override
``__len__``, that provides the size of the dataset, and ``__getitem__``,
supporting integer indexing in range from 0 to len(self) exclusive.
"""
def __len__(self):
return len(self.inputs) if not isinstance(self.inputs, (tuple,list)) else len(self.inputs[0])
def add_input_transform(self, transform, add_to_front=True, idx=None):
if idx is None:
idx = np.arange(len(self.num_inputs))
elif not is_tuple_or_list(idx):
idx = [idx]
if add_to_front:
for i in idx:
self.input_transform[i] = transforms.Compose([transform, self.input_transform[i]])
else:
for i in idx:
self.input_transform[i] = transforms.Compose([self.input_transform[i], transform])
def add_target_transform(self, transform, add_to_front=True, idx=None):
if idx is None:
idx = np.arange(len(self.num_targets))
elif not is_tuple_or_list(idx):
idx = [idx]
if add_to_front:
for i in idx:
self.target_transform[i] = transforms.Compose([transform, self.target_transform[i]])
else:
for i in idx:
self.target_transform[i] = transforms.Compose([self.target_transform[i], transform])
def add_co_transform(self, transform, add_to_front=True, idx=None):
if idx is None:
idx = np.arange(len(self.min_inputs_or_targets))
elif not is_tuple_or_list(idx):
idx = [idx]
if add_to_front:
for i in idx:
self.co_transform[i] = transforms.Compose([transform, self.co_transform[i]])
else:
for i in idx:
self.co_transform[i] = transforms.Compose([self.co_transform[i], transform])
[docs] def load(self, num_samples=None, load_range=None):
"""
Load all data or a subset of the data into actual memory.
For instance, if the inputs are paths to image files, then this
function will actually load those images.
:param num_samples: (int (optional)):
number of samples to load. if None, will load all
:param load_range: (numpy array of integers (optional)):
the index range of images to load
e.g. np.arange(4) loads the first 4 inputs+targets
"""
def _parse_shape(x):
if isinstance(x, (list,tuple)):
return (len(x),)
elif isinstance(x, th.Tensor):
return x.size()
else:
return (1,)
if num_samples is None and load_range is None:
num_samples = len(self)
load_range = np.arange(num_samples)
elif num_samples is None and load_range is not None:
num_samples = len(load_range)
elif num_samples is not None and load_range is None:
load_range = np.arange(num_samples)
if self.has_target:
for enum_idx, sample_idx in enumerate(load_range):
input_sample, target_sample = self.__getitem__(sample_idx)
if enum_idx == 0:
if self.num_inputs == 1:
_shape = [len(load_range)] + list(_parse_shape(input_sample))
inputs = np.empty(_shape)
else:
inputs = []
for i in range(self.num_inputs):
_shape = [len(load_range)] + list(_parse_shape(input_sample[i]))
inputs.append(np.empty(_shape))
#inputs = [np.empty((len(load_range), *_parse_shape(input_sample[i]))) for i in range(self.num_inputs)]
if self.num_targets == 1:
_shape = [len(load_range)] + list(_parse_shape(target_sample))
targets = np.empty(_shape)
#targets = np.empty((len(load_range), *_parse_shape(target_sample)))
else:
targets = []
for i in range(self.num_targets):
_shape = [len(load_range)] + list(_parse_shape(target_sample[i]))
targets.append(np.empty(_shape))
#targets = [np.empty((len(load_range), *_parse_shape(target_sample[i]))) for i in range(self.num_targets)]
if self.num_inputs == 1:
inputs[enum_idx] = input_sample
else:
for i in range(self.num_inputs):
inputs[i][enum_idx] = input_sample[i]
if self.num_targets == 1:
targets[enum_idx] = target_sample
else:
for i in range(self.num_targets):
targets[i][enum_idx] = target_sample[i]
return inputs, targets
else:
for enum_idx, sample_idx in enumerate(load_range):
input_sample = self.__getitem__(sample_idx)
if enum_idx == 0:
if self.num_inputs == 1:
_shape = [len(load_range)] + list(_parse_shape(input_sample))
inputs = np.empty(_shape)
#inputs = np.empty((len(load_range), *_parse_shape(input_sample)))
else:
inputs = []
for i in range(self.num_inputs):
_shape = [len(load_range)] + list(_parse_shape(input_sample[i]))
inputs.append(np.empty(_shape))
#inputs = [np.empty((len(load_range), *_parse_shape(input_sample[i]))) for i in range(self.num_inputs)]
if self.num_inputs == 1:
inputs[enum_idx] = input_sample
else:
for i in range(self.num_inputs):
inputs[i][enum_idx] = input_sample[i]
return inputs