How to save/load a model in PyTorch and resume training from a checkpoint
Sometimes you want to pause training partway (for machine or human reasons) and resume later. In environments with a continuous-use time limit like Colab, or when you want to train beyond your initial epoch count, you want to save the model weights to a file and load them later to resume.
Saving the model alone won't resume from the same accuracy
Saving and loading a model goes like this — but this only works for plain inference in eval mode. If you try to resume training, you'll notice the loss and accuracy don't continue from before saving; they revert to initial values.
# save
torch.save(model.state_dict(), PATH)
# load
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()
You should save the optimizer too
To resume training, you need to save/load the optimizer's state in addition to the model weights.
# save
save_path = "my_model_training_state.pt"
torch.save({'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,},
save_path)
# load
model = TheModelClass(*args, **kwargs)
optimizer = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)
PATH = "my_model_training_state.pt"
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Move the optimizer state to the current device. Without this you can get a
# device mismatch between before and after saving.
for state in optimizer.state.values():
for k, v in state.items():
if isinstance(v, torch.Tensor):
state[k] = v.to(device)
epoch = checkpoint['epoch']
loss = checkpoint['loss']
# model.eval()
# # - or -
model.train()
model = model.to(device)
criterion = nn.CrossEntropyLoss()
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
Now you can resume from the loss and accuracy you had before saving.
Originally published in Japanese on Qiita. I build apps with Core ML and ARKit and write about ML/AR. GitHub / X

Top comments (0)