DEV Community

Shashank-Holla
Shashank-Holla

Posted on • Updated on

Call for your attention! Do you remember LSTM?

This is my attempt to jot down idea behind Recurrent Neural network variants such as LSTM and GRU and the thought process behind attention mechanism.

Idea behind RNN

Deep Learning has provided many solutions to interesting problems in computer vision, speech and audio domains. One such weapon from Deep Learning's arsenal to tackle sequence data such as text, audio is Recurrent neural networks. Recurrent Neural network allows to extract gist behind texts of data (sentiment analysis), annotate sequences (image captioning) or even generate new sequences (language translations). Recurrent network posesses ability to persist information that allows us to operate on sequence of input to produce sequence of output vectors.

Need for LSTMS, GRUs

The persistence of information in RNN is made possible by loops in neural network. When RNN makes its prediction, it considers the current input as well as the learnings of the past inputs. To decipher this, RNN can be viewed as unrolled version of the same network, with learning from each time step passed to the next.

RNN

RNN network performs really well for problems with short contexts. But, as contexts gets longer and as the time steps increases, the network fails to carry forward the learnings from the initial stage. Hence, neural network suffers from Short term memory caused by vanishing gradients. Remember that the network learns by calculating gradients and adjusting internal weights. But, when the calculated gradients become quite small, hence vanishing gradients, the network fails to learn anything meaningful. As an example, consider a long sentence that says -

The General's method is to have his troops ready by dawn.

Our recurrent network might find it hard to decide the context of the word General and if its used as an adjective/noun in our sentence. LSTM and GRU tries to solve these problems.

Talking about LSTM

The core idea behind LSTM (Long Short Term memory) is the cell state, a pathway that transfers relative information all the way down the sequence chain. Its working is somewhat similar to the skip connection that is used in Resnet models. Its very easy for information to travel along unchanged through the sequence. LSTM provides means to add or remove information to the cell state by structures called gates. LSTM has three such gates- input gate, forget gate and output gate to learn over time what information is important. Cell state and hidden state from previous timestep and input are processed by these gates.

LSTM

The first contributor in LSTM is the forget gate. It's role is to decide the part of the cell state's information that needs to be thrown away or kept. Sigmoid function is used in these gate and it acts like a selector/controller to selectively remove few features from the embedding vector. In our example, forget gate looks at the previous hidden state h(t-1) and input x(t) and would decide to forget the adjective sense of the word 'general'.

The second contributor is the input gate which decides the part of the input's information that needs to be added to the cell state. Here, a sigmoid function selects the part of the input that is to be updated. Next, a tanh function creates new candidate vector from the inputs. By tanh characteristics, the new vectors are squished within the range -1 to 1 which regulates the network and prevents possible gradient explosions. By elementwise multiplication, the sigmoid layer decides which of the features of the new candidate vector is important. In our 'General' example, the model might decide the word is being used as a noun and add this part of speech to the cell state.
With forget gate deciding the part of previous cell state to be forgotten and input gate deciding the part of the input to be added, the new cell state is calculated.

The last contributor is the output gate which now decides the next hidden state. The output gate provides the filtered version of the new cell state. It again uses the sigmoid function to decide the important features of the cell state and tanh function to squash the new cell state between the desired range. In the general example, this would amount to the model deciding adjective as the right part of speech to assist in further timesteps.

Talk about GRU

Gated Recurrent Network (GRU) is similar to LSTM but with few modifications. GRU's doesn't have cell state and uses only the hidden state to transfer information. It uses two gates- Update gate and Reset gate rather than the 3 gates that were used in LSTM.

GRU

The first gate is Update gate which acts similar to the forget gate and input gate of LSTM. It helps the model determine how much of past information needs to be passed along.

The second gate is the Reset gate which decides how much of the past information to forget.

Differences and which is better?

GRU doesn't possess internal memory (hidden state) compared to LSTM. GRU also has 3 internal neural networks compared to LSTM which has 4 such networks. Hence, GRU has fewer tensor operations and are little speedier to train than LSTM. However, neither of the model outweighs the other. In cases of training on less data or on need for speed, GRU can be considered. While in cases of longer sequences and to maintain long distance contexts LSTM might be preferred.

Attention

For problems such as machine translation which requires many to many mapping, encoder-decoder based RNN models are used. In traditional RNN encoder-decoder architecture, the encoder provides a single context vector to the decoder. The problem with this approach is that it stresses and requires the encoder to provide a utopic vector that provides all the information that the decoder requires to decipher, which is a complex task for both.

Attention in RNN is a mechanism that provides insight to the decoder to focus on certain part of the input sequence to predict certain part of the output. This part of the post focuses on the idea behind attention and how it tries to solve the above problem.

attentionmechanism

With the attention mechanism model, all the hidden states from the encoder are offered for decoder's consideration. In our example, hidden vectors are [h1, h2, h3, h4] while si-1 is the state of the decoder. For the current timestep calculation, the previous decoder state is considered. Attention mechanism begins with attention weights calculation. si-1 (previous decoder state) is concatenated with the hidden state vectors and fed to a shallow fully connected layer. Later, softmax function is applied to the output of fc which ensures the outputs sums up to 1. Now the hidden state vectors [h1, h2, h3, h4] are scaled by the attention weights to produce the context vectors. Due to softmax nature of the output, Context vectors captures the degree of relevance of the hidden vector. If the score is close to 1, then the decoder is heavily influenced by that particular hidden state. With this approach, encoder is relieved of the burden to encode all information into single hidden vector and provides the decoder with greater context to predict.

Top comments (0)