▮ Elements
Metric learning aims to measure the similarity between samples while using distance metrics for learning. Due to a survey in 2019, this field seems to become more and more important. So for this post, I’d like to share the 3 main elements of deep metric learning.
- Model Network Structure
- Metric Loss Function
- Informative Input Sample Selection
▮ Model Network Structure
I’ve created a post on the differences, including the structure, between plain classification and classification using metric learning.
If you are interested please check that out.
Post: 393. Different Approaches For Image Classification
▮ Metric Loss Function
Numerous types of loss functions can be used for deep metric learning. For this post, I’d like to share 2 popular losses among them.
Contrastive Loss
The contrastive loss inputs 2 pairs of images. If the pair are similar, it re-embeds the pair to closer positions. On the other hand, if the pair are dissimilar, it re-embeds the pair to farther positions.
This means the distances between all the images that are similar will eventually become close to 0. This can help distinguish between inter-classes relatively explicitly but makes it hard to do intra-class classification. It can understand if the image is a dog or cat explicitly, but can’t distinguish whether it is a sleeping dog or a running dog.
Triplet Loss
Triplet loss is a less greedy approach compared to the contrastive loss. The triplet loss inputs a set of 3 images; an “anchor” image, a “positive” image that is similar to the anchor image, and a “negative” image that is dissimilar to the anchor image. This method encourages dissimilar pairs to become more distant from any pairs that are similar by at least a certain margin value. Since this approach is not trying to re-embed similar images to the same point in the vector space, it can tolerate some intra-class variance.
▮ Informative Input Sample Selection
In addition to the 2 elements previously stated, informative sample selection also plays a huge role in tasks such as classification and clustering. If this sampling is done poorly it can slow down, or even stop, the learning process.
For example, let’s say we are sampling the inputs for a multi-class classification model which uses the triplet loss and the currently learned information is embedded like in the figure below.
We can tell that embeddings A and B are much harder to distinguish compared to embeddings A and E. In extreme terms, if the sampling method constantly selects Image A and Image E for training, these inputs would not be that informative because it’s already pretty obvious that those two images are from 2 completely different classes. It would be much better if the sampling method selects images that are harder to distinguish such as Image A and Image B.
For this reason, properly “mining” the right samples becomes critical.
Mining Methods
There are mainly 3 methods for triplet mining.
- Hard Negative Mining
Get false-positive samples that are closer than the true-positive sample -
Semi-hard Negative Mining
Get false-positive samples that are farther away from a true-positive sample but within the margin - Easy Negative Mining
Get false-positive samples that are farther away even from the margin