DEV Community

Maykeye
Maykeye

Posted on

BakaLLM: Designing new LLM

Hello, fairy dairy diary!

A while ago I ventured on a journey, endless perhaps for a mere mortal without H100 cluster, of making my own LLM, successor to the humanity, with blackjack and fairy maidens! The hopelessness of the endeavor shall not dissuade us from doing it, after all it's all about journey rather than destination.

I might as well write about about it, all under humming of my laptop while training lose of my model approaches perplexity of something you can download from the net! Is it time for chatgpt to panic? Perhaps not!

(no name is solidified yet, so for this blog I will use BakaLLM because ⑨ is the strongest

2833165892-anime, touhou, 1girl, Cirno in battle pose ### ugly, bad, 2girls, 3girls.
)

(Image generated by NN by stable horde) So.

The premise is simple: local machines lack memory. Wisdom of the streets says we need to download more RAM! So difficult in the real world, but what if we enhance the memory of the model itself during its design? What if we allow it to dare to reach out of its context window, into lands left so far behind, beyond, no llama can remember it ever happening? Many designs were aimed at this. Some were even published with weights if you know where to look. And these designs do exactly that: they give more memory to the model.

Let's stew them together. Who is going to stop us? Sanity? AI police?

In this day of age?

Fairy warning, if you have any idea about ML, you probably need a barf bucket(or two) if you dare to venture in this ongoing story further. You were warned! Though today is simple day of looking on building blocks.

Oh, the style of this blog is intentionally not serious, so easily impressive minds wouldn't take it as a tutorial or educational resource: For teaching-proper requires more knowledge than reading couple of papers and watching youtube videos(with mandatory naps). It's ramblings of a someone whose GPU is too busy to play video games for a while. But I leave arxiv links.

So let's see what kind of mix of architectures whose new current iteration with 172_062_349 parms looks like this...

ChatGPT is free from panicking for now

...has under its hood.

(Spoiler: previous iteration eventually reached loss of ~4.86, on this wikitext103 dataset(more on that on later date), this iteration will beat it)

The following models were used(misunderstood, misapplied) during this quest to world conquest:

Transformer-XL is pretty and simple

The idea of Transformer XL is simple: you reuse KV from the past in the attention phase to create new values. And these values also will be reused later. It's like doubling a context size, but with more whistles from the past. And less of doubling. Perfect match. More layers = more past.

Implementing it was trickier than I thought:

By default pytorch's F.scaled_dot_product_attention aligns causal mask to the left(torch.ones(L, S, dtype=torch.bool).tril(diagonal=0))
Bad torch!
It means if we have more queries than keys, the mask will be used like this:

1000 0000
1100 0000
1110 0000
1111 0000
Enter fullscreen mode Exit fullscreen mode

Not very useful mask. Might-as-well-remove-half-of-the-keys kind of mask. We want "half-causal" kind of mask:

1111 1000
1111 1100
1111 1110
1111 1111
Enter fullscreen mode Exit fullscreen mode

So much prettier! It allows us to speak to the past from any position in the current timeline! Unfortunately doing it is ugly for now.

It will change (https://github.com/pytorch/pytorch/issues/108108), but so far I had to resort to building mask manually, which tanks the performance. Pytorch just can't handle the greatness of this mask.

Alternatives were considered: flash_attention has a lot of quirks which makes it no go: no support for f32, has limits on head_dim(more on that later). Issue in github states xformers also use right alignment, not sure it's true thought: at least materialize method return the same mask. I also considered doing two attention, then just adding them together(lol). It didn't work too well.

  • Mega

If you squint it's like XL

If you don't squint looks nothing like XL

From this architecture I borrowed its main mechanism moving average(and feeding it to QK but not V). This adds another source of residual information. It also makes each token to be a function of all of its predecessors on early stages. Neat. Will remember this fact for later.

Now implementing it is even more trickier. Since I'm not that familiar with FFT or linalg, I took an implementation from lucidrains who seems to implement papers for hobby.

Let's play a game. Can you spot the bug?


def conv1d_fft(x, weights, dim = -2, weight_dim = -1):
    # O(N log(N)) 1d convolution using some fourier trick

    assert weight_dim >= dim

    N = x.shape[dim]
    M = weights.shape[weight_dim]

    fast_len = next_fast_len(N + M - 1)

    f_x = rfft(x, n = fast_len, dim = dim)
    f_weight = rfft(weights, n = fast_len, dim = weight_dim)

    f_v_weight = f_x * append_dims(f_weight.conj(), weight_dim - dim)
    out = irfft(f_v_weight, fast_len, dim = dim)
    out = out.roll(-1, dims = (dim,))

    indices = torch.arange(start = fast_len - N, end = fast_len, dtype = torch.long, device = x.device)
    out = out.index_select(dim, indices)
    return out
Enter fullscreen mode Exit fullscreen mode

That's a trick question! The bug is hidden in code deeper than this: it's an usage of rfft, the Memory leaker.

I almost considered switching to JAX.
I also considered calculating exponents in the loop but it was so slow.
And what if it's cuda issue that exist in JAX too? And I need to learn something new, don't want to become too smart. So I decided to let it leak the memory and just kill training loop every 100 steps and restart it. Then fast-forward to where we left.

Mem usage: ~1 hour to OoM

~1 hour to OoM. Not a pretty picture

RMT is so simple

We need more retention. And RMT provides just that, but from another angle. Essentially it asks the question "what if model autosummarized itself on token level?" by adding memory blocks, model fills them on the end on the go. Then you take memory from the tail and put it in the head of the next chunk. Simple, elegant.

They (and RETRO and probably many others) use external knowledge DB. Basically if you have context, they look for it, give more context.

I haven't found good implementation, it seems kNN is not so simple. lucidrains as usual has the implementation. It uses 3rd party library, which has its quirks: (once it hits this capacity, it will be reset for now, due to limitations in faiss' ability to remove entries)

Current intention is to use N memory slots and attention mechanism to lookup memory in linear time. This will be left for later as it can be done after training.

Not memory related borrowings which makes BakaLLM not like the other model:

I took idea of using Layernorm for Q and K. Gave batter result on TinyStories so I kept it.

  • From Llama(?):

Gated MLP. In older GPTs e.g. neox MLP is

x = up(x) # [b, s, d] -> [b, s, 4d]
x = act(x) #  [b, s, 4d] -> [b, s, 4d]
x = down(x) # [b, s, 4d] -> [b, s, d]
Enter fullscreen mode Exit fullscreen mode

Llama uses

x, g = up(x), gate(x) # [b, s, d] -> [b, s, Md]
x = x*act(g) #  [b, s, Md] -> [b, s, Md]
x = down(x) # [b, s, Md] -> [b, s, d]
Enter fullscreen mode Exit fullscreen mode

where M is a convoluted constant because 四 is a scary number.

I'm not sure if it's a llama novelty, but I haven't seen it before llama(because I didn't look), and since it's easy to implement, idea was borrowed.

  • And finally From RetNet I took a claim that this model is a successor.

In my case - humanity.
Since BakaNet already has Mega, no further borrowing is necessary. But this one is very important. Some might say it's the most important part. They will be wrong, but we will not judge.

In next part, someday, we'll talk about how it's all glued together and speak of the Greatest Conspiracy Against the Small Values.

Mandatory nap is mandatory after all!

Top comments (0)