<?xml version="1.0" encoding="UTF-8"?>
<rss version="2.0" xmlns:atom="http://www.w3.org/2005/Atom" xmlns:dc="http://purl.org/dc/elements/1.1/">
  <channel>
    <title>DEV Community: Alex Xiaoli Shen</title>
    <description>The latest articles on DEV Community by Alex Xiaoli Shen (@alabebop).</description>
    <link>https://dev.to/alabebop</link>
    <image>
      <url>https://media2.dev.to/dynamic/image/width=90,height=90,fit=cover,gravity=auto,format=auto/https:%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Fuser%2Fprofile_image%2F3393175%2Fa56cbcd8-8c26-48f4-ad60-2e68bb6b46e4.jpg</url>
      <title>DEV Community: Alex Xiaoli Shen</title>
      <link>https://dev.to/alabebop</link>
    </image>
    <atom:link rel="self" type="application/rss+xml" href="https://dev.to/feed/alabebop"/>
    <language>en</language>
    <item>
      <title>Hands-On Transformer Deep Dive: Part 2 — Multi-head Attention Variants with Code</title>
      <dc:creator>Alex Xiaoli Shen</dc:creator>
      <pubDate>Tue, 05 Aug 2025 14:11:15 +0000</pubDate>
      <link>https://dev.to/alabebop/hands-on-transformer-deep-dive-part-2-multi-head-attention-variants-with-code-4405</link>
      <guid>https://dev.to/alabebop/hands-on-transformer-deep-dive-part-2-multi-head-attention-variants-with-code-4405</guid>
      <description>&lt;p&gt;&lt;em&gt;This is Part 2 of the “Hands-on Transformer Deep Dive” series. We’ll walk step-by-step through modern Transformers’ algorithms and components, and build our own LLM from scratch. If you missed Part 1, check it out &lt;a href="https://dev.to/alabebop/deep-dive-into-transformers-attention-scale-mask-dropout-and-more-16fn"&gt;here&lt;/a&gt;.&lt;/em&gt;&lt;/p&gt;

&lt;p&gt;In this article, we dive deep into multi-head attention mechanism, a foundational building block of modern Transformers. We’ll look into four of its variants: MHA, MQA, GQA, and MLA, implement them from scratch with only PyTorch, and discuss their characteristics and pros and cons.&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F78kjtazgur9mifclt2nw.png" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F78kjtazgur9mifclt2nw.png" alt="Multi-head attention variants illustrated (from DeepSeek-v2 paper)"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;h2&gt;
  
  
  Introduction
&lt;/h2&gt;

&lt;p&gt;Multi-head attention allows the model to capture complex patterns by looking at the data from multiple “perspectives”. While single-head attention computes the attention once over the whole input, multi-head attention splits the model’s total feature dimension across multiple heads and run them simultaneously. Each head learns its own query, key, and value projections. Then all heads are combined to recover the full model feature dimension.&lt;/p&gt;

&lt;p&gt;Splitting across multiple attention heads allows the model to represent various aspects of the data more effectively. Each head can focus on a smaller, specialized subspace. For example, one head might focus on nearby words to capture phrases or recognize named entities, another head might specialize in understanding relative word positions or long-range semantic links, while still another head could attend to negations or modifiers that change the sentiment or meaning. Combining all heads then recovers the full representation capacity.&lt;/p&gt;

&lt;p&gt;Multi-head attention has shown better learning dynamics and improved expressiveness compared to single-head attention. However, different practical constraints, such as memory limitation and inference speed, require trade-offs between model performance, computation cost, and flexibility.&lt;/p&gt;

&lt;p&gt;In the following sections, we’ll discuss and implement four of the most popular multi-head attention variants: the classic MHA, Multi-Query Attention (MQA), Grouped-Query Attention (GQA), and Multi-head Latent Attention (MLA).&lt;/p&gt;

&lt;h2&gt;
  
  
  MHA: Multi-Head Attention
&lt;/h2&gt;

&lt;p&gt;The classic MHA is just as what was discussed in the introduction: the attention mechanism is split across multiple attention heads, each learning its own smaller Q, K, V projections (the dimension size is the model dimension divided by number of heads). Then all the heads are combined and the model also learns a final output projection of the original dimension size.&lt;/p&gt;

&lt;h3&gt;
  
  
  MHA Implementation
&lt;/h3&gt;

&lt;p&gt;Below is an MHA implementation with step-by-step explanation in comments:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;torch.nn&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;torch.nn.functional&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;F&lt;/span&gt;

&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;MultiHeadAttention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
  &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;embed_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_heads&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.1&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="nf"&gt;super&lt;/span&gt;&lt;span class="p"&gt;().&lt;/span&gt;&lt;span class="nf"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
    &lt;span class="k"&gt;assert&lt;/span&gt; &lt;span class="n"&gt;embed_dim&lt;/span&gt; &lt;span class="o"&gt;%&lt;/span&gt; &lt;span class="n"&gt;num_heads&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; 
      &lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;model dimension (received embed_dim: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;embed_dim&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt;) must be divisible &lt;/span&gt;&lt;span class="se"&gt;\
&lt;/span&gt;&lt;span class="s"&gt;         by the number of attention heads (received num_head: f&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;num_head&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt;)&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;
    &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;num_heads&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;num_heads&lt;/span&gt;
    &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;head_dim&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;embed_dim&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;num_heads&lt;/span&gt;

    &lt;span class="c1"&gt;# Initialize the query, key, value and final output projections
&lt;/span&gt;    &lt;span class="c1"&gt;#   shape: (embed_dim, embed_dim)
&lt;/span&gt;    &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;W_q&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;embed_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;embed_dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;W_k&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;embed_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;embed_dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;W_v&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;embed_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;embed_dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;W_output&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;embed_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;embed_dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="c1"&gt;# Initialize the dropout layer
&lt;/span&gt;    &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Dropout&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

  &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mask&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;None&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="c1"&gt;# Get the batch size, sequence length, embedding dimension from input x
&lt;/span&gt;    &lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;seq_len&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;embed_dim&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;size&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;

    &lt;span class="c1"&gt;# Step 1. Pass the input through query, key, and value projections
&lt;/span&gt;    &lt;span class="c1"&gt;#    shape of input x: (batch_size, seq_len, embed_dim)
&lt;/span&gt;    &lt;span class="c1"&gt;#    shape of projection layer: (embed_dim, embed_dim)
&lt;/span&gt;    &lt;span class="c1"&gt;#    shape after projection: (batch_size, seq_len, embed_dim)
&lt;/span&gt;    &lt;span class="c1"&gt;# Step 2. Split the last dimension into multiple heads
&lt;/span&gt;    &lt;span class="c1"&gt;#    shape after split: (batch_size, seq_len, num_heads, head_dim)
&lt;/span&gt;    &lt;span class="c1"&gt;# Step 3. Rearrange dimension 1 and 2 for parallel computation
&lt;/span&gt;    &lt;span class="c1"&gt;#    shape after: (batch_size, num_heads, seq_len, head_dim)
&lt;/span&gt;    &lt;span class="n"&gt;queries&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;W_q&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;view&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;seq_len&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;num_heads&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; \
      &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;head_dim&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;transpose&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;keys&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;W_k&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;view&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;seq_len&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;num_heads&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; \
      &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;head_dim&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;transpose&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;values&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;W_v&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;view&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;seq_len&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;num_heads&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;head_dim&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;transpose&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="c1"&gt;# Step 4. Calculate attention values
&lt;/span&gt;    &lt;span class="c1"&gt;# Step 4-1. scaled dot-product attention attn_scores = QK^T/sqrt(d_k)
&lt;/span&gt;    &lt;span class="c1"&gt;#   Note: since attention is calculated per head, 
&lt;/span&gt;    &lt;span class="c1"&gt;#      we scale by head dimension instead of model dimension
&lt;/span&gt;    &lt;span class="c1"&gt;#    shape of queries: (batch_size, num_heads, seq_len, head_dim)
&lt;/span&gt;    &lt;span class="c1"&gt;#    shape of keys transposed: (batch_size, num_heads, head_dim, seq_len)
&lt;/span&gt;    &lt;span class="c1"&gt;#    shape of attn_scores: (batch_size, num_heads, seq_len, seq_len)
&lt;/span&gt;    &lt;span class="n"&gt;attn_scores&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;matmul&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;queries&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;keys&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;transpose&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;math&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;sqrt&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;head_dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="c1"&gt;# Step 4-2. apply mask
&lt;/span&gt;    &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;mask&lt;/span&gt; &lt;span class="ow"&gt;is&lt;/span&gt; &lt;span class="ow"&gt;not&lt;/span&gt; &lt;span class="bp"&gt;None&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
      &lt;span class="n"&gt;attn_scores&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;attn_scores&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;masked_fill&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;mask&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="nf"&gt;float&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;-inf&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
    &lt;span class="c1"&gt;# Step 4-3. softmax
&lt;/span&gt;    &lt;span class="n"&gt;attn_scores&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;F&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;softmax&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;attn_scores&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="c1"&gt;# Step 4-4. dropout
&lt;/span&gt;    &lt;span class="n"&gt;attn_scores&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;attn_scores&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="c1"&gt;# Step 4-5. attention values
&lt;/span&gt;    &lt;span class="c1"&gt;#    shape of attn_scores: (batch_size, num_heads, seq_len, seq_len)
&lt;/span&gt;    &lt;span class="c1"&gt;#    shape of values: (batch_size, num_heads, seq_len, head_dim)
&lt;/span&gt;    &lt;span class="c1"&gt;#    shape of attn_values: (batch_size, num_heads, seq_len, head_dim)
&lt;/span&gt;    &lt;span class="n"&gt;attn_values&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;matmul&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;attn_scores&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;values&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="c1"&gt;# Step 5. Rearrange values dimension and reshape to concatenate heads
&lt;/span&gt;    &lt;span class="c1"&gt;#    shape after rearrange: (batch_size, seq_len, num_head, head_dim)
&lt;/span&gt;    &lt;span class="c1"&gt;#    shape after concatenation: (batch_size, seq_len, embed_dim)
&lt;/span&gt;    &lt;span class="n"&gt;attn_values&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;attn_values&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;transpose&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;reshape&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;seq_len&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;embed_dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="c1"&gt;# Step 6. Go through the final output projection
&lt;/span&gt;    &lt;span class="c1"&gt;#    shape of output: (batch_size, seq_len, embed_dim)
&lt;/span&gt;    &lt;span class="n"&gt;output&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;W_output&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;attn_values&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;output&lt;/span&gt;

&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;


&lt;p&gt;For a simple demonstration here we included the masked scaled dot-product attention code directly in the forward() method (Step 4 &amp;gt; Step 4–1 to Step 4–5). To have more flexibility choosing from different attention mechanisms, you can abstract the attention implementation away to a separate function or module, then plug it in here.&lt;/p&gt;

&lt;p&gt;Here is an example implementation putting it in a separate module. We’ll use it in the MQA, GQA and MLA implementation.&lt;br&gt;
&lt;/p&gt;
&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;torch.nn&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;torch.nn.functional&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;F&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;math&lt;/span&gt;

&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;MaskedScaledDotProductAttention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
  &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.1&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="nf"&gt;super&lt;/span&gt;&lt;span class="p"&gt;().&lt;/span&gt;&lt;span class="nf"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
    &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Dropout&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;


  &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;queries&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;keys&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;values&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mask&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;None&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="n"&gt;d_k&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;queries&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;size&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;attn_scores&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;matmul&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;queries&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;keys&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;transpose&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;math&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;sqrt&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;d_k&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;mask&lt;/span&gt; &lt;span class="ow"&gt;is&lt;/span&gt; &lt;span class="ow"&gt;not&lt;/span&gt; &lt;span class="bp"&gt;None&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
      &lt;span class="n"&gt;attn_scores&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;masked_fill&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;mask&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="nf"&gt;float&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;-inf&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;

    &lt;span class="n"&gt;attn_weights&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;F&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;softmax&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;attn_scores&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;attn_weights&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;attn_weights&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="n"&gt;attn_values&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;matmul&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;attn_weights&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;values&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;attn_values&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;


&lt;p&gt;If you’d like to learn more about implementation details of masked scaled dot-product attention, check out &lt;a href="https://dev.to/alabebop/deep-dive-into-transformers-attention-scale-mask-dropout-and-more-16fn"&gt;Part 1&lt;/a&gt; of this series.&lt;/p&gt;
&lt;h2&gt;
  
  
  MQA: Multi-Query Attention
&lt;/h2&gt;

&lt;p&gt;While classic MHA delivers richer representations and improves model performance, maintaining each head’s own set of queries, keys, and values greatly increases memory and computation overhead. For example, during the intermediate attention score computation (QK^T), the multi-head attention score tensor is of shape (batch_size, num_heads, seq_len, seq_len), making its size num_heads times as big as its single-head counterpart.&lt;/p&gt;

&lt;p&gt;This overhead is especially problematic at inference. To address this, Multi-Query Attention (MQA) was introduced by Noam Shazeer in 2019 to improve efficiency for autoregressive transformer decoders during inference (&lt;a href="https://arxiv.org/abs/1911.02150" rel="noopener noreferrer"&gt;Shazeer, 2019.&lt;/a&gt;)&lt;/p&gt;

&lt;p&gt;MQA modifies the MHA architecture by sharing keys and values across all heads, while still allowing each head to have its own queries. At inference, this significantly reduces memory needs and computation overhead, leading to faster token generation with minimal impact on model performance.&lt;/p&gt;
&lt;h3&gt;
  
  
  MQA implementation
&lt;/h3&gt;

&lt;p&gt;Below is an implementation of MQA using the above &lt;code&gt;MaskedScaledDotProductAttention&lt;/code&gt; module in the attention calculation step. The key differences from MHA are:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;While query projection has the same dimensions as MHA’s query projection &lt;em&gt;(embed_dim, embed_dim)&lt;/em&gt;, the key and value projections are initialized with only per head dimension &lt;em&gt;(embed_dim, &lt;strong&gt;head_dim&lt;/strong&gt;)&lt;/em&gt;
&lt;/li&gt;
&lt;li&gt;To share the single key head and value head across multiple query heads, we insert a dummy dimension of 1 at the position of query tensor’s num_heads dimension (aka dimension position 1). At attention computation, PyTorch automatically broadcasts this dimension &lt;em&gt;num_heads&lt;/em&gt; times, enabling simultaneous computation of shared keys/values and separate queries per head.
&lt;/li&gt;
&lt;/ul&gt;
&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;torch.nn&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;

&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;MultiQueryAttention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
  &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;embed_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_heads&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.1&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="nf"&gt;super&lt;/span&gt;&lt;span class="p"&gt;().&lt;/span&gt;&lt;span class="nf"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
    &lt;span class="k"&gt;assert&lt;/span&gt; &lt;span class="n"&gt;embed_dim&lt;/span&gt; &lt;span class="o"&gt;%&lt;/span&gt; &lt;span class="n"&gt;num_heads&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Model hidden dimension (embed_dim) &lt;/span&gt;&lt;span class="se"&gt;\
&lt;/span&gt;&lt;span class="s"&gt;      must be divisible by number of heads (num_heads). &lt;/span&gt;&lt;span class="se"&gt;\
&lt;/span&gt;&lt;span class="s"&gt;      Got embed_dim: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;embed_dim&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt;, num_heads: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;num_heads&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt;.&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;
    &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;embed_dim&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;embed_dim&lt;/span&gt;
    &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;num_heads&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;num_heads&lt;/span&gt;
    &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;head_dim&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;embed_dim&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;num_heads&lt;/span&gt;

    &lt;span class="c1"&gt;# initialize linear projections
&lt;/span&gt;    &lt;span class="c1"&gt;## Each head has its own Q so the Q projection has 
&lt;/span&gt;    &lt;span class="c1"&gt;##    the shape of full model dimensions
&lt;/span&gt;    &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;W_q&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;embed_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;embed_dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="c1"&gt;## K and V are shared across heads, 
&lt;/span&gt;    &lt;span class="c1"&gt;##    so the projection's second dimension is only head_dim 
&lt;/span&gt;    &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;W_k&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;embed_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;head_dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;W_v&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;embed_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;head_dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="c1"&gt;## Final output projection, also has full model dimensions
&lt;/span&gt;    &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;W_output&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;embed_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;embed_dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="c1"&gt;# Initialize the attention module with dropout
&lt;/span&gt;    &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;attention&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;MaskedScaledDotProductAttention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

  &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mask&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;None&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;seq_len&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;embed_dim&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;size&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;

    &lt;span class="c1"&gt;# Step 1. Pass input x through Q projection, split heads and 
&lt;/span&gt;    &lt;span class="c1"&gt;#      rearrange dimensions 1, 2 for parallelism
&lt;/span&gt;    &lt;span class="c1"&gt;#    shape: (batch_size, num_heads, seq_len, head_dim)
&lt;/span&gt;    &lt;span class="n"&gt;queries&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;W_q&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;view&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;seq_len&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;num_heads&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;head_dim&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;transpose&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="c1"&gt;# Step 2. Pass input x through K and V projections, insert a dimension at
&lt;/span&gt;    &lt;span class="c1"&gt;#        position 1 to represent the shared K and V heads
&lt;/span&gt;    &lt;span class="c1"&gt;#    shape: (batch_size, 1, seq_len, head_dim)
&lt;/span&gt;    &lt;span class="n"&gt;keys&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;W_k&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;unsqueeze&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;values&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;W_v&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;unsqueeze&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="c1"&gt;# Step 3. Calculate attention
&lt;/span&gt;    &lt;span class="c1"&gt;#    shape: (batch_size, num_heads, seq_len, head_dim)
&lt;/span&gt;    &lt;span class="n"&gt;attn_values&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;attention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;queries&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;keys&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;values&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mask&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="c1"&gt;# Step 4. Concatenate attention values across heads
&lt;/span&gt;    &lt;span class="n"&gt;attn_values&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;attn_values&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;transpose&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;reshape&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;seq_len&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;embed_dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="c1"&gt;# Step 5. Pass attention values through the final output projection
&lt;/span&gt;    &lt;span class="n"&gt;output&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;W_output&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;attn_values&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;output&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;

&lt;h2&gt;
  
  
  GQA: Grouped-Query Attention
&lt;/h2&gt;

&lt;p&gt;While MQA greatly improved inference efficiency by saving memory and compute overheads, the simplification of sharing keys and values across all query heads can restrict each head’s expressiveness and their ability to independently “attend” to its own representation subspace. Thus, training models directly with MQA can lead to degraded performance and training instability.&lt;/p&gt;

&lt;p&gt;To overcome these limitations, Grouped-Query Attention (GQA) was introduced by Google Research in 2023 (&lt;a href="https://arxiv.org/abs/2305.13245" rel="noopener noreferrer"&gt;Ainslie et. al, 2023&lt;/a&gt;). Instead of sharing one set of key and value heads across all query heads, GQA partitions query heads into groups and let each group share one set of key and value heads. This approach maintains MQA’s efficiency while preserving more of the model’s representational capacity, making it suitable for both training and inference.&lt;/p&gt;

&lt;p&gt;GQA has been adopted in several notable LLMs, including LLaMA 2, LLaMA 3, Qwen2 and Qwen3.&lt;/p&gt;
&lt;h3&gt;
  
  
  GQA Implementation
&lt;/h3&gt;

&lt;p&gt;Below is an implementation of GQA. The key differences from MQA are:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;While the query projection has the same shape as MHA and MQA: &lt;em&gt;(embed_dim, embed_dim)&lt;/em&gt;, the key and the value projections are of shape &lt;em&gt;(embed_dim, &lt;strong&gt;num_kv_groups * head_dim&lt;/strong&gt;)&lt;/em&gt;, so that we can easily split them into groups.&lt;/li&gt;
&lt;li&gt;To share keys and values in each query group at attention calculation, we split the key and value tensors into &lt;em&gt;num_kv_groups&lt;/em&gt; groups, rearrange the dimensions to align the &lt;em&gt;num_kv_groups&lt;/em&gt; with query tensor’s &lt;em&gt;num_heads&lt;/em&gt; dimension. Then we repeat the keys and values along the &lt;em&gt;num_kv_groups&lt;/em&gt; dimension for &lt;em&gt;heads_per_group&lt;/em&gt; (= &lt;em&gt;num_heads / num_kv_groups&lt;/em&gt;) times. As &lt;em&gt;num_kv_groups * heads_per_group = num_heads&lt;/em&gt;, we can then compute attention of each query head simultaneously just like MHA and MQA.
&lt;/li&gt;
&lt;/ul&gt;
&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;torch.nn&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;

&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;GroupedQueryAttention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
  &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_heads&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_kv_groups&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;embed_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.1&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="k"&gt;assert&lt;/span&gt; &lt;span class="n"&gt;embed_dim&lt;/span&gt; &lt;span class="o"&gt;%&lt;/span&gt; &lt;span class="n"&gt;num_heads&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Model dimension must &lt;/span&gt;&lt;span class="se"&gt;\
&lt;/span&gt;&lt;span class="s"&gt;        be divisible by number of heads. Got embed_dim: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;embed_dim&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt;, &lt;/span&gt;&lt;span class="se"&gt;\
&lt;/span&gt;&lt;span class="s"&gt;        num_heads: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;num_heads&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;    
    &lt;span class="k"&gt;assert&lt;/span&gt; &lt;span class="n"&gt;num_heads&lt;/span&gt; &lt;span class="o"&gt;%&lt;/span&gt; &lt;span class="n"&gt;num_kv_groups&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Number of heads must be &lt;/span&gt;&lt;span class="se"&gt;\
&lt;/span&gt;&lt;span class="s"&gt;        divisible by number of KV groups. Got num_heads: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;num_heads&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt;, &lt;/span&gt;&lt;span class="se"&gt;\
&lt;/span&gt;&lt;span class="s"&gt;        num_kv_groups: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;num_kv_groups&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;
    &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;num_heads&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;num_heads&lt;/span&gt;
    &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;head_dim&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;embed_dim&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;num_heads&lt;/span&gt;
    &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;num_kv_groups&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;num_kv_groups&lt;/span&gt;
    &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;groups_per_head&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;num_heads&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;num_groups&lt;/span&gt;

    &lt;span class="c1"&gt;# initialize Q projection
&lt;/span&gt;    &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;W_q&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;embed_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;embed_dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="c1"&gt;# initialize K, V projections
&lt;/span&gt;    &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;W_k&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;embed_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;embed_dim&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;num_kv_groups&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;W_v&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;embed_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;embed_dim&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;num_kv_groups&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="c1"&gt;# initialize final output projection
&lt;/span&gt;    &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;W_output&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;embed_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;embed_dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="c1"&gt;# initialize attention with dropout
&lt;/span&gt;    &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;attention&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;MaskedScaledDotProductAttention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

  &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mask&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;None&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;seq_len&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;embed_dim&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;size&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;

    &lt;span class="c1"&gt;# Step 1. pass x through Q projection, split heads and rearrange for parallelism
&lt;/span&gt;    &lt;span class="c1"&gt;#    shape -&amp;gt; (batch_size, num_heads, seq_len, head_dim)
&lt;/span&gt;    &lt;span class="n"&gt;queries&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;W_q&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;view&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;seq_len&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;num_heads&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;head_dim&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;transpose&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="c1"&gt;# Step 2. pass x through K and V projections, split groups
&lt;/span&gt;    &lt;span class="c1"&gt;#    shape -&amp;gt; (batch_size, num_kv_groups, seq_len, head_dim)
&lt;/span&gt;    &lt;span class="n"&gt;keys&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;W_k&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;view&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;seq_len&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;num_kv_groups&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;head_dim&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;transpose&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;values&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;W_v&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;view&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;seq_len&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;num_kv_groups&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;head_dim&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;transpose&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="c1"&gt;# Step 3. repeat keys and values heads_per_group times along the num_kv_groups dimension
&lt;/span&gt;    &lt;span class="c1"&gt;#    shape: (batch_size, num_kv_groups, seq_len, head_dim) -&amp;gt;
&lt;/span&gt;    &lt;span class="c1"&gt;#              (batch_size, num_heads(=num_kv_groups * heads_per_group), seq_len, head_dim)
&lt;/span&gt;    &lt;span class="n"&gt;heads_per_group&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;num_heads&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;num_kv_groups&lt;/span&gt;
    &lt;span class="n"&gt;keys&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;keys&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;repeat_interleave&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;heads_per_group&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;values&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;values&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;repeat_interleave&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;heads_per_group&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="c1"&gt;# Step 4. compute attention
&lt;/span&gt;    &lt;span class="c1"&gt;#    shape: (batch_size, num_heads, seq_len, head_eim)
&lt;/span&gt;    &lt;span class="n"&gt;attn_values&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;attention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;queries&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;keys&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;values&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mask&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="c1"&gt;# Step 5. concatenate heads
&lt;/span&gt;    &lt;span class="c1"&gt;#    shape -&amp;gt; (batch_size, seq_len, embed_dim)
&lt;/span&gt;    &lt;span class="n"&gt;attn_values&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;attn_values&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;transpose&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;reshape&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;seq_len&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;embed_dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="c1"&gt;# Step 6. pass attn_values through the final output projection
&lt;/span&gt;    &lt;span class="n"&gt;output&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;W_output&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;attn_values&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;output&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;

&lt;h2&gt;
  
  
  MLA: Multi-head Latent Attention
&lt;/h2&gt;

&lt;p&gt;While GQA strikes a balance between efficiency and quality, further improvements are needed to scale to even larger models and longer, more complex inputs. Multi-head Latent Attention (MLA) was introduced by DeepSeek-AI in their 2024 &lt;a href="https://arxiv.org/abs/2405.04434" rel="noopener noreferrer"&gt;DeepSeek-v2 paper&lt;/a&gt; to address this (DeepSeek-AI, 2024).&lt;/p&gt;

&lt;p&gt;MLA uses a low-rank factorization approach to jointly compress keys and values into one much smaller learned latent vector. This compression significantly reduces memory and computation needs for KV-cache, enabling efficient processing of larger and more complex inputs, boosting generation throughput, while maintaining model performance.&lt;/p&gt;

&lt;p&gt;Ablation and empirical tests on four hard benchmarks showed that, while MHA outperforms GQA and MQA, MLA performs even better than MHA and requires much smaller amount of KV-cache.&lt;/p&gt;
&lt;h3&gt;
  
  
  MLA Implementation
&lt;/h3&gt;

&lt;p&gt;Here is the full MLA formula as provided in DeepSeek-v2 paper:&lt;/p&gt;

&lt;p&gt;

&lt;/p&gt;
&lt;div class="katex-element"&gt;
  &lt;span class="katex-display"&gt;&lt;span class="katex"&gt;&lt;span class="katex-mathml"&gt;ctQ=WDQht[qt,1C;qt,2C;…;qt,nhC]=qtC=WUQctQ[qt,1R;qt,2R;…;qt,nhR]=qtR=RoPE(WQRctQ)qt,i=[qt,iC;qt,iR]ctKV=WDKVht[kt,1C;kt,2C;…;kt,nhC]=ktC=WUKctKVktR=RoPE(WKRht)kt,i=[kt,iC;kt,iR][vt,1C;vt,2C;…;vt,nhC]=vtC=WUVctKVot,i=∑j=1tSoftmaxj(qt,i⊤kj,idh+dhR)vj,iCut=WO[ot,1;ot,2;…;ot,nh]
\begin{align}
c_t^{Q} = W^{DQ} h_t  \\  % (1)
[q_{t,1}^{C}; q_{t,2}^{C}; \ldots; q_{t,n_h}^{C}] = q_t^{C} = W^{UQ} c_t^{Q}  \\  % (2)
[q_{t,1}^{R}; q_{t,2}^{R}; \ldots; q_{t,n_h}^{R}] = q_t^{R} = \mathrm{RoPE}(W^{QR} c_t^{Q})  \\  % (3)
q_{t,i} = [ q_{t,i}^{C}; q_{t,i}^{R} ]  \\  % (4)
c_t^{KV} = W^{DKV} h_t  \\  % (5)
[k_{t,1}^{C}; k_{t,2}^{C}; \ldots; k_{t,n_h}^{C}] = k_t^{C} = W^{UK} c_t^{KV}  \\  % (6)
k_t^{R} = \mathrm{RoPE}(W^{KR} h_t)  \\  % (7)
k_{t,i} = [ k_{t,i}^{C}; k_{t,i}^{R} ]  \\  % (8)
[v_{t,1}^{C}; v_{t,2}^{C}; \ldots; v_{t,n_h}^{C}] = v_t^{C} = W^{UV} c_t^{KV}  \\  % (9)
o_{t,i} = \sum_{j=1}^{t} \mathrm{Softmax}j \left( \frac{ q{t,i}^\top k_{j,i} }{ \sqrt{d_h + d_h^{R}} } \right) v_{j,i}^{C}  \\  % (10)
u_t = W^{O} [ o_{t,1}; o_{t,2}; \ldots; o_{t,n_h} ]  % (11)
\end{align}
&lt;/span&gt;&lt;span class="katex-html"&gt;&lt;span class="base"&gt;&lt;span class="strut"&gt;&lt;/span&gt;&lt;span class="mtable"&gt;&lt;span class="col-align-r"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;c&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;t&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;Q&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mrel"&gt;=&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;W&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;D&lt;/span&gt;&lt;span class="mord mathnormal mtight"&gt;Q&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;h&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;t&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mopen"&gt;[&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;q&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;t&lt;/span&gt;&lt;span class="mpunct mtight"&gt;,&lt;/span&gt;&lt;span class="mord mtight"&gt;1&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;C&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mpunct"&gt;;&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;q&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;t&lt;/span&gt;&lt;span class="mpunct mtight"&gt;,&lt;/span&gt;&lt;span class="mord mtight"&gt;2&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;C&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mpunct"&gt;;&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="minner"&gt;…&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mpunct"&gt;;&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;q&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;t&lt;/span&gt;&lt;span class="mpunct mtight"&gt;,&lt;/span&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;n&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size3 size1 mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;h&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;C&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mclose"&gt;]&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mrel"&gt;=&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;q&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;t&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;C&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mrel"&gt;=&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;W&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;U&lt;/span&gt;&lt;span class="mord mathnormal mtight"&gt;Q&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;c&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;t&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;Q&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mopen"&gt;[&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;q&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;t&lt;/span&gt;&lt;span class="mpunct mtight"&gt;,&lt;/span&gt;&lt;span class="mord mtight"&gt;1&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;R&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mpunct"&gt;;&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;q&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;t&lt;/span&gt;&lt;span class="mpunct mtight"&gt;,&lt;/span&gt;&lt;span class="mord mtight"&gt;2&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;R&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mpunct"&gt;;&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="minner"&gt;…&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mpunct"&gt;;&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;q&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;t&lt;/span&gt;&lt;span class="mpunct mtight"&gt;,&lt;/span&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;n&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size3 size1 mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;h&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;R&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mclose"&gt;]&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mrel"&gt;=&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;q&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;t&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;R&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mrel"&gt;=&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathrm"&gt;RoPE&lt;/span&gt;&lt;/span&gt;&lt;span class="mopen"&gt;(&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;W&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;QR&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;c&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;t&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;Q&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mclose"&gt;)&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;q&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;t&lt;/span&gt;&lt;span class="mpunct mtight"&gt;,&lt;/span&gt;&lt;span class="mord mathnormal mtight"&gt;i&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mrel"&gt;=&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mopen"&gt;[&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;q&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;t&lt;/span&gt;&lt;span class="mpunct mtight"&gt;,&lt;/span&gt;&lt;span class="mord mathnormal mtight"&gt;i&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;C&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mpunct"&gt;;&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;q&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;t&lt;/span&gt;&lt;span class="mpunct mtight"&gt;,&lt;/span&gt;&lt;span class="mord mathnormal mtight"&gt;i&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;R&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mclose"&gt;]&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;c&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;t&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;K&lt;/span&gt;&lt;span class="mord mathnormal mtight"&gt;V&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mrel"&gt;=&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;W&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;DK&lt;/span&gt;&lt;span class="mord mathnormal mtight"&gt;V&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;h&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;t&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mopen"&gt;[&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;k&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;t&lt;/span&gt;&lt;span class="mpunct mtight"&gt;,&lt;/span&gt;&lt;span class="mord mtight"&gt;1&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;C&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mpunct"&gt;;&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;k&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;t&lt;/span&gt;&lt;span class="mpunct mtight"&gt;,&lt;/span&gt;&lt;span class="mord mtight"&gt;2&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;C&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mpunct"&gt;;&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="minner"&gt;…&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mpunct"&gt;;&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;k&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;t&lt;/span&gt;&lt;span class="mpunct mtight"&gt;,&lt;/span&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;n&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size3 size1 mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;h&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;C&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mclose"&gt;]&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mrel"&gt;=&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;k&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;t&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;C&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mrel"&gt;=&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;W&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;U&lt;/span&gt;&lt;span class="mord mathnormal mtight"&gt;K&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;c&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;t&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;K&lt;/span&gt;&lt;span class="mord mathnormal mtight"&gt;V&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;k&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;t&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;R&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mrel"&gt;=&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathrm"&gt;RoPE&lt;/span&gt;&lt;/span&gt;&lt;span class="mopen"&gt;(&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;W&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;K&lt;/span&gt;&lt;span class="mord mathnormal mtight"&gt;R&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;h&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;t&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mclose"&gt;)&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;k&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;t&lt;/span&gt;&lt;span class="mpunct mtight"&gt;,&lt;/span&gt;&lt;span class="mord mathnormal mtight"&gt;i&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mrel"&gt;=&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mopen"&gt;[&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;k&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;t&lt;/span&gt;&lt;span class="mpunct mtight"&gt;,&lt;/span&gt;&lt;span class="mord mathnormal mtight"&gt;i&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;C&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mpunct"&gt;;&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;k&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;t&lt;/span&gt;&lt;span class="mpunct mtight"&gt;,&lt;/span&gt;&lt;span class="mord mathnormal mtight"&gt;i&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;R&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mclose"&gt;]&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mopen"&gt;[&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;v&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;t&lt;/span&gt;&lt;span class="mpunct mtight"&gt;,&lt;/span&gt;&lt;span class="mord mtight"&gt;1&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;C&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mpunct"&gt;;&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;v&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;t&lt;/span&gt;&lt;span class="mpunct mtight"&gt;,&lt;/span&gt;&lt;span class="mord mtight"&gt;2&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;C&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mpunct"&gt;;&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="minner"&gt;…&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mpunct"&gt;;&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;v&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;t&lt;/span&gt;&lt;span class="mpunct mtight"&gt;,&lt;/span&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;n&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size3 size1 mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;h&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;C&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mclose"&gt;]&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mrel"&gt;=&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;v&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;t&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;C&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mrel"&gt;=&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;W&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;U&lt;/span&gt;&lt;span class="mord mathnormal mtight"&gt;V&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;c&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;t&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;K&lt;/span&gt;&lt;span class="mord mathnormal mtight"&gt;V&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;o&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;t&lt;/span&gt;&lt;span class="mpunct mtight"&gt;,&lt;/span&gt;&lt;span class="mord mathnormal mtight"&gt;i&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mrel"&gt;=&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mop op-limits"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;j&lt;/span&gt;&lt;span class="mrel mtight"&gt;=&lt;/span&gt;&lt;span class="mord mtight"&gt;1&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="mop op-symbol large-op"&gt;∑&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;t&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathrm"&gt;Softmax&lt;/span&gt;&lt;/span&gt;&lt;span class="mord mathnormal"&gt;j&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="minner"&gt;&lt;span class="mopen"&gt;&lt;span class="delimsizing mult"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="delimsizinginner delim-size4"&gt;&lt;span&gt;⎝&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="delimsizinginner delim-size4"&gt;&lt;span&gt;⎛&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mopen nulldelimiter"&gt;&lt;/span&gt;&lt;span class="mfrac"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord sqrt"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span class="svg-align"&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;d&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;h&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mbin"&gt;+&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;d&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;h&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;R&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="hide-tail"&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="frac-line"&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;q&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;t&lt;/span&gt;&lt;span class="mpunct"&gt;,&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mord mathnormal"&gt;i&lt;/span&gt;&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;⊤&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;k&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;j&lt;/span&gt;&lt;span class="mpunct mtight"&gt;,&lt;/span&gt;&lt;span class="mord mathnormal mtight"&gt;i&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mclose nulldelimiter"&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mclose"&gt;&lt;span class="delimsizing mult"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="delimsizinginner delim-size4"&gt;&lt;span&gt;⎠&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="delimsizinginner delim-size4"&gt;&lt;span&gt;⎞&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;v&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;j&lt;/span&gt;&lt;span class="mpunct mtight"&gt;,&lt;/span&gt;&lt;span class="mord mathnormal mtight"&gt;i&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;C&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;u&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;t&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mrel"&gt;=&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;W&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;O&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mopen"&gt;[&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;o&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;t&lt;/span&gt;&lt;span class="mpunct mtight"&gt;,&lt;/span&gt;&lt;span class="mord mtight"&gt;1&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mpunct"&gt;;&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;o&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;t&lt;/span&gt;&lt;span class="mpunct mtight"&gt;,&lt;/span&gt;&lt;span class="mord mtight"&gt;2&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mpunct"&gt;;&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="minner"&gt;…&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mpunct"&gt;;&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;o&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;t&lt;/span&gt;&lt;span class="mpunct mtight"&gt;,&lt;/span&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;n&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size3 size1 mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;h&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mclose"&gt;]&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="tag"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="eqn-num"&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="eqn-num"&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="eqn-num"&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="eqn-num"&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="eqn-num"&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="eqn-num"&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="eqn-num"&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="eqn-num"&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="eqn-num"&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="eqn-num"&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="eqn-num"&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;
&lt;/div&gt;



&lt;p&gt;Where:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;span class="katex-element"&gt;
  &lt;span class="katex"&gt;&lt;span class="katex-mathml"&gt;hth_t&lt;/span&gt;&lt;span class="katex-html"&gt;&lt;span class="base"&gt;&lt;span class="strut"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;h&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;t&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;
&lt;/span&gt;
: input token embedding at position 
&lt;span class="katex-element"&gt;
  &lt;span class="katex"&gt;&lt;span class="katex-mathml"&gt;tt&lt;/span&gt;&lt;span class="katex-html"&gt;&lt;span class="base"&gt;&lt;span class="strut"&gt;&lt;/span&gt;&lt;span class="mord mathnormal"&gt;t&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;
&lt;/span&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;span class="katex-element"&gt;
  &lt;span class="katex"&gt;&lt;span class="katex-mathml"&gt;nhn_h&lt;/span&gt;&lt;span class="katex-html"&gt;&lt;span class="base"&gt;&lt;span class="strut"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;n&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;h&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;
&lt;/span&gt;
: number of attention heads&lt;/li&gt;
&lt;li&gt;
&lt;span class="katex-element"&gt;
  &lt;span class="katex"&gt;&lt;span class="katex-mathml"&gt;WDQ,WDKVW^{DQ}, W^{DKV}&lt;/span&gt;&lt;span class="katex-html"&gt;&lt;span class="base"&gt;&lt;span class="strut"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;W&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;D&lt;/span&gt;&lt;span class="mord mathnormal mtight"&gt;Q&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mpunct"&gt;,&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;W&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;DK&lt;/span&gt;&lt;span class="mord mathnormal mtight"&gt;V&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;
&lt;/span&gt;
: down-projection matrices for query and key-value content vectors&lt;/li&gt;
&lt;li&gt;
&lt;span class="katex-element"&gt;
  &lt;span class="katex"&gt;&lt;span class="katex-mathml"&gt;WUQ,WUK,WUVW^{UQ}, W^{UK}, W^{UV}&lt;/span&gt;&lt;span class="katex-html"&gt;&lt;span class="base"&gt;&lt;span class="strut"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;W&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;U&lt;/span&gt;&lt;span class="mord mathnormal mtight"&gt;Q&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mpunct"&gt;,&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;W&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;U&lt;/span&gt;&lt;span class="mord mathnormal mtight"&gt;K&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mpunct"&gt;,&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;W&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;U&lt;/span&gt;&lt;span class="mord mathnormal mtight"&gt;V&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;
&lt;/span&gt;
: up-projection for query, key and value from content vectors&lt;/li&gt;
&lt;li&gt;
&lt;span class="katex-element"&gt;
  &lt;span class="katex"&gt;&lt;span class="katex-mathml"&gt;WQR,WKRW^{QR}, W^{KR}&lt;/span&gt;&lt;span class="katex-html"&gt;&lt;span class="base"&gt;&lt;span class="strut"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;W&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;QR&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mpunct"&gt;,&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;W&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;K&lt;/span&gt;&lt;span class="mord mathnormal mtight"&gt;R&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;
&lt;/span&gt;
: linear projections generating relative queries and keys (before RoPE)&lt;/li&gt;
&lt;li&gt;
&lt;span class="katex-element"&gt;
  &lt;span class="katex"&gt;&lt;span class="katex-mathml"&gt;WOW^O&lt;/span&gt;&lt;span class="katex-html"&gt;&lt;span class="base"&gt;&lt;span class="strut"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;W&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;O&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;
&lt;/span&gt;
: output linear projection matrix&lt;/li&gt;
&lt;li&gt;
&lt;span class="katex-element"&gt;
  &lt;span class="katex"&gt;&lt;span class="katex-mathml"&gt;ctQc_t^Q&lt;/span&gt;&lt;span class="katex-html"&gt;&lt;span class="base"&gt;&lt;span class="strut"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;c&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;t&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;Q&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;
&lt;/span&gt;
: content query vector (down-projected from input 
&lt;span class="katex-element"&gt;
  &lt;span class="katex"&gt;&lt;span class="katex-mathml"&gt;hth_t&lt;/span&gt;&lt;span class="katex-html"&gt;&lt;span class="base"&gt;&lt;span class="strut"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;h&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;t&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;
&lt;/span&gt;
)&lt;/li&gt;
&lt;li&gt;
&lt;span class="katex-element"&gt;
  &lt;span class="katex"&gt;&lt;span class="katex-mathml"&gt;ctKVc_t^{KV}&lt;/span&gt;&lt;span class="katex-html"&gt;&lt;span class="base"&gt;&lt;span class="strut"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;c&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;t&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;K&lt;/span&gt;&lt;span class="mord mathnormal mtight"&gt;V&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;
&lt;/span&gt;
: content key-value vector (also down-projected from input)&lt;/li&gt;
&lt;li&gt;
&lt;span class="katex-element"&gt;
  &lt;span class="katex"&gt;&lt;span class="katex-mathml"&gt;qtC,qt,iCq_t^C, q_{t,i}^C&lt;/span&gt;&lt;span class="katex-html"&gt;&lt;span class="base"&gt;&lt;span class="strut"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;q&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;t&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;C&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mpunct"&gt;,&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;q&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;t&lt;/span&gt;&lt;span class="mpunct mtight"&gt;,&lt;/span&gt;&lt;span class="mord mathnormal mtight"&gt;i&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;C&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;
&lt;/span&gt;
: content queries of all heads / head i&lt;/li&gt;
&lt;li&gt;
&lt;span class="katex-element"&gt;
  &lt;span class="katex"&gt;&lt;span class="katex-mathml"&gt;qtR,qr,iRq_t^R, q_{r,i}^R&lt;/span&gt;&lt;span class="katex-html"&gt;&lt;span class="base"&gt;&lt;span class="strut"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;q&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;t&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;R&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mpunct"&gt;,&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;q&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;r&lt;/span&gt;&lt;span class="mpunct mtight"&gt;,&lt;/span&gt;&lt;span class="mord mathnormal mtight"&gt;i&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;R&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;
&lt;/span&gt;
: relative positional queries of all heads / head i&lt;/li&gt;
&lt;li&gt;
&lt;span class="katex-element"&gt;
  &lt;span class="katex"&gt;&lt;span class="katex-mathml"&gt;qt,iq_{t,i}&lt;/span&gt;&lt;span class="katex-html"&gt;&lt;span class="base"&gt;&lt;span class="strut"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;q&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;t&lt;/span&gt;&lt;span class="mpunct mtight"&gt;,&lt;/span&gt;&lt;span class="mord mathnormal mtight"&gt;i&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;
&lt;/span&gt;
: concatenated content and relative query vectors of head i&lt;/li&gt;
&lt;li&gt;
&lt;span class="katex-element"&gt;
  &lt;span class="katex"&gt;&lt;span class="katex-mathml"&gt;ktC,kt,iCk_t^C, k_{t,i}^C&lt;/span&gt;&lt;span class="katex-html"&gt;&lt;span class="base"&gt;&lt;span class="strut"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;k&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;t&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;C&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mpunct"&gt;,&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;k&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;t&lt;/span&gt;&lt;span class="mpunct mtight"&gt;,&lt;/span&gt;&lt;span class="mord mathnormal mtight"&gt;i&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;C&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;
&lt;/span&gt;
: content keys for all heads / head i&lt;/li&gt;
&lt;li&gt;
&lt;span class="katex-element"&gt;
  &lt;span class="katex"&gt;&lt;span class="katex-mathml"&gt;ktRk_t^R&lt;/span&gt;&lt;span class="katex-html"&gt;&lt;span class="base"&gt;&lt;span class="strut"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;k&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;t&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;R&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;
&lt;/span&gt;
: relative positional keys&lt;/li&gt;
&lt;li&gt;
&lt;span class="katex-element"&gt;
  &lt;span class="katex"&gt;&lt;span class="katex-mathml"&gt;kt,ik_{t,i}&lt;/span&gt;&lt;span class="katex-html"&gt;&lt;span class="base"&gt;&lt;span class="strut"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;k&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;t&lt;/span&gt;&lt;span class="mpunct mtight"&gt;,&lt;/span&gt;&lt;span class="mord mathnormal mtight"&gt;i&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;
&lt;/span&gt;
: concatenated content and relative key vectors of head i&lt;/li&gt;
&lt;li&gt;
&lt;span class="katex-element"&gt;
  &lt;span class="katex"&gt;&lt;span class="katex-mathml"&gt;vtC,vt,iCv_t^C, v_{t,i}^C&lt;/span&gt;&lt;span class="katex-html"&gt;&lt;span class="base"&gt;&lt;span class="strut"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;v&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;t&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;C&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mpunct"&gt;,&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;v&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;t&lt;/span&gt;&lt;span class="mpunct mtight"&gt;,&lt;/span&gt;&lt;span class="mord mathnormal mtight"&gt;i&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;C&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;
&lt;/span&gt;
: content values for all heads / head i&lt;/li&gt;
&lt;li&gt;
&lt;span class="katex-element"&gt;
  &lt;span class="katex"&gt;&lt;span class="katex-mathml"&gt;dh,dhRd_h, d_h^R&lt;/span&gt;&lt;span class="katex-html"&gt;&lt;span class="base"&gt;&lt;span class="strut"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;d&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;h&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mpunct"&gt;,&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;d&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;h&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;R&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;
&lt;/span&gt;
: dimensions of content and relative positional subspaces per head&lt;/li&gt;
&lt;li&gt;
&lt;span class="katex-element"&gt;
  &lt;span class="katex"&gt;&lt;span class="katex-mathml"&gt;ot,io_{t,i}&lt;/span&gt;&lt;span class="katex-html"&gt;&lt;span class="base"&gt;&lt;span class="strut"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;o&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;t&lt;/span&gt;&lt;span class="mpunct mtight"&gt;,&lt;/span&gt;&lt;span class="mord mathnormal mtight"&gt;i&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;
&lt;/span&gt;
: attention output for head i at position t&lt;/li&gt;
&lt;li&gt;
&lt;span class="katex-element"&gt;
  &lt;span class="katex"&gt;&lt;span class="katex-mathml"&gt;utu_t&lt;/span&gt;&lt;span class="katex-html"&gt;&lt;span class="base"&gt;&lt;span class="strut"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;u&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;t&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;
&lt;/span&gt;
: final output&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;DeepSeek’s MLA is deeply integrated with &lt;a href="https://arxiv.org/abs/2104.09864" rel="noopener noreferrer"&gt;RoPE (Rotary Position Embedding)&lt;/a&gt;. We will do a deep dive in positional embedding in the next article and also implement full MLA with RoPE. For now we’ll just implement a simplified version without RoPE to demonstrate MLA’s idea of learning compressed latent content vectors instead of full Q, K, V projections.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;torch.nn&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;torch.nn.functional&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;F&lt;/span&gt;

&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;MultiheadLatentAttentionSimplified&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
  &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;embed_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_heads&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;q_latent_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;kv_latent_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.1&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="nf"&gt;super&lt;/span&gt;&lt;span class="p"&gt;().&lt;/span&gt;&lt;span class="nf"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
    &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;embed_dim&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;embed_dim&lt;/span&gt;
    &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;num_heads&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;num_heads&lt;/span&gt;
    &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;head_dim&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;embed_dim&lt;/span&gt; &lt;span class="o"&gt;//&lt;/span&gt; &lt;span class="n"&gt;num_head&lt;/span&gt;
    &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;q_latent_dim&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;q_latent_dim&lt;/span&gt;
    &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;kv_latent_dim&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;kv_latent_dim&lt;/span&gt;

    &lt;span class="c1"&gt;# Initialize projections
&lt;/span&gt;    &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;W_DQ&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;embed_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;q_latent_dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;W_UQ&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;q_latent_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;embed_dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;W_DKV&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;embed_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;kv_latent_dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;W_UK&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;kv_latent_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;embed_dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;W_UV&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;kv_latent_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;embed_dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;W_output&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;embed_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;embed_dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="c1"&gt;# Initialize attention module
&lt;/span&gt;    &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;attention&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;MaskedScaledDotProductAttention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

  &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mask&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;None&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;seq_len&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;

    &lt;span class="c1"&gt;# Step 1. Compress and decompress Q
&lt;/span&gt;    &lt;span class="n"&gt;c_q&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;W_DQ&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;  &lt;span class="c1"&gt;# (batch_size, seq_len, q_latent_dim)
&lt;/span&gt;    &lt;span class="n"&gt;q_content&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;W_UQ&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;c_q&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;  &lt;span class="c1"&gt;# (batch_size, seq_len, embed_dim)
&lt;/span&gt;
    &lt;span class="c1"&gt;# Step 2. Compres K and V into one latent subspace
&lt;/span&gt;    &lt;span class="n"&gt;c_kv&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;W_DKV&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;  &lt;span class="c1"&gt;# (batch_size, seq_len, kv_latent_dim)
&lt;/span&gt;
    &lt;span class="c1"&gt;# Step 3. Decompress K and V respectively
&lt;/span&gt;    &lt;span class="n"&gt;k_content&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;W_UK&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;c_kv&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;  &lt;span class="c1"&gt;# (batch_size, seq_len, embed_dim)
&lt;/span&gt;    &lt;span class="n"&gt;v_content&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;W_UV&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;c_kv&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;  &lt;span class="c1"&gt;# (batch_size, seq_len, embed_dim)
&lt;/span&gt;
    &lt;span class="c1"&gt;# Step 4. Split heads and reshape for multi-head attention
&lt;/span&gt;    &lt;span class="c1"&gt;# -&amp;gt; (batch_size, num_heads, seq_len, head_dim)
&lt;/span&gt;    &lt;span class="n"&gt;queries&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;q_content&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;view&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;seq_len&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;num_heads&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;head_dim&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;transpose&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;keys&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;k_content&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;view&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;seq_len&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;num_heads&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;head_dim&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;transpose&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;values&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;v_content&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;view&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;seq_len&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;num_heads&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;head_dim&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;transpose&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="c1"&gt;# Step 5. Apply attention
&lt;/span&gt;    &lt;span class="c1"&gt;# -&amp;gt; (batch_size, num_heads, seq_len, head_dim)
&lt;/span&gt;    &lt;span class="n"&gt;attn_output&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;attention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;queries&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;keys&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;values&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mask&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="c1"&gt;# Step 6. Concatenate heads and reshape
&lt;/span&gt;    &lt;span class="n"&gt;attn_output&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;attn_output&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;transpose&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;reshape&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;seq_len&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;embed_dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="c1"&gt;# Step 7. Apply final output projection
&lt;/span&gt;    &lt;span class="n"&gt;output&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;W_output&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;attn_output&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;output&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;c_kv&lt;/span&gt; &lt;span class="c1"&gt;# return c_kv to show that it will be cached
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;h2&gt;
  
  
  References
&lt;/h2&gt;

&lt;ul&gt;
&lt;li&gt;Noam Shazeer. (2019). Fast Transformer Decoding: One Write-Head is All You Need. &lt;a href="https://arxiv.org/abs/1911.02150" rel="noopener noreferrer"&gt;arXiv:1911.02150&lt;/a&gt;
&lt;/li&gt;
&lt;li&gt;Joshua Ainslie, James Lee-Thorp, Michiel de Jong, Yury Zemlyanskiy, Federico Lebrón, Sumit Sanghai (2023). GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints.
&lt;a href="https://arxiv.org/abs/2305.13245" rel="noopener noreferrer"&gt;arXiv:2305.13245&lt;/a&gt;
&lt;/li&gt;
&lt;li&gt;DeepSeek-AI (2024). DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model.
&lt;a href="https://arxiv.org/abs/2405.04434" rel="noopener noreferrer"&gt;arXiv:2405.04434&lt;/a&gt;
&lt;/li&gt;
&lt;/ul&gt;

</description>
      <category>transformer</category>
      <category>multiheadattention</category>
      <category>mla</category>
      <category>gqa</category>
    </item>
    <item>
      <title>Hands-On Transformer Deep Dive: Part 1 — Masked Attention Explained &amp; Implemented</title>
      <dc:creator>Alex Xiaoli Shen</dc:creator>
      <pubDate>Wed, 30 Jul 2025 07:24:45 +0000</pubDate>
      <link>https://dev.to/alabebop/deep-dive-into-transformers-attention-scale-mask-dropout-and-more-16fn</link>
      <guid>https://dev.to/alabebop/deep-dive-into-transformers-attention-scale-mask-dropout-and-more-16fn</guid>
      <description>&lt;p&gt;&lt;em&gt;In this “Hands-On Transformer Deep Dive” series, we go step-by-step through the algorithms and components of modern Transformers, with working code and engineering insights. Follow along to deepen your understanding — and build your own Transformers from scratch.&lt;/em&gt;&lt;/p&gt;

&lt;p&gt;In this article, we dive deep into the core attention mechanism used in most of today’s Transformer models: the Masked Scaled Dot-product Attention. We’ll implement it from scratch using only PyTorch, and look into the specifics of when and where to apply the scale, mask, dropout, and why.&lt;/p&gt;

&lt;h2&gt;
  
  
  Introduction
&lt;/h2&gt;

&lt;p&gt;Transformers have become the foundation of modern generative LMs. The attention mechanism lies in its core. There are many flavors of attention mechanisms, e.g., Additive Attention (Bahadnau, 2014), Dot-product Attention (Luong, 2015), Scaled Dot-product Attention (Vaswani et al., 2017), masked attention (used for padding and causal decoding), multi-head attention (which also has multiple variants). The masked scaled dot-product attention is the foundational building block of all the autoregressive GPT-like models prevalent today.&lt;/p&gt;

&lt;h2&gt;
  
  
  Implementation
&lt;/h2&gt;

&lt;p&gt;The attention mechanism enables LLMs to learn and generate context-dependent representations by letting each token “attend” to all tokens. The formula tells us its basic working, where given queries Q, keys K, values V, the attention output is calculated as follows (d_k is the dimension of the keys):&lt;/p&gt;

&lt;p&gt;

&lt;/p&gt;
&lt;div class="katex-element"&gt;
  &lt;span class="katex-display"&gt;&lt;span class="katex"&gt;&lt;span class="katex-mathml"&gt;Attention(Q,K,V)=Softmax(QKTdk)V
  Attention(Q,K,V)=Softmax(\frac{QK^T}{\sqrt{d_k}})V
&lt;/span&gt;&lt;span class="katex-html"&gt;&lt;span class="base"&gt;&lt;span class="strut"&gt;&lt;/span&gt;&lt;span class="mord mathnormal"&gt;A&lt;/span&gt;&lt;span class="mord mathnormal"&gt;tt&lt;/span&gt;&lt;span class="mord mathnormal"&gt;e&lt;/span&gt;&lt;span class="mord mathnormal"&gt;n&lt;/span&gt;&lt;span class="mord mathnormal"&gt;t&lt;/span&gt;&lt;span class="mord mathnormal"&gt;i&lt;/span&gt;&lt;span class="mord mathnormal"&gt;o&lt;/span&gt;&lt;span class="mord mathnormal"&gt;n&lt;/span&gt;&lt;span class="mopen"&gt;(&lt;/span&gt;&lt;span class="mord mathnormal"&gt;Q&lt;/span&gt;&lt;span class="mpunct"&gt;,&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mord mathnormal"&gt;K&lt;/span&gt;&lt;span class="mpunct"&gt;,&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mord mathnormal"&gt;V&lt;/span&gt;&lt;span class="mclose"&gt;)&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mrel"&gt;=&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="base"&gt;&lt;span class="strut"&gt;&lt;/span&gt;&lt;span class="mord mathnormal"&gt;S&lt;/span&gt;&lt;span class="mord mathnormal"&gt;o&lt;/span&gt;&lt;span class="mord mathnormal"&gt;f&lt;/span&gt;&lt;span class="mord mathnormal"&gt;t&lt;/span&gt;&lt;span class="mord mathnormal"&gt;ma&lt;/span&gt;&lt;span class="mord mathnormal"&gt;x&lt;/span&gt;&lt;span class="mopen"&gt;(&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mopen nulldelimiter"&gt;&lt;/span&gt;&lt;span class="mfrac"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord sqrt"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span class="svg-align"&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;d&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;k&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="hide-tail"&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="frac-line"&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;Q&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;K&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;T&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mclose nulldelimiter"&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mclose"&gt;)&lt;/span&gt;&lt;span class="mord mathnormal"&gt;V&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;
&lt;/div&gt;


&lt;p&gt;The implementation, however, includes a couple of more details:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;mask: to enable padding and causal attention (where a token can only “attend” to tokens that came before itself)&lt;/li&gt;
&lt;li&gt;dropout: a regularization method to prevent the model from relying too heavily on a few specific positions in the sequence&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;Below is the code implementing the masked scaled dot-product attention mechanism step-by-step:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;###
# Masked Scaled Dot-product Attention Implementation
###
&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;math&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;torch.nn&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;torch.nn.functional&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;F&lt;/span&gt;


&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;attention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;query&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;key&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;value&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mask&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;None&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.1&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="c1"&gt;# Step 1 &amp;amp; 2: dot-product and scale
&lt;/span&gt;    &lt;span class="n"&gt;d_k&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;key&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;size&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;scores&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;matmul&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;query&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;key&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;transpose&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;math&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;sqrt&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;d_k&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="c1"&gt;# Step 3: mask
&lt;/span&gt;    &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;mask&lt;/span&gt; &lt;span class="ow"&gt;is&lt;/span&gt; &lt;span class="ow"&gt;not&lt;/span&gt; &lt;span class="bp"&gt;None&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
        &lt;span class="n"&gt;scores&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;scores&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;masked_fill&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;mask&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="nf"&gt;float&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;-inf&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;

    &lt;span class="c1"&gt;# Step 4: softmax
&lt;/span&gt;    &lt;span class="n"&gt;attn_weights&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;F&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;softmax&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;scores&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="c1"&gt;# Step 5: dropout
&lt;/span&gt;    &lt;span class="n"&gt;attn_weights&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Dropout&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;)(&lt;/span&gt;&lt;span class="n"&gt;attn_weights&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="c1"&gt;# Step 6: weighted sum
&lt;/span&gt;    &lt;span class="n"&gt;output&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;matmul&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;attn_weights&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;value&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;output&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;


&lt;p&gt;For easy demonstration this is implemented as a function. There is also a PyTorch module implementation at the end of the article, which you can use to plug into your own PyTorch network.&lt;/p&gt;
&lt;h2&gt;
  
  
  Step-by-step Explanation and Nuances
&lt;/h2&gt;

&lt;p&gt;Following the 6 steps we can see the actual formula is:&lt;/p&gt;


&lt;div class="katex-element"&gt;
  &lt;span class="katex-display"&gt;&lt;span class="katex"&gt;&lt;span class="katex-mathml"&gt;Attention(Q,K,V,M,D)=[D⊙Softmax(QKTdk+mask(M))]V
  Attention(Q, K, V, M, D) = \left[D \odot Softmax\left(\frac{QK^T}{\sqrt{d_k}} + mask(M)\right)\right]V
&lt;/span&gt;&lt;span class="katex-html"&gt;&lt;span class="base"&gt;&lt;span class="strut"&gt;&lt;/span&gt;&lt;span class="mord mathnormal"&gt;A&lt;/span&gt;&lt;span class="mord mathnormal"&gt;tt&lt;/span&gt;&lt;span class="mord mathnormal"&gt;e&lt;/span&gt;&lt;span class="mord mathnormal"&gt;n&lt;/span&gt;&lt;span class="mord mathnormal"&gt;t&lt;/span&gt;&lt;span class="mord mathnormal"&gt;i&lt;/span&gt;&lt;span class="mord mathnormal"&gt;o&lt;/span&gt;&lt;span class="mord mathnormal"&gt;n&lt;/span&gt;&lt;span class="mopen"&gt;(&lt;/span&gt;&lt;span class="mord mathnormal"&gt;Q&lt;/span&gt;&lt;span class="mpunct"&gt;,&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mord mathnormal"&gt;K&lt;/span&gt;&lt;span class="mpunct"&gt;,&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mord mathnormal"&gt;V&lt;/span&gt;&lt;span class="mpunct"&gt;,&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mord mathnormal"&gt;M&lt;/span&gt;&lt;span class="mpunct"&gt;,&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mord mathnormal"&gt;D&lt;/span&gt;&lt;span class="mclose"&gt;)&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mrel"&gt;=&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="base"&gt;&lt;span class="strut"&gt;&lt;/span&gt;&lt;span class="minner"&gt;&lt;span class="mopen delimcenter"&gt;&lt;span class="delimsizing size3"&gt;[&lt;/span&gt;&lt;/span&gt;&lt;span class="mord mathnormal"&gt;D&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mbin"&gt;⊙&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mord mathnormal"&gt;S&lt;/span&gt;&lt;span class="mord mathnormal"&gt;o&lt;/span&gt;&lt;span class="mord mathnormal"&gt;f&lt;/span&gt;&lt;span class="mord mathnormal"&gt;t&lt;/span&gt;&lt;span class="mord mathnormal"&gt;ma&lt;/span&gt;&lt;span class="mord mathnormal"&gt;x&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="minner"&gt;&lt;span class="mopen delimcenter"&gt;&lt;span class="delimsizing size3"&gt;(&lt;/span&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mopen nulldelimiter"&gt;&lt;/span&gt;&lt;span class="mfrac"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord sqrt"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span class="svg-align"&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;d&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;k&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="hide-tail"&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="frac-line"&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;Q&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;K&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;T&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mclose nulldelimiter"&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mbin"&gt;+&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mord mathnormal"&gt;ma&lt;/span&gt;&lt;span class="mord mathnormal"&gt;s&lt;/span&gt;&lt;span class="mord mathnormal"&gt;k&lt;/span&gt;&lt;span class="mopen"&gt;(&lt;/span&gt;&lt;span class="mord mathnormal"&gt;M&lt;/span&gt;&lt;span class="mclose"&gt;)&lt;/span&gt;&lt;span class="mclose delimcenter"&gt;&lt;span class="delimsizing size3"&gt;)&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mclose delimcenter"&gt;&lt;span class="delimsizing size3"&gt;]&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mord mathnormal"&gt;V&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;
&lt;/div&gt;



&lt;p&gt;Where&lt;/p&gt;

&lt;p&gt;Q: query tensor of shape (batch_size, seq_len, d_q)&lt;br&gt;
K: key tensor of shape (batch_size, seq_len, d_k)&lt;br&gt;
V: value tensor of shape (batch_size, seq_len, d_v), normally d_q, d_k, and d_v are the same&lt;br&gt;
M: mask matrix of shape (seq_len, seq_len), 0 for masking and 1 for passing through&lt;br&gt;
D: dropout, a probability p between 0 and 1, where each element has p probability to be set 0 and 1-p probability to be kept and scaled up to 1/(1-p) (to compensate for the removed elements and keep the expected sum)&lt;br&gt;
Let’s look at each step and understand the nuances about why they have to be in this order.&lt;/p&gt;
&lt;h3&gt;
  
  
  Step 1. Dot Product
&lt;/h3&gt;


&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;scores&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;matmul&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;query&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;key&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;transpose&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;


&lt;p&gt;Here we use dot product to compute the correlation between the query and key to get the raw attention score. All further operations build on these fundamental scores.&lt;/p&gt;
&lt;h3&gt;
  
  
  Step 2. Scale by sqrt(d_k)
&lt;/h3&gt;


&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;d_k&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;query&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;size&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; 
&lt;span class="n"&gt;scores&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;scores&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;math&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;sqrt&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;d_k&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;


&lt;p&gt;Here we scale the raw attention scores to prevent the next step, softmax, from being too “peaky”. This stabilizes the gradients for models with a large dimension (d_k). As d_k grows, variances between dot product results also increases (roughly by d_k).&lt;/p&gt;

&lt;p&gt;Why do this scaling before softmax? The softmax function is highly sensitive to large input differences, where higher variance causes the largest score to dominate (i.e., its probability becomes close to 1 and others close to 0). This is a “peaky” distribution and can lead to vanishing gradients and poor gradient flow, which hurt learning. Therefore, we need to do the scale by sqrt(d_k) regularization here before moving on to softmax.&lt;/p&gt;
&lt;h3&gt;
  
  
  Step 3. Mask
&lt;/h3&gt;


&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;mask&lt;/span&gt; &lt;span class="ow"&gt;is&lt;/span&gt; &lt;span class="ow"&gt;not&lt;/span&gt; &lt;span class="bp"&gt;None&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="n"&gt;scores&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;scores&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;masked_fill&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;mask&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="nf"&gt;float&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;-inf&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;


&lt;p&gt;Mask is used for padding and/or causal attention (where a token can only “see and attend” to the tokens before it). It is typically an additive mask, adding a large negative number (e.g., float(‘-inf’)) to certain positions to block them. With softmax these positions’ probability then become near-zero.&lt;/p&gt;

&lt;p&gt;Why apply the mask before softmax? Because we use softmax to calculate the probability distribution from attention scores. For the tokens the query shouldn’t “see” we need their probabilities to be 0, and for the rest we need their probabilities to sum to 1. Masking before softmax with -inf satisfies this. If we mask after softmax (e.g., with zero), the masked tokens have already contributed to the probability distribution, and it also causes the probabilities to no longer sum to 1.&lt;/p&gt;

&lt;p&gt;On the other hand, what about masking before computing the attention scores, i.e., before Step 1. dot product? It seems intuitive to not let the query “see” the keys of tokens that it shouldn’t see in the first place, right? Unfortunately, with the nature of dot product computation, masking with 0 doesn’t really remove the influence of the corresponding positions’ attention scores and their probabilities in the following softmax step would also not be zero-ed out. Also, masking with -inf makes the computation itself impossible.&lt;/p&gt;
&lt;h3&gt;
  
  
  Step 4. Softmax
&lt;/h3&gt;


&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;attn_weights&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;F&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;softmax&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;scores&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;


&lt;p&gt;Apply softmax to the attention scores to compute the attention weights that sum to 1. The attention weights are the weight of each key that the query should pay attention to. After dot product the attention scores is a tensor of dimension (batch_size, (num_heads,) query_len, key_len), where query_len and key_len are the same in self-attention and are often noted as seq_len. We only need to compute weights of the keys which is in the last dimension, so dim=-1 tells softmax which dimension we’re interested in.&lt;/p&gt;
&lt;h3&gt;
  
  
  Step 5. Dropout
&lt;/h3&gt;


&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;attn_weights&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Dropout&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;)(&lt;/span&gt;&lt;span class="n"&gt;attn_weights&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;


&lt;p&gt;Here we apply the dropout regularization for smoother gradients and better generalization. The nn.Dropout(dropout) gives us a dropout module of the dropout rate we need (which randomly zeroes out p percent of the elements and scale the rest by 1/1-p). We then pass the attn_weights through this dropout module to apply it.&lt;/p&gt;

&lt;p&gt;Why is the dropout applied after softmax? As explained in Step 3. mask section, softmax calculates the attention weights from attention scores, which is a probability distribution representing how much “attention” each key should get. If we applied dropout before softmax, setting some attention scores to 0 and scaling up the others, we’d totally mess up the probability distribution, not to mention that the elements we “dropped out” (set to zero) don’t necessarily get a zero probabilities if you consider how the softmax function works.&lt;/p&gt;
&lt;h3&gt;
  
  
  Step 6. Coupute Output Value
&lt;/h3&gt;


&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;output&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;matmul&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;attn_weights&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;value&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;


&lt;p&gt;Finally, multiply the attention weights with the value tensor and we get our masked scaled dot-product attention output, hooray!&lt;/p&gt;
&lt;h2&gt;
  
  
  PyTorch Module Implementation
&lt;/h2&gt;

&lt;p&gt;Here is a PyTorch module implementation that can be plugged into your PyTorch modules.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;math&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;torch.nn&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;torch.nn.functional&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;F&lt;/span&gt;

&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;Attention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mask&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;None&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.1&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="nf"&gt;super&lt;/span&gt;&lt;span class="p"&gt;().&lt;/span&gt;&lt;span class="nf"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;mask&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;mask&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Dropout&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;query&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;key&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;value&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="c1"&gt;# Step 1 &amp;amp; 2: dot product and scale
&lt;/span&gt;        &lt;span class="n"&gt;d_k&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;query&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;size&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;scores&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;matmul&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;query&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;key&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;transpose&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;&lt;span class="o"&gt;/&lt;/span&gt;&lt;span class="n"&gt;math&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;sqrt&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;d_k&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="c1"&gt;# Step 3 mask
&lt;/span&gt;        &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;mask&lt;/span&gt; &lt;span class="ow"&gt;is&lt;/span&gt; &lt;span class="ow"&gt;not&lt;/span&gt; &lt;span class="bp"&gt;None&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
            &lt;span class="n"&gt;scores&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;scores&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;masked_fill&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;mask&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="nf"&gt;float&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;-inf&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
        &lt;span class="c1"&gt;# Step 4 softmax
&lt;/span&gt;        &lt;span class="n"&gt;attn_weights&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;F&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;softmax&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;scores&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="c1"&gt;# Step 5 dropout
&lt;/span&gt;        &lt;span class="n"&gt;attn_weights&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;attn_weights&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="c1"&gt;# Step 6 output
&lt;/span&gt;        &lt;span class="n"&gt;output&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;matmul&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;attn_weights&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;value&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

        &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;value&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;This is the first of a series articles diving deep into the Transformer model architecture and algorithm implementations. Next up we’ll look into multi-head attention and its variants. Stay tuned and tell me what you think and what you’d like to read!&lt;/p&gt;

&lt;h2&gt;
  
  
  References &amp;amp; Further Readings
&lt;/h2&gt;

&lt;ul&gt;
&lt;li&gt;Vaswani et al., &lt;a href="https://arxiv.org/abs/1706.03762" rel="noopener noreferrer"&gt;Attention is All You Need&lt;/a&gt;
&lt;/li&gt;
&lt;li&gt;Harvard NLP, &lt;a href="https://nlp.seas.harvard.edu/annotated-transformer/" rel="noopener noreferrer"&gt;The Annotated Transformer&lt;/a&gt;
&lt;/li&gt;
&lt;/ul&gt;

</description>
      <category>genai</category>
      <category>machinelearning</category>
      <category>algorithms</category>
      <category>ai</category>
    </item>
  </channel>
</rss>
