I’d like to share 2 different ways to save and load a model using Pytorch.
Saving The Entire Model
#save model
torch.save(model, PATH)
#load model
model = torch.load(PATH)
model.eval()
This save/load process has the least amount of code to implement. However, since this method is saving the entire module using Python’s Pickle module, the serialized data is bound to the specific class and requires you to load the model with the exact directory structure used when saving the model. This can be a disadvantage when you want to use your model in other projects after refactors.
Saving Using state_dict
#save model
torch.save(model.state_dict(), PATH)
#load model
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()
This save/load process is the one recommended by Pytorch. In Pytorch, the learnable parameters are contained in the model’s parameters: model.parameters(). A state_dict is a Python dictionary object mapping each layer to its parameters. Because of this, they can be easily saved and modified.
Using this save/load method will provide you with the most flexibility for restoring models later on.