Here is one way to combine losses using PyTorch.
Let’s say we want to train a binary semantic segmentation model using Binary Cross Entropy Loss. When you look at your data, you’ve noticed that the data is quite imbalanced. So you google other loss options for this case, and you find out Dice Loss to be a good fit. How can you combine cross entropy loss and dice loss?
#Imports
import numpy
import torch
import torch.nn as nn
import torch.nn.functional as F
# Binary Cross Entropy Loss + Dice
class DiceBCELogitsLoss(nn.Module):
def __init__(self, weight=None, size_average=True):
super(DiceBCELogitsLoss, self).__init__()
def forward(self, inputs, targets, smooth=1):
#Flatten model output
flattened_input = torch.flatten(inputs)
#Scale down to 0~1
scaled_inputs = torch.sigmoid(flattened_input)
#Flatten annotation data
targets = targets.view(-1)
#Calculate Dice_loss
intersection = (scaled_inputs * targets).sum()
dice_loss = 1 - (2.*intersection + smooth)/(scaled_inputs.sum() + targets.sum() + smooth)
#Calculate Binary Cross Entropy
BCE = F.binary_cross_entropy_with_logits(flattened_input, targets.float(), reduction='mean')
#Combine
Dice_BCE_LOGITS = BCE + dice_loss
return Dice_BCE_LOGITS
#Create Instance
criterion = DiceBCELogitsLoss()
#Calculate Loss
calculated_loss = criterion(model_output, annotation)
#model_output:[Batch_size,Img_Height,Img_Width]
#annotation:[Batch_size,Img_Height,Img_Width]