DEV Community

Maykeye
Maykeye

Posted on

BakaLLM Part 17.5. Sssnake faster! Yet another milestone achieved

Hello fairy dairy diary~

Today was the day. Well, night. It's 3am to be precise, but it is not important. What important Pythia 70m is defeated!

Cirno holds a weight

Was it change in architecture? Yes, but insignificant! Was it change in reducing memory and different training routine? Yes.

The results of new training method are awesome

n_ctx loss ppl note
1048576 3.47578120231628 32.3230695329532 after 2 epoch ber 8batch (2h56m)
2048 3.55905652046204 35.130037033514 Vanila pythia 70m-deduped
1048576 3.66562509536743 39.0805576049546 after 14 epoch, wide MLP
1048576 3.66744780540466 39.1518550871408 after 15 epoch, wide MLP
1048576 3.67499995231628 39.448654976142 after 12 epoch, wide MLP
1048576 3.67812490463257 39.5721229571853 after 13 epoch, wide MLP
1048576 3.68697905540466 39.9240562362004 after 11 epoch, wide MLP
1048576 3.69973969459534 40.436777078352 after 10 epoch, wide MLP
1048576 3.7109375 40.892124929554 after 9 epoch, wide MLP
1048576 3.73046875 41.6986498257003 after 8 epoch, wide MLP
1048576 3.74401044845581 42.2671609833295 after 7 epoch, wide MLP
1048576 3.76718759536743 43.2582339397375 after 6 epoch, wide MLP
1048576 3.77265620231628 43.4954442322217 after 1 epoch ber 8batch (3h8m)
1048576 3.77838540077209 43.7453534703837 After 8 epoch rmt-copy-staged
1048576 3.79479169845581 44.4689724850806 after 5 epoch, wide MLP
1048576 3.79869794845581 44.6430191223772 After 7 epoch rmt-copy-staged
1048576 3.82682299613953 45.9164295891796 after 4 epoch, wide MLP
1048576 3.83984375 46.5182054023163 After 6 epoch rmt-copy-staged
1048576 3.85807299613953 47.3739735245328 After 10 epoch, staged post-training
1048576 3.87864589691162 48.3586880673365 after 3 epoch ber 5batch+mamba w. 8 layers

New results in bold.

Changes were made. Current experiment are done in 00x_celp18 branch, which is a short for "compiled elephant a=1 d=8".

Zeroth change. Mamboization.

I continued dancing with snakes (I think I covered start in part17 already). But if originally I used mamba each 2 layers over network of 12 layers. I changed it 8 layers, all with mamba. It improved results of E3 to 3.8786.

First change. Optimization.

As name of the branch celp18 suggests, torch.compile was used to optimize Elephant activation. I also removed a and inlined d so optimizer can work better. torch.compile did wonders: it worked as fast as manually created triton kernel to copy tensor. And it gave backward for free. Yes, I thought of making custom triton function for elephant, which is why branch came to existence, but measurement shown that it's not needed. Thank goodness. No need to write CUDA.

But it became not just faster, it also saved a lot of memory. Apparently 1/((1+(x/a).abs())**d) ate lots of VRAM. I don't remember exactly, but I managed to put in additional batch just by putting an elephant on the diet.

Second change. Batch. Batch never change

I changed the training to be more effective.

So, previously a batch of BS articles was split into chunks of training_ctx tokens. Single input of model was a tensor of shape [BS, training_ctx] with tokens for the current chunks. If one document from batch ran out of chunks before the rest, it was removed from the batch, so next input was [BS-1, training_ctx] and so on. For the most part it meant that batch started at 5 batches, but soon short batches were done and half of the training step was [1, training_ctx] as it was often that one article in Wiki103 is much longer than others. I thought that it was not that big deal.

It was.

It was the biggest deal! As huge as the world itself! Fixing it improved the model significantly!

Now each article is still split into chunks of training_ctx tokens.
But. Once we run out of chunk for particular article, we not just remove it, but add chunks from the next article.

So if we have articles A1,A2,A3,A4 with chunks A1C1 and A2 has just one chunk:

❄ Training step 1: [A1C1, A1C2], [A2C1, padding], [A3C1, A3C2]

❄ Training step 2: [A1C3, A1C4], [A3C3, A3C4], [A4C1, A4C2]

Previously we would wait until A1 and A2 run out of chunks before introudcing A4. Not anymore! The world can not wait.
The world won't wait. No. Yes. Probably.

This change required significant rewrite of model code as previously it assumed that all batches have equal RMT cache and KV cache, so now I have to process first chunk of the input by parts as not every batch has KV cache or RMT.

2nd chunks and later: no separation is necessary. KV is filled, RMT is calculated.

In some places code became more complex, in other more simpler. At some point I completely removed support of code without flash-attn2.

Oh well, tomorrow(today after sleep) I'll end E3 and push it as mainline version. Probably will push weights to HF as well as this is significant improvement and beathing pythia-70 matters a lot. I'll also need new to choose model to overtake. Cerebras 111M? Some tiny version of GPT2? Oh well, so many variants to choose from.

But as of now it's 4 am, so chill out~!

Top comments (0)