DEV Community

Cover image for Building Word Embeddings with PyTorch and Lightning AI Part 2: Creating Labels for Next-Word Prediction
Rijul Rajesh
Rijul Rajesh

Posted on

Building Word Embeddings with PyTorch and Lightning AI Part 2: Creating Labels for Next-Word Prediction

In the previous article, we created the training inputs for our word embedding model using one-hot encoding. In this article, we will define the labels and prepare the data for training

Let's continue by setting up the labels.

Our goal is to predict the next token given the current token.

Consider the sentence:

The Incredibles is great

If the current token is "The Incredibles", the next token is "is".

So, the correct label is the one-hot encoding for "is".

Next, if the current token is "is", the model should predict the one-hot encoding for "great".

After "great", the sentence ends. To continue training with the next sentence, we set the next token to the one-hot encoding for "Despicable Me".

Now let's convert these labels into a PyTorch tensor.

labels = torch.tensor([
    [0., 1., 0., 0.],
    [0., 0., 1., 0.],
    [0., 0., 0., 1.],
    [0., 1., 0., 0.]
])
Enter fullscreen mode Exit fullscreen mode

At this point, we have finished encoding our training data.

Next, we combine the inputs and labels into a TensorDataset, and then use that dataset to create a DataLoader.

dataset = TensorDataset(inputs, labels)
dataloader = DataLoader(dataset)
Enter fullscreen mode Exit fullscreen mode

Creating the Neural Network

Now let's start implementing the word embedding model.

class WordEmbedding(L.LightningModule):
    def __init__(self):
        # Initialize the weight tensors for the embedding network and loss function

    def forward(self):
        # Make a forward pass through the embedding network

    def configure_optimizers(self):
        # Configure the Adam optimizer

    def training_step(self, batch, batch_idx):
        # Calculate the loss (Cross Entropy Loss)
Enter fullscreen mode Exit fullscreen mode

In the next article, we will begin implementing each of these methods one by one.

AI agents write code fast. They also silently remove logic, change behavior, and introduce bugs -- without telling you. You often find out in production.

git-lrc fixes this. It hooks into git commit and reviews every diff before it lands. 60-second setup. Completely free.

Any feedback or contributors are welcome! It's online, source-available, and ready for anyone to use.

Give it a ⭐ star on Github

Top comments (0)