Source code for pywick.datasets.MultiFolderDataset

import itertools
import os

from PIL import Image
from .FolderDataset import FolderDataset, npy_loader, pil_loader, rgb_image_loader, rgba_image_loader, _find_classes, _finds_inputs_and_targets


[docs]class MultiFolderDataset(FolderDataset): """ This class extends the FolderDataset with abilty to supply multiple root directories. The ``rel_target_root`` must exist relative to each root directory. For complete description of functionality see ``FolderDataset`` :param roots: (list): list of root directories to traverse\n :param class_mode: (string in `{'label', 'image', 'path'}):` type of target sample to look for and return\n `label` = return class folder as target\n `image` = return another image as target (determined by optional target_prefix/postfix)\n NOTE: if class_mode == 'image', in addition to input, you must also provide rel_target_root, target_prefix or target_postfix (in any combination). `path` = determines paths for inputs and targets and applies the respective loaders to the path :param class_to_idx: (dict): If specified, the given class_to_idx map will be used. Otherwise one will be derived from the directory structure. :param input_regex: (string `(default is any valid image file)`): regular expression to find input images\n e.g. if all your inputs have the word 'input', you'd enter something like input_regex='*input*' :param rel_target_root: (string `(default is Nothing)`): root of directory where to look for target images RELATIVE to the root dir (first arg) :param target_prefix: (string `(default is Nothing)`): prefix to use (if any) when trying to locate the matching target :param target_postfix: (string): postfix to use (if any) when trying to locate the matching target :param transform: (torch transform): transform to apply to input sample individually :param target_transform: (torch transform): transform to apply to target sample individually :param co_transform: (torch transform): transform to apply to both the input and the target :param apply_co_transform_first: (bool): whether to apply the co-transform before or after individual transforms (default: True = before) :param default_loader: (string in `{'npy', 'pil'}` or function `(default: pil)`): defines how to load samples from file. Will be applied to both input and target unless a separate target_loader is defined.\n if a function is provided, it should take in a file path as input and return the loaded sample. :param target_loader: (string in `{'npy', 'pil'}` or function `(default: pil)`): defines how to load target samples from file\n if a function is provided, it should take in a file path as input and return the loaded sample. :param exclusion_file: (string): list of files to exclude when enumerating all files. The list must be a full path relative to the root parameter :param target_index_map: (dict `(defaults to binary mask: {255:1})): a dictionary that maps pixel values in the image to classes to be recognized.\n Used in conjunction with 'image' class_mode to produce a label for semantic segmentation For semantic segmentation this is required so the default is a binary mask. However, if you want to turn off this feature then specify target_index_map=None """ def __init__(self, roots, class_mode='label', class_to_idx=None, input_regex='*', rel_target_root='', target_prefix='', target_postfix='', target_extension='png', transform=None, target_transform=None, co_transform=None, apply_co_transform_first=True, default_loader='pil', target_loader=None, exclusion_file=None, target_index_map=None): # call the super constructor first, then set our own parameters # super().__init__() self.num_inputs = 1 # these are hardcoded for the fit module to work self.num_targets = 1 # these are hardcoded for the fit module to work if default_loader == 'npy': default_loader = npy_loader elif default_loader == 'pil': default_loader = pil_loader self.default_loader = default_loader # separate loading for targets (e.g. for black/white masks) self.target_loader = target_loader if class_to_idx: self.classes = class_to_idx.keys() self.class_to_idx = class_to_idx else: self.classes, self.class_to_idx = _find_classes(roots) data_list = [] for root in roots: datai, _ = _finds_inputs_and_targets(root, class_mode=class_mode, class_to_idx=self.class_to_idx, input_regex=input_regex, rel_target_root=rel_target_root, target_prefix=target_prefix, target_postfix=target_postfix, target_extension=target_extension, exclusion_file=exclusion_file) data_list.append(datai) self.data = list(itertools.chain.from_iterable(data_list)) if len(self.data) == 0: raise (RuntimeError('Found 0 data items in subfolders of: {}'.format(roots))) print('Found %i data items' % len(self.data)) self.roots = [os.path.expanduser(x) for x in roots] self.transform = transform self.target_transform = target_transform self.co_transform = co_transform self.apply_co_transform_first = apply_co_transform_first self.target_index_map = target_index_map self.class_mode = class_mode