Hello fairy dairy diary~
Today was the day. Well, night. It's 3am to be precise, but it is not important. What important Pythia 70m is defeated!
Was it change in architecture? Yes, but insignificant! Was it change in reducing memory and different training routine? Yes.
The results of new training method are awesome
n_ctx | loss | ppl | note |
---|---|---|---|
1048576 | 3.47578120231628 | 32.3230695329532 | after 2 epoch ber 8batch (2h56m) |
2048 | 3.55905652046204 | 35.130037033514 | Vanila pythia 70m-deduped |
1048576 | 3.66562509536743 | 39.0805576049546 | after 14 epoch, wide MLP |
1048576 | 3.66744780540466 | 39.1518550871408 | after 15 epoch, wide MLP |
1048576 | 3.67499995231628 | 39.448654976142 | after 12 epoch, wide MLP |
1048576 | 3.67812490463257 | 39.5721229571853 | after 13 epoch, wide MLP |
1048576 | 3.68697905540466 | 39.9240562362004 | after 11 epoch, wide MLP |
1048576 | 3.69973969459534 | 40.436777078352 | after 10 epoch, wide MLP |
1048576 | 3.7109375 | 40.892124929554 | after 9 epoch, wide MLP |
1048576 | 3.73046875 | 41.6986498257003 | after 8 epoch, wide MLP |
1048576 | 3.74401044845581 | 42.2671609833295 | after 7 epoch, wide MLP |
1048576 | 3.76718759536743 | 43.2582339397375 | after 6 epoch, wide MLP |
1048576 | 3.77265620231628 | 43.4954442322217 | after 1 epoch ber 8batch (3h8m) |
1048576 | 3.77838540077209 | 43.7453534703837 | After 8 epoch rmt-copy-staged |
1048576 | 3.79479169845581 | 44.4689724850806 | after 5 epoch, wide MLP |
1048576 | 3.79869794845581 | 44.6430191223772 | After 7 epoch rmt-copy-staged |
1048576 | 3.82682299613953 | 45.9164295891796 | after 4 epoch, wide MLP |
1048576 | 3.83984375 | 46.5182054023163 | After 6 epoch rmt-copy-staged |
1048576 | 3.85807299613953 | 47.3739735245328 | After 10 epoch, staged post-training |
1048576 | 3.87864589691162 | 48.3586880673365 | after 3 epoch ber 5batch+mamba w. 8 layers |
New results in bold.
Changes were made. Current experiment are done in 00x_celp18
branch, which is a short for "compiled elephant a=1 d=8".
Zeroth change. Mamboization.
I continued dancing with snakes (I think I covered start in part17 already). But if originally I used mamba each 2 layers over network of 12 layers. I changed it 8 layers, all with mamba. It improved results of E3 to 3.8786.
First change. Optimization.
As name of the branch celp18
suggests, torch.compile
was used to optimize Elephant activation. I also removed a
and inlined d
so optimizer can work better. torch.compile
did wonders: it worked as fast as manually created triton kernel to copy tensor. And it gave backward
for free. Yes, I thought of making custom triton function for elephant, which is why branch came to existence, but measurement shown that it's not needed. Thank goodness. No need to write CUDA.
But it became not just faster, it also saved a lot of memory. Apparently 1/((1+(x/a).abs())**d)
ate lots of VRAM. I don't remember exactly, but I managed to put in additional batch just by putting an elephant on the diet.
Second change. Batch. Batch never change
I changed the training to be more effective.
So, previously a batch of BS
articles was split into chunks of training_ctx
tokens. Single input of model was a tensor of shape [BS, training_ctx]
with tokens for the current chunks. If one document from batch ran out of chunks before the rest, it was removed from the batch, so next input was [BS-1, training_ctx]
and so on. For the most part it meant that batch started at 5 batches, but soon short batches were done and half of the training step was [1, training_ctx]
as it was often that one article in Wiki103 is much longer than others. I thought that it was not that big deal.
It was.
It was the biggest deal! As huge as the world itself! Fixing it improved the model significantly!
Now each article is still split into chunks of training_ctx
tokens.
But. Once we run out of chunk for particular article, we not just remove it, but add chunks from the next article.
So if we have articles A1,A2,A3,A4 with chunks A1C1 and A2 has just one chunk:
❄ Training step 1: [A1C1, A1C2], [A2C1, padding], [A3C1, A3C2]
❄ Training step 2: [A1C3, A1C4], [A3C3, A3C4], [A4C1, A4C2]
Previously we would wait until A1 and A2 run out of chunks before introudcing A4. Not anymore! The world can not wait.
The world won't wait. No. Yes. Probably.
This change required significant rewrite of model code as previously it assumed that all batches have equal RMT cache and KV cache, so now I have to process first chunk of the input by parts as not every batch has KV cache or RMT.
2nd chunks and later: no separation is necessary. KV is filled, RMT is calculated.
In some places code became more complex, in other more simpler. At some point I completely removed support of code without flash-attn2.
Oh well, tomorrow(today after sleep) I'll end E3 and push it as mainline version. Probably will push weights to HF as well as this is significant improvement and beathing pythia-70 matters a lot. I'll also need new to choose model to overtake. Cerebras 111M? Some tiny version of GPT2? Oh well, so many variants to choose from.
But as of now it's 4 am, so chill out~!
Top comments (0)