DEV Community

Denis
Denis

Posted on

Parallel RNNs?

Did you check out the recent ICLR results? I got intrigued by a rather provocative paper from Apple - ParaRNN, claiming parallelism for RNNs, when this is supposedly their main weakness, the very reason transformers replaced them (in most tasks).

So let's dig into all of it at the lowest level possible. If you know what an RNN is and what a derivative is, this article is for you.

1. The DEER algorithm

DEER = Deep Equilibrium Evaluation of Recurrence (Lim et al., 2024). The base algorithm on which ParaRNN is built.

1.1. Formulation as a root-finding problem

Consider an RNN with transition function f:RDRDf: \mathbb{R}^D \to \mathbb{R}^D , initial state s0\mathbf{s}0 , and unknown states s1:T\mathbf{s}{1:T} . Let's introduce the residual:

r(s1:T):=[s1f(s0), s2f(s1), , sTf(sT1)]RT×D \mathbf{r}(\mathbf{s}{1:T}) := [\mathbf{s}_1 - f(\mathbf{s}_0),\ \mathbf{s}_2 - f(\mathbf{s}_1),\ \ldots,\ \mathbf{s}_T - f(\mathbf{s}{T-1})] \in \mathbb{R}^{T \times D}

The true trajectory s1:T\mathbf{s}^{}_{1:T} is the **unique solution* of the equation with zero residual:

r(s1:T)=0 \mathbf{r}(\mathbf{s}^{*}_{1:T}) = \mathbf{0}

When we say "apply an RNN to a sequence", we mean the standard procedure: take the initial state s0\mathbf{s}_0 , apply the transition function ff , get s1\mathbf{s}_1 , then apply ff again to get s2\mathbf{s}_2 , and so on:

s1=f(s0),s2=f(s1),,sT=f(sT1) \mathbf{s}1 = f(\mathbf{s}_0), \quad \mathbf{s}_2 = f(\mathbf{s}_1), \quad \ldots, \quad \mathbf{s}_T = f(\mathbf{s}{T-1})

Accordingly, r\mathbf{r} turns out to be a vector with all elements equal to 0, again because when the recurrence holds we have s1=f(s0)\mathbf{s}_1 = f(\mathbf{s}_0) and therefore s1f(s0)=0\mathbf{s}_1 - f(\mathbf{s}_0) = 0 .


1.2. Newton's iterations

So we need to find a solution to the equation r(s)=0\mathbf{r}(\mathbf{s}) = \mathbf{0} , or in the full case, a vector that solves a system of equations. But let's start with the scalar case.

Scalar case: one equation in one variable

Suppose we have a smooth function r:RRr: \mathbb{R} \to \mathbb{R} and we want to find ParseError: KaTeX parse error: Expected group after '^' at position 2: s^̲ such that r(s)=0r(s^) = 0 . Geometrically, we want to find the point where the graph of the function crosses the x-axis.

The idea of Newton's method rests on a simple thought: in a small neighborhood of a point, a smooth function is almost indistinguishable from its tangent line. If we're standing at our current approximation s(i)s^{(i)} (which, in general, is not a root - there r(s(i))0r(s^{(i)}) \neq 0 ), we can pretend that rr is its tangent line at this point, and for such a linear function it's easy to find analytically where it crosses the axis.

The tangent to rr at the point s(i)s^{(i)} is the first-order term of the Taylor expansion:

r(s)r(s(i))+r(s(i))(ss(i)) r(s) \approx r(s^{(i)}) + r'(s^{(i)})\,(s - s^{(i)})

The Taylor expansion is a way to approximate any smooth function near a point s0s_0 by a polynomial: r(s)=r(s0)+r(s0)(ss0)+r(s0)2!(ss0)2+r(s0)3!(ss0)3+r(s) = r(s_0) + r'(s_0)(s - s_0) + \frac{r''(s_0)}{2!}(s - s_0)^2 + \frac{r'''(s_0)}{3!}(s - s_0)^3 + \ldots , where each subsequent term refines the approximation by adding information about an ever finer feature of the function's shape (slope, curvature, etc.). The logical meaning: if a function is smooth, then its behavior in a neighborhood of a point is fully encoded in the values of its derivatives at that single point - by measuring a few numbers at s0s_0 , we can reconstruct the function's values nearby. The divisor k!k! arises naturally from the requirement that at s0s_0 all derivatives of the polynomial coincide with those of the function itself (it cancels with the factorial that pops out when differentiating (ss0)k(s-s_0)^k kk times).


We set this linear approximation equal to zero and find where it crosses the axis:

r(s(i))+r(s(i))(ss(i))=0 r(s^{(i)}) + r'(s^{(i)})\,(s - s^{(i)}) = 0

Solve for ss - this is just school algebra:

s=s(i)r(s(i))r(s(i)) s = s^{(i)} - \frac{r(s^{(i)})}{r'(s^{(i)})}

And we declare this to be our next approximation:

  s(i+1)=s(i)r(s(i))r(s(i))   \boxed{\; s^{(i+1)} = s^{(i)} - \frac{r(s^{(i)})}{r'(s^{(i)})} \;}

This shows graphically the step s(i)s(i+1)s^{(i)} \to s^{(i+1)} , and on the graph you can see that the root of the equation (we're interested in the intersection of the function with the x-axis) shifted from 2 to 1, which shows improvement, since the target value is 0.

We can rewrite this in terms of the increment Δs(i+1):=s(i+1)s(i)\Delta s^{(i+1)} := s^{(i+1)} - s^{(i)} , which will be more convenient when generalizing:

r(s(i))Δs(i+1)=r(s(i)) r'(s^{(i)})\,\Delta s^{(i+1)} = -r(s^{(i)})

That is, "find an increment such that the linear correction rΔsr' \cdot \Delta s cancels out the current residual rr ".

Multidimensional case: N equations in N variables

Now let's generalize. Instead of one function of one variable, we have r:RNRN\mathbf{r}: \mathbb{R}^N \to \mathbb{R}^N - a vector-valued function of a vector argument, and we're looking for a vector sRN\mathbf{s}^* \in \mathbb{R}^N such that r(s)=0\mathbf{r}(\mathbf{s}^*) = \mathbf{0} .

The logic remains literally the same, only the objects change:

Scalar Vector
function r(s)r(s) vector function r(s)\mathbf{r}(\mathbf{s})
derivative r(s)r'(s) - a number Jacobian J(s)J(\mathbf{s}) - an N×NN \times N matrix
tangent (line) tangent hyperplane
division by rr' multiplication by J1J^{-1} (i.e. solving a linear system)

Where J(s)=rsJ(\mathbf{s}) = \frac{\partial \mathbf{r}}{\partial \mathbf{s}} is the Jacobian, the multidimensional analog of the ordinary derivative (more on this below).

The Jacobian J(s)J(\mathbf{s}) is just a matrix of all partial derivatives: at position (i,j)(i, j) stands ri/sj\partial r_i / \partial s_j . It plays the role of a derivative - it shows how a small change in s\mathbf{s} affects r\mathbf{r} to first order.

Linearization of r\mathbf{r} around the point s(i)\mathbf{s}^{(i)} :

r(s)r(s(i))+J(s(i))(ss(i)) \mathbf{r}(\mathbf{s}) \approx \mathbf{r}(\mathbf{s}^{(i)}) + J(\mathbf{s}^{(i)})\,(\mathbf{s} - \mathbf{s}^{(i)})

We set the linear approximation equal to the zero vector:

r(s(i))+J(s(i))(ss(i))=0 \mathbf{r}(\mathbf{s}^{(i)}) + J(\mathbf{s}^{(i)})\,(\mathbf{s} - \mathbf{s}^{(i)}) = \mathbf{0}

And by denoting the increment Δs(i+1):=ss(i)\Delta\mathbf{s}^{(i+1)} := \mathbf{s} - \mathbf{s}^{(i)} , we get a linear system for Δs(i+1)\Delta\mathbf{s}^{(i+1)} :

  J(s(i))Δs(i+1)=r(s(i))   \boxed{\; J(\mathbf{s}^{(i)})\,\Delta\mathbf{s}^{(i+1)} = -\mathbf{r}(\mathbf{s}^{(i)}) \;}

Having solved the system and obtained Δs(i+1)\Delta\mathbf{s}^{(i+1)} , we update the approximation:

s(i+1)=s(i)+Δs(i+1) \mathbf{s}^{(i+1)} = \mathbf{s}^{(i)} + \Delta\mathbf{s}^{(i+1)}

This is often written compactly using the inverse matrix:

s(i+1)=s(i)J(s(i))1r(s(i)) \mathbf{s}^{(i+1)} = \mathbf{s}^{(i)} - J(\mathbf{s}^{(i)})^{-1}\,\mathbf{r}(\mathbf{s}^{(i)})
  • this is the same formula, just shorter. The notation with J1J^{-1} is purely notational: in practice, no one ever computes the inverse matrix, because it's both expensive and numerically unstable. Instead, one solves the system JΔs=rJ \Delta\mathbf{s} = -\mathbf{r} directly - for example, via LU decomposition, or via forward substitution if JJ has a special structure (which is what happens in our case).

1.3. Application to our RNN problem

For RNNs everything follows exactly this pattern, only the dimensions are specific:

  • s=(s1,,sT)RTD\mathbf{s} = (\mathbf{s}_1, \ldots, \mathbf{s}_T) \in \mathbb{R}^{TD} - all hidden states glued into one long vector of length TDTD .
  • r:RTDRTD\mathbf{r}: \mathbb{R}^{TD} \to \mathbb{R}^{TD} - the vector of all one-step residuals, of the same length.
  • J(s)RTD×TDJ(\mathbf{s}) \in \mathbb{R}^{TD \times TD} - the Jacobian of the residual with respect to the state.

We apply the same Newton step:

J(s(i))Δs(i+1)=r(s(i)) J(\mathbf{s}^{(i)})\,\Delta\mathbf{s}^{(i+1)} = -\mathbf{r}(\mathbf{s}^{(i)})

And here a legitimate question arises: "But isn't solving a linear system of size TD×TDTD \times TD the same sequential problem? Where's the parallelization?"

If JJ were an arbitrary dense matrix, then yes - a naive solution would cost O((TD)3)O((TD)^3) , and there'd be no benefit. But JJ is not arbitrary. Due to the Markov property of the RNN (each step ff sees only the previous state st1\mathbf{s}_{t-1} , not the full history), in the Jacobian the overwhelming majority of blocks are zero. Specifically: in block-row tt , nonzero entries appear only in columns tt and t1t-1 . This gives us a block bidiagonal structure:

J(s)=(ID000 fs(s1)ID00 0fs(s2)ID0  00fs(sT1)ID) J(\mathbf{s}) = \begin{pmatrix} I_D & 0 & 0 & \cdots & 0 \ -\frac{\partial f}{\partial \mathbf{s}}(\mathbf{s}1) & I_D & 0 & \cdots & 0 \ 0 & -\frac{\partial f}{\partial \mathbf{s}}(\mathbf{s}_2) & I_D & \cdots & 0 \ \vdots & & \ddots & \ddots & \vdots \ 0 & 0 & \cdots & -\frac{\partial f}{\partial \mathbf{s}}(\mathbf{s}{T-1}) & I_D \end{pmatrix}

What is a Jacobian, generally

When we have an ordinary function of one variable r:RRr: \mathbb{R} \to \mathbb{R} , its derivative r(s)r'(s) is a single number that tells us "how fast the output changes for a small change in the input". It plays the role of a local coefficient of proportionality: if we shift ss by a small δ\delta , then rr changes by approximately r(s)δr'(s) \cdot \delta .

Now imagine a function whose input and output are both vectors. Say r:RNRM\mathbf{r}: \mathbb{R}^N \to \mathbb{R}^M : we feed in a vector of NN numbers and get out a vector of MM numbers. The notion of "derivative" gets more complex here, because now we have to answer N×MN \times M questions at once: "how does the ii -th output component change when the jj -th input component changes?". The answers to all these questions naturally collect into a matrix of size M×NM \times N - and this is the Jacobian:

J(s)=rs=(r1s1r1s2r1sN r2s1r2s2r2sN  rMs1rMs2rMsN) J(\mathbf{s}) = \frac{\partial \mathbf{r}}{\partial \mathbf{s}} = \begin{pmatrix} \frac{\partial r_1}{\partial s_1} & \frac{\partial r_1}{\partial s_2} & \cdots & \frac{\partial r_1}{\partial s_N} \ \frac{\partial r_2}{\partial s_1} & \frac{\partial r_2}{\partial s_2} & \cdots & \frac{\partial r_2}{\partial s_N} \ \vdots & \vdots & \ddots & \vdots \ \frac{\partial r_M}{\partial s_1} & \frac{\partial r_M}{\partial s_2} & \cdots & \frac{\partial r_M}{\partial s_N} \end{pmatrix}

At position (i,j)(i, j) stands the number ri/sj\partial r_i / \partial s_j - the partial derivative of the ii -th output component with respect to the jj -th input component. So the Jacobian is literally a complete sensitivity map: each cell answers a specific question - "how sensitively does this output coordinate respond to this input coordinate?".


Recall the definition of the residual:

rt(s1:T)=stf(st1) \mathbf{r}t(\mathbf{s}{1:T}) = \mathbf{s}t - f(\mathbf{s}{t-1})

This expression depends only on two variables: on st\mathbf{s}t (through the first term) and on st1\mathbf{s}{t-1} (through the second). All other sk\mathbf{s}_k simply don't appear in the formula. And the derivative with respect to a variable that doesn't appear in the formula equals zero.

Let's go case by case to see what block rt/sk\partial \mathbf{r}_t / \partial \mathbf{s}_k we get for different kk :

Case 1: k=tk = t . We take the derivative of stf(st1)\mathbf{s}t - f(\mathbf{s}{t-1}) with respect to st\mathbf{s}_t . Only the first term depends on st\mathbf{s}_t , and its derivative with respect to itself is the identity matrix. We get:

rtst=ID \frac{\partial \mathbf{r}_t}{\partial \mathbf{s}_t} = I_D

Case 2: k=t1k = t - 1 . We take the derivative with respect to st1\mathbf{s}{t-1} . The first term doesn't depend on it; the second is f(st1)-f(\mathbf{s}{t-1}) , and its derivative is f/s-\partial f/\partial \mathbf{s} evaluated at the point st1\mathbf{s}_{t-1} :

rtst1=fs(st1) \frac{\partial \mathbf{r}t}{\partial \mathbf{s}{t-1}} = -\frac{\partial f}{\partial \mathbf{s}}(\mathbf{s}_{t-1})

Case 3: all other kk (i.e., ktk \neq t and kt1k \neq t - 1 ). The variable sk\mathbf{s}_k simply doesn't appear in the formula for rt\mathbf{r}_t . Hence:

rtsk=0D×D(zero matrix) \frac{\partial \mathbf{r}t}{\partial \mathbf{s}_k} = 0{D \times D} \quad (\text{zero matrix})

That's it. Out of T2T^2 blocks, exactly T+(T1)=2T1T + (T-1) = 2T - 1 are nonzero: TT identity matrices on the main diagonal and T1T-1 transition Jacobians on the subdiagonal. Everything else is zero. If we write out the whole matrix:

J(s)=(ID000 fs(s1)ID00 0fs(s2)ID0  00fs(sT1)ID) J(\mathbf{s}) = \begin{pmatrix} I_D & 0 & 0 & \cdots & 0 \ -\frac{\partial f}{\partial \mathbf{s}}(\mathbf{s}1) & I_D & 0 & \cdots & 0 \ 0 & -\frac{\partial f}{\partial \mathbf{s}}(\mathbf{s}_2) & I_D & \cdots & 0 \ \vdots & & \ddots & \ddots & \vdots \ 0 & 0 & \cdots & -\frac{\partial f}{\partial \mathbf{s}}(\mathbf{s}{T-1}) & I_D \end{pmatrix}

Where the Markov property comes in

The key reason this structure appears is the Markov property of the RNN. The transition function ff at each step looks only at the previous state st1\mathbf{s}{t-1} , not at the entire history s1,,st1\mathbf{s}_1, \ldots, \mathbf{s}{t-1} . Because of this, the residual rt\mathbf{r}t turns out to be a "local" object: it depends only on two adjacent states - the current st\mathbf{s}_t and the previous st1\mathbf{s}{t-1} .

How much memory do we actually need

Although formally the Jacobian has TD×TDTD \times TD cells, we only need to store the nonzero blocks. These are:

  • TT identity matrices IDI_D - but we don't even need to store these; we know they're IDI_D and can substitute on the fly;
  • T1T - 1 transition Jacobians f/s(st)\partial f/\partial \mathbf{s}(\mathbf{s}_t) of size D×DD \times D - that's (T1)D2(T-1) \cdot D^2 numbers, which for T=1000T = 1000 , D=256D = 256 gives about 65 million numbers instead of 65 billion. Already feasible.

How the Jacobian's structure gives us parallelism

Now the main question: why does this structure let us solve the system JΔs=rJ \Delta\mathbf{s} = -\mathbf{r} in parallel? Here it's important to distinguish two levels:

Level 1: the structure lets us solve the system via forward substitution. Take the system JΔs=rJ \Delta\mathbf{s} = -\mathbf{r} and write it out row by row. The first block-row of the matrix JJ is (ID,0,0,,0)(I_D, 0, 0, \ldots, 0) , so the first equation of the system is:

IDΔs1=r1 I_D \cdot \Delta\mathbf{s}_1 = -\mathbf{r}_1
  • that is, simply Δs1=r1\Delta\mathbf{s}_1 = -\mathbf{r}_1 . We got the first chunk of the answer essentially for free.

The second block-row of JJ is (f/s(s1),ID,0,,0)(-\partial f/\partial \mathbf{s}(\mathbf{s}_1), I_D, 0, \ldots, 0) , so the second equation:

fs(s1)Δs1+IDΔs2=r2 -\frac{\partial f}{\partial \mathbf{s}}(\mathbf{s}_1) \cdot \Delta\mathbf{s}_1 + I_D \cdot \Delta\mathbf{s}_2 = -\mathbf{r}_2

From which:

Δs2=fs(s1)Δs1r2 \Delta\mathbf{s}_2 = \frac{\partial f}{\partial \mathbf{s}}(\mathbf{s}_1) \cdot \Delta\mathbf{s}_1 - \mathbf{r}_2

And in general, for any t>1t > 1 :

  Δst=fs(st1)Δst1rt   \boxed{\; \Delta\mathbf{s}t = \frac{\partial f}{\partial \mathbf{s}}(\mathbf{s}{t-1}) \, \Delta\mathbf{s}_{t-1} - \mathbf{r}_t \;}

This is the linear recurrence that our gigantic TD×TDTD \times TD system has turned into. Note that in general solving a linear system costs O(N3)O(N^3) - but here we avoided inverting any matrix, thanks to JJ being block bidiagonal. The system is solved by a simple top-to-bottom sweep through the equations. This is what's called forward substitution.

If things ended here, we'd have only a sequential algorithm taking O(T)O(T) steps - each Δst\Delta\mathbf{s}t depends on Δst1\Delta\mathbf{s}{t-1} , and we have to traverse the recurrence strictly in order. Same as just running the RNN sequentially. Parallelism is born at the next level.

Level 2: the recurrence is LINEAR, and therefore associative. This is the main trick. Let's note the principal difference between two situations:

  • Original RNN: st=f(st1,xt)\mathbf{s}t = f(\mathbf{s}{t-1}, x_t) - the function ff is nonlinear, so this kind of recurrence cannot be parallelized: we have to compute each step honestly in sequence.
  • Recurrence for Δs\Delta\mathbf{s} : Δst=AtΔst1+bt\Delta\mathbf{s}t = A_t \cdot \Delta\mathbf{s}{t-1} + b_t (where At=f/s(st1)A_t = \partial f/\partial \mathbf{s}(\mathbf{s}_{t-1}) , bt=rtb_t = -\mathbf{r}_t ) - this one is linear. That means we can derive a closed-form expression from it:
Δst=AtAt1A2Δs1+(AtA3b2)+(AtA4b3)++bt \Delta\mathbf{s}t = A_t A{t-1} \cdots A_2 \cdot \Delta\mathbf{s}_1 + (A_t \cdots A_3 \cdot b_2) + (A_t \cdots A_4 \cdot b_3) + \ldots + b_t

All these matrix products AtAt1A_t A_{t-1} \cdots can be computed in any order (matrix multiplication is associative: (AB)C=A(BC)(AB)C = A(BC) ). Which means we can build a computation tree where we first compute all pairwise products A2A1,A4A3,A6A5,A_2 A_1, A_4 A_3, A_6 A_5, \ldots in parallel, then all 4-tuples A4A3A2A1,A8A7A6A5,A_4 A_3 A_2 A_1, A_8 A_7 A_6 A_5, \ldots , and so on. In log2T\log_2 T tree levels we get all the cumulative products we need, and from them we assemble all Δst\Delta\mathbf{s}_t simultaneously.

This is the parallel scan (also known as parallel prefix sum in its generalized form). By analogy with ordinary addition: if you need to sum a billion numbers, sequentially that's a billion steps, but with a pairwise tree it's only log2(109)30\log_2(10^9) \approx 30 levels. The same trick works for any associative operation, and the composition of linear maps (i.e., multiplication of their matrices) is associative.

Bottom line on complexity: one Newton step runs in O(logT)O(\log T) parallel depth (instead of O(T)O(T) sequential steps), and the entire RNN application takes O(iterslogT)O(\text{iters} \cdot \log T) , where iters is the number of Newton iterations.

Top comments (0)