164. Combining Losses

Here is one way to combine losses using PyTorch.

Let’s say we want to train a binary semantic segmentation model using Binary Cross Entropy Loss. When you look at your data, you’ve noticed that the data is quite imbalanced. So you google other loss options for this case, and you find out Dice Loss to be a good fit. How can you combine cross entropy loss and dice loss?


#Imports import numpy import torch import torch.nn as nn import torch.nn.functional as F # Binary Cross Entropy Loss + Dice class DiceBCELogitsLoss(nn.Module): def __init__(self, weight=None, size_average=True): super(DiceBCELogitsLoss, self).__init__() def forward(self, inputs, targets, smooth=1): #Flatten model output flattened_input = torch.flatten(inputs) #Scale down to 0~1 scaled_inputs = torch.sigmoid(flattened_input) #Flatten annotation data targets = targets.view(-1) #Calculate Dice_loss intersection = (scaled_inputs * targets).sum() dice_loss = 1 - (2.*intersection + smooth)/(scaled_inputs.sum() + targets.sum() + smooth) #Calculate Binary Cross Entropy BCE = F.binary_cross_entropy_with_logits(flattened_input, targets.float(), reduction='mean') #Combine Dice_BCE_LOGITS = BCE + dice_loss return Dice_BCE_LOGITS #Create Instance criterion = DiceBCELogitsLoss() #Calculate Loss calculated_loss = criterion(model_output, annotation) #model_output:[Batch_size,Img_Height,Img_Width] #annotation:[Batch_size,Img_Height,Img_Width]