DEV Community

Sirajuddin Shaik
Sirajuddin Shaik

Posted on

Mamba/SSM Basics

State Space Models offer linear-time sequence modeling with content-aware selective filtering, challenging Transformers for long-context inference.

Why This Matters

State Space Models (SSMs) provide a principled alternative to Transformers for long-sequence modeling. In production systems handling long contexts (e.g., code generation, genomic analysis), Transformer attention's quadratic cost becomes a bottleneck. Mamba achieves linear-time inference with constant-memory state, making it viable for million-token contexts where attention-based models are prohibitively expensive.

Core Idea

SSMs originate from continuous-time control theory: a latent state evolves over time driven by input, and observations are linear projections of that state. Mamba's key innovation is making the SSM parameters input-selective — the model learns to gate which information enters and exits the state, mimicking attention's ability to focus on relevant tokens without the O(n2)O(n^2) cost.

Technical Details

The continuous-time SSM is defined as:

x(t)=Ax(t)+Bu(t),y(t)=Cx(t)+Du(t) x'(t) = Ax(t) + Bu(t), \quad y(t) = Cx(t) + Du(t)

where x(t)RNx(t) \in \mathbb{R}^N is latent state, u(t)u(t) is input, and ARN×NA \in \mathbb{R}^{N \times N} , BRN×1B \in \mathbb{R}^{N \times 1} , CR1×NC \in \mathbb{R}^{1 \times N} . Using zero-order hold discretization with step Δ\Delta :

Aˉ=exp(ΔA),Bˉ=(ΔA)1(exp(ΔA)I)ΔB \bar{A} = \exp(\Delta A), \quad \bar{B} = (\Delta A)^{-1}(\exp(\Delta A) - I) \cdot \Delta B

The recurrent update becomes:

xk=Aˉxk1+Bˉuk,yk=Cxk x_k = \bar{A}x_{k-1} + \bar{B}u_k, \quad y_k = Cx_k

Mamba's selective mechanism makes BB , CC , and Δ\Delta input-dependent:

Bk=LinearB(xk),Ck=LinearC(xk),Δk=softplus(LinearΔ(xk)) B_k = \text{Linear}B(x_k), \quad C_k = \text{Linear}_C(x_k), \quad \Delta_k = \text{softplus}(\text{Linear}\Delta(x_k))

The parallel scan algorithm computes this recurrence in O(nlogn)O(n \log n) during training. Inference is O(1)O(1) O(1) per token with fixed state size N NN , yielding constant-memory decoding regardless of sequence length.

How It Works

  1. Project input: Map token uku_k to expanded dimension DND \cdot N .
  2. Generate selective parameters: Compute input-dependent BkB_k , CkC_k , Δk\Delta_k from uku_k .
  3. Discretize: Convert continuous (A,B)(A, B) to discrete (Aˉk,Bˉk)(\bar{A}_k, \bar{B}_k) using Δk\Delta_k .
  4. Recurrent scan: Apply parallel scan (training) or sequential update (inference) to compute hidden states xkx_k .
  5. Output projection: Compute yk=Ckxky_k = C_k x_k , then project through gating (SiLU) to output dimension.

Key Insights

  • Selectivity is essential: Non-selective SSMs (S4) cannot do in-context retrieval; making B,C,ΔB, C, \Delta input-dependent enables content-aware filtering.
  • Diagonal + low-rank structure on AA enables O(n)O(n) recurrence; Mamba uses diagonal AA matrices exclusively.
  • Hardware-aware design: The scan kernel is IO-bound, not compute-bound — Mamba's CUDA kernel fuses discretization, scan, and output projection to minimize memory reads.
  • Linear decoding cost: Unlike KV-cache which grows linearly, SSM state is fixed-size O(ND)O(ND) , making generation memory-constant.

Sources

Top comments (0)