DEV Community

Maykeye
Maykeye

Posted on

BakaLLM, p4. Building LLM from scratch: New baseline is ready.

Hello, fairy dairy diary.

First of all, instead of Cirno, here's a funny graph of experiments with positional embeddings.

NoPE=1/2xpos + 1/2linspace

It shows 2.5+ experiments. One uses rotary embeddings(on the bottom). Worst uses positional embeddings calculated using linspace and the middle is uses NoPE(no positional embedding at all).
The funny part is how NoPE bounces between these two. While humor is subjective, but this one is objectively fun. Yes.

This graph also shows why this post has no Cirno: GPU is busy.

Code for linspace and NoPE(as I did it commenting out everything) is below

class BakaPositions(nn.Module):
    def __init__(self, n_dim, edge=0.1) -> None:
        super().__init__()
        # self.n_dim = n_dim
        # self.edge = edge
        # self.scale = nn.Parameter(torch.ones(1, n_dim))

    def forward(self, state: Tensor, seq_dim: int, offset=0):
        # offset is unused by the baka/pristine so it's always = 0, I just used it to remember it exists
        # assert seq_dim == -2
        return state #Use NoPE for now 
        n_seq = state.shape[seq_dim]
        add = torch.linspace(offset, offset + self.edge, n_seq, device=state.device, dtype=state.dtype)
        add = add.view(n_seq, 1)
        return state + add * self.scale
Enter fullscreen mode Exit fullscreen mode

Idea of using LinSpace(only k was affected by that) was to scale K towards the end of the sequence, so when N-th query is is looking for tokens, it will first look at the biggest K it can see, and thanks to causal mask it's the tokens around N and before that.

Oh, and BakaLLM is finally on github. And hugging face. HF has weights, github has branch 000 with almost all the code I used to previous part. Except utils, which are actually important, the code included for historical reasons.

Once I get results that are good enough, I'll even make sure it can be run effectively. This blog will be used for braindumping ideas and findings.

The Pristine Baseline.

The branch 001_pristine is a new baseline for BakaLLM(working title!). It uses rotary positional embeddings, gated SiLU MLP, sliding window of 512 tokens(without any attention between windows), and BakaState which is a placeholder for caches, data, etc. Now it contains inputs/outputs and offset which will be used to show that token 1 of the window is actually token offset+1 globally

Positions and do we need them?

Previous iteration of BakaLLM (branch 000, henceforth will be referred as Branch Zero) already have used RMT-inspired approach. How exactly they should be positioned? They are not "real" tokens. They are a "summary" of the everything that happened before. So what are their relative or absolute positions? Ideally I would prefer to use NoPE and not care about it. Original RMT paper uses Relative positional embeddings. While it works for pure RMT, "relative" will behave as "absolute" if you stretch it long enough. Assume that Memorizing Transformers will be added to the architecture as I plan one day. What does it leaves us with? If memory grabs history from 10K tokens away, well, relative distance between 10K tokens is both enormous and meaningless. If we grab the history, it doesn't matter that it was 1K, 10K or 100K tokens away. But in relative positional embedding it does matter.

Road to the future

There are two ways forward:
a) NoPE. Do less than nothing and throw away posemb entirely.
b) Recalculate positions each time. Branch Zero cached KV where K already was rotated. We don't need it. Cache might contain unrotated K. Then once prompt is constructed using history, cache from previous chunk, everything can be rotated.

NoPE is the KISSest approach. I think I will give it full 3 epochs of training and then decide what to do. At this stage I want to throw away as much as possible, and NoPE is number one candidate to achieve it. If it will be slightly worse I may even keep it: adding is easier than removing.

Second is to return to one head approach. Pristine uses the parameters copy-pasted from Pythia 160m with 12 heads.

NoPE is volatile so far or lying with wandb and statistics!

Nope is extremely volatile, jumping all over the place. To the point that depending on the graph scale, it may look as good as xpos or as bad as linspace. But always in between.

Volatility example 1

Note the point about 8.5K: here graph shows everything is fine, close to rotary embeddings. Let's zoom in

Volatility example 2

at no point at time NoPE embeddings are close to the rotary.

This is why NoPE will be trained for several epochs: it's interesting to see if it will chill out.

Top comments (0)