"""
Losses are critical to training a neural network well. The training can only make progress if you
provide a meaningful measure of loss for each training step. What the loss looks like usually depends
on your application. Pytorch has a number of `loss functions <https://pytorch.org/docs/stable/nn.html#loss-functions/>`_ that
you can use out of the box. However, some more advanced and cutting edge loss functions exist that are not (yet) part of
Pytorch. We include those below for your experimenting.\n
**Caution:** if you decide to use one of these, you will definitely want to peruse the source code first, as it has
many additional useful notes and references which will help you.
Keep in mind that losses are specific to the type of task. Classification losses are computed differently from Segmentation losses.
Within segmentation domain make sure to use BCE (Binary Cross Entropy) for any work involving binary masks (e.g. num_classes = 1)
Make sure to read the documentation and notes (in the code) for each loss to understand how it is applied.
`Read this blog post <https://gombru.github.io/2018/05/23/cross_entropy_loss/>`_
Note:
Logit is the vector of raw (non-normalized) predictions that a classification model generates, which is ordinarily then passed to a normalization function.
If the model is solving a multi-class classification problem, logits typically become an input to the softmax function. The softmax function then generates
a vector of (normalized) probabilities with one value for each possible class.
For example, BCEWithLogitsLoss is a BCE that accepts R((-inf, inf)) and automatically applies torch.sigmoid to convert it to ([0,1]) space.
However, if you use one-hot encoding or similar methods where you need to convert a tensor to pytorch from another source (e.g. numpy), you will need to
make sure to apply the correct type to the resulting tensor. E.g. If y_hot is of type long and the BCE loss expects a Tensor of type float then you
can try converting y_hot with y_hot = y_hot.type_as(output).
To convert predictions into (0,1) range you will sometimes need to use either softmax or sigmoid.
Softmax is used for multi-classification in the Logistic Regression model, whereas Sigmoid is used for binary classification in the Logistic Regression model
"""
## Various loss calculation functions ##
# Sources: https://github.com/bermanmaxim/jaccardSegment/blob/master/losses.py (?)
# https://github.com/doodledood/carvana-image-masking-challenge/blob/master/losses.py (MIT)
# https://github.com/atlab/attorch/blob/master/attorch/losses.py (MIT)
# https://github.com/EKami/carvana-challenge (MIT)
# https://github.com/DingKe/pytorch_workplace (MIT)
import numpy as np
import torch
import math
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
from torch.autograd import Variable
from torch import Tensor
from typing import Iterable, Set
__all__ = ['ActiveContourLoss', 'ActiveContourLossAlt', 'AngularPenaltySMLoss', 'AsymLoss', 'BCELoss2d', 'BCEDiceLoss',
'BCEWithLogitsViewLoss', 'BCEDiceTL1Loss', 'BCEDicePenalizeBorderLoss', 'BCEDiceFocalLoss', 'BinaryFocalLoss',
'ComboBCEDiceLoss', 'ComboSemsegLossWeighted', 'EncNetLoss', 'FocalLoss', 'FocalLoss2',
'HausdorffERLoss', 'HausdorffDTLoss', 'LovaszSoftmax', 'mIoULoss', 'MixSoftmaxCrossEntropyOHEMLoss',
'MSE3D', 'OhemCELoss', 'OhemCrossEntropy2d', 'OhemBCEDicePenalizeBorderLoss', 'PoissonLoss',
'PoissonLoss3d', 'RecallLoss', 'RMILoss', 'RMILossAlt', 'RMIBCEDicePenalizeBorderLoss', 'SoftInvDiceLoss',
'SoftDiceLoss', 'StableBCELoss', 'TverskyLoss', 'ThresholdedL1Loss', 'WeightedSoftDiceLoss', 'WeightedBCELoss2d',
'BDLoss', 'L1Loss3d', 'WingLoss', 'BoundaryLoss']
VOID_LABEL = 255
N_CLASSES = 1
[docs]class StableBCELoss(nn.Module):
def __init__(self, **_):
super(StableBCELoss, self).__init__()
[docs] @staticmethod
def forward(input_, target, **_):
neg_abs = - input_.abs()
loss = input_.clamp(min=0) - input_ * target + (1 + neg_abs.exp()).log()
return loss.mean()
# WARN: Only applicable to Binary Segmentation!
def binaryXloss(logits, label):
mask = (label.view(-1) != VOID_LABEL)
nonvoid = mask.long().sum()
if nonvoid == 0:
# only void pixels, the gradients should be 0
return logits.sum() * 0.
# if nonvoid == mask.numel():
# # no void pixel, use builtin
# return F.cross_entropy(logits, label)
target = label.contiguous().view(-1)[mask]
logits = logits.contiguous().view(-1)[mask]
# loss = F.binary_cross_entropy(logits, target.float())
loss = StableBCELoss()(logits, target.float())
return loss
def naive_single(logit, label):
# single images
mask = (label.view(-1) != 255)
num_preds = mask.long().sum()
if num_preds == 0:
# only void pixels, the gradients should be 0
return logit.sum() * 0.
target = label.contiguous().view(-1)[mask].float()
logit = logit.contiguous().view(-1)[mask]
prob = torch.sigmoid(logit)
intersect = target * prob
union = target + prob - intersect
loss = (1. - intersect / union).sum()
return loss
# WARN: Only applicable to Binary Segmentation!
def hingeloss(logits, label):
mask = (label.view(-1) != 255)
num_preds = mask.long().sum().item()
if num_preds == 0:
# only void pixels, the gradients should be 0
return logits.sum().item() * 0.
target = label.contiguous().view(-1)[mask]
target = 2. * target.float() - 1. # [target == 0] = -1
logits = logits.contiguous().view(-1)[mask]
hinge = 1. / num_preds * F.relu(1. - logits * target).sum().item()
return hinge
def gamma_fast(gt, permutation):
p = len(permutation)
gt = gt.gather(0, permutation)
gts = gt.sum()
intersection = gts - gt.float().cumsum(0)
union = gts + (1 - gt).float().cumsum(0)
jaccard = 1. - intersection / union
jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]
return jaccard
# WARN: Only applicable to Binary Segmentation right now (zip function needs to be replaced)!
def lovaszloss(logits, labels, prox=False, max_steps=20, debug=None):
"""
`The Lovasz-Softmax loss <https://arxiv.org/abs/1705.08790>`_
:param logits:
:param labels:
:param prox:
:param max_steps:
:param debug:
:return:
"""
if debug is None:
debug = {}
# image-level Lovasz hinge
if logits.size(0) == 1:
# single image case
loss = lovasz_single(logits.squeeze(0), labels.squeeze(0), prox, max_steps, debug)
else:
losses = []
# assert len(logits[0]) == len(labels[0])
for logit, label in zip(logits, labels):
loss = lovasz_single(logit, label, prox, max_steps, debug)
losses.append(loss)
loss = sum(losses) / len(losses)
return loss
def naiveloss(logits, labels):
# image-level Lovasz hinge
if logits.size(0) == 1:
# single image case
loss = naive_single(logits.squeeze(0), labels.squeeze(0))
else:
losses = []
for logit, label in zip(logits, labels):
loss = naive_single(logit, label)
losses.append(loss)
loss = sum(losses) / len(losses)
return loss
def iouloss(pred, gt):
# works for one binary pred and associated target
# make byte tensors
pred = (pred == 1)
mask = (gt != 255)
gt = (gt == 1)
union = (gt | pred)[mask].long().sum()
if not union:
return 0.
else:
intersection = (gt & pred)[mask].long().sum()
return 1. - intersection / union
def compute_step_length(x, grad, active, eps=1e-6):
# compute next intersection with an edge in the direction grad
# OR next intersection with a 0 - border
# returns: delta in ind such that:
# after a step delta in the direction grad, x[ind] and x[ind+1] will be equal
delta = np.inf
ind = -1
if active > 0:
numerator = (x[:active] - x[1:active + 1]) # always positive (because x is sorted)
denominator = (grad[:active] - grad[1:active + 1])
# indices corresponding to negative denominator won't intersect
# also, we are not interested in indices in x that are *already equal*
valid = (denominator > eps) & (numerator > eps)
valid_indices = valid.nonzero()
intersection_times = numerator[valid] / denominator[valid]
if intersection_times.size():
delta, ind = intersection_times.min(0)
ind = valid_indices[ind]
delta, ind = delta[0], ind[0, 0]
if grad[active] > 0:
intersect_zero = x[active] / grad[active]
if intersect_zero > 0. and intersect_zero < delta:
return intersect_zero, -1
return delta, ind
def project(gam, active, members):
tovisit = set(range(active + 1))
while tovisit:
v = tovisit.pop()
if len(members[v]) > 1:
avg = 0.
for k in members[v]:
if k != v: tovisit.remove(k)
avg += gam[k] / len(members[v])
for k in members[v]:
gam[k] = avg
if active + 1 < len(gam):
gam[active + 1:] = 0.
def find_proximal(x0, gam, lam, eps=1e-6, max_steps=20, debug=None):
if debug is None:
debug = {}
# x0: sorted margins data
# gam: initial gamma_fast(target, perm)
# regularisation parameter lam
x = x0.clone()
act = (x >= eps).nonzero()
finished = False
if not act.size():
finished = True
else:
active = act[-1, 0]
members = {i: {i} for i in range(active + 1)}
if active > 0:
equal = (x[:active] - x[1:active + 1]) < eps
for i, e in enumerate(equal):
if e:
members[i].update(members[i + 1])
members[i + 1] = members[i]
project(gam, active, members)
step = 0
while not finished and step < max_steps and active > -1:
step += 1
res = compute_step_length(x, gam, active, eps)
delta, ind = res
if ind == -1:
active = active - len(members[active])
stop = torch.dot(x - x0, gam) / torch.dot(gam, gam) + 1. / lam
if 0 <= stop < delta:
delta = stop
finished = True
x = x - delta * gam
if not finished:
if ind >= 0:
repr = min(members[ind])
members[repr].update(members[ind + 1])
for m in members[ind]:
if m != repr:
members[m] = members[repr]
project(gam, active, members)
if "path" in debug:
debug["path"].append(x.numpy())
if "step" in debug:
debug["step"] = step
if "finished" in debug:
debug["finished"] = finished
return x, gam
def lovasz_binary(margins, label, prox=False, max_steps=20, debug=None):
if debug is None:
debug = {}
# 1d vector inputs
# Workaround: can't sort Variable bug
# prox: False or lambda regularization value
_, perm = torch.sort(margins.detach(), dim=0, descending=True)
margins_sorted = margins[perm]
grad = gamma_fast(label, perm)
loss = torch.dot(F.relu(margins_sorted), grad)
if prox is not False:
xp, gam = find_proximal(margins_sorted.detach(), grad, prox, max_steps=max_steps, eps=1e-6, debug=debug)
hook = margins_sorted.register_hook(lambda grad: (margins_sorted.detach() - xp))
return loss, hook, gam
else:
return loss
def lovasz_single(logit, label, prox=False, max_steps=20, debug=None):
if debug is None:
debug = {}
# single images
mask = (label.view(-1) != 255)
num_preds = mask.long().sum()
if num_preds == 0:
# only void pixels, the gradients should be 0
return logit.sum() * 0.
target = label.contiguous().view(-1)[mask]
signs = 2. * target.float() - 1.
logit = logit.contiguous().view(-1)[mask]
margins = (1. - logit * signs)
loss = lovasz_binary(margins, target, prox, max_steps, debug=debug)
return loss
def dice_coefficient(logit, label, isCuda=True):
'''
WARNING THIS IS VERY SLOW FOR SOME REASON!!
:param logit: calculated guess (expects torch.Tensor)
:param label: truth label (expects torch.Tensor)
:return: dice coefficient
'''
A = label.view(-1)
B = logit.view(-1)
A = A.clone()
B = B.clone()
if len(A) != len(B):
raise AssertionError
for i in list(range(len(A))):
if A[i] > 0.5:
A[i] = 1.0
else:
A[i] = 0.0
if B[i] > 0.5:
B[i] = 1.0
else:
B[i] = 0.0
if isCuda:
A = A.type(torch.cuda.ByteTensor)
else:
A = A.type(torch.ByteTensor)
dice = torch.masked_select(B, A).sum()*2.0 / (B.sum() + A.sum())
return dice
# ==================================== #
# Source: https://github.com/EKami/carvana-challenge
[docs]class WeightedSoftDiceLoss(torch.nn.Module):
def __init__(self, **_):
super(WeightedSoftDiceLoss, self).__init__()
[docs] @staticmethod
def forward(logits, labels, weights, **_):
probs = torch.sigmoid(logits)
num = labels.size(0)
w = weights.view(num,-1)
w2 = w*w
m1 = probs.view(num,-1)
m2 = labels.view(num,-1)
intersection = (m1 * m2)
score = 2. * ((w2*intersection).sum(1)+1) / ((w2*m1).sum(1) + (w2*m2).sum(1)+1)
score = 1 - score.sum()/num
return score
def dice_coeff(pred, target):
smooth = 1.
num = pred.size(0)
m1 = pred.view(num, -1) # Flatten
m2 = target.view(num, -1) # Flatten
intersection = (m1 * m2).sum()
return (2. * intersection + smooth) / (m1.sum() + m2.sum() + smooth)
def dice_coeff_hard_np(y_true, y_pred):
smooth = 1.
y_true_f = np.flatten(y_true)
y_pred_f = np.round(np.flatten(y_pred))
intersection = np.sum(y_true_f * y_pred_f)
score = (2. * intersection + smooth) / (np.sum(y_true_f) + np.sum(y_pred_f) + smooth)
return score
# ==================================== #
# Source: https://github.com/doodledood/carvana-image-masking-challenge/blob/master/losses.py
# TODO Replace this with nn.BCEWithLogitsLoss??
[docs]class BCELoss2d(nn.Module):
def __init__(self, weight=None, size_average=True, **_):
super(BCELoss2d, self).__init__()
self.bce_loss = nn.BCELoss(weight, size_average)
[docs] def forward(self, logits, labels, **_):
probs = torch.sigmoid(logits)
probs_flat = probs.view(-1)
targets_flat = labels.view(-1)
return self.bce_loss(probs_flat, targets_flat)
[docs]class SoftDiceLoss(nn.Module):
def __init__(self, smooth=1.0, **_):
super(SoftDiceLoss, self).__init__()
self.smooth = smooth
[docs] def forward(self, logits, labels, **_):
num = labels.size(0)
probs = torch.sigmoid(logits)
m1 = probs.view(num, -1)
m2 = labels.view(num, -1)
intersection = (m1 * m2)
# smooth = 1.
score = 2. * (intersection.sum(1) + self.smooth) / (m1.sum(1) + m2.sum(1) + self.smooth)
score = 1 - score.sum() / num
return score
[docs]class FocalLoss(nn.Module):
"""
Weighs the contribution of each sample to the loss based in the classification error.
If a sample is already classified correctly by the CNN, its contribution to the loss decreases.
:eps: Focusing parameter. eps=0 is equivalent to BCE_loss
"""
def __init__(self, l=0.5, eps=1e-6, **_):
super(FocalLoss, self).__init__()
self.l = l
self.eps = eps
[docs] def forward(self, logits, labels, **_):
labels = labels.view(-1)
probs = torch.sigmoid(logits).view(-1)
losses = -(labels * torch.pow((1. - probs), self.l) * torch.log(probs + self.eps) + \
(1. - labels) * torch.pow(probs, self.l) * torch.log(1. - probs + self.eps))
loss = torch.mean(losses)
return loss
[docs]class ThresholdedL1Loss(nn.Module):
def __init__(self, threshold=0.5, **_):
super(ThresholdedL1Loss, self).__init__()
self.threshold = threshold
[docs] def forward(self, logits, labels, **_):
labels = labels.view(-1)
probs = torch.sigmoid(logits).view(-1)
probs = (probs > self.threshold).float()
losses = torch.abs(labels - probs)
loss = torch.mean(losses)
return loss
[docs]class BCEDiceTL1Loss(nn.Module):
def __init__(self, threshold=0.5, **_):
super(BCEDiceTL1Loss, self).__init__()
self.bce = nn.BCEWithLogitsLoss(weight=None, size_average=None, reduce=None, reduction='mean', pos_weight=None)
self.dice = SoftDiceLoss()
self.tl1 = ThresholdedL1Loss(threshold=threshold)
[docs] def forward(self, logits, labels, **_):
return self.bce(logits, labels) + self.dice(logits, labels) + self.tl1(logits, labels)
[docs]class BCEDiceFocalLoss(nn.Module):
'''
:param num_classes: number of classes
:param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more
focus on hard misclassified example
:param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch.
:param weights: (list(), default = [1,1,1]) Optional weighing (0.0-1.0) of the losses in order of [bce, dice, focal]
'''
def __init__(self, focal_param, weights=None, **kwargs):
if weights is None:
weights = [1.0,1.0,1.0]
super(BCEDiceFocalLoss, self).__init__()
self.bce = BCEWithLogitsViewLoss(weight=None, size_average=True, **kwargs)
self.dice = SoftDiceLoss(**kwargs)
self.focal = FocalLoss(l=focal_param, **kwargs)
self.weights = weights
[docs] def forward(self, logits, labels, **_):
return self.weights[0] * self.bce(logits, labels) + self.weights[1] * self.dice(logits, labels) + self.weights[2] * self.focal(logits.unsqueeze(1), labels.unsqueeze(1))
[docs]class BCEDiceLoss(nn.Module):
def __init__(self, **_):
super(BCEDiceLoss, self).__init__()
self.bce = BCELoss2d()
self.dice = SoftDiceLoss()
[docs] def forward(self, logits, labels, **_):
return self.bce(logits, labels) + self.dice(logits, labels)
[docs]class WeightedBCELoss2d(nn.Module):
def __init__(self, **_):
super(WeightedBCELoss2d, self).__init__()
[docs] @staticmethod
def forward(logits, labels, weights, **_):
w = weights.view(-1) # (-1 operation flattens all the dimensions)
z = logits.view(-1) # (-1 operation flattens all the dimensions)
t = labels.view(-1) # (-1 operation flattens all the dimensions)
loss = w*z.clamp(min=0) - w*z*t + w*torch.log(1 + torch.exp(-z.abs()))
loss = loss.sum()/w.sum()
return loss
[docs]class BCEDicePenalizeBorderLoss(nn.Module):
def __init__(self, kernel_size=55, **_):
super(BCEDicePenalizeBorderLoss, self).__init__()
self.bce = WeightedBCELoss2d()
self.dice = WeightedSoftDiceLoss()
self.kernel_size = kernel_size
[docs] def to(self, device):
super().to(device=device)
self.bce.to(device=device)
self.dice.to(device=device)
[docs] def forward(self, logits, labels, **_):
a = F.avg_pool2d(labels, kernel_size=self.kernel_size, padding=self.kernel_size // 2, stride=1)
ind = a.ge(0.01) * a.le(0.99)
ind = ind.float()
weights = torch.ones(a.size()).to(device=logits.device)
w0 = weights.sum()
weights = weights + ind * 2
w1 = weights.sum()
weights = weights / w1 * w0
loss = self.bce(logits, labels, weights) + self.dice(logits, labels, weights)
return loss
# ==== Focal Loss with extra parameters ==== #
# Source: https://github.com/Hsuxu/Loss_ToolBox-PyTorch/blob/master/FocalLoss/FocalLoss.py
# License: MIT
[docs]class FocalLoss2(nn.Module):
"""
This is a implementation of Focal Loss with smooth label cross entropy supported which is proposed in
'Focal Loss for Dense Object Detection. (https://arxiv.org/abs/1708.02002)'
Focal_Loss= -1*alpha*(1-pt)*log(pt)
Params:
:param num_class:
:param alpha: (tensor) 3D or 4D the scalar factor for this criterion
:param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more
focus on hard misclassified example
:param smooth: (float,double) smooth value when cross entropy
:param balance_index: (int) balance class index, should be specific when alpha is float
:param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch.
"""
def __init__(self, num_class, alpha=None, gamma=2, balance_index=-1, smooth=None, size_average=True, **_):
super(FocalLoss2, self).__init__()
self.num_class = num_class
self.alpha = alpha
self.gamma = gamma
self.smooth = smooth
self.size_average = size_average
if self.alpha is None:
self.alpha = torch.ones(self.num_class, 1)
elif isinstance(self.alpha, (list, np.ndarray)):
if len(self.alpha) != self.num_class:
raise AssertionError
self.alpha = torch.FloatTensor(alpha).view(self.num_class, 1)
self.alpha = self.alpha / self.alpha.sum()
elif isinstance(self.alpha, float):
alpha = torch.ones(self.num_class, 1)
alpha = alpha * (1 - self.alpha)
alpha[balance_index] = self.alpha
self.alpha = alpha
else:
raise TypeError('Not support alpha type')
if self.smooth is not None:
if self.smooth < 0 or self.smooth > 1.0:
raise ValueError('smooth value should be in [0,1]')
[docs] def forward(self, logits, labels, **_):
# logits = F.softmax(logits, dim=1)
if logits.dim() > 2:
# N,C,d1,d2 -> N,C,m (m=d1*d2*...)
logits = logits.view(logits.size(0), logits.size(1), -1)
logits = logits.permute(0, 2, 1).contiguous()
logits = logits.view(-1, logits.size(-1))
labels = labels.view(-1, 1)
# N = input.size(0)
# alpha = torch.ones(N, self.num_class)
# alpha = alpha * (1 - self.alpha)
# alpha = alpha.scatter_(1, target.long(), self.alpha)
epsilon = 1e-10
alpha = self.alpha.to(logits.device)
idx = labels.cpu().long()
one_hot_key = torch.FloatTensor(labels.size(0), self.num_class).zero_()
one_hot_key = one_hot_key.scatter_(1, idx, 1)
one_hot_key = one_hot_key.to(logits.device)
if self.smooth:
one_hot_key = torch.clamp(one_hot_key, self.smooth / (self.num_class - 1), 1.0 - self.smooth)
pt = (one_hot_key * logits).sum(1) + epsilon
logpt = pt.log()
gamma = self.gamma
alpha = alpha[idx]
loss = -1 * alpha * torch.pow((1 - pt), gamma) * logpt
if self.size_average:
loss = loss.mean()
else:
loss = loss.sum()
return loss
# -------- #
# Source: https://github.com/huaifeng1993/DFANet/blob/master/loss.py
class FocalLoss3(nn.Module):
"""
This criterion is a implemenation of Focal Loss, which is proposed in Focal Loss for Dense Object Detection.
Loss(x, class) = - \alpha (1-softmax(x)[class])^gamma \log(softmax(x)[class])
The losses are averaged across observations for each minibatch.
Params:
:param alpha: (1D Tensor, Variable) - the scalar factor for this criterion
:param gamma: (float, double) - gamma > 0
:param size_average: (bool) - size_average(bool): By default, the losses are averaged over observations for each minibatch.
However, if the field size_average is set to False, the losses are instead summed for each minibatch.
"""
def __init__(self, class_num, alpha=None, gamma=2, size_average=True, **_):
super(FocalLoss3, self).__init__()
if alpha is None:
self.alpha = Variable(torch.ones(class_num+1))
else:
if isinstance(alpha, Variable):
self.alpha = alpha
else:
self.alpha = Variable(alpha)
self.gamma = gamma
self.class_num = class_num
self.size_average = size_average
def forward(self, inputs, labels, **_): # variables
P = F.softmax(inputs)
if len(inputs.size()) == 3:
torch_out = torch.zeros(inputs.size())
else:
b,c,h,w = inputs.size()
torch_out = torch.zeros([b,c+1,h,w])
if inputs.is_cuda:
torch_out = torch_out.cuda()
class_mask = Variable(torch_out)
class_mask.scatter_(1, labels.long(), 1.)
class_mask = class_mask[:,:-1,:,:]
if inputs.is_cuda and not self.alpha.is_cuda:
self.alpha = self.alpha.cuda()
# print('alpha',self.alpha.size())
alpha = self.alpha[labels.data.view(-1)].view_as(labels)
# print (alpha.size(),class_mask.size(),P.size())
probs = (P * class_mask).sum(1) # + 1e-6#.view(-1, 1)
log_p = probs.log()
batch_loss = -alpha * (torch.pow((1 - probs), self.gamma)) * log_p
if self.size_average:
loss = batch_loss.mean()
else:
loss = batch_loss.sum()
return loss
# -------- #
# -------- #
# Source: https://discuss.pytorch.org/t/is-this-a-correct-implementation-for-focal-loss-in-pytorch/43327/4
[docs]class BinaryFocalLoss(nn.Module):
'''
Implementation of binary focal loss. For multi-class focal loss use one of the other implementations.
gamma = 0 is equivalent to BinaryCrossEntropy Loss
'''
def __init__(self, gamma=1.333, eps=1e-6, alpha=1.0, **_):
super().__init__()
self.gamma = gamma
self.eps = eps
self.alpha = alpha
[docs] def forward(self, inputs, labels, **_):
BCE_loss = F.binary_cross_entropy_with_logits(inputs, labels, reduction='none')
pt = torch.exp(-BCE_loss) # prevents nans when probability 0
F_loss = self.alpha * (1 - pt) ** self.gamma * BCE_loss
return F_loss.mean()
# -------- #
# ==== Additional Losses === #
# Source: https://github.com/atlab/attorch/blob/master/attorch/losses.py
# License: MIT
[docs]class PoissonLoss(nn.Module):
def __init__(self, bias=1e-12, **_):
super().__init__()
self.bias = bias
[docs] def forward(self, output, labels, **_):
# _assert_no_grad(target)
with torch.no_grad: # Pytorch 0.4.0 replacement (should be ok to use like this)
return (output - labels * torch.log(output + self.bias)).mean()
[docs]class PoissonLoss3d(nn.Module):
def __init__(self, bias=1e-12, **_):
super().__init__()
self.bias = bias
[docs] def forward(self, output, target, **_):
# _assert_no_grad(target)
with torch.no_grad: # Pytorch 0.4.0 replacement (should be ok to use like this)
lag = target.size(1) - output.size(1)
return (output - target[:, lag:, :] * torch.log(output + self.bias)).mean()
[docs]class L1Loss3d(nn.Module):
def __init__(self, bias=1e-12, **_):
super().__init__()
self.bias = bias
[docs] @staticmethod
def forward(output, target, **_):
# _assert_no_grad(target)
with torch.no_grad: # Pytorch 0.4.0 replacement (should be ok to use like this)
lag = target.size(1) - output.size(1)
return (output - target[:, lag:, :]).abs().mean()
[docs]class MSE3D(nn.Module):
def __init__(self, **_):
super().__init__()
[docs] @staticmethod
def forward(output, target, **_):
# _assert_no_grad(target)
with torch.no_grad: # Pytorch 0.4.0 replacement (should be ok to use like this)
lag = target.size(1) - output.size(1)
return (output - target[:, lag:, :]).pow(2).mean()
# ==== Custom ==== #
[docs]class BCEWithLogitsViewLoss(nn.BCEWithLogitsLoss):
'''
Silly wrapper of nn.BCEWithLogitsLoss because BCEWithLogitsLoss only takes a 1-D array
'''
def __init__(self, weight=None, size_average=True, **_):
super().__init__(weight=weight, size_average=size_average)
[docs] def forward(self, input_, target, **_):
'''
:param input_:
:param target:
:return:
Simply passes along input.view(-1), target.view(-1)
'''
return super().forward(input_.view(-1), target.view(-1))
# ===================== #
# Source: https://discuss.pytorch.org/t/one-hot-encoding-with-autograd-dice-loss/9781/5
# For calculating dice loss on images where multiple classes are present at the same time
def multi_class_dice_loss(output, target, weights=None, ignore_index=None):
# output : NxCxHxW float tensor
# target : NxHxW long tensor
# weights : C float tensor
# ignore_index : int value to ignore from loss
smooth = 1.
loss = 0.
output = output.exp()
encoded_target = output.detach().clone().zero_()
if ignore_index is not None:
mask = target == ignore_index
target = target.clone()
target[mask] = 0
encoded_target.scatter_(1, target.unsqueeze(1), 1)
mask = mask.unsqueeze(1).expand_as(encoded_target)
encoded_target[mask] = 0
else:
encoded_target.scatter_(1, target.unsqueeze(1), 1)
if weights is None:
weights = torch.ones(output.size(1)).type_as(output.detach())
intersection = output * encoded_target
numerator = 2 * intersection.sum(3).sum(2).sum(0) + smooth
denominator = (output + encoded_target).sum(3).sum(2).sum(0) + smooth
loss_per_channel = weights * (1 - (numerator / denominator))
return loss_per_channel.sum() / output.size(1)
# ====================== #
# Source: https://discuss.pytorch.org/t/how-to-implement-soft-iou-loss/15152
# Calculation of soft-IOU loss
def to_one_hot(tensor, nClasses):
n, h, w = tensor.size()
one_hot = torch.zeros(n, nClasses, h, w).scatter_(1, tensor.view(n, 1, h, w), 1)
return one_hot
# ====================== #
# Source: https://gist.github.com/jeremyjordan/9ea3032a32909f71dd2ab35fe3bacc08
# Another calculation of dice loss over multiple classes. Input is numpy matrices.
def soft_multiclass_dice_loss(y_true, y_pred, epsilon=1e-6):
'''
Soft dice loss calculation for arbitrary batch size, number of classes, and number of spatial dimensions.
Assumes the `channels_last` format.
# Arguments
y_true: b x X x Y( x Z...) x c One hot encoding of ground truth
y_pred: b x X x Y( x Z...) x c Network output, must sum to 1 over c channel (such as after softmax)
epsilon: Used for numerical stability to avoid divide by zero errors
# References
V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation
https://arxiv.org/abs/1606.04797
More details on Dice loss formulation
https://mediatum.ub.tum.de/doc/1395260/1395260.pdf (page 72)
Adapted from https://github.com/Lasagne/Recipes/issues/99#issuecomment-347775022
'''
# skip the batch and class axis for calculating Dice score
axes = tuple(range(1, len(y_pred.shape) - 1))
numerator = 2. * np.sum(y_pred * y_true, axes)
denominator = np.sum(np.square(y_pred) + np.square(y_true), axes)
return 1 - np.mean(numerator / (denominator + epsilon)) # average over classes and batch
[docs]class mIoULoss(nn.Module):
def __init__(self, weight=None, size_average=True, num_classes=2, **_):
super(mIoULoss, self).__init__()
self.classes = num_classes
[docs] def forward(self, inputs, target_oneHot, **_):
# inputs => N x Classes x H x W
# target_oneHot => N x Classes x H x W
N = inputs.size()[0]
# predicted probabilities for each pixel along channel
inputs = F.softmax(inputs, dim=1)
# Numerator Product
inter = inputs * target_oneHot
## Sum over all pixels N x C x H x W => N x C
inter = inter.view(N, self.classes, -1).sum(2)
# Denominator
union = inputs + target_oneHot - (inputs * target_oneHot)
## Sum over all pixels N x C x H x W => N x C
union = union.view(N, self.classes, -1).sum(2)
loss = inter / union
## Return average loss over classes and batch
return -loss.mean()
# ====================== #
# Source: https://github.com/snakers4/mnasnet-pytorch/blob/master/src/models/semseg_loss.py
# Combination Loss from BCE and Dice
[docs]class ComboBCEDiceLoss(nn.Module):
"""
Combination BinaryCrossEntropy (BCE) and Dice Loss with an optional running mean and loss weighing.
"""
def __init__(self, use_running_mean=False, bce_weight=1, dice_weight=1, eps=1e-6, gamma=0.9, combined_loss_only=True, **_):
"""
:param use_running_mean: - bool (default: False) Whether to accumulate a running mean and add it to the loss with (1-gamma)
:param bce_weight: - float (default: 1.0) Weight multiplier for the BCE loss (relative to dice)
:param dice_weight: - float (default: 1.0) Weight multiplier for the Dice loss (relative to BCE)
:param eps: -
:param gamma:
:param combined_loss_only: - bool (default: True) whether to return a single combined loss or three separate losses
"""
super().__init__()
'''
Note: BCEWithLogitsLoss already performs a torch.sigmoid(pred)
before applying BCE!
'''
self.bce_logits_loss = nn.BCEWithLogitsLoss()
self.dice_weight = dice_weight
self.bce_weight = bce_weight
self.eps = eps
self.gamma = gamma
self.combined_loss_only = combined_loss_only
self.use_running_mean = use_running_mean
self.bce_weight = bce_weight
self.dice_weight = dice_weight
if self.use_running_mean is True:
self.register_buffer('running_bce_loss', torch.zeros(1))
self.register_buffer('running_dice_loss', torch.zeros(1))
self.reset_parameters()
[docs] def to(self, device):
super().to(device=device)
self.bce_logits_loss.to(device=device)
[docs] def reset_parameters(self):
self.running_bce_loss.zero_()
self.running_dice_loss.zero_()
[docs] def forward(self, outputs, labels, **_):
# inputs and targets are assumed to be BxCxWxH (batch, color, width, height)
outputs = outputs.squeeze() # necessary in case we're dealing with binary segmentation (color dim of 1)
if len(outputs.shape) != len(labels.shape):
raise AssertionError
# assert that B, W and H are the same
if outputs.size(-0) != labels.size(-0):
raise AssertionError
if outputs.size(-1) != labels.size(-1):
raise AssertionError
if outputs.size(-2) != labels.size(-2):
raise AssertionError
bce_loss = self.bce_logits_loss(outputs, labels)
dice_target = (labels == 1).float()
dice_output = torch.sigmoid(outputs)
intersection = (dice_output * dice_target).sum()
union = dice_output.sum() + dice_target.sum() + self.eps
dice_loss = (-torch.log(2 * intersection / union))
if self.use_running_mean is False:
bmw = self.bce_weight
dmw = self.dice_weight
# loss += torch.clamp(1 - torch.log(2 * intersection / union),0,100) * self.dice_weight
else:
self.running_bce_loss = self.running_bce_loss * self.gamma + bce_loss.data * (1 - self.gamma)
self.running_dice_loss = self.running_dice_loss * self.gamma + dice_loss.data * (1 - self.gamma)
bm = float(self.running_bce_loss)
dm = float(self.running_dice_loss)
bmw = 1 - bm / (bm + dm)
dmw = 1 - dm / (bm + dm)
loss = bce_loss * bmw + dice_loss * dmw
if self.combined_loss_only:
return loss
else:
return loss, bce_loss, dice_loss
[docs]class ComboSemsegLossWeighted(nn.Module):
def __init__(self,
use_running_mean=False,
bce_weight=1,
dice_weight=1,
eps=1e-6,
gamma=0.9,
use_weight_mask=False,
combined_loss_only=False, **_
):
super().__init__()
self.use_weight_mask = use_weight_mask
self.nll_loss = nn.BCEWithLogitsLoss()
self.dice_weight = dice_weight
self.bce_weight = bce_weight
self.eps = eps
self.gamma = gamma
self.combined_loss_only = combined_loss_only
self.use_running_mean = use_running_mean
self.bce_weight = bce_weight
self.dice_weight = dice_weight
if self.use_running_mean is True:
self.register_buffer('running_bce_loss', torch.zeros(1))
self.register_buffer('running_dice_loss', torch.zeros(1))
self.reset_parameters()
[docs] def to(self, device):
super().to(device=device)
self.nll_loss.to(device=device)
[docs] def reset_parameters(self):
self.running_bce_loss.zero_()
self.running_dice_loss.zero_()
[docs] def forward(self, logits, labels, weights, **_):
# logits and labels are assumed to be BxCxWxH
if len(logits.shape) != len(labels.shape):
raise AssertionError
# assert that B, W and H are the same
if logits.size(0) != labels.size(0):
raise AssertionError
if logits.size(2) != labels.size(2):
raise AssertionError
if logits.size(3) != labels.size(3):
raise AssertionError
# weights are assumed to be BxWxH
# assert that B, W and H are the are the same for target and mask
if logits.size(0) != weights.size(0):
raise AssertionError
if logits.size(2) != weights.size(1):
raise AssertionError
if logits.size(3) != weights.size(2):
raise AssertionError
if self.use_weight_mask:
bce_loss = F.binary_cross_entropy_with_logits(input=logits,
target=labels,
weight=weights)
else:
bce_loss = self.nll_loss(input=logits,
target=labels)
dice_target = (labels == 1).float()
dice_output = torch.sigmoid(logits)
intersection = (dice_output * dice_target).sum()
union = dice_output.sum() + dice_target.sum() + self.eps
dice_loss = (-torch.log(2 * intersection / union))
if self.use_running_mean is False:
bmw = self.bce_weight
dmw = self.dice_weight
# loss += torch.clamp(1 - torch.log(2 * intersection / union),0,100) * self.dice_weight
else:
self.running_bce_loss = self.running_bce_loss * self.gamma + bce_loss.data * (1 - self.gamma)
self.running_dice_loss = self.running_dice_loss * self.gamma + dice_loss.data * (1 - self.gamma)
bm = float(self.running_bce_loss)
dm = float(self.running_dice_loss)
bmw = 1 - bm / (bm + dm)
dmw = 1 - dm / (bm + dm)
loss = bce_loss * bmw + dice_loss * dmw
if self.combined_loss_only:
return loss
else:
return loss, bce_loss, dice_loss
# ====================== #
# Source: https://github.com/Tramac/awesome-semantic-segmentation-pytorch/blob/master/core/utils/loss.py
# Description: http://www.erogol.com/online-hard-example-mining-pytorch/
# Online Hard Example Loss
[docs]class OhemCrossEntropy2d(nn.Module):
def __init__(self, thresh=0.6, min_kept=0, ignore_index=-100, is_binary=True, **kwargs):
super().__init__()
self.ignore_label = ignore_index
self.is_binary = is_binary
self.thresh = float(thresh)
self.min_kept = int(min_kept)
self.criterion = BCEWithLogitsViewLoss(**kwargs)
[docs] def forward(self, logits, labels, **_):
"""
Args:
predict:(n, c, h, w)
labels:(n, h, w)
"""
if self.is_binary:
predict = torch.sigmoid(logits)
else:
predict = F.softmax(logits, dim=1)
n, c, h, w = predict.size()
input_label = labels.detach().cpu().numpy().ravel().astype(np.int32)
x = np.rollaxis(predict.detach().cpu().numpy(), 1).reshape((c, -1))
input_prob = np.exp(x - x.max(axis=0).reshape((1, -1)))
input_prob /= input_prob.sum(axis=0).reshape((1, -1))
valid_flag = input_label != self.ignore_label
valid_inds = np.where(valid_flag)[0]
label = input_label[valid_flag]
num_valid = valid_flag.sum()
if self.min_kept >= num_valid:
print('Labels: {}'.format(num_valid))
elif num_valid > 0:
prob = input_prob[:, valid_flag]
pred = prob[label-1, np.arange(len(label), dtype=np.int32)]
threshold = self.thresh
if self.min_kept > 0:
index = pred.argsort()
threshold_index = index[min(len(index), self.min_kept) - 1]
if pred[threshold_index] > self.thresh:
threshold = pred[threshold_index]
kept_flag = pred <= threshold
valid_inds = valid_inds[kept_flag]
#print('hard ratio: {} = {} / {} '.format(round(len(valid_inds)/num_valid, 4), len(valid_inds), num_valid))
label = input_label[valid_inds].copy()
input_label.fill(self.ignore_label)
input_label[valid_inds] = label
#print(np.sum(input_label != self.ignore_label))
labels = torch.from_numpy(input_label.reshape(labels.size())).type_as(predict).to(labels.device)
predict = predict.squeeze() # in case we're dealing with B/W images instead of RGB
return self.criterion(predict, labels)
# ====================== #
# Source: https://github.com/Tramac/awesome-semantic-segmentation-pytorch/blob/master/core/utils/loss.py
# Loss used for EncNet
[docs]class EncNetLoss(nn.CrossEntropyLoss):
"""
2D Cross Entropy Loss with SE Loss
Specifically used for EncNet.
se_loss is the Semantic Encoding Loss from the paper `Context Encoding for Semantic Segmentation <https://arxiv.org/pdf/1803.08904v1>`_.
It computes probabilities of contexts appearing together.
Without SE_loss and Aux_loss this class simply forwards inputs to Torch's Cross Entropy Loss (nn.CrossEntropyLoss)
"""
def __init__(self, se_loss=True, se_weight=0.2, nclass=19, aux=False, aux_weight=0.4, weight=None, ignore_index=-1, **_):
super(EncNetLoss, self).__init__(weight, None, ignore_index)
self.se_loss = se_loss
self.aux = aux
self.nclass = nclass
self.se_weight = se_weight
self.aux_weight = aux_weight
self.bceloss = nn.BCELoss(weight)
[docs] def forward(self, *inputs, **_):
preds, target = tuple(inputs)
inputs = tuple(list(preds) + [target])
if not self.se_loss and not self.aux:
return super(EncNetLoss, self).forward(*inputs)
elif not self.se_loss:
pred1, pred2, target = tuple(inputs)
loss1 = super(EncNetLoss, self).forward(pred1, target)
loss2 = super(EncNetLoss, self).forward(pred2, target)
return dict(loss=loss1 + self.aux_weight * loss2)
elif not self.aux:
pred, se_pred, target = tuple(inputs)
se_target = self._get_batch_label_vector(target, nclass=self.nclass).type_as(pred)
loss1 = super(EncNetLoss, self).forward(pred, target)
loss2 = self.bceloss(torch.sigmoid(se_pred), se_target)
return dict(loss=loss1 + self.se_weight * loss2)
else:
pred1, se_pred, pred2, target = tuple(inputs)
se_target = self._get_batch_label_vector(target, nclass=self.nclass).type_as(pred1)
loss1 = super(EncNetLoss, self).forward(pred1, target)
loss2 = super(EncNetLoss, self).forward(pred2, target)
loss3 = self.bceloss(torch.sigmoid(se_pred), se_target)
return dict(loss=loss1 + self.aux_weight * loss2 + self.se_weight * loss3)
@staticmethod
def _get_batch_label_vector(target, nclass):
# target is a 3D Variable BxHxW, output is 2D BxnClass
batch = target.size(0)
tvect = Variable(torch.zeros(batch, nclass))
for i in range(batch):
hist = torch.histc(target[i].cpu().data.float(),
bins=nclass, min=0,
max=nclass - 1)
vect = hist > 0
tvect[i] = vect
return tvect
[docs]class MixSoftmaxCrossEntropyOHEMLoss(OhemCrossEntropy2d):
"""
Loss taking into consideration class and segmentation targets together, as well as, using OHEM
"""
def __init__(self, aux=False, aux_weight=0.4, weight=None, ignore_index=-1, **kwargs):
super(MixSoftmaxCrossEntropyOHEMLoss, self).__init__(ignore_index=ignore_index)
self.aux = aux
self.aux_weight = aux_weight
self.bceloss = nn.BCELoss(weight)
[docs] def to(self, device):
super().to(device=device)
self.bceloss.to(device=device)
def _aux_forward(self, *inputs, **_):
*preds, target = tuple(inputs)
loss = super(MixSoftmaxCrossEntropyOHEMLoss, self).forward(preds[0], target)
for i in range(1, len(preds)):
aux_loss = super(MixSoftmaxCrossEntropyOHEMLoss, self).forward(preds[i], target)
loss += self.aux_weight * aux_loss
return loss
[docs] def forward(self, *inputs, **_):
preds, target = tuple(inputs)
inputs = tuple(list(preds) + [target])
if self.aux:
return dict(loss=self._aux_forward(*inputs))
else:
return dict(loss=super(MixSoftmaxCrossEntropyOHEMLoss, self).forward(preds, target))
# ====================== #
# Source: https://github.com/zhanghang1989/PyTorch-Encoding/blob/master/encoding/nn/loss.py
# OHEM Segmentation Loss
class OHEMSegmentationLosses(OhemCrossEntropy2d):
"""
2D Cross Entropy Loss with Auxiliary Loss
"""
def __init__(self, se_loss=False, se_weight=0.2, num_classes=1,
aux=False, aux_weight=0.4, weight=None,
ignore_index=-1):
super(OHEMSegmentationLosses, self).__init__(ignore_index)
self.se_loss = se_loss
self.aux = aux
self.num_classes = num_classes
self.se_weight = se_weight
self.aux_weight = aux_weight
self.bceloss = nn.BCELoss(weight)
def to(self, device):
super().to(device=device)
self.bceloss.to(device=device)
def forward(self, *inputs, **_):
if not self.se_loss and not self.aux:
return super(OHEMSegmentationLosses, self).forward(*inputs)
elif not self.se_loss:
pred1, pred2, target = tuple(inputs)
loss1 = super(OHEMSegmentationLosses, self).forward(pred1, target)
loss2 = super(OHEMSegmentationLosses, self).forward(pred2, target)
return loss1 + self.aux_weight * loss2
elif not self.aux:
pred, se_pred, target = tuple(inputs)
se_target = self._get_batch_label_vector(target, nclass=self.num_classes).type_as(pred)
loss1 = super(OHEMSegmentationLosses, self).forward(pred, target)
loss2 = self.bceloss(torch.sigmoid(se_pred), se_target)
return loss1 + self.se_weight * loss2
else:
pred1, se_pred, pred2, target = tuple(inputs)
se_target = self._get_batch_label_vector(target, nclass=self.num_classes).type_as(pred1)
loss1 = super(OHEMSegmentationLosses, self).forward(pred1, target)
loss2 = super(OHEMSegmentationLosses, self).forward(pred2, target)
loss3 = self.bceloss(torch.sigmoid(se_pred), se_target)
return loss1 + self.aux_weight * loss2 + self.se_weight * loss3
@staticmethod
def _get_batch_label_vector(target, nclass):
# target is a 3D Variable BxHxW, output is 2D BxnClass
batch = target.size(0)
tvect = Variable(torch.zeros(batch, nclass))
for i in range(batch):
hist = torch.histc(target[i].cpu().data.float(),
bins=nclass, min=0,
max=nclass-1)
vect = hist>0
tvect[i] = vect
return tvect
# ====================== #
# Source: https://github.com/yinmh17/DNL-Semantic-Segmentation/blob/master/model/seg/loss/ohem_ce_loss.py
# OHEM CrossEntropy Loss
[docs]class OhemCELoss(nn.Module):
def __init__(self, configer, is_binary=False):
super(OhemCELoss, self).__init__()
self.configer = configer
weight = self.configer.get('loss.params.ohem_ce_loss.weight', default=None)
self.weight = torch.FloatTensor(weight) if weight is not None else weight
self.reduction = self.configer.get('loss.params.ohem_ce_loss.reduction', default='mean')
self.ignore_index = self.configer.get('loss.params.ohem_ce_loss.ignore_index', default=-100)
self.thresh = self.configer.get('loss.params.ohem_ce_loss.thresh', default=0.7)
self.min_kept = max(1, self.configer.get('loss.params.ohem_ce_loss.minkeep', default=5))
self.is_binary = is_binary
[docs] def forward(self, logits, labels, **_):
"""
Args:
logits:(n, c, h, w)
labels:(n, h, w)
weight (Tensor, optional): a manual rescaling weight given to each class.
If given, has to be a Tensor of size "nclasses"
"""
batch_kept = self.min_kept * labels.size(0)
labels = self._scale_target(labels, (logits.size(2), logits.size(3)))
if self.is_binary:
prob_out = torch.sigmoid(logits)
else:
prob_out = F.softmax(logits, dim=1)
tmp_target = labels.clone()
tmp_target[tmp_target == self.ignore_index] = 0
prob = prob_out.gather(1, tmp_target.unsqueeze(1))
mask = labels.contiguous().view(-1, ) != self.ignore_index
sort_prob, sort_indices = prob.contiguous().view(-1, )[mask].contiguous().sort()
min_threshold = sort_prob[min(batch_kept, sort_prob.numel() - 1)] if sort_prob.numel() > 0 else 0.0
threshold = max(min_threshold, self.thresh)
loss_matrix = F.cross_entropy(logits, labels,
weight=self.weight.to(logits.device) if self.weight is not None else None,
ignore_index=self.ignore_index, reduction='none')
loss_matrix = loss_matrix.contiguous().view(-1, )
sort_loss_matirx = loss_matrix[mask][sort_indices]
select_loss_matrix = sort_loss_matirx[sort_prob < threshold]
if self.reduction == 'sum' or select_loss_matrix.numel() == 0:
return select_loss_matrix.sum()
elif self.reduction == 'mean':
return select_loss_matrix.mean()
else:
raise NotImplementedError('Reduction Error!')
@staticmethod
def _scale_target(targets_, scaled_size):
targets = targets_.clone().unsqueeze(1).float()
targets = F.interpolate(targets, size=scaled_size, mode='nearest')
return targets.squeeze(1).long()
# ===================== #
# Source: https://github.com/Hsuxu/Loss_ToolBox-PyTorch/blob/master/LovaszSoftmax/lovasz_loss.py
def lovasz_grad(gt_sorted):
"""
Computes gradient of the Lovasz extension w.r.t sorted errors
See Alg. 1 in paper
"""
p = len(gt_sorted)
gts = gt_sorted.sum()
intersection = gts - gt_sorted.float().cumsum(0)
union = gts + (1 - gt_sorted).float().cumsum(0)
jaccard = 1. - intersection / union
if p > 1: # cover 1-pixel case
jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]
return jaccard
[docs]class LovaszSoftmax(nn.Module):
def __init__(self, reduction='mean', **_):
super(LovaszSoftmax, self).__init__()
self.reduction = reduction
[docs] @staticmethod
def prob_flatten(input, target):
if input.dim() not in [4, 5]:
raise AssertionError
num_class = input.size(1)
if input.dim() == 4:
input = input.permute(0, 2, 3, 1).contiguous()
input_flatten = input.view(-1, num_class)
elif input.dim() == 5:
input = input.permute(0, 2, 3, 4, 1).contiguous()
input_flatten = input.view(-1, num_class)
target_flatten = target.view(-1)
return input_flatten, target_flatten
[docs] def lovasz_softmax_flat(self, inputs, targets):
num_classes = inputs.size(1)
losses = []
for c in range(num_classes):
target_c = (targets == c).float()
if num_classes == 1:
input_c = inputs[:, 0]
else:
input_c = inputs[:, c]
loss_c = (torch.autograd.Variable(target_c) - input_c).abs()
loss_c_sorted, loss_index = torch.sort(loss_c, 0, descending=True)
target_c_sorted = target_c[loss_index]
losses.append(torch.dot(loss_c_sorted, torch.autograd.Variable(lovasz_grad(target_c_sorted))))
losses = torch.stack(losses)
if self.reduction == 'none':
loss = losses
elif self.reduction == 'sum':
loss = losses.sum()
else:
loss = losses.mean()
return loss
[docs] def forward(self, inputs, targets, **_):
inputs, targets = self.prob_flatten(inputs, targets)
losses = self.lovasz_softmax_flat(inputs, targets)
return losses
# ===================== #
# Source: https://github.com/xuuuuuuchen/Active-Contour-Loss/blob/master/Active-Contour-Loss.py (MIT)
[docs]class ActiveContourLoss(nn.Module):
"""
`Learning Active Contour Models for Medical Image Segmentation <http://openaccess.thecvf.com/content_CVPR_2019/papers/Chen_Learning_Active_Contour_Models_for_Medical_Image_Segmentation_CVPR_2019_paper.pdf>`_
Note that is only works for B/W masks right now... which is kind of the point of this loss as contours in RGB should be cast to B/W
before computing the loss.
Params:
:param mu: (float, default=1.0) - Scales the inner region loss relative to outer region (less or more prominent)
:param lambdaP: (float, default=1.0) - Scales the combined region loss compared to the length loss (less or more prominent)
"""
def __init__(self, lambdaP=5., mu=1., is_binary: bool = False, **_):
super(ActiveContourLoss, self).__init__()
self.lambdaP = lambdaP
self.mu = mu
self.is_binary = is_binary
[docs] def forward(self, logits, labels, **_):
if self.is_binary:
logits = torch.sigmoid(logits)
else:
logits = F.softmax(logits, dim=1)
if labels.shape != logits.shape:
if logits.shape > labels.shape:
labels.unsqueeze(dim=1)
else:
raise Exception(f'Non-matching shapes for logits ({logits.shape}) and labels ({labels.shape})')
"""
length term
"""
x = logits[:,:,1:,:] - logits[:,:,:-1,:] # horizontal gradient (B, C, H-1, W)
y = logits[:,:,:,1:] - logits[:,:,:,:-1] # vertical gradient (B, C, H, W-1)
delta_x = x[:,:,1:,:-2]**2 # (B, C, H-2, W-2)
delta_y = y[:,:,:-2,1:]**2 # (B, C, H-2, W-2)
delta_u = torch.abs(delta_x + delta_y)
epsilon = 1e-8 # where is a parameter to avoid square root is zero in practice.
length = torch.mean(torch.sqrt(delta_u + epsilon)) # eq.(11) in the paper, mean is used instead of sum.
"""
region term
"""
C_in = torch.ones_like(logits)
C_out = torch.zeros_like(labels)
region_in = torch.abs(torch.mean(logits[:,0,:,:] * ((labels[:, 0, :, :] - C_in) ** 2))) # equ.(12) in the paper, mean is used instead of sum.
region_out = torch.abs(torch.mean((1-logits[:,0,:,:]) * ((labels[:, 0, :, :] - C_out) ** 2))) # equ.(12) in the paper
return length + self.lambdaP * (self.mu * region_in + region_out)
[docs]class ActiveContourLossAlt(nn.Module):
"""
`Learning Active Contour Models for Medical Image Segmentation <http://openaccess.thecvf.com/content_CVPR_2019/papers/Chen_Learning_Active_Contour_Models_for_Medical_Image_Segmentation_CVPR_2019_paper.pdf>`_
Note that is only works for B/W masks right now... which is kind of the point of this loss as contours in RGB should be cast to B/W
before computing the loss.
Params:
:param len_w: (float, default=1.0) - The multiplier to use when adding boundary loss.
:param reg_w: (float, default=1.0) - The multiplier to use when adding region loss.
:param apply_log: (bool, default=True) - Whether to transform the log into log space (due to the
"""
def __init__(self, len_w=1., reg_w=1., apply_log=True, is_binary: bool = False, **_):
super(ActiveContourLossAlt, self).__init__()
self.len_w = len_w
self.reg_w = reg_w
self.epsilon = 1e-8 # a parameter to avoid square root = zero issues
self.apply_log = apply_log
self.is_binary = is_binary
[docs] def forward(self, logits, labels, **_):
# must convert raw logits to predicted probabilities for each pixel along channel
if self.is_binary:
probs = torch.sigmoid(logits)
else:
probs = F.softmax(logits, dim=1)
if labels.shape != logits.shape:
if logits.shape > labels.shape:
labels.unsqueeze(dim=1)
else:
raise Exception(f'Non-matching shapes for logits ({logits.shape}) and labels ({labels.shape})')
"""
length term:
- Subtract adjacent pixels from each other in X and Y directions
- Determine where they differ from the ground truth (targets)
- Calculate MSE
"""
# horizontal and vertical directions
x = probs[:, :, 1:, :] - probs[:, :, :-1, :] # differences in horizontal direction
y = probs[:, :, :, 1:] - probs[:, :, :, :-1] # differences in vertical direction
target_x = labels[:, :, 1:, :] - labels[:, :, :-1, :]
target_y = labels[:, :, :, 1:] - labels[:, :, :, :-1]
# find difference between values of probs and targets
delta_x = (target_x - x).abs() # do we need to subtract absolute values or relative?
delta_y = (target_y - y).abs()
# get MSE of the differences per pixel
# importantly because deltas are mostly < 1, a simple square of the error will actually yield LOWER results
# so we select 0.5 as the middle ground where small error will be further minimized while large error will
# be highlighted (pushed to be > 1 and up to 2.5 for maximum error).
# len_error_sq = ((delta_x + 0.5) ** 2) + ((delta_y + 0.5) ** 2)
# length = torch.sqrt(len_error_sq.sum() + self.epsilon)
# the length loss here is simply the MSE of x and y deltas
length_loss = torch.sqrt(delta_x.sum() ** 2 + delta_y.sum() ** 2 + self.epsilon)
"""
region term (should this be done in log space to avoid instabilities?)
- compute the error produced by all pixels that are not equal to 0 outside of the ground truth mask
- compute error produced by all pixels that are not equal to 1 inside the mask
"""
# reference code for selecting masked values from a tensor
# t_m_bool = t_mask.type(torch.ByteTensor)
# t_result = t_in.masked_select(t_m_bool)
# C_1 = torch.ones((image_size, image_size), device=target.device)
# C_2 = torch.zeros((image_size, image_size), device=target.device)
# the sum of all pixel values that are not equal 0 outside of the ground truth mask
error_in = probs[:, 0, :, :] * ((labels[:, 0, :, :] - 1) ** 2) # invert the ground truth mask and multiply by probs
# the sum of all pixel values that are not equal 1 inside of the ground truth mask
probs_diff = (probs[:, 0, :, :] - labels[:, 0, :, :]).abs() # subtract mask from probs giving us the errors
error_out = (probs_diff * labels[:, 0, :, :]) # multiply mask by error, giving us the error terms inside the mask.
if self.apply_log:
loss = torch.log(length_loss) + torch.log(error_in.sum() + error_out.sum())
else:
# loss = self.len_w * length_loss
loss = self.reg_w * (error_in.sum() + error_out.sum())
return torch.clamp(loss, min=0.0) # make sure we don't return negative values
# ===================== #
# Sources: https://github.com/JunMa11/SegLoss
# https://github.com/MIC-DKFZ/nnUNet/tree/master/nnunet (Apache 2.0)
def uniq(a: Tensor) -> Set:
return set(torch.unique(a.cpu()).numpy())
def sset(a: Tensor, sub: Iterable) -> bool:
return uniq(a).issubset(sub)
def simplex(t: Tensor, axis=1) -> bool:
_sum = t.sum(axis).type(torch.float32)
_ones = torch.ones_like(_sum, dtype=torch.float32)
return torch.allclose(_sum, _ones)
def one_hot(t: Tensor, axis=1) -> bool:
return simplex(t, axis) and sset(t, [0, 1])
def numpy_haussdorf(pred: np.ndarray, target: np.ndarray) -> float:
from scipy.spatial.distance import directed_hausdorff
if len(pred.shape) != 2:
raise AssertionError
if pred.shape != target.shape:
raise AssertionError
return max(directed_hausdorff(pred, target)[0], directed_hausdorff(target, pred)[0])
def haussdorf(preds: Tensor, target: Tensor) -> Tensor:
if preds.shape != target.shape:
raise AssertionError
if not one_hot(preds):
raise AssertionError
if not one_hot(target):
raise AssertionError
B, C, _, _ = preds.shape
res = torch.zeros((B, C), dtype=torch.float32, device=preds.device)
n_pred = preds.detach().cpu().numpy()
n_target = target.detach().cpu().numpy()
for b in range(B):
if C == 2:
res[b, :] = numpy_haussdorf(n_pred[b, 0], n_target[b, 0])
continue
for c in range(C):
res[b, c] = numpy_haussdorf(n_pred[b, c], n_target[b, c])
return res
def softmax_helper(x):
rpt = [1 for _ in range(len(x.size()))]
rpt[1] = x.size(1)
x_max = x.max(1, keepdim=True)[0].repeat(*rpt)
e_x = torch.exp(x - x_max)
return e_x / e_x.sum(1, keepdim=True).repeat(*rpt)
def sum_tensor(inp, axes, keepdim=False):
axes = np.unique(axes).astype(int)
if keepdim:
for ax in axes:
inp = inp.sum(int(ax), keepdim=True)
else:
for ax in sorted(axes, reverse=True):
inp = inp.sum(int(ax))
return inp
def get_tp_fp_fn(net_output, gt, axes=None, mask=None, square=False):
"""
net_output must be (b, c, x, y(, z)))
gt must be a label map (shape (b, 1, x, y(, z)) OR shape (b, x, y(, z))) or one hot encoding (b, c, x, y(, z))
if mask is provided it must have shape (b, 1, x, y(, z)))
:param net_output:
:param gt:
:param axes:
:param mask: mask must be 1 for valid pixels and 0 for invalid pixels
:param square: if True then fp, tp and fn will be squared before summation
:return:
"""
if axes is None:
axes = tuple(range(2, len(net_output.size())))
shp_x = net_output.shape
shp_y = gt.shape
with torch.no_grad():
if len(shp_x) != len(shp_y):
gt = gt.view((shp_y[0], 1, *shp_y[1:]))
if all([i == j for i, j in zip(net_output.shape, gt.shape)]):
# if this is the case then gt is probably already a one hot encoding
y_onehot = gt
else:
gt = gt.long()
y_onehot = torch.zeros(shp_x)
if net_output.device.type == "cuda":
y_onehot = y_onehot.cuda(net_output.device.index)
y_onehot.scatter_(1, gt, 1)
tp = net_output * y_onehot
fp = net_output * (1 - y_onehot)
fn = (1 - net_output) * y_onehot
if mask is not None:
tp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(tp, dim=1)), dim=1)
fp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fp, dim=1)), dim=1)
fn = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fn, dim=1)), dim=1)
if square:
tp = tp ** 2
fp = fp ** 2
fn = fn ** 2
tp = sum_tensor(tp, axes, keepdim=False)
fp = sum_tensor(fp, axes, keepdim=False)
fn = sum_tensor(fn, axes, keepdim=False)
return tp, fp, fn
# ===================== #
# Boundary Loss
# Source: https://github.com/JunMa11/SegLoss/blob/71b14900e91ea9405d9705c95b451fc819f24c70/test/loss_functions/boundary_loss.py#L102
def compute_sdf(img_gt, out_shape):
"""
compute the signed distance map of binary mask
img_gt: segmentation, shape = (batch_size, x, y, z)
out_shape: the Signed Distance Map (SDM)
sdf(x) = 0; x in segmentation boundary
-inf|x-y|; x in segmentation
+inf|x-y|; x out of segmentation
"""
from scipy.ndimage import distance_transform_edt
from skimage import segmentation as skimage_seg
img_gt = img_gt.astype(np.uint8)
gt_sdf = np.zeros(out_shape)
for b in range(out_shape[0]): # batch size
for c in range(1, out_shape[1]): # channel
posmask = img_gt[b][c].astype(np.bool)
if posmask.any():
negmask = ~posmask
posdis = distance_transform_edt(posmask)
negdis = distance_transform_edt(negmask)
boundary = skimage_seg.find_boundaries(posmask, mode='inner').astype(np.uint8)
sdf = negdis - posdis
sdf[boundary==1] = 0
gt_sdf[b][c] = sdf
return gt_sdf
[docs]class BDLoss(nn.Module):
def __init__(self, is_binary: bool = False, **_):
"""
compute boundary loss
only compute the loss of foreground
ref: https://github.com/LIVIAETS/surface-loss/blob/108bd9892adca476e6cdf424124bc6268707498e/losses.py#L74
"""
self.is_binary = is_binary
super(BDLoss, self).__init__()
# self.do_bg = do_bg
[docs] def forward(self, logits, labels, **_):
"""
net_output: (batch_size, class, x,y,z)
target: ground truth, shape: (batch_size, 1, x,y,z)
bound: precomputed distance map, shape (batch_size, class, x,y,z)
"""
if self.is_binary:
logits = torch.sigmoid(logits)
else:
logits = F.softmax(logits, dim=1)
with torch.no_grad():
if len(logits.shape) != len(labels.shape):
labels = labels.view((labels.shape[0], 1, *labels.shape[1:]))
if all([i == j for i, j in zip(logits.shape, labels.shape)]):
# if this is the case then gt is probably already a one hot encoding
y_onehot = labels
else:
labels = labels.long()
y_onehot = torch.zeros(logits.shape)
if logits.device.type == "cuda":
y_onehot = y_onehot.cuda(logits.device.index)
y_onehot.scatter_(1, labels, 1)
gt_sdf = compute_sdf(y_onehot.cpu().numpy(), logits.shape)
phi = torch.from_numpy(gt_sdf)
if phi.device != logits.device:
phi = phi.to(logits.device).type(torch.float32)
# pred = net_output[:, 1:, ...].type(torch.float32)
# phi = phi[:,1:, ...].type(torch.float32)
multipled = torch.einsum("bcxyz,bcxyz->bcxyz", logits[:, 1:, ...], phi[:, 1:, ...])
bd_loss = multipled.mean()
return bd_loss
# ===================== #
# Source: https://github.com/kevinzakka/pytorch-goodies/blob/master/losses.py
[docs]class TverskyLoss(nn.Module):
"""Computes the Tversky loss [1].
Args:
:param alpha: controls the penalty for false positives.
:param beta: controls the penalty for false negatives.
:param eps: added to the denominator for numerical stability.
Returns:
tversky_loss: the Tversky loss.
Notes:
alpha = beta = 0.5 => dice coeff
alpha = beta = 1 => tanimoto coeff
alpha + beta = 1 => F beta coeff
References:
[1]: https://arxiv.org/abs/1706.05721
"""
def __init__(self, alpha, beta, eps=1e-7, **_):
super(TverskyLoss, self).__init__()
self.alpha = alpha
self.beta = beta
self.eps = eps
[docs] def forward(self, logits, labels, **_):
"""
Args:
:param logits: a tensor of shape [B, C, H, W]. Corresponds to the raw output or logits of the model.
:param labels: a tensor of shape [B, H, W] or [B, 1, H, W].
:return: loss
"""
num_classes = logits.shape[1]
if num_classes == 1:
true_1_hot = torch.eye(num_classes + 1)[labels.squeeze(1).long()]
true_1_hot = true_1_hot.permute(0, 3, 1, 2).float()
true_1_hot_f = true_1_hot[:, 0:1, :, :]
true_1_hot_s = true_1_hot[:, 1:2, :, :]
true_1_hot = torch.cat([true_1_hot_s, true_1_hot_f], dim=1)
pos_prob = torch.sigmoid(logits)
neg_prob = 1 - pos_prob
probas = torch.cat([pos_prob, neg_prob], dim=1)
else:
true_1_hot = torch.eye(num_classes)[labels.squeeze(1)]
true_1_hot = true_1_hot.permute(0, 3, 1, 2).float()
probas = F.softmax(logits, dim=1)
true_1_hot = true_1_hot.type(logits.type())
dims = (0,) + tuple(range(2, logits.ndimension()))
intersection = torch.sum(probas * true_1_hot, dims)
fps = torch.sum(probas * (1 - true_1_hot), dims)
fns = torch.sum((1 - probas) * true_1_hot, dims)
num = intersection
denom = intersection + (self.alpha * fps) + (self.beta * fns)
tversky_loss = (num / (denom + self.eps)).mean()
return 1 - tversky_loss
# ===================== #
# Source: https://github.com/cvqluu/Angular-Penalty-Softmax-Losses-Pytorch
[docs]class AngularPenaltySMLoss(nn.Module):
def __init__(self, in_features, out_features, loss_type='arcface', eps=1e-7, s=None, m=None, **_):
'''
Angular Penalty Softmax Loss
Three 'loss_types' available: ['arcface', 'sphereface', 'cosface']
These losses are described in the following papers:
ArcFace: https://arxiv.org/abs/1801.07698
SphereFace: https://arxiv.org/abs/1704.08063
CosFace/Ad Margin: https://arxiv.org/abs/1801.05599
- Example -
criterion = AngularPenaltySMLoss(in_features, out_features, loss_type='arcface') # loss_type in ['arcface', 'sphereface', 'cosface']
'''
super(AngularPenaltySMLoss, self).__init__()
loss_type = loss_type.lower()
if loss_type not in ['arcface', 'sphereface', 'cosface']:
raise AssertionError
if loss_type == 'arcface':
self.s = 64.0 if not s else s
self.m = 0.5 if not m else m
if loss_type == 'sphereface':
self.s = 64.0 if not s else s
self.m = 1.35 if not m else m
if loss_type == 'cosface':
self.s = 30.0 if not s else s
self.m = 0.4 if not m else m
self.loss_type = loss_type
self.in_features = in_features
self.out_features = out_features
self.fc = nn.Linear(in_features, out_features, bias=False)
self.eps = eps
[docs] def forward(self, x, labels, **_):
'''
input shape (N, in_features)
'''
if len(x) != len(labels):
raise AssertionError
if torch.min(labels) < 0:
raise AssertionError
if torch.max(labels) >= self.out_features:
raise AssertionError
for W in self.fc.parameters():
W = F.normalize(W, p=2, dim=1)
x = F.normalize(x, p=2, dim=1)
wf = self.fc(x)
if self.loss_type == 'cosface':
numerator = self.s * (torch.diagonal(wf.transpose(0, 1)[labels]) - self.m)
if self.loss_type == 'arcface':
numerator = self.s * torch.cos(torch.acos(torch.clamp(torch.diagonal(wf.transpose(0, 1)[labels]), -1. + self.eps, 1 - self.eps)) + self.m)
if self.loss_type == 'sphereface':
numerator = self.s * torch.cos(self.m * torch.acos(torch.clamp(torch.diagonal(wf.transpose(0, 1)[labels]), -1. + self.eps, 1 - self.eps)))
excl = torch.cat([torch.cat((wf[i, :y], wf[i, y + 1:])).unsqueeze(0) for i, y in enumerate(labels)], dim=0)
denominator = torch.exp(numerator) + torch.sum(torch.exp(self.s * excl), dim=1)
L = numerator - torch.log(denominator)
return -torch.mean(L)
# ===================== #
# Source: https://github.com/JunMa11/SegLoss/blob/master/losses_pytorch/dice_loss.py
[docs]class AsymLoss(nn.Module):
def __init__(self, apply_nonlin=None, batch_dice=False, do_bg=True, smooth=1., square=False, **_):
"""
paper: https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=8573779
"""
super(AsymLoss, self).__init__()
self.square = square
self.do_bg = do_bg
self.batch_dice = batch_dice
self.apply_nonlin = apply_nonlin
self.smooth = smooth
self.beta = 1.5
[docs] def forward(self, logits, labels, loss_mask=None, **_):
shp_x = logits.shape
if self.batch_dice:
axes = [0] + list(range(2, len(shp_x)))
else:
axes = list(range(2, len(shp_x)))
if self.apply_nonlin is not None:
logits = self.apply_nonlin(logits)
tp, fp, fn = get_tp_fp_fn(logits, labels, axes, loss_mask, self.square)# shape: (batch size, class num)
weight = (self.beta**2)/(1+self.beta**2)
asym = (tp + self.smooth) / (tp + weight*fn + (1-weight)*fp + self.smooth)
if not self.do_bg:
if self.batch_dice:
asym = asym[1:]
else:
asym = asym[:, 1:]
asym = asym.mean()
return -asym
# ===================== #
# Source: https://github.com/BloodAxe/pytorch-toolbelt
# Used to enhance facial segmentation
def wing_loss(output: torch.Tensor, target: torch.Tensor, width=5, curvature=0.5, reduction="mean"):
"""
https://arxiv.org/pdf/1711.06753.pdf
:param output:
:param target:
:param width:
:param curvature:
:param reduction:
:return:
"""
diff_abs = (target - output).abs()
loss = diff_abs.clone()
idx_smaller = diff_abs < width
idx_bigger = diff_abs >= width
loss[idx_smaller] = width * torch.log(1 + diff_abs[idx_smaller] / curvature)
C = width - width * math.log(1 + width / curvature)
loss[idx_bigger] = loss[idx_bigger] - C
if reduction == "sum":
loss = loss.sum()
if reduction == "mean":
loss = loss.mean()
return loss
[docs]class WingLoss(nn.modules.loss._Loss):
"""
Used to enhance facial segmentation
"""
def __init__(self, width=5, curvature=0.5, reduction="mean", **_):
super(WingLoss, self).__init__(reduction=reduction)
self.width = width
self.curvature = curvature
[docs] def forward(self, prediction, target, **_):
return wing_loss(prediction, target, self.width, self.curvature, self.reduction)
# ===================== #
# Source: https://github.com/JUNHAOYAN/FPN/tree/master/RMI
# ..which is adapted from: https://github.com/ZJULearning/RMI (MIT License)
# Segmentation loss (memory intensive)
[docs]class RMILoss(nn.Module):
"""
region mutual information
I(A, B) = H(A) + H(B) - H(A, B)
This version need a lot of memory if do not dwonsample.
"""
def __init__(self,
num_classes=1,
rmi_radius=3,
rmi_pool_way=0,
rmi_pool_size=3,
rmi_pool_stride=3,
loss_weight_lambda=0.5,
lambda_way=1,
device="cuda", **_):
super(RMILoss, self).__init__()
self._CLIP_MIN = 1e-6 # min clip value after softmax or sigmoid operations
self._CLIP_MAX = 1.0 # max clip value after softmax or sigmoid operations
self._POS_ALPHA = 5e-4 # add this factor to ensure the AA^T is positive definite
self._IS_SUM = 1 # sum the loss per channel
self.num_classes = num_classes
# radius choices
if rmi_radius not in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]:
raise AssertionError
self.rmi_radius = rmi_radius
if rmi_pool_way not in [0, 1, 2, 3]:
raise AssertionError
self.rmi_pool_way = rmi_pool_way
# set the pool_size = rmi_pool_stride
if rmi_pool_size != rmi_pool_stride:
raise AssertionError
self.rmi_pool_size = rmi_pool_size
self.rmi_pool_stride = rmi_pool_stride
self.weight_lambda = loss_weight_lambda
self.lambda_way = lambda_way
# dimension of the distribution
self.half_d = self.rmi_radius * self.rmi_radius
self.d = 2 * self.half_d
self.kernel_padding = self.rmi_pool_size // 2
# ignore class
self.ignore_index = 255
self.device = device
[docs] def forward(self, logits, labels, **_):
if self.num_classes == 1:
loss = self.forward_sigmoid(logits, labels)
else:
loss = self.forward_softmax_sigmoid(logits, labels)
return loss
[docs] def forward_softmax_sigmoid(self, inputs, targets):
"""
Using both softmax and sigmoid operations.
Args:
inputs : [N, C, H, W], dtype=float32
targets : [N, H, W], dtype=long
"""
# PART I -- get the normal cross entropy loss
normal_loss = F.cross_entropy(input=inputs,
target=targets.long(),
ignore_index=self.ignore_index,
reduction='mean')
# PART II -- get the lower bound of the region mutual information
# get the valid label and logits
# valid label, [N, C, H, W]
label_mask_3D = targets < self.num_classes
valid_onehot_labels_4D = F.one_hot(targets.long() * label_mask_3D.long(),
num_classes=self.num_classes).float()
label_mask_3D = label_mask_3D.float()
valid_onehot_labels_4D = valid_onehot_labels_4D * label_mask_3D.unsqueeze(dim=3)
valid_onehot_labels_4D = valid_onehot_labels_4D.permute(0, 3, 1, 2).requires_grad_(False)
# valid probs
probs_4D = torch.sigmoid(inputs) * label_mask_3D.unsqueeze(dim=1)
probs_4D = probs_4D.clamp(min=self._CLIP_MIN, max=self._CLIP_MAX)
# get region mutual information
rmi_loss = self.rmi_lower_bound(valid_onehot_labels_4D, probs_4D)
# add together
final_loss = (self.weight_lambda * normal_loss + rmi_loss * (1 - self.weight_lambda) if self.lambda_way
else normal_loss + rmi_loss * self.weight_lambda)
return final_loss
[docs] def forward_sigmoid(self, logits_4D, labels_4D):
"""
Using the sigmiod operation both.
Args:
logits_4D : [N, C, H, W], dtype=float32
labels_4D : [N, H, W], dtype=long
"""
# label mask -- [N, H, W, 1]
label_mask_3D = labels_4D < self.num_classes
# valid label
valid_onehot_labels_4D = F.one_hot(labels_4D.long() * label_mask_3D.long(),
num_classes=self.num_classes).float()
label_mask_3D = label_mask_3D.float()
label_mask_flat = label_mask_3D.view([-1, ])
valid_onehot_labels_4D = valid_onehot_labels_4D * label_mask_3D.unsqueeze(dim=3)
valid_onehot_labels_4D.requires_grad_(False)
# PART I -- calculate the sigmoid binary cross entropy loss
valid_onehot_label_flat = valid_onehot_labels_4D.view([-1, self.num_classes]).requires_grad_(False)
logits_flat = logits_4D.permute(0, 2, 3, 1).contiguous().view([-1, self.num_classes])
# binary loss, multiplied by the not_ignore_mask
valid_pixels = torch.sum(label_mask_flat)
binary_loss = F.binary_cross_entropy_with_logits(logits_flat,
target=valid_onehot_label_flat,
weight=label_mask_flat.unsqueeze(dim=1),
reduction='sum')
bce_loss = torch.div(binary_loss, valid_pixels + 1.0)
# PART II -- get rmi loss
# onehot_labels_4D -- [N, C, H, W]
probs_4D = logits_4D.sigmoid() * label_mask_3D.unsqueeze(dim=1) + self._CLIP_MIN
valid_onehot_labels_4D = valid_onehot_labels_4D.permute(0, 3, 1, 2).requires_grad_(False)
# get region mutual information
rmi_loss = self.rmi_lower_bound(valid_onehot_labels_4D, probs_4D)
# add together
final_loss = (self.weight_lambda * bce_loss + rmi_loss * (1 - self.weight_lambda) if self.lambda_way
else bce_loss + rmi_loss * self.weight_lambda)
return final_loss
[docs] def rmi_lower_bound(self, labels_4D, probs_4D):
"""
calculate the lower bound of the region mutual information.
Args:
labels_4D : [N, C, H, W], dtype=float32
probs_4D : [N, C, H, W], dtype=float32
"""
if labels_4D.size() != probs_4D.size():
raise AssertionError
p, s = self.rmi_pool_size, self.rmi_pool_stride
if self.rmi_pool_stride > 1:
if self.rmi_pool_way == 0:
labels_4D = F.max_pool2d(labels_4D, kernel_size=p, stride=s, padding=self.kernel_padding)
probs_4D = F.max_pool2d(probs_4D, kernel_size=p, stride=s, padding=self.kernel_padding)
elif self.rmi_pool_way == 1:
labels_4D = F.avg_pool2d(labels_4D, kernel_size=p, stride=s, padding=self.kernel_padding)
probs_4D = F.avg_pool2d(probs_4D, kernel_size=p, stride=s, padding=self.kernel_padding)
elif self.rmi_pool_way == 2:
# interpolation
shape = labels_4D.size()
new_h, new_w = shape[2] // s, shape[3] // s
labels_4D = F.interpolate(labels_4D, size=(new_h, new_w), mode='nearest')
probs_4D = F.interpolate(probs_4D, size=(new_h, new_w), mode='bilinear', align_corners=True)
else:
raise NotImplementedError("Pool way of RMI is not defined!")
# we do not need the gradient of label.
label_shape = labels_4D.size()
n, c = label_shape[0], label_shape[1]
# combine the high dimension points from label and probability map. new shape [N, C, radius * radius, H, W]
la_vectors, pr_vectors = self.map_get_pairs(labels_4D, probs_4D, radius=self.rmi_radius, is_combine=0)
la_vectors = la_vectors.view([n, c, self.half_d, -1]).type(torch.double).to(self.device).requires_grad_(False)
pr_vectors = pr_vectors.view([n, c, self.half_d, -1]).type(torch.double).to(self.device)
# small diagonal matrix, shape = [1, 1, radius * radius, radius * radius]
diag_matrix = torch.eye(self.half_d).unsqueeze(dim=0).unsqueeze(dim=0)
# the mean and covariance of these high dimension points
# Var(X) = E(X^2) - E(X) E(X), N * Var(X) = X^2 - X E(X)
la_vectors = la_vectors - la_vectors.mean(dim=3, keepdim=True)
la_cov = torch.matmul(la_vectors, la_vectors.transpose(2, 3))
pr_vectors = pr_vectors - pr_vectors.mean(dim=3, keepdim=True)
pr_cov = torch.matmul(pr_vectors, pr_vectors.transpose(2, 3))
# https://github.com/pytorch/pytorch/issues/7500
# waiting for batched torch.cholesky_inverse()
pr_cov_inv = torch.inverse(pr_cov + diag_matrix.type_as(pr_cov) * self._POS_ALPHA)
# if the dimension of the point is less than 9, you can use the below function
# to acceleration computational speed.
# pr_cov_inv = utils.batch_cholesky_inverse(pr_cov + diag_matrix.type_as(pr_cov) * _POS_ALPHA)
la_pr_cov = torch.matmul(la_vectors, pr_vectors.transpose(2, 3))
# the approxiamation of the variance, det(c A) = c^n det(A), A is in n x n shape;
# then log det(c A) = n log(c) + log det(A).
# appro_var = appro_var / n_points, we do not divide the appro_var by number of points here,
# and the purpose is to avoid underflow issue.
# If A = A^T, A^-1 = (A^-1)^T.
appro_var = la_cov - torch.matmul(la_pr_cov.matmul(pr_cov_inv), la_pr_cov.transpose(-2, -1))
# appro_var = la_cov - torch.chain_matmul(la_pr_cov, pr_cov_inv, la_pr_cov.transpose(-2, -1))
# appro_var = torch.div(appro_var, n_points.type_as(appro_var)) + diag_matrix.type_as(appro_var) * 1e-6
# The lower bound. If A is nonsingular, ln( det(A) ) = Tr( ln(A) ).
rmi_now = 0.5 * self.log_det_by_cholesky(appro_var + diag_matrix.type_as(appro_var) * self._POS_ALPHA)
# rmi_now = 0.5 * torch.logdet(appro_var + diag_matrix.type_as(appro_var) * _POS_ALPHA)
# mean over N samples. sum over classes.
rmi_per_class = rmi_now.view([-1, self.num_classes]).mean(dim=0).float()
# is_half = False
# if is_half:
# rmi_per_class = torch.div(rmi_per_class, float(self.half_d / 2.0))
# else:
rmi_per_class = torch.div(rmi_per_class, float(self.half_d))
rmi_loss = torch.sum(rmi_per_class) if self._IS_SUM else torch.mean(rmi_per_class)
return rmi_loss
[docs] @staticmethod
def map_get_pairs(labels_4D, probs_4D, radius=3, is_combine=True):
"""get map pairs
Args:
labels_4D : labels, shape [N, C, H, W]
probs_4D : probabilities, shape [N, C, H, W]
radius : the square radius
Return:
tensor with shape [N, C, radius * radius, H - (radius - 1), W - (radius - 1)]
"""
# pad to ensure the following slice operation is valid
# pad_beg = int(radius // 2)
# pad_end = radius - pad_beg
# the original height and width
label_shape = labels_4D.size()
h, w = label_shape[2], label_shape[3]
new_h, new_w = h - (radius - 1), w - (radius - 1)
# https://pytorch.org/docs/stable/nn.html?highlight=f%20pad#torch.nn.functional.pad
# padding = (pad_beg, pad_end, pad_beg, pad_end)
# labels_4D, probs_4D = F.pad(labels_4D, padding), F.pad(probs_4D, padding)
# get the neighbors
la_ns = []
pr_ns = []
# for x in range(0, radius, 1):
for y in range(0, radius, 1):
for x in range(0, radius, 1):
la_now = labels_4D[:, :, y:y + new_h, x:x + new_w]
pr_now = probs_4D[:, :, y:y + new_h, x:x + new_w]
la_ns.append(la_now)
pr_ns.append(pr_now)
if is_combine:
# for calculating RMI
pair_ns = la_ns + pr_ns
p_vectors = torch.stack(pair_ns, dim=2)
return p_vectors
else:
# for other purpose
la_vectors = torch.stack(la_ns, dim=2)
pr_vectors = torch.stack(pr_ns, dim=2)
return la_vectors, pr_vectors
[docs] @staticmethod
def log_det_by_cholesky(matrix):
"""
Args:
matrix: matrix must be a positive define matrix.
shape [N, C, D, D].
Ref:
https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/ops/linalg/linalg_impl.py
"""
# This uses the property that the log det(A) = 2 * sum(log(real(diag(C))))
# where C is the cholesky decomposition of A.
chol = torch.cholesky(matrix)
# return 2.0 * torch.sum(torch.log(torch.diagonal(chol, dim1=-2, dim2=-1) + 1e-6), dim=-1)
return 2.0 * torch.sum(torch.log(torch.diagonal(chol, dim1=-2, dim2=-1) + 1e-8), dim=-1)
# ===================== #
# Source: https://github.com/RElbers/region-mutual-information-pytorch
# Segmentation loss (memory intensive)
[docs]class RMILossAlt(nn.Module):
"""
PyTorch Module which calculates the Region Mutual Information loss (https://arxiv.org/abs/1910.12037).
"""
def __init__(self,
with_logits,
radius=3,
bce_weight=0.5,
downsampling_method='max',
stride=3,
use_log_trace=True,
use_double_precision=True,
epsilon=0.0005, **_):
"""
:param with_logits:
If True, apply the sigmoid function to the prediction before calculating loss.
:param radius:
RMI radius.
:param bce_weight:
Weight of the binary cross entropy. Must be between 0 and 1.
:param downsampling_method:
Downsampling method used before calculating RMI. Must be one of ['avg', 'max', 'region-extraction'].
If 'region-extraction', then downscaling is done during the region extraction phase. Meaning that the stride is the spacing between consecutive regions.
:param stride:
Stride used for downsampling.
:param use_log_trace:
Whether to calculate the log of the trace, instead of the log of the determinant. See equation (15).
:param use_double_precision:
Calculate the RMI using doubles in order to fix potential numerical issues.
:param epsilon:
Magnitude of the entries added to the diagonal of M in order to fix potential numerical issues.
"""
super().__init__()
self.use_double_precision = use_double_precision
self.with_logits = with_logits
self.bce_weight = bce_weight
self.stride = stride
self.downsampling_method = downsampling_method
self.radius = radius
self.use_log_trace = use_log_trace
self.epsilon = epsilon
[docs] def forward(self, logits, labels, **_):
labels = labels.unsqueeze(1)
# Calculate BCE if needed
if self.bce_weight != 0:
if self.with_logits:
bce = F.binary_cross_entropy_with_logits(logits, target=labels)
else:
bce = F.binary_cross_entropy(logits, target=labels)
bce = bce.mean() * self.bce_weight
else:
bce = 0.0
# Apply sigmoid to get probabilities. See final paragraph of section 4.
if self.with_logits:
logits = torch.sigmoid(logits)
# Calculate RMI loss
rmi = self.rmi_loss(input_=logits, target=labels)
rmi = rmi.mean() * (1.0 - self.bce_weight)
return rmi + bce
[docs] def rmi_loss(self, input_, target):
"""
Calculates the RMI loss between the prediction and target.
:return:
RMI loss
"""
if input_.shape != target.shape:
raise AssertionError
vector_size = self.radius * self.radius
# Get region vectors
y = self.extract_region_vector(target)
p = self.extract_region_vector(input_)
# Convert to doubles for better precision
if self.use_double_precision:
y = y.double()
p = p.double()
# Small diagonal matrix to fix numerical issues
eps = torch.eye(vector_size, dtype=y.dtype, device=y.device) * self.epsilon
eps = eps.unsqueeze(dim=0).unsqueeze(dim=0)
# Subtract mean
y = y - y.mean(dim=3, keepdim=True)
p = p - p.mean(dim=3, keepdim=True)
# Covariances
y_cov = y @ transpose(y)
p_cov = p @ transpose(p)
y_p_cov = y @ transpose(p)
# Approximated posterior covariance matrix of Y given P
m = y_cov - y_p_cov @ transpose(inverse(p_cov + eps)) @ transpose(y_p_cov)
# Lower bound of RMI
if self.use_log_trace:
rmi = 0.5 * log_trace(m + eps)
else:
rmi = 0.5 * log_det(m + eps)
# Normalize
rmi = rmi / float(vector_size)
# Sum over classes, mean over samples.
return rmi.sum(dim=1).mean(dim=0)
[docs] def downsample(self, x):
# Skip if stride is 1
if self.stride == 1:
return x
# Skip if we pool during region extraction.
if self.downsampling_method == 'region-extraction':
return x
padding = self.stride // 2
if self.downsampling_method == 'max':
return F.max_pool2d(x, kernel_size=self.stride, stride=self.stride, padding=padding)
if self.downsampling_method == 'avg':
return F.avg_pool2d(x, kernel_size=self.stride, stride=self.stride, padding=padding)
raise ValueError(self.downsampling_method)
def transpose(x):
return x.transpose(-2, -1)
def inverse(x):
return torch.inverse(x)
def log_trace(x):
x = torch.cholesky(x)
diag = torch.diagonal(x, dim1=-2, dim2=-1)
return 2 * torch.sum(torch.log(diag + 1e-8), dim=-1)
def log_det(x):
return torch.logdet(x)
# ====================== #
# Source: https://github.com/NRCan/geo-deep-learning/blob/develop/losses/boundary_loss.py
[docs]class BoundaryLoss(nn.Module):
"""Boundary Loss proposed in:
Alexey Bokhovkin et al., Boundary Loss for Remote Sensing Imagery Semantic Segmentation
https://arxiv.org/abs/1905.07852
"""
# in previous implementations theta0=3, theta=5
def __init__(self, theta0=19, theta=19, ignore_index=None, weight=None, is_binary: bool = False, **_):
super().__init__()
self.theta0 = theta0
self.theta = theta
self.ignore_index = ignore_index
self.weight = weight
self.is_binary = is_binary
[docs] def forward(self, logits, labels, **_):
"""
Input:
- logits: the output from model (before softmax)
shape (N, C, H, W)
- labels: ground truth map
shape (N, H, w)
Return:
- boundary loss, averaged over mini-batch
"""
n, c, _, _ = logits.shape
# sigmoid / softmax so that predicted map can be distributed in [0, 1]
if self.is_binary:
logits = torch.sigmoid(logits)
else:
logits = torch.softmax(logits, dim=1)
# one-hot vector of ground truth
# print(gt.shape)
# zo = F.one_hot(gt, c)
# print(zo.shape)
if self.is_binary:
one_hot_gt = labels
else:
one_hot_gt = F.one_hot(labels.long()).permute(0, 3, 1, 2).squeeze(dim=-1).contiguous().float()
# boundary map
gt_b = F.max_pool2d(1 - one_hot_gt, kernel_size=self.theta0, stride=1, padding=(self.theta0 - 1) // 2)
gt_b -= 1 - one_hot_gt
pred_b = F.max_pool2d(1 - logits, kernel_size=self.theta0, stride=1, padding=(self.theta0 - 1) // 2)
pred_b -= 1 - logits
# extended boundary map
gt_b_ext = F.max_pool2d(gt_b, kernel_size=self.theta, stride=1, padding=(self.theta - 1) // 2)
pred_b_ext = F.max_pool2d(pred_b, kernel_size=self.theta, stride=1, padding=(self.theta - 1) // 2)
# reshape
gt_b = gt_b.view(n, c, -1)
pred_b = pred_b.view(n, c, -1)
gt_b_ext = gt_b_ext.view(n, c, -1)
pred_b_ext = pred_b_ext.view(n, c, -1)
# Precision, Recall
eps = 1e-7
P = (torch.sum(pred_b * gt_b_ext, dim=2) + eps) / (torch.sum(pred_b, dim=2) + eps)
R = (torch.sum(pred_b_ext * gt_b, dim=2) + eps) / (torch.sum(gt_b, dim=2) + eps)
# Boundary F1 Score
BF1 = (2 * P * R + eps) / (P + R + eps)
# summing BF1 Score for each class and average over mini-batch
loss = torch.mean(1 - BF1)
return loss
# ====================== #
"""
Hausdorff loss implementation based on paper:
https://arxiv.org/pdf/1904.10030.pdf
copy pasted from - all credit goes to original authors:
https://github.com/SilmarilBearer/HausdorffLoss
"""
from scipy.ndimage.morphology import distance_transform_edt as edt
from scipy.ndimage import convolve
import cv2
[docs]class HausdorffDTLoss(nn.Module):
"""Binary Hausdorff loss based on distance transform"""
def __init__(self, alpha=2.0, **_):
super(HausdorffDTLoss, self).__init__()
self.alpha = alpha
@torch.no_grad()
@staticmethod
def distance_field(img: np.ndarray) -> np.ndarray:
field = np.zeros_like(img)
for batch in range(len(img)):
fg_mask = img[batch] > 0.5
if fg_mask.any():
bg_mask = ~fg_mask
fg_dist = edt(fg_mask)
bg_dist = edt(bg_mask)
field[batch] = fg_dist + bg_dist
return field
[docs] def forward(self, logits: torch.Tensor, labels: torch.Tensor, debug=False, **_) -> torch.Tensor:
"""
Uses one binary channel: 1 - fg, 0 - bg
pred: (b, 1, x, y, z) or (b, 1, x, y)
target: (b, 1, x, y, z) or (b, 1, x, y)
"""
labels = labels.unsqueeze(1)
if logits.dim() not in (4, 5):
raise AssertionError("Only 2D and 3D supported")
if (logits.dim() != labels.dim()):
raise AssertionError("Prediction and target need to be of same dimension")
# this is necessary for binary loss
logits = torch.sigmoid(logits)
pred_dt = torch.from_numpy(self.distance_field(logits.detach().cpu().numpy())).float()
target_dt = torch.from_numpy(self.distance_field(labels.detach().cpu().numpy())).float()
pred_error = (logits - labels) ** 2
distance = pred_dt.to(logits.device) ** self.alpha + target_dt.to(logits.device) ** self.alpha
dt_field = pred_error * distance
loss = dt_field.mean()
if debug:
return (
loss.detach().cpu().numpy(),
(
dt_field.detach().cpu().numpy()[0, 0],
pred_error.detach().cpu().numpy()[0, 0],
distance.detach().cpu().numpy()[0, 0],
pred_dt.detach().cpu().numpy()[0, 0],
target_dt.detach().cpu().numpy()[0, 0],
),
)
else:
return loss
[docs]class HausdorffERLoss(nn.Module):
"""Binary Hausdorff loss based on morphological erosion"""
def __init__(self, alpha=2.0, erosions=10, **kwargs):
super(HausdorffERLoss, self).__init__()
self.alpha = alpha
self.erosions = erosions
self.prepare_kernels()
[docs] def prepare_kernels(self):
cross = np.array([cv2.getStructuringElement(cv2.MORPH_CROSS, (3, 3))])
bound = np.array([[[0, 0, 0], [0, 1, 0], [0, 0, 0]]])
self.kernel2D = cross * 0.2
self.kernel3D = np.array([bound, cross, bound]) * (1 / 7)
[docs] def forward(self, pred: torch.Tensor, target: torch.Tensor, debug=False) -> torch.Tensor:
"""
Uses one binary channel: 1 - fg, 0 - bg
pred: (b, 1, x, y, z) or (b, 1, x, y)
target: (b, 1, x, y, z) or (b, 1, x, y)
"""
target = target.unsqueeze(1)
if pred.dim() not in (4, 5):
raise AssertionError("Only 2D and 3D supported")
if (pred.dim() != target.dim()):
raise AssertionError("Prediction and target need to be of same dimension")
pred = torch.sigmoid(pred)
if debug:
eroted, erosions = self.perform_erosion(pred.detach().cpu().numpy(), target.detach().cpu().numpy(), debug)
return eroted.mean(), erosions
else:
eroted = torch.from_numpy(self.perform_erosion(pred.detach().cpu().numpy(), target.detach().cpu().numpy(), debug)).float()
loss = eroted.mean()
return loss
# ====================== #
"""
Recall Loss
copy pasted from - all credit goes to original authors:
https://github.com/shuaizzZ/Recall-Loss-PyTorch/blob/master/recall_loss.py
"""
[docs]class RecallLoss(nn.Module):
""" An unofficial implementation of
<Recall Loss for Imbalanced Image Classification and Semantic Segmentation>
Created by: Zhang Shuai
Email: shuaizzz666@gmail.com
recall = TP / (TP + FN)
Args:
weight: An array of shape [C,]
predict: A float32 tensor of shape [N, C, *], for Semantic segmentation task is [N, C, H, W]
target: A int64 tensor of shape [N, *], for Semantic segmentation task is [N, H, W]
Return:
diceloss
"""
def __init__(self, weight=None, **_):
super(RecallLoss, self).__init__()
if weight is not None:
weight = torch.Tensor(weight)
self.weight = weight / torch.sum(weight) # Normalized weight
self.smooth = 1e-5
[docs] def forward(self, logits, labels, **_):
N, C = logits.size()[:2]
_, predict = torch.max(logits, 1)# # (N, C, *) ==> (N, 1, *)
predict = predict.view(N, 1, -1) # (N, 1, *)
labels = labels.view(N, 1, -1) # (N, 1, *)
last_size = labels.size(-1)
## convert predict & target (N, 1, *) into one hot vector (N, C, *)
predict_onehot = torch.zeros((N, C, last_size)).cuda() # (N, 1, *) ==> (N, C, *)
predict_onehot.scatter_(1, predict, 1) # (N, C, *)
target_onehot = torch.zeros((N, C, last_size)).cuda() # (N, 1, *) ==> (N, C, *)
target_onehot.scatter_(1, labels, 1) # (N, C, *)
true_positive = torch.sum(predict_onehot * target_onehot, dim=2) # (N, C)
total_target = torch.sum(target_onehot, dim=2) # (N, C)
## Recall = TP / (TP + FN)
recall = (true_positive + self.smooth) / (total_target + self.smooth) # (N, C)
if hasattr(self, 'weight'):
if self.weight.type() != logits.type():
self.weight = self.weight.type_as(logits)
recall = recall * self.weight * C # (N, C)
recall_loss = 1 - torch.mean(recall) # 1
return recall_loss
# ====================== #
[docs]class SoftInvDiceLoss(torch.nn.Module):
"""
Well-performing loss for binary segmentation
"""
def __init__(self, smooth=1., is_binary=True, **_):
super(SoftInvDiceLoss, self).__init__()
self.smooth = smooth
self.is_binary = is_binary
[docs] def forward(self, logits, labels, **_):
# sigmoid / softmax so that predicted map can be distributed in [0, 1]
if self.is_binary:
logits = torch.sigmoid(logits)
else:
logits = torch.softmax(logits, dim=1)
iflat = 1 - logits.view(-1)
tflat = 1 - labels.view(-1)
intersection = (iflat * tflat).sum()
return 1 - ((2. * intersection + self.smooth) / (iflat.sum() + tflat.sum() + self.smooth))
# ======================= #
# --- COMBINED LOSSES --- #
[docs]class OhemBCEDicePenalizeBorderLoss(OhemCrossEntropy2d):
"""
Combined OHEM (Online Hard Example Mining) process with BCE-Dice penalized loss
"""
def __init__(self, thresh=0.6, min_kept=0, ignore_index=-100, kernel_size=21, **_):
super().__init__()
self.ignore_label = ignore_index
self.thresh = float(thresh)
self.min_kept = int(min_kept)
self.criterion = BCEDicePenalizeBorderLoss(kernel_size=kernel_size)
[docs]class RMIBCEDicePenalizeBorderLoss(RMILossAlt):
"""
Combined RMI and BCEDicePenalized Loss
"""
def __init__(self, kernel_size=21, rmi_weight=1.0, bce_weight=1.0, **kwargs):
super().__init__(**kwargs)
self.bce = BCEDicePenalizeBorderLoss(kernel_size=kernel_size)
self.bce_weight = bce_weight
self.rmi_weight = rmi_weight
[docs] def to(self, device):
super().to(device=device)
self.bce.to(device=device)
[docs] def forward(self, logits, labels, **_):
if labels.shape != logits.shape:
if logits.shape > labels.shape:
labels.unsqueeze(dim=1)
else:
raise Exception(f'Non-matching shapes for logits ({logits.shape}) and labels ({labels.shape})')
# Calculate RMI loss
rmi = self.rmi_loss(input_=torch.sigmoid(logits), target=labels)
bce = self.bce(logits, labels)
# rmi = rmi.mean() * (1.0 - self.bce_weight)
return self.rmi_weight * rmi + self.bce_weight * bce