DEV Community

Maykeye
Maykeye

Posted on

BakaLLM, part (9). The epoch of Stalling has begun!

(Topic should include ⑨ but the site cuts it out)
Hello, fairy dairy diary~

I'd say good day, but it's a midnight in our comfy Abyss and we ran out of good in the good department as well!

Less than happy Cirno

Lots of experiments were done this week. Most of them ended in failure, some more failure than rest. Some were funny, if you consider squiggly lines to be a peak(pun intended) of the comedy:

RMT

2 RMTs graph

Two lines that almost exact are baseline RMT implementation and current top of the mainline: pause. And one funny line where I used GLU as memory solidifier(in prototype I used whole layer which worked wonders, I wanted to check if I can go with simpler approach). So while baseline worked, it didn't improve that much. I had to reduce number of batches to 3. This is the least number of batches I'm comfortable with.

The least failed experiment was Anti Small-Value-Conspiracy which I did back in prototype of BakaLLM. First epoch was good, third no. In the end it got 4.04 loss on validation test while previous best was 4.02. However since A-SVC is tweakable, it's not considered lost cause. It also the only experiment which lived 3 epochs this week.

BRT

Now, prototyping Block-Recurrent Transformer-inspired layer.
Original is looking this way:

Block-Recurrent Transformer

Well, it was implemented loosely for better or worse.
I moved everything to a new layer, it had two attentions: cache_reader that used (cache + hidden_state) to produce output with then added to MLP and XL-attention, and cache_writer that used the same input and produced the new cache.
This interpretation of original architecture is too loose: while both have sharing of key/values, mine is worse: same K,V projections are used for both cache and input, which is not good. Oh well, I fix it later. In this BRT-Loose I also had this funny graph:

NaNs, NaNs everywhere

It has not-a-number. Lots of it. I remember that I was told if you use bfloat in SDPA, it may not be stable(which is why HF's transformer cast it to f32 then back). Made so my code also would use f32... And I got Out-of-memory. On batch-size 2 it was working, but ETA was so big, my brain cells couldn't comprehend such number.
I managed to reduce ETA and number of parameters by decreasing number of heads and dim_attn, but result were still not good enough. I need to add proper sharing to improve it further.

TL

Thin layers(additional layer - self-attention only) also were trained and tried. However they made everything worse. After watching latest hu-po video, insight was got that first and last layers are the most important. I've inserted TT in the very beginning, which might explain why everything went astray as there were no gentle massage from Mr Elephant to tell each token how awesome they are. Experiment can be repeated placing them in the middle.

Split LN

One difference between GPT-NeoX and StableLM-alpha-2 is that StableLM uses one layer norm per layer: it passes the same normalized values to both attention and MLP. GPT-NeoX uses different norms to MLP and layer norm. At this point layer norm is basically the only layer where model can apply bias. In linear layers they are disabled.
Well, using GPT-NeoX approach didn't improve the situation, made it worse. Not much worse, but when improvement is expected and epoch 1 was already worse, further epochs were not trained.

Plans to-do

  • BRT is yet to be given chance to shine, as in its current form it's borked
  • Memory transformers
  • XL compressed. I remember reading that people compressed XL. I don't remember if it was compressed in a way "for each non-overlapping sequential pair of input tokens, produce one token worth of data" (which can be done using linear layer of form (2 x dim model-> dim_model)) or is it was "reduce number of params per token" (dim_model->dim_model//8). Both variants can be trained. Idea is to go even further than storing 512 raw history. I want to have 2K context+, where 512 is current chunk, 512 is raw past K,V, 1024 is compressed tokens which represent even more data.
  • Thin layers in the middle.

While progress slowed down significantly from "I can't wait when calculations are finished to see better results" to "I can't wait when calculations are finished to see if results are better", there are lots of other retention mechanisms as well.

The elephants will prevail!

Top comments (0)