DEV Community

Maykeye
Maykeye

Posted on

BakaLLM, part 12.4, gatekeeping is all you need

Hello fairy dairy diary!

How do you feel about gates?

Cirno at the gate

So, last time Step4 was introduced to the mix. Result of first epoch were promising, however at the end curse of RMT-16 was not broken, but we managed to get really close.

Valid split Evaluation
(using step4 at each 4th token was meh: 4.04 valid loss)

Also you may notice I tried to finetune each component separately. Here are conclusions: STEP4 finetune is the best, ATTN finetune very good, MLP finetune - biggest waste of FLOPS. Also I didn't notice difference of including or excluding LayerNorms, but it was not measured on valid test, graph of loss on training set was identical.

Now. It was yesterweek. Forward must we go. And backward. For this is fate of LLM training, so speaking of the past:

I had an idea. Probably a good one, occasion so rare it always warrants a writing. One of 000_OLD quirk was that it got massive buff in performance once I started gating the results. I decided to revive this. Now layers don't have attention, step4 or mlp. They have their gated variant! Current gatekeeping is defined as follows

class BakaGated(nn.Module):
    def __init__(self, base) -> None:
        super().__init__()
        self.base = base
        self.config = base.config
        self.fc_gate = nn.Linear(self.config.dim_model, self.config.dim_model, False)
        self.fn_gate = nn.ReLU()

    def forward(self, input: BakaState | Tensor):
        y = self.base(input)
        if isinstance(input, BakaState):
            assert input.input is not None
            input = input.input
        # TODO: sigmoid(x) -- worse, but start is much better
        # TODO: sigmoid(y) -- completely worse
        # TODO: relu(x) -- better in E1
        # TODO: replace sigmoid with something other(lerelu?)
        gate = self.fn_gate(self.fc_gate(input))
        return y * gate
...
class BakaLayer(nn.Module):
    def __init__(self, config: BakaConfig, step_forward_source: Optional[BakaStepForward]=None) -> None:
        super().__init__()
        assert step_forward_source is None or isinstance(step_forward_source, BakaStepForward)
        self.config = config
        self.norm_in = BakaRMSNorm(config.dim_model)
        self.attn = BakaGated(BakaAttention(config))
        self.mlp = BakaGated(BakaMLP(config))
        self.step_forward = BakaGated(step_forward_source or BakaStepForward(config))
Enter fullscreen mode Exit fullscreen mode

I really wanted to reintroduce elephants, but unfortunately I started to run OoM with them. In fact I couldn't even calculate y * GATE(y-x), only calculate y * GATE(y) or y * GATE(x), that's how close we are to 16GB

Drum roll, here're preliminary results:

Training graph
(higher is better)

Blue horizontal is the base comparison - step4_1, everything else is delta to it(which is why higher is better: assume step4_1 has worst loss=100, and nstep has ideal loss=0, then delta = 100-0 = 100, higher than 0), step4_1 was a training where each 4th token is step4'ed. Faint blue - nstep4 (new step4), where each token is step4'ed except first three. Maybe I should use XL cache to extract inputs to step them as well.

It's interesting that it goes bad first, but gets better. Red = y * sigmoid(x). And green = y * relu(x). Here interesting part is that they align so closely. Assume sigmoid as the base:

Sigmoid base

Relu seems better, just barely, but better. It has a potential to lift the RMT16 curse. I'll train 2 more epochs with it and start thinking to use other gatekeeping.

Oh, by the, because of OoM I tried to share weights between Step4 when played around elephants, it didn't help that much. Maybe writing custom forward/backward function will help.

OTP

Cirno + Hong OTP confirmed.

Top comments (0)