TL;DR: The Token-and-Duration Transducer (TDT) extends RNN-T by jointly predicting what token to emit and how many frames that token covers. This lets the model skip multiple encoder frames per step during inference instead of advancing one at a time, yielding up to 2.82x faster decoding with comparable or better accuracy.
Word Error Rate (WER) is a useful metric to try to optimise, but if your model takes 10 seconds to transcribe 1 second of audio, nobody's shipping it. The Huggingface Open ASR Leaderboard tracks both accuracy and speed. At the time of writing, in the huggingface top 10, Nvidia's Parakeet TDT models are more than 3x ahead of the nearest competition in RTFx (Inverse Real Time Factor/Throughput, i.e. how many seconds of audio the model can process per second of wall-clock time).
These models are significantly faster than the competition while maintaining competitive WERs. The mechanism? A modification to the RNN-Transducer called the Token-and-Duration Transducer (TDT). In this post, we'll first look at how RNN-T and TDT work at inference time to build intuition for why TDT is faster, then circle back to explain how each model is trained.
Part 1: Inference - Decoding with Frame Skipping
RNN-T Architecture
Without going into too much detail, there are a few ways to train a speech-to-text model: CTC, AED, Decoder-only or RNN-T/TDT. Each of these have pros and cons, for a full comparison see Desh's Analysis.
RNN-T hits a useful middle ground: it has enough modeling capacity to capture label dependencies (unlike CTC), its autoregressive component is lightweight (unlike AED/Decoder-only), and it can be trained end-to-end with a well-understood loss function. RNN-T is already fast - much faster than AED or Decoder-only models, and only somewhat slower than CTC. But there's still space to speed up.
An RNN-T consists of three components:
- Encoder : Maps audio to hidden representations for each frame . Typically, this is a large transformer.
- Predictor (a.k.a. decoder) : A small autoregressive network that maps the previous non-blank tokens to representations . The name "RNN-T" implies an RNN here, but this can be substituted for any autoregressive network. During training, the predictor sees the ground-truth label prefix (teacher forcing); during inference, it autoregresses on its own previous predictions.
- Joint network : A small network (often a single linear layer) that combines encoder and predictor outputs to produce logits over the vocabulary (where is the blank symbol):
Here indexes the encoder time-step, indexes how many output labels have been emitted so far (position in the target sequence), and denotes any candidate vocabulary symbol (including blank) when we write .
RNN-T Inference
At inference time, the model decodes greedily by stepping through encoder frames:
# RNN-T Greedy Decoding (simplified)
t = 0
u = 0
output = []
while t < T:
logits = joint(encoder[t], predictor(output))
token = argmax(logits)
if token == BLANK:
t += 1 # advance ONE frame
else:
output.append(token)
# stay at same t, advance u
At each step, the model either emits blank (advance one frame) or emits a token (stay at the same frame, advance the label index). The key observation in speeding this up will be to allow frame-skipping.
For a 10-second utterance at 80ms frame rate (after subsampling), that's ~125 sequential joint network calls at minimum. Most of those will be blanks - in typical speech, tokens are sparse relative to frames. The model spends most of its time predicting "nothing is happening" one frame at a time. The joint network is cheap per call, but the sequential one-frame-at-a-time structure leaves performance on the table.
TDT addresses this.
TDT: The Key Modification
The core idea of TDT (Xu et al., 2023): instead of predicting just a token at each step, jointly predict the token and how many frames it covers.
In standard RNN-T, the joint network outputs a single distribution over symbols (vocabulary + blank). In TDT, the joint network outputs two independent distributions:
- Token distribution: - same as RNN-T
- Duration distribution: - probability over a set of allowed durations
where is a predefined set of durations. A typical choice is , though the set can be configured - for example, (omitting 0) is also valid.
The two heads share the same encoder and predictor representations but are independently normalized (separate softmax operations):
# TDT Joint Network Output
logits = joint(encoder[t], predictor(output)) # shape: [V + 1 + |D|]
# Split into token and duration logits
token_logits = logits[:V+1] # shape: [V + 1]
duration_logits = logits[V+1:] # shape: [|D|]
# Independent softmax
token_probs = softmax(token_logits)
duration_probs = softmax(duration_logits)
TDT Inference
The inference speedup is immediate. Compare with the RNN-T loop above:
# TDT Greedy Decoding (simplified)
t = 0
output = []
while t < T:
logits = joint(encoder[t], predictor(output))
token = argmax(token_logits)
duration = argmax(duration_logits)
if token == BLANK:
t += max(1, duration) # skip MULTIPLE frames!
else:
output.append(token)
t += duration # can also skip frames on token emission
Instead of advancing one frame at a time, the model can skip over stretches of silence or steady-state audio. If the model predicts blank with duration 4, it skips 4 frames in one step - reducing joint network calls for that stretch proportionally.
Let's trace through a concrete example. Suppose we have 8 encoder frames (
), target "hi" → tokens [h, i], and durations
:
t=0: joint(enc[0], pred([]))
→ token=h (p=0.8), duration=0 (p=0.7)
→ emit 'h', stay at t=0
t=0: joint(enc[0], pred([h]))
→ token=i (p=0.6), duration=2 (p=0.5)
→ emit 'i', jump to t=2
t=2: joint(enc[2], pred([h, i]))
→ token=blank (p=0.9), duration=3 (p=0.6)
→ skip to t=5
t=5: joint(enc[5], pred([h, i]))
→ token=blank (p=0.95), duration=3 (p=0.8)
→ skip to t=8 → DONE!
4 joint network calls instead of 8+ for standard RNN-T. That's the speedup.
The TDT paper reports up to 2.82x faster inference than standard RNN-T on speech recognition tasks, with comparable or better accuracy. The speedup is more pronounced on longer utterances with more silence.
Part 2: Training - Mechanics of Forward-Backward
Now that we've seen what these models do at inference time, let's understand how they're trained. This requires a bit more machinery.
The Lattice and Alignments (Standard RNN-T)
During training, have the audio and we know the correct transcription , but typically don't know the correct frame-word alignment.
Suppose we have
encoder frames and the target transcription
(
), we have many potential ways to get the exact same transcript, for example:
Path A (early speech): ∅, the, quick, brown, fox, ∅, ∅, ∅, ∅, ∅, ∅, ∅
(orange) → all tokens emitted by t=4, rest is silence
Path B (spread out): the, ∅, ∅, quick, ∅, brown, ∅, ∅, fox, ∅, ∅, ∅
(pink) → tokens spread across the utterance
Path C (late speech): ∅, ∅, ∅, ∅, ∅, the, quick, brown, ∅, ∅, fox, ∅
(blue) → speech starts late, around t=5
Remember that each time we output a blank symbol , we increment (the time-frame of the encoder) and each time we output a token, we feed that back into the predictor to get the next predictor output (increment by one).
Our goal now is to maximise the chance of the correct transcript (irrespective of the alignment - which we don't yet know). RNN-T's solution to this is maximise the probability over all possible alignments. The way we visualise this is by constructing a lattice, which will encode any possible frame-word alignment.
The joint network produces a probability distribution
at every node
in a
grid (the lattice), where
is the number of encoder frames and
is the number of target tokens. In the above example, for
and
, we evaluate:
the joiner called on the 3rd frame of the encoder output, and the predictor called on the first 2 model outputs. This gives us a probability distribution over the entire vocab, , plus blank, . For training, we only care about the probability of the next correct token (in this case "brown") or blank, - so we just show these two transitions in the lattice:
- Emit blank : moves from - a step right along the time axis.
- Emit the next token : moves from - a step up along the label axis.
Every valid path from bottom-left [start] to top-right [end] emits exactly the target sequence and is a different valid alignment. Different paths through this lattice correspond to different timings of the same transcription.
To get the probability of a given path/alignment we use the product of all token/blank probabilities along that path. e.g.
Where (as descibed above), is the time-frame index of the encoder output, and is the amount of the transcript that the predictor has seen so far.
The total probability of (the correct transcription) is defined as the sum over all such paths:
where is the set of all valid alignments for .
This probability is the objective we will try to maximise in training. Or more accurately, we will try to minimise the negative log-likelihood:
So, our loss is completely agnostic to the alignment the model wants to use, we just want to maximise the total probability mass running through this lattice.
Now we need to efficiently calculate this loss, - and the relevant gradients.
RNN-T Training: The Forward-Backward Algorithm
As long as we stay on this training lattice, we will produce the correct transcript. The probability of staying on this lattice is the thing we will try to maximise - so we want to boost the chance of any transitions on this lattice (scaled by the impact they have on the final probability).
It's worth noting here that the output of the joiner is normalised, so increasing the chance of e.g. the token , will implicitly decrease the chance of all other tokens here e.g. .
So how do we get this probability?
If we start from no transcript - with probability 1 (we must start with nothing yet transcribed) - we can get the chance of moving in either valid direction:
So, the chance of going up in the lattice - emitting - is say . The chance of emitting is say . This is quite good, it means the chance of emitting any other random token is only - indicating a well trained model.
Here we'll keep as the probability of getting to a node (from the start). So, what's the chance of progressing any further through this lattice:
To get to node
i.e. two blanks and one correct
token. We have three possible paths:
Path 1: "the", ∅, ∅
(↑, →, →)
P_1 = 0.4 * 0.1 * 0.1 = 0.004
Path 2: ∅, "the", ∅
(→, ↑, →)
P_2 = 0.5 * 0.5 * 0.1 = 0.025
Path 3: ∅, ∅, "the"
(→, →, ↑)
P_3 = 0.5 * 0.4 * 0.7 = 0.14
So the sum over all paths to node is . This means that the rest of the time: of the time, we've already gone wrong at this stage - left the training lattice - e.g. output ["then", , ] or [ , "apple", ].
More generally, we define , a.k.a. the forward variable, as the sum of all correct the paths to a given node:
It's also useful to think of this as the total amount of probability mass that flows through the lattice to a given node.
Now if we enumerate all paths to the [end] node and sum the probabilities we will get the full transcript probability:
The problem with this is that we will have way too many paths to enumerate. Even for the above small example, with and we have 330 potential paths through the lattice.
To solve this issue, we notice that to get the probability mass that gets to a given node, we only care about the mass that gets to the previous adjacent nodes (i.e. one blank token backwards, or one correct token backwards):
We don't care about individual paths leading up to these predecessor nodes, just the total sum over all possible paths to them - the total probability mass that arrives there. This means we get the following:
with
(as the chance of starting at
is
). Each term above says: the mass arriving at
is the mass at the predecessor node, times the probability of the transition from the predecessor to
. This means that we get to skip enumerating every possible path and just run through each node in the lattice with this sum - all the way to the [end] node.
Now that we've efficiently calculated the total probability , we need to calculate the amount that each transition effects this final sum - the gradient of with respect to . This will tell us how much to update the model weights each step. Specifically, if we make a small change in transition probabilities , what will be the effect on the total probability .
Gradient Calculation:
So let's work this out for a given node; e.g. probability of a blank transition at : . This means that the model has already output the correct first token - e.g. "the" as well as 3 blank tokens - in some order.
For some path "Path k" through the lattice - that goes through our transition , we have:
where is some transition that exists along this path. This means that a small change in our transition will affect the total path probability:
So, this is the amount that changing our transition probability will affect a path that uses it - just the total probability of the path (not including the given transition). We also observe that changing this probability won't affect paths that don't use this transition. We also know that the total probability of the correct transcript is the sum over all possible correct paths:
so naturally, the effect of changing this transition on , is the sum of the effects it has on each relevant path:
which we can split into the sum over paths that get to the transition
; the transition itself; and the sum over all the paths that leave the transition and get to the [end] state:
All this is saying is: the amount that the total probability changes is the amount of probability mass that gets to a given transition
the amount of probability mass that gets from the transition to the [end] state.
But we've already done the maths for the first part! The sum over paths that get to a given node is just , and the second part looks very similar - we'll call this the backward variable .
The backward variable is the mirror image of : it represents the total probability mass that flows from node to the final state - "how much probability mass will still reach the target from here."
Here, in a symmetric way, we set , but now walk backwards through the lattice:
Effectively, "the amount of probability mass that will reach the target from each of the next nodes" "the probability of getting there from the current node".
This simplifies life a lot for our example:
Or more generally:
Where is a transition at a given node pointing to another node .
The forward variable gets the mass to the transition, and the backward variable represents the mass that will eventually arrive at the target from that point. The full loss gradient normalizes by the total likelihood (see the original Graves 2012 paper for the complete derivation):
This gives a nice result! The gradient with respect to any given transition probability is just the proportion of the total probability mass that flows through that transition. Early in training, when is small, the gradient is still significant for any correct transitions. This also explains why lattice paths tend to collapse to a small number of dominant alignments later in training - the highest-probability paths receive the largest gradients, incentivizing further path concentration.
n.b. This "path collapse" is a key insight of the K2's RNN-T pruned loss which simplifies the gradient computation significantly by only considering paths near (in time) to the high-probability alignments.
The forward-backward algorithm computes all of this in time.
TDT Training: The Modified Forward-Backward
Training TDT requires modifying the forward-backward algorithm to account for the duration variable. The loss is still the negative log-likelihood , but the lattice transitions are now richer. Recall that in TDT, transitions can skip multiple frames:
- Blank with duration : - advances by frames
- Token with duration : - advances by frames and emits a token
Note the asymmetry: blanks must have (you must advance at least one frame when emitting nothing), but tokens can have if 0 is in (emitting a token without advancing - useful for fast speech or multi-token emissions at a single frame). If doesn't include 0, every emission also advances at least one frame.
We now have two independent distributions predicted from each node:
Modified Forward Variable
The forward variable now has a more complex recurrence. At each position , we must sum over all durations that could have led here:
The key difference from standard RNN-T: instead of looking back exactly 1 step, we look back steps for each duration in . This makes the forward pass instead of - a constant factor increase since is typically small (4–5 elements).
Backward Variable and Gradients
The backward variable follows the same pattern but in reverse:
The gradient computation uses both and in the standard way, summing over each possible duration for a given token prediction and scaling by the duration probabilities ( ). For the token logit, the gradient at position is:
where is the set of reachable states from :
In this case, to count the paths affected by the chance of predicting e.g. "fox", we have 4 possible lattice transitions to count, and 3 possible transitions for the blank token .
For the duration logits, the gradient at position accounts for all transitions that use duration , either the correct token or a blank transition:
or for (blank not allowed at zero duration):
This too is somewhat intuitive. It represents the sum over all valid paths that use this duration. Now we're done! This is all the maths required to understand the efficient TDT training mechanics. For the full derivation see the TDT paper.
Some More Training Tricks
Working in Log-space: As is usual in machine learning when working with probabilities, we use log-space. Big summations of log probabilities are much more stable than big products of raw probabilities.
The Sigma Trick - Logit Under-Normalization: Every transition in the lattice, whether blank or token, gets penalized by (typically 0.05) in log-space. Since this penalty is applied per transition, paths with more steps accumulate a larger total penalty. This biases the model toward using fewer, larger-duration steps rather than many duration-1 steps.
The Omega Trick - Sampled RNN-T Loss: with probability , the loss falls back to the standard RNN-T loss (ignoring durations entirely). This acts as a regularizer, ensuring the token predictions remain well-calibrated even without duration information. This is important for the batched inference case, where we will have to increment the entire batch encoder-frame by the same amount (e.g. the minimum predicted token duration).
Practical Considerations and Pitfalls
Training Memory
TDT has the same memory footprint challenge as standard RNN-T: the joint network output is a 4D tensor of shape . For large vocabularies and long sequences, this can be enormous. The standard mitigation is fused loss computation - instead of materializing the full joint tensor, compute the loss and gradients in a fused kernel that only materializes one slice at a time. Also, it's typically important to keep the vocab-size small - the above example uses full words, but a smaller vocab of sub-words is usually preferable.
Duration Set Design
The choice of duration set matters. The paper uses as the default. Some considerations:
- Must include 1: Duration 1 is needed to recover standard single-frame advancement. Duration 0 is optional - it allows token emission without frame advancement (useful for fast speech), but some configurations omit it.
- Larger durations = more skipping: The model learns when to use large skips vs. small ones. In practice, the model is conservative enough that large durations don't cause problems.
- More durations = slightly slower training: The forward-backward complexity scales linearly with , though with typical set sizes (4–5 elements) this is a small constant factor.
Comparison with Multi-Blank Transducer
TDT is related to but distinct from the Multi-Blank Transducer, which adds multiple blank symbols (big-blank-2, big-blank-3, etc.) that skip different numbers of frames. The key difference:
| Multi-Blank | TDT | |
|---|---|---|
| Duration prediction | Implicit (via blank type) | Explicit (separate head) |
| Token durations | Always 0 (no frame skip on token) | Variable (tokens can skip frames too) |
| Vocab size increase | blank symbols | No vocab increase; separate duration head |
| Independence | Token and duration coupled | Token and duration independently normalized |
TDT's independent normalization means the model doesn't need to use vocabulary capacity on multiple blank symbols, and the duration prediction can be more fine-grained.
Summary
TDT extends RNN-T by jointly predicting tokens and their durations. The key ideas are:
- Two-headed joint network: independently predict token and duration distributions
- Variable-stride lattice: transitions can skip multiple frames, not just one
- Modified forward-backward: same algorithm structure, just summing over durations at each step
- Training tricks: logit under-normalization ( ) and sampled RNN-T loss ( ) for stable training
The result: models that are up to 2.82x faster at inference with comparable or better accuracy than standard transducers - and RNN-T was already fast to begin with. This is how Nvidia's Parakeet-TDT models dominate the RTFx column at the top of the HuggingFace leaderboard.
The NeMo toolkit has a full implementation, and pretrained Parakeet-TDT checkpoints are available on HuggingFace.
References:
- Xu et al., "Efficient Sequence Transduction by Jointly Predicting Tokens and Durations", 2023
- Graves, "Sequence Transduction with Recurrent Neural Networks", 2012
- Huang et al., "Multi-blank Transducers for Speech Recognition", 2022
- Kuang et al., "Pruned RNN-T for fast, memory-efficient ASR training", 2022









Top comments (0)