Source code for medicaltorch.losses

import torch
from torch.nn import Module

[docs]def dice_loss(input, target): """Dice loss. :param input: The input (predicted) :param target: The target (ground truth) :returns: the Dice score between 0 and 1. """ eps = 0.0001 iflat = input.view(-1) tflat = target.view(-1) intersection = (iflat * tflat).sum() union = iflat.sum() + tflat.sum() dice = (2.0 * intersection + eps) / (union + eps) return - dice
[docs]class MaskedDiceLoss(Module): """A masked version of the Dice loss. :param ignore_value: the value to ignore. """ def __init__(self, ignore_value=-100.0): super().__init__() self.ignore_value = ignore_value def forward(self, input, target): eps = 0.0001 masking = target == self.ignore_value masking = masking.sum(3).sum(2) masking = masking == 0 masking = masking.squeeze() labeled_target = target.index_select(0, masking.nonzero().squeeze()) labeled_input = input.index_select(0, masking.nonzero().squeeze()) iflat = labeled_input.view(-1) tflat = labeled_target.view(-1) intersection = (iflat * tflat).sum() union = iflat.sum() + tflat.sum() dice = (2.0 * intersection + eps) / (union + eps) return - dice
class ConfidentMSELoss(Module): def __init__(self, threshold=0.96): self.threshold = threshold super().__init__() def forward(self, input, target): n = input.size(0) conf_mask = torch.gt(target, self.threshold).float() input_flat = input.view(n, -1) target_flat = target.view(n, -1) conf_mask_flat = conf_mask.view(n, -1) diff = (input_flat - target_flat)**2 diff_conf = diff * conf_mask_flat loss = diff_conf.mean() return loss