DEV Community

Maykeye
Maykeye

Posted on

BakaLLM, part 3: it's testing time, it's testing time

Hello, fairy dairy diary!

Today let's look at the current state of the current BakaLLM iteration.

Drum roll!

Cirno plays drums

Table roll~! Roll everythng.

Model n_ctx loss notes
Pythia-14m 2048 4.3343 Small models were evaluated at 2K context in F32
Pythia-70m-deduped 2048 3.5699
Pythia-160m-deduped 1024 3.1113 (for comparison with BakaPythia160)
Pythia-160m-deduped 2048 3.0813
Pythia-1b-deduped 2048 2.4778
Cerebras-GPT-111M 2048 3.5955
Cerebras-GPT-256M 2048 3.2286
gpt2(137M) 1024 3.1929 Maximum allowed context length
Open Llama 3B 2048 1.9555 BF16, starting to hit memory VRAM limits
Open Llama 7B 1024 1.8960 BF16; still had to reduce context due to VRAM
BTLM 3B 2048 2.1689 BF16
Mistral 7B 1024 1.6808 BF16
--- -- -- --
BakaPythia160 1024 4.4512
bakanet2_mega_qk_v_mismatch 512++ 4.0092 Due to error in mega implementation QK was passed to transformer un-mega-fied, while V megafied, it should be the other way

First of all, the test was done on wikit103, valid split.

Wiki103 was chosen specifically due to shoutout from TransformerXL paper

When trained only on WikiText-103, Transformer-XL manages to generate reasonably coherent, novel text articles with thousands of tokens.

By the way they report ppl of 23.09 which means loss of 3.1394, so around Pythia-160m w 1024 ctx.

The test code was based on HF's perplexirty scirpt, but I have used loss instead of perplexity, because after exponentiation bakanet looks so much worse because it feels like doing useless maths work and maths is maths.

The test is simple:
We slide the window over tokenized articles, window size=n_ctx, stride=n_ctx//2. If window was slid, we ignore first n_ctx//2 tokens in loss calc. Then we call model with inputs and targets.

Due to limited VRAM(16GB of 3080TI laptop) not every test is like other. Big models use BF16 to be loaded at all. 3B models use BF16 + 2K context(It works better than 1K F32 context as model sees more). GPT2 can't use more than 1K tokens.

Technically the test is not fair. BTLM can handle 8K tokens. But my VRAM is limited, so I used 2K only, as with other models, some of then can't use more than 2K.

Technically I can use exllama/gptq/etc loaders and use 2K tokens. But I don't want to, and that's a good enough reason. For now I want to use single loader, which happens to be good old Transformers. Nothing flash. Technically I can use transformers loader built-in "load in 8bit/4bit" parameter. But I don't want to go lower than BF16 for now.

From the list we see: Mistral is a goat, GPT2 with smaller context and smaller amount of parms outperformed Cerebras-GPT-256M.

And in the end there are two(!) of my models. bakanet2_mega_qk_v_mismatch is what this blog is all about. It has "512++" context because it keeps cache; I simply dumped all tokens as input and it split text into chunks internally.
Other models will OoM (except some loaders like exllama : it also chunkifies long inputs)

BakaPythia160 is AutoModelForCausalLM.from_config(pythia160m_config). It was trained similar to bakanet, but with chunk size=1024(2K didn't fit)

Idea was that if BakaLLM architecture is better, bakanetllm would be better(and it is), and losing to pretrained model happened because it porobably in a single batch saw more tokens than my whole model per epoch. It seems so. "Seems" because it still not apple to apple comparison. BakaLLM uses 2K sequences which it then splits into windows of 512 tokens. But I saw OoM and reduced input of the model.

Pythia-14m is defeated. A model has <1/10 of BakaNet parameters and 0/10 chances to win~

Cirno celebrates victory

Next target: 3.⑨⑨. It will not defeat bigger models(yet), but it's important psychological barrier, which must be broken.

Ways to break it are simple:

  • Fix Mega layer
  • Add token
  • More memory retention! The memory stew will continue until being cooked until BakaNet will remember everything!

Happy mandatory nap! For next time we'll talk about clean up and refactoring.

Bonus image:
BakaPythia train loss. Notice spikes at the beginning of run1. Due to some bugginesss, I had to restart it, erasing results so run1 is longer.
Pythia160m
Pythia160m

Epoch 3 of BakaNet

Epoch 3 of BakaNet

At that time I didn't yet read docs how to put several graphs into single project, only how to append to existing one, so each epoch used different project.

Notice how bakanet is much more "volatile" with spikes biased downward. I think it because memory(such as RMT) does its thing and longer text goes, the better it works

Top comments (0)