DEV Community

Maykeye
Maykeye

Posted on

BakaLLM, part 14: (MLP) Sharing is caring(?)

Hello, fairy dairy diary!

New paper that I skimmed by diagonal and got inspired!
One Wide Feedforward is All You Need

Wide Cirno

They base their experiments mainly on full transformers with encoders, so I blame my illeteracy laziness~ non-reading on the fact we are decoder-only baka anyway. So for now I just made it possible to share MLP by setting into config how to share weights. I didn't believe that one feed forward is all I needed, so for the first experiment I reduced them to 3 in 12 layers.

  @dataclass
  class BakaConfig:
      dim_model: int = 1024
      dim_ff: int = 0
>>>   ff_pattern = 'aaaabbbbcccc'
Enter fullscreen mode Exit fullscreen mode

)

Also for control test, instead of increasing dim_ff, I just shared some MLPs using "standard" dim_ff = 4*dim_model

Training graph follows RMT closely

Training Graph
While there are some differences that suggest it may be worse

Diff of graphs

During evaluation it turned out to be even better
Results

Results turned out to be fine.
And I can also test other patterns like 'abcabcabcabc'!

Considering that it sheared ~70M of parameters, that's something that needs deeper investigation on differrent number of shared MLP

The paper also shown the following fact:

These experiments show that, contrary to the FFN,
individual layer-specific attention weights are more
important and not as redundant, as sharing the at-
tention leads to significant accuracy drops

Which also reminded me previous experiments where I doubled whole layers (y=layer2(layer2(layer1(layer1(x)) and that finetuning attention or step4 affected results much more than finetuning MLP.

Top comments (0)