DEV Community

Quoc Bao
Quoc Bao

Posted on

6 2

Convert my Pytorch model to Pytorch Lightning

Hello, everybody! Today I am going to show how you how to convert my model from Pytorch to Pytorch Lightning. Pytorch Lightning is a light-weight deep learning framework built upon Pytorch. It removes a lot of boilerplate code (standard code that can be found in almost any deep learning pipeline) and adds in many functions that helps to interfere training at a specific position.

Firstly, I import the libraries.

pip install pytorch-lightning
import pytorch_lightning as pl
Enter fullscreen mode Exit fullscreen mode

Pytorch LightningModule resembles nn.Module. Forward function can be defined in a pl class.

# an nn class can be converted to a pl class by replacing nn with pl

class NeuralNet(nn.Module): 
# --> class NeuralNet(pl.LightningModule):
    def __init__(self, input_size, num_classes):
        super(NeuralNet, self).__init__()
        self.fc1 = nn.Linear(input_size, 50)
        self.fc2 = nn.Linear(50, num_classes)

# --> specific functions belong to nn class should not be changed!
    def forward(self, x):
        out = self.fc1(x)
        out = torch.sigmoid(out)
        out = self.fc2(out)
        return out
Enter fullscreen mode Exit fullscreen mode

Read more here.

Sentry image

See why 4M developers consider Sentry, “not bad.”

Fixing code doesn’t have to be the worst part of your day. Learn how Sentry can help.

Learn more

Top comments (1)

Collapse
 
amitkumarmalaker profile image
Amit Kumar Malaker

hey are you on discord?

👋 Kindness is contagious

Please leave a ❤️ or a friendly comment on this post if you found it helpful!

Okay