DEV Community

Maykeye
Maykeye

Posted on

BakaLLM. Part 10. Progress? `Dim` light in thinning tunnel? Thin layers revisted: turning thin layers into thicc MLP

Hello, Fairy Dairy Diary! My GPU is busy as usual so I vent a little!

Let's talk about thin layers and hourglasses! Contender(result will be known in couple of days unless something bad happens) to be moved from experiment branches to mainline as the next step towards model becoming a true successor of humanity!

Cirno and elephant look at hourglass

Now, Initially I was playing around with thin layers: throw away MLP, insert layer. Inserting them in the beginning was bad idea. Terrible loss, -9M out of 10. Would not recommend, I danced around, initiated them with almost zero, still no bueno:

Thin layers

However something interesting caught my attention. Once thin layers moved toward the middle, loss improved substantially.
Not to the point where it was better(it was comparable to XL, but 9M more parameters and two more training hours: up to 14H now; though I compress convert terrabytes of en's MadLad400 from gunzip to zstd, so it might have had the effect)

Well, so I extra layers away. Getta out of here! Pew, pew, dd, dd, dd!

And then 9M extra parms were plumped into layer 0, following LoRAShear: if first layers are too important to be LoRAfied, they need more params to help future layers. First I called another layer of the same architecture in parallel(it was bad), then I increased MLP size and it was good. Loss improved substantially once again, now it was the best (recorded) loss after epoch 1. However while it was better, it was not 9M extra parms better(4.5164; compare to 4.5648 of long-reset). I also wanted to increase dim_attn, but meh'ed it for now. Layer 1 wants to have K,V from layer 0. So I either have to add linear transform attn->model for k,v to have same shape as future layers, which would eat up half of parms probably, or just cut out some results. Which seemed like bad idea.

Next logical question was: how bad it would be if we remove these extra 9M args, but reshuffle existing arguments around. Dance, numbers, dance!
And here Hourglass(all titles WIP!!!!) architecture came to be.

Since Bakanet follows Pythia, it has dim_ff = 4 dim_model, and each MLP consists of 3 linear layers for 4 dim_model^2 args (upcast/downcast/gate). Ignore that there are three of them and consider the follow. We take four layers in the middle and reduce their intermediate mlp size by 1 dim_model. Now we have 3 dim_model^2 parms in thin layers and we have 4dim_model^2 to spare. Well, let's plump them into first layer! Now it has mlp of with immsize = 8 dim_model^2, just like with extra parms.

And to my surprise the result was even better: 4.5070 loss on valid split after epoch 1. New record! Probably (I didn't record post-epoch1 for pauses, so the network will have to beat 4.0221 after epoch 3)

So now model's MLPs' immediate sizes look like this:

Hourglass shape

Scaling in MLP goes
layer 0: x8, layer1-3: x4, layers 4-7: x3, layers 8-11: x4 again. and 8 + 3x4 + 4x3 + 4x4 = 48 = 12x4. Math checks out. Number of parms the same (and printing number of parameters to console confimred it, hooray).

So far training in epoch 2 goes smoothly,

Progress

the problem with wandb smoothing it tends to draw optimistic picture in the beginning of the epoch, but by the time end is reached, it shifts beginning up for... reasons? I swear to every yokai if by the end of epoch it will decide to move graph so it fully overlaps previous graph, I'll start throwing loss in csv/sqlite3 myself.

Hourglass also opens up another possibility for optimization: remove more parameters from middle layers and move them into layers -1 or 1. Or towards RMT and other architectural tweaks

Now time to read about retention mechanisms and drink some coffee. Lots of coffee.

Somnum delenda est, (probably)

Top comments (0)