import os
import random
import math
import numpy as np
import torch as th
[docs]class Compose:
"""
Composes (chains) several transforms together.
:param transforms: (list of transforms) to apply sequentially
"""
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, *inputs):
for transform in self.transforms:
if not isinstance(inputs, (list,tuple)):
inputs = [inputs]
inputs = transform(*inputs)
return inputs
[docs]class RandomChoiceCompose:
"""
Randomly choose to apply one transform from a collection of transforms
e.g. to randomly apply EITHER 0-1 or -1-1 normalization to an input:
>>> transform = RandomChoiceCompose([RangeNormalize(0,1),
RangeNormalize(-1,1)])
>>> x_norm = transform(x) # only one of the two normalizations is applied
:param transforms: (list of transforms) to choose from at random
"""
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, *inputs):
tform = random.choice(self.transforms)
outputs = tform(*inputs)
return outputs
[docs]class ToTensor:
"""
Converts a numpy array to torch.Tensor
"""
def __call__(self, *inputs):
outputs = []
idx = None
for idx, _input in enumerate(inputs):
_input = th.from_numpy(_input)
outputs.append(_input)
return outputs if idx >= 1 else outputs[0]
[docs]class ToFile:
"""
Saves an image to file. Useful as a pass-through transform
when wanting to observe how augmentation affects the data
NOTE: Only supports saving to Numpy currently
:param root: (string):
path to main directory in which images will be saved
"""
def __init__(self, root):
if root.startswith('~'):
root = os.path.expanduser(root)
self.root = root
self.counter = 0
def __call__(self, *inputs):
for idx, _input in inputs:
fpath = os.path.join(self.root, 'img_%i_%i.npy'%(self.counter, idx))
np.save(fpath, _input.numpy())
self.counter += 1
return inputs
[docs]class ToNumpyType:
"""
Converts an object to a specific numpy type (with the idea to be passed to ToTensor() next)
:param type: (one of `{numpy.double, numpy.float, numpy.int64, numpy.int32, and numpy.uint8})
"""
def __init__(self, type):
self.type = type
def __call__(self, input_):
if isinstance(input_, list): # handle a simple list
return np.array(input_, dtype=self.type)
return input_.astype(self.type)
[docs]class ChannelsLast:
"""
Transposes a tensor so that the channel dim is last
`HWC` and `DHWC` are aliases for this transform.
:param safe_check: (bool):
if true, will check if channels are already last and, if so,
will just return the inputs
"""
def __init__(self, safe_check=False):
self.safe_check = safe_check
def __call__(self, *inputs):
ndim = inputs[0].dim()
if self.safe_check:
# check if channels are already last
if inputs[0].size(-1) < inputs[0].size(0):
return inputs
plist = list(range(1, ndim))+[0]
outputs = []
idx = None
for idx, _input in enumerate(inputs):
_input = _input.permute(*plist)
outputs.append(_input)
return outputs if idx >= 1 else outputs[0]
HWC = ChannelsLast
DHWC = ChannelsLast
[docs]class ChannelsFirst:
"""
Transposes a tensor so that the channel dim is first.
`CHW` and `CDHW` are aliases for this transform.
:param safe_check: (bool):
if true, will check if channels are already first and, if so,
will just return the inputs
"""
def __init__(self, safe_check=False):
self.safe_check = safe_check
def __call__(self, *inputs):
ndim = inputs[0].dim()
if self.safe_check:
# check if channels are already first
if inputs[0].size(0) < inputs[0].size(-1):
return inputs
plist = [ndim-1] + list(range(0,ndim-1))
outputs = []
idx = None
for idx, _input in enumerate(inputs):
_input = _input.permute(*plist)
outputs.append(_input)
return outputs if idx >= 1 else outputs[0]
CHW = ChannelsFirst
CDHW = ChannelsFirst
[docs]class TypeCast:
"""
Cast a torch.Tensor to a different type
param dtype: (string or torch.*Tensor literal or list) of such
data type to which input(s) will be cast.
If list, it should be the same length as inputs.
"""
def __init__(self, dtype='float'):
if isinstance(dtype, (list,tuple)):
dtypes = []
for dt in dtype:
if isinstance(dt, str):
if dt == 'byte':
dt = th.ByteTensor
elif dt == 'double':
dt = th.DoubleTensor
elif dt == 'float':
dt = th.FloatTensor
elif dt == 'int':
dt = th.IntTensor
elif dt == 'long':
dt = th.LongTensor
elif dt == 'short':
dt = th.ShortTensor
dtypes.append(dt)
self.dtype = dtypes
else:
if isinstance(dtype, str):
if dtype == 'byte':
dtype = th.ByteTensor
elif dtype == 'double':
dtype = th.DoubleTensor
elif dtype == 'float':
dtype = th.FloatTensor
elif dtype == 'int':
dtype = th.IntTensor
elif dtype == 'long':
dtype = th.LongTensor
elif dtype == 'short':
dtype = th.ShortTensor
self.dtype = dtype
def __call__(self, *inputs):
if not isinstance(self.dtype, (tuple,list)):
dtypes = [self.dtype]*len(inputs)
else:
dtypes = self.dtype
outputs = []
idx = None
for idx, _input in enumerate(inputs):
_input = _input.type(dtypes[idx])
outputs.append(_input)
return outputs if idx >= 1 else outputs[0]
[docs]class AddChannel:
"""Adds a dummy channel to an image, also known as expanding an axis or unsqueezing a dim
This will make an image of size (28, 28) to now be
of size (1, 28, 28), for example.
param axis: (int): dimension to be expanded to singleton
"""
def __init__(self, axis=0):
self.axis = axis
def __call__(self, *inputs):
outputs = []
idx = None
for idx, _input in enumerate(inputs):
_input = _input.unsqueeze(self.axis)
outputs.append(_input)
return outputs if idx >= 1 else outputs[0]
ExpandAxis = AddChannel
Unsqueeze = AddChannel
[docs]class Transpose:
"""
Swaps two dimensions of a tensor
:param dim1: (int):
first dim to switch
:param dim2: (int):
second dim to switch
"""
def __init__(self, dim1, dim2):
self.dim1 = dim1
self.dim2 = dim2
def __call__(self, *inputs):
outputs = []
idx = None
for idx, _input in enumerate(inputs):
_input = th.transpose(_input, self.dim1, self.dim2)
outputs.append(_input)
return outputs if idx >= 1 else outputs[0]
[docs]class RangeNormalize:
"""
Given min_val: (R, G, B) and max_val: (R,G,B),
will normalize each channel of the th.*Tensor to
the provided min and max values.
Works by calculating :
a = (max'-min')/(max-min)\n
b = max' - a * max\n
new_value = a * value + b
where min' & max' are given values,
and min & max are observed min/max for each channel
:param min_val: (float or integer):
Lower bound of normalized tensor
:param max_val: (float or integer):
Upper bound of normalized tensor
Example:
>>> x = th.rand(3,5,5)
>>> rn = RangeNormalize((0,0,10),(1,1,11))
>>> x_norm = rn(x)
Also works with just one value for min/max:
>>> x = th.rand(3,5,5)
>>> rn = RangeNormalize(0,1)
>>> x_norm = rn(x)
"""
def __init__(self, min_val, max_val):
"""
Normalize a tensor between a min and max value
"""
self.min_val = min_val
self.max_val = max_val
def __call__(self, *inputs):
outputs = []
idx = None
for idx, _input in enumerate(inputs):
_min_val = _input.min()
_max_val = _input.max()
a = (self.max_val - self.min_val) / (_max_val - _min_val)
b = self.max_val- a * _max_val
_input = _input.mul(a).add(b)
outputs.append(_input)
return outputs if idx >= 1 else outputs[0]
[docs]class StdNormalize:
"""
Normalize torch tensor to have zero mean and unit std deviation
"""
def __call__(self, *inputs):
outputs = []
idx = None
for idx, _input in enumerate(inputs):
_input = _input.sub(_input.mean()).div(_input.std())
outputs.append(_input)
return outputs if idx >= 1 else outputs[0]
[docs]class Slice2D:
"""
Take a random 2D slice from a 3D image along
a given axis. This image should not have a 4th channel dim.
:param axis: (int `in {0, 1, 2}`):
the axis on which to take slices
:param reject_zeros: (bool):
whether to reject slices that are all zeros
"""
def __init__(self, axis=0, reject_zeros=False):
self.axis = axis
self.reject_zeros = reject_zeros
def __call__(self, x, y=None):
while True:
keep_slice = random.randint(0, x.size(self.axis) - 1)
if self.axis == 0:
slice_x = x[keep_slice, :, :]
if y is not None:
slice_y = y[keep_slice, :, :]
elif self.axis == 1:
slice_x = x[:, keep_slice, :]
if y is not None:
slice_y = y[:, keep_slice, :]
elif self.axis == 2:
slice_x = x[:, :, keep_slice]
if y is not None:
slice_y = y[:, :, keep_slice]
if not self.reject_zeros:
break
if y is not None and th.sum(slice_y) > 0:
break
if th.sum(slice_x) > 0:
break
if y is not None:
return slice_x, slice_y
return slice_x
[docs]class RandomCrop:
"""
Randomly crop a torch tensor
:param size: (tuple or list):
dimensions of the crop
"""
def __init__(self, size):
self.size = size
def __call__(self, *inputs):
h_idx = random.randint(0,inputs[0].size(1)-self.size[0])
w_idx = random.randint(0,inputs[1].size(2)-self.size[1])
outputs = []
idx = None
for idx, _input in enumerate(inputs):
_input = _input[:, h_idx:(h_idx+self.size[0]),w_idx:(w_idx+self.size[1])]
outputs.append(_input)
return outputs if idx >= 1 else outputs[0]
[docs]class SpecialCrop:
"""
Perform a special crop - one of the four corners or center crop
:param size: (tuple or list):
dimensions of the crop
:param crop_type: (int in `{0,1,2,3,4}`):
0 = center crop
1 = top left crop
2 = top right crop
3 = bottom right crop
4 = bottom left crop
"""
def __init__(self, size, crop_type=0):
if crop_type not in {0, 1, 2, 3, 4}:
raise ValueError('crop_type must be in {0, 1, 2, 3, 4}')
self.size = size
self.crop_type = crop_type
def __call__(self, x, y=None):
if self.crop_type == 0:
# center crop
x_diff = (x.size(1)-self.size[0])/2.
y_diff = (x.size(2)-self.size[1])/2.
ct_x = [int(math.ceil(x_diff)),x.size(1)-int(math.floor(x_diff))]
ct_y = [int(math.ceil(y_diff)),x.size(2)-int(math.floor(y_diff))]
indices = [ct_x,ct_y]
elif self.crop_type == 1:
# top left crop
tl_x = [0, self.size[0]]
tl_y = [0, self.size[1]]
indices = [tl_x,tl_y]
elif self.crop_type == 2:
# top right crop
tr_x = [0, self.size[0]]
tr_y = [x.size(2)-self.size[1], x.size(2)]
indices = [tr_x,tr_y]
elif self.crop_type == 3:
# bottom right crop
br_x = [x.size(1)-self.size[0],x.size(1)]
br_y = [x.size(2)-self.size[1],x.size(2)]
indices = [br_x,br_y]
elif self.crop_type == 4:
# bottom left crop
bl_x = [x.size(1)-self.size[0], x.size(1)]
bl_y = [0, self.size[1]]
indices = [bl_x,bl_y]
x = x[:,indices[0][0]:indices[0][1],indices[1][0]:indices[1][1]]
if y is not None:
y = y[:,indices[0][0]:indices[0][1],indices[1][0]:indices[1][1]]
return x, y
return x
[docs]class Pad:
"""
Pads an image to the given size
Arguments
---------
:param size: (tuple or list):
size of crop
"""
def __init__(self, size):
self.size = size
def __call__(self, x, y=None):
x = x.numpy()
shape_diffs = [int(np.ceil((i_s - d_s))) for d_s,i_s in zip(x.shape,self.size)]
shape_diffs = np.maximum(shape_diffs,0)
pad_sizes = [(int(np.ceil(s/2.)),int(np.floor(s/2.))) for s in shape_diffs]
x = np.pad(x, pad_sizes, mode='constant')
if y is not None:
y = y.numpy()
y = np.pad(y, pad_sizes, mode='constant')
return th.from_numpy(x), th.from_numpy(y)
return th.from_numpy(x)
[docs]class PadNumpy:
"""
Pads a Numpy image to the given size
Return a Numpy image / image pair
Arguments
---------
:param size: (tuple or list):
size of crop
"""
def __init__(self, size):
self.size = size
def __call__(self, x, y=None):
shape_diffs = [int(np.ceil((i_s - d_s))) for d_s,i_s in zip(x.shape,self.size)]
shape_diffs = np.maximum(shape_diffs,0)
pad_sizes = [(int(np.ceil(s/2.)),int(np.floor(s/2.))) for s in shape_diffs]
x = np.pad(x, pad_sizes, mode='constant')
if y is not None:
y = np.pad(y, pad_sizes, mode='constant')
return x, y
return x
[docs]class RandomFlip:
"""
Randomly flip an image horizontally and/or vertically with
some probability.
:param h: (bool):
whether to horizontally flip w/ probability p
:param v: (bool):
whether to vertically flip w/ probability p
:param p: (float between [0,1]):
probability with which to apply allowed flipping operations
"""
def __init__(self, h=True, v=False, p=0.5):
self.horizontal = h
self.vertical = v
self.p = p
def __call__(self, x, y=None):
x = x.numpy()
if y is not None:
y = y.numpy()
# horizontal flip with p = self.p
if self.horizontal:
if random.random() < self.p:
x = x.swapaxes(2, 0)
x = x[::-1, ...]
x = x.swapaxes(0, 2)
if y is not None:
y = y.swapaxes(2, 0)
y = y[::-1, ...]
y = y.swapaxes(0, 2)
# vertical flip with p = self.p
if self.vertical:
if random.random() < self.p:
x = x.swapaxes(1, 0)
x = x[::-1, ...]
x = x.swapaxes(0, 1)
if y is not None:
y = y.swapaxes(1, 0)
y = y[::-1, ...]
y = y.swapaxes(0, 1)
if y is None:
# must copy because torch doesnt current support neg strides
return th.from_numpy(x.copy())
return th.from_numpy(x.copy()),th.from_numpy(y.copy())
[docs]class RandomOrder:
"""
Randomly permute the channels of an image
"""
def __call__(self, *inputs):
order = th.randperm(inputs[0].dim())
outputs = []
idx = None
for idx, _input in enumerate(inputs):
_input = _input.index_select(0, order)
outputs.append(_input)
return outputs if idx >= 1 else outputs[0]