Hello, fairy dairy diary!
Tried flash fft conv and even though I updated torch to 2.1 where memory leak was fixed, it still eaten way too much memory. But! I remembered another flash! Flash attention. Now, I assumed it's used by F.SDPA by default. In fact maybe it is. But I used nasty masks for XL sake and with nasty masks nothing is guranteed. However Flash Attention fixed alignment long time ago, so I added option to use it instead. I had to throw away RMT knowing every other token from RMT for simplicity, but it didn't became too bad, in fact most loss was identical, only the beginning of the 1st was different.
Here are their differences against step4.
Loss was so insignificantly different that it might as well be addressed to me using nondeterministic algos.
To the end of the week I'll try to play around with increased batch size or maybe even (gasp) context size, oh horrors of giving up on 512 tokens and increasing them to 640 or maybe even 768. And I'm still thinking how to reMEGAfy the model. Failing, I think I'll try step2 in general and then for Q/K only. Interesting if adding them under attn will help.
(Here: I'd say chill, but -30C is chilled enough.)
Top comments (0)