DEV Community

Maykeye
Maykeye

Posted on

BakaLLM, part 12, 1 step backward, 4 steps forward: starting new experiment

Hello fairy dairy diary!

Ice skating

I was not satisfied with (my implementation of) RMT so far. One of the "problem" with it was how robust 16-token-long RMT was. No matter what kind of things I tried to throw at it, it went worse. Even when I simply increased number of tokens to 24, it got way worse.

So while RMT experiment was the first one to put us below 4.00, it was thrown into the drawer box, locked and let stewn for a while.

I even took a break and played with calculating clipping on the fly (going from 1/2 to 2/e). But after pondering, simplest way forward was found.

Behold super-puper novel architecture Step4 that works in parallel with MLP and attention plays role of hyper-local attention.

Basic idea is to do linear transformation over several tokens. However to not deal with causal mask, we split text to very small chunks and produce the result of last token only which is built using predcessing N tokens.

The implementation is simple. Assume we have input tokens to the layer

X,Z,A,B,C,D,E,F,G,H,I,J,K,L

We rearrange these tokens as "wide" tokens by concatenating 4 tokens at once at feature dimension (throwing tokens from left if can't combine 4 tokens exactly)

ABCD,EFGH,IJKL - 3 tokens left

Now we use the same Linear layer on each token that maps 4 * dim_model to dim_model.

ABCD->D',EFGH->H',IJKL->L'

Then we rearrange input but this time 4 tokens are created as new dimension:

Then we reshape it back to 0 0 0 D' 0 0 0 H' 0 0 0 L', pad whatever we thrown away 0 0 0 0 0 D' 0 0 0 H' 0 0 0 L' and return it (it will be added to residual)

It's like linear attention in range of 4 tokens! But no softmax, Linear only: each 4th token is build using tokens 1,2,3,4!

Comapring to RMT-16 (graph shows (Loss(RMT16) - Loss)):

Blue=RMT

It started well, dipped down, then became OKaish ok. Which is already better than previous experiments.

I also tried to do it outside of layers, and do it once before. But it was bad. After 1 epoch results on valid split are promising: loss is 4.4529, which is lower than 4.4648 of RMT. Granted RMT had only 190M parameters, while this has 218M. But it can be optimised by sharing weights in the middle, for example.

Experiment forward are straightforward: train 2 more epochs. Then "shift" inputs so we return not 25% filled but practically all filled but first 3 tokens (after processing X,Z,A,B,C,D,E,F,G,H,I,J,K,L, process X,Z,A,B,C,D,E,F,G,H,I,J,K, etc)
and share weights.
Maybe go full MLP mixer and mix 768 times over 4 features. Maybe compress. Well, anyway, new record after epoch 1 is set.

Chill!

Top comments (0)