DEV Community

Maykeye
Maykeye

Posted on

BakaLLM Part 6: Training and eXtra Large Struggles (Part 1) (hopefully last part of the struggles)

Hello, fairy-dairy diary! Today there will be a lot of pictures and general explanation how training is done!

Adding XL turned out way more !!FUN!! than anticipated. Losses were lost where they were supposed to skyrocket, and gained where they were supposed to dwindle.

First, mandatory Cirno (Back to WD15 with this one)

Cirno frustrated

Let's see which experiments were run and what have we learned, with this raunchy graph of eldritch horrors

Training graph

The gray line is 🐘 good old elephant.

Its training was done like this:

Elephant training

The training set is split into sequences, where each sequence is an article from wikipedia

  • During training sequences are glued together into the batch. If one sequence is shorter than the longest, it's padded from the right side(side is important). In elephant I glued together 6 sequences.

For the illustrative purposes let's assume each word is a token and the batch is two sequences: "There are many fairies in Gensokyo, but only one is the bestest" and "I'm not scared of youkai or gods ..."

  • Next batch is further split batch into minibatches where each subsequences has much smaller length. Subsequences overlap(stride=context/2). This is "training window". On the illustration minibatch is colored with peachy background.

  • This minibatch is fed to the model. Entirely.
    Inside the model the minibatch is once again split into even smaller windows, context windows. In the illustrations they are black rectangles around tokens.

  • Model iterates over this smallest window. Once the context window calculated, model remembers its output and moves to the next chunk(no overlaps there).

  • After all input are passed, the loss is calculated, AdamW updates the weights. Curtains~!

  • Ofs=0: Rotary Embeddings are used. Since there are several context windows and the second window goes after first, first token of the second window should be rotated as if it was 512th token(context size in Baka=512) if we want to pretend that context windows of a sequence belong to one long sequence . However we don't care at Elephant. It doens't matter that much.

  • Note that during training, minibatches that start from <pad> token are thrown away and model doesn't see them. If Minibatch sees context window which is filled with <pad> it still calculates it.

This is why right-side tokenization is pretty much essential for BakaLLM. Since <pad>ding is right-sided, sequence itself never sees the padding, so we don't care about removing padding through attention mask: causal attention takes care of it.

In fact, BakaLLM ignores attention_mask argument it receives on input: the only reason it exists is it's easier to add unused argument to the model than battle with DataCollator and tokenizer to not return it.

XL-no-mem-fwd

Green lines on the training graph above(There are 2 lines, first is for epoch #1, second is for epoch #2 -- nothing unusual here, epoch 2 is better(loss lower) than epoch 1)

XL, attempt Uno

This is basically elephant training, but with minor changes.
For illustration let's consider bigger sequence.

Now when model iterates context chunks, after itearting chunks it remember K,V values. K stored in cache AFTER rotary embeddings applied.

Thus when calculating rotations, we basically get

q = rot(q, offset=chunk.offset)
k = rot(k, offset=chunk.offset)
next_past_k = k
k = cat(past_k, k)
past_k = next_past_k
Enter fullscreen mode Exit fullscreen mode

or Q,K rotated as usual, then K gets pre-rotated K.

Each mini-batch is its own microcosm. It doens't know anything existed before it, so first chunk it sees has is located at pos=0, microbatch "And her name is" has no direct knowledge "There are many fairies in Gensokyo" ever existed.

I don't stop gradients. I do "detach" past, so model sees k,v from the last chunk only, but gradients in this chunk can see everything. This increased the memory so I had to reduce batch count to 5. Which increased the number of times backprop/optimizer got called, and overall training time increased by ~2 hours. Not the point where I start thinking of halving parameters, but very close to.

Well, overall this is the best result.

The red line in the middle is the same algorithm but stride of minibatches increased to ctx_size, halving everything.
However it was meaningless as at that point memory didn't pass through minibatch boundries so after seeing bad performance I killed it.

XL-Repos

Blue line on the graph. Repos stand for "repositioning", as I've started "repositining" Q and K such way as if they were always at start of the sequence:

Repositioning

This was supposed to be an enchantment over previous step.
No longer K are cached rotated, they are cached before. Then during attention-phase they are rotated as if they were at position 0 and current q/k rotated as they were followed immediately afterwards every token from the past chunk(so essentially their offset=context window size):

next_past_k = k
q = rot(q, offset=last_chunk.n_seq)
k = rot(k, offset=last_chunk.n_seq)
past_k = rot(past_k, offset=0)
k = cat(past_k, k)
past_k = next_past_k
Enter fullscreen mode Exit fullscreen mode

From intuition it makes sense: instead of learning bunch of positions, the model learn that everything for position 0-2xcontext only. Granted, during the training, training window = context_window x 4, so model didn't meet that many 'bunch' of different positions. But still now, everything which previously was learned for position 1.5K now learns for position 0.5K which is much more common.

From experiments, Streaming LLM reset does it FWIU.

And yet. It was a fail. A grandiose catastrophe. It was so bad I even give it a little chance to show itself and sunk more hours on early stages of epoch 2. It didn't work.

And it was only harbinger of !!FUN!!

XL-Long-NoReset

It was a time to change training procedure, connect the past to the future.

XL-LONG-NORESET

Now, I pass history between minibatches. And don't reset offset, so every token is rotated against its true position in the sequence.

It also meant I no longer can move minibatch forward by half of its size:
If my sequence is "ABCDEF" and after pondering on "ABCD" I pass it to "CDEF", model can look up the answer in the history and get "CD" tokens for for free. Which is not what I want, at least now. Though I did consider putting restraints like "prompt must be N tokens long" constraint but for now went against it: current token size is 512. I want to use lots of it.

So training schedule was changed.

Now training window has full stride, so I can pass correct history. To restore training cycles, after minibatches are done iterating, from the big batch first (training_size//2) tokens are thrown away and the model runs once again.
So basically if previously on batch ABCDE
we ran AB, BC, CD, DE, now we run on AB, CD, E and BC, DE.
No gradient is being cleared still.

And the result are mixed. To say the least. At first, around 7.5K steps it runs circles around everything. Then it stumbles and next 20K steps it's the worst, even worse than elephant without XL. Then it overtakes the elephant. And it goes so fast, it caught up to repos.

And that's where we experiments ended today. Yesterday. Taking notes take time.

Cirno

Well, I definitely will let it run for 2 more epochs. The goal is to overtake no-mem-fwd.

There are some more experiments that can be done:

  • position reset. Combine no-mem-fwd and repos: each N tokens say "ok, that was long enough, let's pretend last chunk started at position 0"

  • position correction. One trait of repos is that position 0 plays two roles: it's either a start of the sequence, or continuation of the history. It may confuse the model. To overcome it, we can set offset of the token to 1K and set offset of the past to 1K-past_length. This way past will always be at position [512..1024] and current window at [1024..1536]

  • use sigmoid gate over past before rotation/concaatenation. In Baka2 this helped almost always. But current mascot is Cirno, not Meiling so I want to avoid it.

  • Combine ideas above somehow

More experiments were done and found to be useless

  • Decrease theta to 2K in rotary embedding(instead of 10K)
    Rationale was that at bf16 precision, rotation around 10K theta is not too precise. It didn't improve. I didn't try theta=20K

  • Zero rot.
    Keys were prepended and then query and new keys rotated. So query queired past keys rather than its own. It was silly idea and rationale was I was going to sleep and didn't want computer to idle.

In next part we'll explore more Extra Large wonders and see what was chosen if any

But for now, Extra Large Dreams, Dreamy Diary.

Top comments (0)