176. CrossEntropyLoss for Segmentation Models

torch.nn.CrossEntropyLoss()

Using torch.nn.CrossEntropyLoss() as a loss function for semantic segmentation models was first confusing for me, so I’d like to share it here.

CrossEntropyLoss is for multi-class models and it expects at least 2 arguments. One for the model prediction and one for label data.

Arg1: Model Prediction

A semantic segmentation model outputs a tensor shape of [Batch_size, Number_of_classes, Img_Height,Img_Width] if you are using Pytorch. (Img_Height and Img_Width may differ from the input image dimensions depending on which architecture you are using.) CrossEntropyLoss expects you to send this tensor shape, without any scaling, as an argument for model prediction.

Arg2: Label data

This was the part I got confused. For label data, after you’ve read the label image, you’ll need to convert that to a tensor shape of [Batch_size, Img_Height, Img_Width]. Each pixel value should correspond to each of the class indexes.

Code
criterion = torch.nn.CrossEntropyLoss()
calculated_loss = criterion(model_output, annotation)
#model_output:[Batch_size,Num_of_Classes,Img_Height,Img_Width]
#annotation:[Batch_size,Img_Height,Img_Width]