from torch.utils.data import DataLoader
[docs]class Dataset:
def __init__(self):
pass
def __len__(self):
pass
def __getitem__(self, idx):
if idx >= len(self):
raise IndexError("CustomRange index out of range")
pass
[docs] def batch(self, *args, **kwargs):
from .batchdataset import BatchDataset
return BatchDataset(self, *args, **kwargs)
[docs] def shuffle(self, *args, **kwargs):
from .shuffledataset import ShuffleDataset
return ShuffleDataset(self, *args, **kwargs)
[docs] def parallel(self, *args, **kwargs):
return DataLoader(self, *args, **kwargs)
[docs] def partition(self, *args, **kwargs):
from .multipartitiondataset import MultiPartitionDataset
return MultiPartitionDataset(self, *args, **kwargs)