DEV Community

Maykeye
Maykeye

Posted on

BakaLLM, part 2

Hello, fairy dairy diary!

In today's episode of BakaLLM we'll discuss the current internals of the BakaLLM. It tries to follow OOP architecture of transformers but doesn't inherit from their classes so far. Because why not.

Cirno

1) Start of the journey: BakaNetForCausalLM

CausalLM

Based on AutoModelForCausalLM, it calls the base model, normalizes the output and then maps hidden dimension to vocab size. (BakaLLM-proper uses 32000 for vocab, borrowed from open llama). Following steps of big sis models, it also handles calculating loss there, but the main idea is to calculate the output.

2) The core: BakaNet

BakaNet archtecture

This is the main class. First it translates token ids into token embeddings. embed_id is a learnable nn.Embedding layer. I read several times they are supposed to be frozen, but cutting them from open llama is too tiresome, so they are trained from scratch at the same time.

Then the input sequence is split into chunks, each has W tokens. Each chunk is passed to memory RMT to get the memory from the previous chunk(if any), here it's called memory injector. Memory injector injects M tokens at the front and the end of the chunk. Detach returns M tokens at the end that will be attached into the start of the next chunk in the loop. If no past chunk is present, learnable tokens are inserted.

Then each of the layers is getting called, where they predict both outputs and memory for the future.

In the end result of chunks is concatenated.

Now, there is an important distinction from RMT. In original RMT gradient never stops, as they have lots of VRAM.

There are no breaks on RMT train-ing

My poor laptop can barely run 3 batches at once, so I remove gradient each second chunk. I.e. memory of chunk2 backprops to chunk1, memory of chunk4 backprops to chunk3, but chunk3 doesn't send gradient to chunk2, breaking the chain of VRAM oppression.

Also when we call layer, each layer gets cached result of the previous chunk.

3.1) Memory injector/attachement

Memory injector (attachement)

Here we add M memory tokens to the front and the end. For the memory itself baka actually uses M-1 tokens. Special twist. Secret sauce! Like BERT, BakaLLM uses special [SEP] tokens to separate memories from the main stream of tokens, they are not the same after all, so it feels good to have some separation. I don't remember if it's actually good numerically or I simply added them for the fun of it and forgot to measure.

3.2) Memory injector/detachment

Memory injector, detachment

This is actually more complicated.
I don't believe that last layer outputs are the same as first layer inputs(maybe they should be if we were to train embed_out as embed_in.mT), therefore there is a special layer, called "memory solidifier". It has 2 properties: a) it's not causal, each token can see each other token. There is no reason why memory token 2 should see memory token 1 but not memory token 3. b) Internally it calculates data for last M-1 tokens (excluding SEP). The role of the layer is to take the output of the last layer and massage it into the form acceptable by 1st layer. Since it affects only last ~30 tokens, it's not that slow.

It returns two values: state without memory separator and tokens and output memory.

4) BakaBlock. The layer. The spaghetto!

Block is so baka

Single layer produces output based on four parts: residual connection (this one is easy, yay!), MEGA, MLP, and self-attention.

Twists:

  • MLP and Mega are gated. Mega output is passed through FC layer, its result passed to sigmoid, result of sigmoid is multiplied. Think silu on steroids. So if silu is Y=X * Sigmoid(X), this one is Y=X * SIGMOID(XW). This produced better results.
  • MLP is gated by self attention output rather than itself. On tiny stories dataset it produced better results.
  • This is embarrassing, but I mixed Q&K and V. Mega supposed to pass Q and K through EMA, not V. Oops.

5) MegaLayer.

No graphviz picture for you

I copied lucidrains's implementation, changed it to have no heads(as bakallm doesn't use multiple heads), and that's it.
The only change I've added is cache usage, so instead of T tokens, we process 2T (and here T includes memory tokens both prefixes and suffixes - as in present, so in past)

6) MLP

MLP

Usual MLP you can find in llama. For act, ReLU is used. I read it's somewhat worse in NLP tasks than modern SiLU, but I already have enough of gates, and ReLU is such neato to backprop, I couldn't resist.

7) Attention to small values

Attention scheme

This one is doozy.

First of all, it uses norm like Persimmon. Secondly, it also uses 1-q and 1-k concatenated to q,k.

Image description

Consider the Q=[1,0,0,0,0]. K1=[1,1,1,1,1], K2=[1,0,0,0,0]. If you ask self-attention, then Q is as similar to K1 as it is similar to K2. Say no to disregarding small values!

Shinmyoumaru Sukuna from Touhou wiki

Conspiracy against small numbers will be no more! Concatenating with n-Q where N is learnable scalar improves training quite a bit, as now such vectors will be treated differently!
(For the record I tested it in small gpt neox by concatenating it right before matmul, but it didn't work, or well).

Next, we have only 1 head, under influence of MEGA. They stated that gating output is as good as having several heads, and you don't have to perform O(N*N) calculation several times once for head(I'm not sure if optimized kernels do it, but one head sounds better than several)

It's also the reason flash attention can't be used. It's limited in size to 256 head_dim. We have 768.
A is number of values to attend - reminder that in memory solidifier we update only memory. (I actually in the code take first (T-A) and concatenate to keep T tokens, but I was too lazy to draw it and it's not important)

And with that BakaLLM architecture is described! I promised you need barf buckets. Well, I will fix EMA->attention inputs at some later time. Next time we'll talk about evaluation and current results. They are not fantastic. The world can sleep safe, for now.

Mandatory nap is mandatory after all!

Top comments (0)