This is how an AI learns how to classify my dog.
I’ve found a blog post visualizing the feature maps of a classification model, so I tried it out!
I’m going to use Resnet18 and visualize the feature maps for the 17 convolution layer inside the network.
import torch
import torch.nn as nn
from torchvision import models, transforms
import matplotlib.pyplot as plt
from PIL import Image
import pathlib
#Transform for inference
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=0., std=1.)
])
#Transform for Displaying
transform2 = transforms.Compose([
transforms.Resize((224, 224)),
])
#Load Image
ori_image = Image.open(str(pathlib.Path.cwd() / 'co.jpg'))
plt.imshow(ori_image)
#Load Model
loaded_model = models.resnet18(pretrained=True)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
loaded_model = loaded_model.to(device)
#Process Image
processed_ori_image = transform2(ori_image)
image = transform(ori_image).unsqueeze(0).to(device)
# List to save conv layer weights
model_weights =[]
# List to save conv layer
conv_layers = []
# get all the model children as list
model_children = list(loaded_model.children())
#counter to keep count of the conv layers
counter = 0
#append all the conv layers and their respective wights to the list
for i in range(len(model_children)):
if type(model_children[i]) == nn.Conv2d:
counter+=1
model_weights.append(model_children[i].weight)
conv_layers.append(model_children[i])
elif type(model_children[i]) == nn.Sequential:
for j in range(len(model_children[i])):
for child in model_children[i][j].children():
if type(child) == nn.Conv2d:
counter+=1
model_weights.append(child.weight)
conv_layers.append(child)
#Append output from each layer
outputs = []
names = []
for layer in conv_layers[0:]:
image = layer(image)
outputs.append(image)
names.append(str(layer))
#Convert 3D tensor to 2D
processed = []
for feature_map in outputs:
feature_map = feature_map.squeeze(0)
gray_scale = torch.sum(feature_map,0)
gray_scale = gray_scale / feature_map.shape[0]
processed.append(gray_scale.data.cpu().numpy())
#Plot Feature maps
fig = plt.figure(figsize=(50, 30))
a = fig.add_subplot(3, 6, 1)
imgplot = plt.imshow(processed_ori_image)
a.axis("off")
a.set_title('Original Image', fontsize=40)
for i in range(len(processed)):
a = fig.add_subplot(3, 6, i+2)
imgplot = plt.imshow(processed[i], cmap='gist_gray')
a.axis("off")
a.set_title(names[i].split('(')[0] +" :"+ str(i)+ ' '+str(processed[i].shape), fontsize=40)
fig.tight_layout()
plt.savefig(str('feature_maps.jpg'), bbox_inches='tight')
Reference; Blog Post