360. Saving Checkpoint During Training

Implementation

Saving checkpoints for model weights during training can be helpful in the case of the following examples.

  1. Want to resume training later
  2. Avoid losing weight data when the process stops during training due to some kind of error.
  3. Restore weights depending on the number of epochs.

Here is one way to implement using Pytorch

Saving Checkpoints

# create dictionary to store necessary information.
checkpoint = {
        'epoch': epoch,
        'state_dict': best_model_wts,
        'optimizer': optimizer.state_dict() #saving this information is helpful when you are using adaptive optimzers.
}
f_path = exp_dir / f'checkpoint_e{epoch}.pt'
torch.save(checkpoint, f_path)

Restoring weights from Checkpoints.

#load checkpoint
checkpoint = torch.load(checkpoint_fpath)

#Load weights
model.load_state_dict(checkpoint['state_dict'])

#Load optimizer
optimizer.load_state_dict(checkpoint['optimizer'])