Greetings all! In my previous post I covered Binary Plane Encoding, a 3-channel grid representation for Snake that doubled the best published score. Three binary channels: head, body, apple. For details check my previous post.
But there was a fourth channel I left out. Direction. The snake's current heading, encoded as a uint8 (0 = up, 1 = right, 2 = down, 3 = left), is painted uniformly across a 20×20 plane due to matrix shape requirements. That's 400 elements carrying exactly 2 bits of information. A 1,600× overhead at the channel level.
Worse, that one integer channel with its 2 bits was blocking the entire state from being bit-packed. The other three grid channels are binary, meaning they could be packed at 1 bit per element. But the direction channel with its scoffs 2 bits, can't. So the replay buffer sees the state as uint8 instead of binary. One channel, 2 bits, holding back one more step of memory optimisation, forcing 1,600 bytes per state instead of 250 (20 × 20 grid, ×4 channels, 1 byte per channel = 1,600 vs 20 × 20 grid, ×5 channels, 1 bit per element / 8 = 250).
This follow-up post is about fixing that, and the pitfalls along the way.
The First Attempt
Four cardinal directions. Two bits encode four states. So the intuitive replacement is two binary channels instead of one integer channel: one bit for North/South, one bit for East/West. Compact, geometric, obvious.
Except it doesn't work. Walk through it:
North and West both map to 0,0 - Collision.
The failure is subtle because the scheme seems right. Four directions, four possible bit combinations, should be a clean fit. But the scheme tries to answer "is there a north/south component?" and "is there an east/west component?" Cardinal movement is strictly one-dimensional. The perpendicular component is always exactly zero. What does the E/W bit say when the snake is moving north? It's not moving east. It's also not moving west. Both map to 0. "Not moving east" is identical to "not moving west" in a single bit.
Two bits should be enough for four directions. They are. Just not those two bits.
Ask Better Questions
The collision happens because the N/S + E/W scheme asks the wrong questions for cardinal movement. The fix isn't more bits. It's better questions.
The correct encoding uses two bits derived geometrically:
Axis bit: which axis is the snake travelling along? (0 = vertical, 1 = horizontal)
Sign bit: which direction on that axis? (0 = negative, 1 = positive)
All four directions get unique codes. The axis bit answers "which axis?" and the sign bit answers "which end?" Both questions always have exactly one answer for cardinal movement. No ambiguity, no collisions. The specific sign convention (whether north is positive or negative) doesn't matter as long as it's internally consistent. The CNN will learn whatever mapping you give it.
The first attempt was asking the wrong questions. Once you ask the right ones, two bits is plenty.
For anyone wondering about diagonal games (8 directions), the axis + sign scheme breaks because a diagonal is on both axes simultaneously. The general solution there is a 4-channel one-hot: one binary plane per cardinal direction, with two planes active for a diagonal. But for Snake, cardinal-only, the 2-channel scheme is the right choice. Don't build the generality you don't need.
The Memory Maths
This is where the change pays off. The state goes from (4, 20, 20) with one integer channel to (5, 20, 20) with all binary channels. Yes, adding a channel saves memory. That sounds backwards but the maths checks out.
Before (4-channel, uint8 storage): 4 × 20 × 20 = 1,600 elements at 1 byte each = 1,600 bytes per state. A 1-million-transition replay buffer (storing both state and next state): 3.2 GB.
After (5-channel, binary bit-packed): 5 × 20 × 20 = 2,000 elements. Every value is now 0 or 1, so each element can be packed at 1 bit, 8 elements per byte. ⌈2,000 / 8⌉ = 250 bytes per state. The same buffer: 500 MB.
6.4× reduction. Adding one channel, removing 2.7 GB.
To put this in perspective: the grid encoding stored naively as float32 (before any compression) would be 6,400 bytes per state, or 12.8 GB for a 1M-transition buffer. The first post's uint8 storage cut that to 3.2 GB (4× reduction). This post's binary bit-packing cuts it again to 500 MB. Across both changes, that's a 25.6× total reduction from the uncompressed float32 starting point.
And compared to the pixel-based approaches from the first post? Wei et al.'s RGB inputs would need approximately 49 GB for the same buffer. Binary Plane Encoding with binary cardinal directions brings that to 500 MB. Nearly a 98× difference. A 1-million-transition replay buffer now fits comfortably in the VRAM of a gaming laptop, hell, it fits in some EPYC CPU caches (AMD's Genoa-X packs up to 1,152 MB of L3). With pixel inputs, it wouldn't fit on most workstations.
Two Lines of Code
The implementation change is in snake_cnn_env.py. Replace the single integer direction plane with two binary planes:
# Before: one integer channel
# grid[3] = self._direction # 0, 1, 2, or 3
grid[3] = float(self._direction % 2 == 1) # axis: 0=vertical, 1=horizontal
grid[4] = float(0 < self._direction < 3) # sign: 0=negative, 1=positive
Update input_channels from 4 to 5 in the model config. Done. We now store 5 channels instead of 4, but each channel is 1 bit instead of 8. One extra channel, massively less storage.
One real cost: changing input_channels changes the shape of the first convolutional weight tensor. Existing checkpoints can't be loaded into a 5-channel model. This requires a fresh training run, so schedule the change at a natural break point, not mid-experiment.
torch.unpackbits Doesn't Exist
The CPU side of bit-packing is trivial. np.packbits and np.unpackbits have existed in NumPy since 2010. Pack on write, unpack on read. Done.
So just implement it on the GPU side right? WRONG. The natural PyTorch equivalent would be torch.unpackbits, which... doesn't exist? The function is absent from the stable API entirely, and importing it raises an AttributeError. This is a genuine gap in PyTorch that anyone implementing binary storage on CUDA will hit.
The community workaround I found uses bitmasks:
mask = 2 ** torch.arange(8, dtype=torch.uint8, device=x.device).reshape(8, 1)
unpacked = (x.unsqueeze(-1) & mask).bool().int().flip(dims=[1])
This works. It preserves the original bit values, converts them to binary via .bool().int(), and flips the bit order to match MSB-first convention. Four operations, correct output.
But I don't need to preserve the original mask values, I just need 0s and 1s. I thought I could do better, and I wouldn't be a programmer if I didn't try for no other reason except... shrugs I wanted to?
shifts = torch.arange(7, -1, -1, device=packed.device, dtype=torch.uint8)
unpacked = ((packed.unsqueeze(-1) >> shifts) & 1) # (B, packed_size, 8)
unpacked = unpacked.reshape(B, -1)[:, :n_elems] # drop padding bits
Each packed byte is broadcast against 8 shift values [7, 6, 5, 4, 3, 2, 1, 0], right-shifting to move each successive bit into the least significant position. Bitwise & with 1 isolates it. Two operations instead of four. No .bool().int() needed because >> shift & 1 always yields binary output directly. No .flip() needed because the descending shift range already produces MSB-first order. Fewer intermediate tensors in VRAM during sampling.
The mask approach also has a shape bug: it's written for a 1D input (flat array of bytes) and breaks on a batched 2D input (B, packed_size). The shift approach handles batched GPU sampling correctly from the start.
Both are fully device-resident with no CPU-GPU transfer. But two operations beats four, and not allocating intermediate tensors matters when batch size and state shape are large. Will reducing two ops make a difference? Probably not, but I saw the OPportunity and took it. And yes, I said that just for the joke.
So, two lines of code changed the state representation to allow bit-packing and saved a lot of storage with no loss of data.
What's Next
This is part of an ongoing series building Rainbow DQN incrementally and measuring each component on Snake. The state representation work runs in parallel to the algorithm comparison. It doesn't change which Rainbow components help or hurt, but a 6.4× memory reduction means larger buffers, more parallel environments, or training on hardware that previously couldn't fit the buffer.
The algorithm results are the next post.
If you've hit the torch.unpackbits gap yourself, or found a cleaner solution than bitwise shifts for GPU-side bit unpacking, I'd like to hear about it in the comments.
This work is part of ongoing research and the findings are planned to be submitted as a peer-reviewed paper.
If you missed the first post in this series:



Top comments (0)