Implementation
Saving checkpoints for model weights during training can be helpful in the case of the following examples.
- Want to resume training later
- Avoid losing weight data when the process stops during training due to some kind of error.
- Restore weights depending on the number of epochs.
Here is one way to implement using Pytorch
Saving Checkpoints
# create dictionary to store necessary information.
checkpoint = {
'epoch': epoch,
'state_dict': best_model_wts,
'optimizer': optimizer.state_dict() #saving this information is helpful when you are using adaptive optimzers.
}
f_path = exp_dir / f'checkpoint_e{epoch}.pt'
torch.save(checkpoint, f_path)
Restoring weights from Checkpoints.
#load checkpoint
checkpoint = torch.load(checkpoint_fpath)
#Load weights
model.load_state_dict(checkpoint['state_dict'])
#Load optimizer
optimizer.load_state_dict(checkpoint['optimizer'])