DEV Community

Cover image for Building LSTMs with PyTorch and Lightning AI Part 2: Starting the LSTM Unit Implementation
Rijul Rajesh
Rijul Rajesh

Posted on

Building LSTMs with PyTorch and Lightning AI Part 2: Starting the LSTM Unit Implementation

In the previous article, we began building the LSTM by defining the class and initializing the weights and biases.

In this article, we will continue by implementing the lstm_unit() function.

The lstm_unit() function requires three inputs:

  • The current input value
  • The current long-term memory value
  • The current short-term memory value
def lstm_unit(self,input_value,long_memory,short_memory):
Enter fullscreen mode Exit fullscreen mode

Stage 1: Determine How Much Long-Term Memory to Remember

We use the parameters defined in the __init__() method to determine what percentage of the existing long-term memory should be retained.

For this, we create long_remember_percent:

def lstm_unit(self,input_value,long_memory,short_memory):
    long_remember_percent = torch.sigmoid(
        (short_memory * self.wlr1) +
        (input_value * self.wlr2) +
        self.blr1
    )
Enter fullscreen mode Exit fullscreen mode

Here:

  • The short-term memory is multiplied by its associated weight.
  • The input value is multiplied by its associated weight.
  • Both results are added together along with the bias.
  • The final sum is passed through the sigmoid activation function.

The sigmoid function produces a value between 0 and 1, which represents the percentage of long-term memory that should be remembered.

Stage 2: Determine Potential Long-Term Memory

Next, we calculate the percentage of new information that could potentially be added to long-term memory.

For this, we create potential_remember_percent:

potential_remember_percent = torch.sigmoid(
    (short_memory * self.wpr1) +
    (input_value * self.wpr2) +
    self.bpr1
)
Enter fullscreen mode Exit fullscreen mode

This uses calculations similar to the previous step.


Stage 2: Calculate Potential Memory

We also need to calculate the candidate memory value itself.

For this, we use the tanh activation function:

potential_memory = torch.tanh(
    (short_memory * self.wp1) +
    (input_value * self.wp2) +
    self.bp1
)
Enter fullscreen mode Exit fullscreen mode

The tanh function produces values between -1 and 1, allowing the LSTM to store both positive and negative information.


Updating the Long-Term Memory

Now we can update the long-term memory.

We do this by:

  1. Keeping the percentage of the existing long-term memory that should be remembered.
  2. Adding the percentage of the new potential memory that should be stored.
updated_long_memory = (
    (long_memory * long_remember_percent) +
    (potential_remember_percent * potential_memory)
)
Enter fullscreen mode Exit fullscreen mode

This gives us the updated long-term memory value for the LSTM.

In the next article, we will continue with the third stage of the LSTM, where we create the updated short-term memory and determine what percentage of the long-term memory should be sent to the output.

AI agents write code fast. They also silently remove logic, change behavior, and introduce bugs -- without telling you. You often find out in production.

git-lrc fixes this. It hooks into git commit and reviews every diff before it lands. 60-second setup. Completely free.

Any feedback or contributors are welcome! It's online, source-available, and ready for anyone to use.

Give it a ⭐ star on Github

Top comments (0)