DEV Community

Daisuke Majima
Daisuke Majima

Posted on • Originally published at qiita.com

Resuming training in PyTorch: save and load the optimizer too

How to save/load a model in PyTorch and resume training from a checkpoint

Sometimes you want to pause training partway (for machine or human reasons) and resume later. In environments with a continuous-use time limit like Colab, or when you want to train beyond your initial epoch count, you want to save the model weights to a file and load them later to resume.

Saving the model alone won't resume from the same accuracy

Saving and loading a model goes like this — but this only works for plain inference in eval mode. If you try to resume training, you'll notice the loss and accuracy don't continue from before saving; they revert to initial values.

# save
torch.save(model.state_dict(), PATH)

# load
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()
Enter fullscreen mode Exit fullscreen mode

You should save the optimizer too

To resume training, you need to save/load the optimizer's state in addition to the model weights.

# save
save_path = "my_model_training_state.pt"
torch.save({'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,},
           save_path)

# load
model = TheModelClass(*args, **kwargs)
optimizer = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)

PATH = "my_model_training_state.pt"
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Move the optimizer state to the current device. Without this you can get a
# device mismatch between before and after saving.
for state in optimizer.state.values():
    for k, v in state.items():
        if isinstance(v, torch.Tensor):
            state[k] = v.to(device)
epoch = checkpoint['epoch']
loss = checkpoint['loss']

# model.eval()
# # - or -
model.train()

model = model.to(device)
criterion = nn.CrossEntropyLoss()
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
Enter fullscreen mode Exit fullscreen mode

Now you can resume from the loss and accuracy you had before saving.


Originally published in Japanese on Qiita. I build apps with Core ML and ARKit and write about ML/AR. GitHub / X

Top comments (0)