Hello fairy dairy diary!
How do you feel about gates?
So, last time Step4 was introduced to the mix. Result of first epoch were promising, however at the end curse of RMT-16 was not broken, but we managed to get really close.
(using step4 at each 4th token was meh: 4.04 valid loss)
Also you may notice I tried to finetune each component separately. Here are conclusions: STEP4 finetune is the best, ATTN finetune very good, MLP finetune - biggest waste of FLOPS. Also I didn't notice difference of including or excluding LayerNorms, but it was not measured on valid test, graph of loss on training set was identical.
Now. It was yesterweek. Forward must we go. And backward. For this is fate of LLM training, so speaking of the past:
I had an idea. Probably a good one, occasion so rare it always warrants a writing. One of 000_OLD
quirk was that it got massive buff in performance once I started gating the results. I decided to revive this. Now layers don't have attention, step4 or mlp. They have their gated variant! Current gatekeeping is defined as follows
class BakaGated(nn.Module):
def __init__(self, base) -> None:
super().__init__()
self.base = base
self.config = base.config
self.fc_gate = nn.Linear(self.config.dim_model, self.config.dim_model, False)
self.fn_gate = nn.ReLU()
def forward(self, input: BakaState | Tensor):
y = self.base(input)
if isinstance(input, BakaState):
assert input.input is not None
input = input.input
# TODO: sigmoid(x) -- worse, but start is much better
# TODO: sigmoid(y) -- completely worse
# TODO: relu(x) -- better in E1
# TODO: replace sigmoid with something other(lerelu?)
gate = self.fn_gate(self.fc_gate(input))
return y * gate
...
class BakaLayer(nn.Module):
def __init__(self, config: BakaConfig, step_forward_source: Optional[BakaStepForward]=None) -> None:
super().__init__()
assert step_forward_source is None or isinstance(step_forward_source, BakaStepForward)
self.config = config
self.norm_in = BakaRMSNorm(config.dim_model)
self.attn = BakaGated(BakaAttention(config))
self.mlp = BakaGated(BakaMLP(config))
self.step_forward = BakaGated(step_forward_source or BakaStepForward(config))
I really wanted to reintroduce elephants, but unfortunately I started to run OoM with them. In fact I couldn't even calculate y * GATE(y-x)
, only calculate y * GATE(y)
or y * GATE(x)
, that's how close we are to 16GB
Drum roll, here're preliminary results:
Blue horizontal is the base comparison - step4_1, everything else is delta to it(which is why higher is better: assume step4_1 has worst loss=100, and nstep has ideal loss=0, then delta = 100-0 = 100, higher than 0), step4_1 was a training where each 4th token is step4'ed. Faint blue - nstep4 (new step4), where each token is step4'ed except first three. Maybe I should use XL cache to extract inputs to step them as well.
It's interesting that it goes bad first, but gets better. Red = y * sigmoid(x)
. And green = y * relu(x)
. Here interesting part is that they align so closely. Assume sigmoid as the base:
Relu seems better, just barely, but better. It has a potential to lift the RMT16 curse. I'll train 2 more epochs with it and start thinking to use other gatekeeping.
Oh, by the, because of OoM I tried to share weights between Step4 when played around elephants, it didn't help that much. Maybe writing custom forward/backward function will help.
Cirno + Hong OTP confirmed.
Top comments (0)