Hello, fairy dairy diary.
All it took to tame mamba was to add a good normalizer layer and put mamba after Attn+MLP! And it finally happened.
model_type | n_params | loss | ppl | note |
---|---|---|---|---|
baka_mamba | 213139584 | 3.88385415077209 | 48.611209417978 | after 3 epoch ber 5batch+mamba again |
baka_mamba | 213139584 | 3.99270844459534 | 54.2014924798642 | after 2 epoch ber 5batch+mamba again |
baka_mamba | 213139584 | 4.26171875 | 70.9317927632271 | after 1 epoch ber 5batch+mamba again |
baka_mamba | 201822336 | 3.94270825386047 | 51.5580446647835 | after 3 epoch ber 5batch+mamba-each-4th |
baka_mamba | 201822336 | 4.05729150772095 | 57.8175005609198 | after 2 epoch ber 5batch+mamba-each-4th |
baka_mamba | 201822336 | 4.36744785308838 | 78.8421579437085 | after 1 epoch ber 5batch+mamba-each-4th |
(I learned that sqlite can output markdown, yes, the life of me has changed to before and after)
I got a new record after E3: 3.883
. It happened when I added mamba each 2nd layer. Adding it each 4th layer made results worse, so I'll look into fitting more mamba layers now, maybe reducing MLP or removing them at all. Hyper parameter choosing is so tedious when each experiment take many hours.
It also beats Llama* (which I reverted), it also is close to 3.8890
which I got after 5 epoch of staged training (1st epoch: train 4 layers, freeze them, 2nd: train layers 5-8, freeze them, 3rd: train 9-12, unfreeze everything, 4th: unfreeze everything 5th: unfreeze everything)
I also tried different combinations like (mamba(mlp(attn)))
but it didn't work as well in E1 so they were all ignored.
Also as it turned out score of 4.10598945617676
that I celebrated so much last time was measured on incorrect weight after 1.5 epochs (2nd epoch was broken in the middle but weights were saved). But mamba after E1 still shown good results. And good results are good.
Chill out!
Top comments (0)