class TinyTransformer(nn.Module):
def __init__(self):
super().__init__()
# setting the constructor for the initial values that we are every gonna need for the training of the data
self.char_embedding = nn.Embedding(65, 64)
self.pos_embedding = nn.Embedding(64, 64)
self.query = nn.Linear(64, 64)
self.key = nn.Linear(64, 64)
self.value = nn.Linear(64, 64)
self.mask = torch.tril(torch.ones(64, 64))
# these are for changing the dimensions we are doing this to enlarge the matrix as to make it of higher resolution so as to make the
# data and weights more refined
self.ff1 = nn.Linear(64, 128)
# this is to join them back again
self.ff2 = nn.Linear(128, 64)
self.output_head = nn.Linear(64, 65)
self.norm1 = nn.LayerNorm(64)
self.norm2 = nn.LayerNorm(64)
self.out_proj = nn.Linear(64, 64)
def forward(self, x):
# feed forward function
x = self.char_embedding(x) + self.pos_embedding(torch.arange(64))
# this is the start of the attention stuff i am writing this as a way to separate the code in section inside a functions
#
Q = self.query(x)
Q = Q.view(32, 64, 2, 32)
Q = Q.transpose(1, 2)
K = self.key(x)
K = K.view(32, 64, 2, 32)
K = K.transpose(1, 2)
V = self.value(x)
V = V.view(32, 64, 2, 32)
V = V.transpose(1, 2)
A = (Q @ K.transpose(-2, -1)) / 32**0.5
A = A.masked_fill(self.mask == 0, float("-inf"))
At = A.softmax(dim=-1)
# the -1 this is just to tell the
output = At @ V
output = output.transpose(1, 2).contiguous().view(32, 64, 64)
output = self.out_proj(output)
# this is where the attention ends and we start with the feed forward thing that will give us the predictions
# added another form of normalization below to improve accuracy the first time the loss function reached 1.8 max now after adding the
# below line it reached to like 1.5 something
x = x + output
x = self.norm1(x)
output = self.ff1(x)
output = torch.relu(output)
output = self.ff2(output)
x = x + output # ← merge back into main flow
x = self.norm2(x)
x = self.output_head(x)
return x
this code is basically boilerplate at this point for training a transformer to anyone in the ai space .
i just want to understand one little line that is here and that has a history behind it that is really interesting .
x = x + output
why are we doing this - X=X+output ?
the neural networks learn through the process of back propagation which basically means that they are essentially looking for the change that moves us closer to the correct predictions by changing the filters that is it now there is a particular problem with this and that is that as we move from one layer to the other the gradient becomes smaller and smaller and this is huge cause the computation would also become harder and harder and more computationally expensive . this happens basically due to the chain rule of the the partial derivative .. but how does this thing solve that ?
*History of this problem *
read this paper here --
this paper is done by the microsoft research team and this is basically about how they solved the problem that more is not always better . in the case of training a deep learning models before this paper the more depth the model had i.e the layers the more error it produced too and that is a huge problem and people didnt know how to solve it cause on one hand you had the depth of better understanding and on the other hand you were having this problem of getting more errors too .
solution
now we might think that the solution is to just add the original embeddings vector ( for my case ) to the context matrix we got after all the computations and you would be right to think that but not for the reason that you might think here in the paper itself it says that its not the reason for this problem .
We argue that this optimization difficulty is unlikely to
be caused by vanishing gradients.
why ? - the reason for removing our suspicion from the diminishing gradient is because there are stuff done to minimize and stop the diminishing gradient problem these are done with the help of stuff like batch normalization and in this case ReLu here are the ways we do it in out code -
x = self.norm1(x) # the batch normalization equivalent in transformers
output = self.ff1(x)
output = torch.relu(output) # another way to solve the vanishing gradient problem
output = self.ff2(output)
x = x + output # ← merge back into main flow
x = self.norm2(x)
x = self.output_head(x)
as you can see this that these does solve the problem of vanishing gradient and yet if we remove the x=x+output the result would be worse you know what lets try it alright --
this is when we do this normally and dont change anything now lets change one thing and that is we remove the line x=x+output that is it and see how it affects the loss function .
so the loss function jumped from 1.70 to 2.47 by just this one line and it might not seem a lot but , remember that this is just a 1 layer model for simplicity and more layers we add the more we move up in the errors too . to solidify my point i want to show the gradient that live by making some of the small adjustments here -
step 44500, loss: 2.5130
char_embedding.weight grad_norm: 0.006248
pos_embedding.weight grad_norm: 0.005838
ff1.weight grad_norm: 0.024721
ff2.weight grad_norm: 0.053932
output_head.weight grad_norm: 0.163109
norm1.weight grad_norm: 0.007271
norm2.weight grad_norm: 0.024594
step 45000, loss: 2.4751
char_embedding.weight grad_norm: 0.005574
pos_embedding.weight grad_norm: 0.005913
ff1.weight grad_norm: 0.023506
ff2.weight grad_norm: 0.056331
output_head.weight grad_norm: 0.161182
norm1.weight grad_norm: 0.007898
norm2.weight grad_norm: 0.020992
step 45500, loss: 2.4623
char_embedding.weight grad_norm: 0.006224
pos_embedding.weight grad_norm: 0.006075
ff1.weight grad_norm: 0.025461
ff2.weight grad_norm: 0.051210
output_head.weight grad_norm: 0.145062
norm1.weight grad_norm: 0.008452
norm2.weight grad_norm: 0.018521
step 46000, loss: 2.4764
char_embedding.weight grad_norm: 0.006709
pos_embedding.weight grad_norm: 0.006148
ff1.weight grad_norm: 0.026940
ff2.weight grad_norm: 0.057071
output_head.weight grad_norm: 0.163159
norm1.weight grad_norm: 0.008988
norm2.weight grad_norm: 0.025112
step 46500, loss: 2.4746
char_embedding.weight grad_norm: 0.006127
pos_embedding.weight grad_norm: 0.006181
ff1.weight grad_norm: 0.025931
ff2.weight grad_norm: 0.056799
output_head.weight grad_norm: 0.158272
norm1.weight grad_norm: 0.008369
norm2.weight grad_norm: 0.025981
so here is the thing that i was saying even though it looks like it solves the diminishing gradient but in fact it doesnt at all .
the true thing that it does is something way more interesting .-
every layer and non linear does some changes and these change compound fast like really fast and for like 20 layers it might work cause even though its a large number of layers the complexity further shoots when we go from this to something like 50 layers and these "small" changes may change the values a lot even though these changes themselves are very very small and the values that is creates after maybe completely different from the original like way to different . here is an example -
5, 3, 8 --->5, 3, 8--->0.3, 0.01, 0.2 and notice something that these are not due to something like diminishing gradient at all these are due to the small changes that we do in between and so what would happen if we add the original in this ? -
[5, 3, 8] + [0.3, 0.01, 0.2] = [5.3, 3.01, 8.2]
so this resultant one is very close to the original right ? that is the main idea of the residual and there are many residual algorithms too but for simplicity we are gonna just stick with the good old addition and frankly its better this way .


Top comments (0)