<?xml version="1.0" encoding="UTF-8"?>
<rss version="2.0" xmlns:atom="http://www.w3.org/2005/Atom" xmlns:dc="http://purl.org/dc/elements/1.1/">
  <channel>
    <title>DEV Community: Pranav Sateesh</title>
    <description>The latest articles on DEV Community by Pranav Sateesh (@stprnvsh).</description>
    <link>https://dev.to/stprnvsh</link>
    <image>
      <url>https://media2.dev.to/dynamic/image/width=90,height=90,fit=cover,gravity=auto,format=auto/https:%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Fuser%2Fprofile_image%2F3691754%2F81eb39f8-22ba-4058-a1af-217e5aa8e8ae.png</url>
      <title>DEV Community: Pranav Sateesh</title>
      <link>https://dev.to/stprnvsh</link>
    </image>
    <atom:link rel="self" type="application/rss+xml" href="https://dev.to/feed/stprnvsh"/>
    <language>en</language>
    <item>
      <title>Mosaic: Sharding Attention Across GPUs When Your Sequence Doesn't Fit</title>
      <dc:creator>Pranav Sateesh</dc:creator>
      <pubDate>Mon, 05 Jan 2026 06:53:43 +0000</pubDate>
      <link>https://dev.to/stprnvsh/mosaic-sharding-attention-across-gpus-when-your-sequence-doesnt-fit-4d92</link>
      <guid>https://dev.to/stprnvsh/mosaic-sharding-attention-across-gpus-when-your-sequence-doesnt-fit-4d92</guid>
      <description>&lt;p&gt;&lt;em&gt;How we built a lightweight library to distribute 150,000-token attention across multiple GPUs&lt;/em&gt;&lt;/p&gt;




&lt;h2&gt;
  
  
  The Problem: Attention Doesn't Scale
&lt;/h2&gt;

&lt;p&gt;You've probably heard that transformers have a "quadratic attention bottleneck." Here's what that actually means in practice.&lt;/p&gt;

&lt;p&gt;Attention computes:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;Attention(Q, K, V) = softmax(QKᵀ / √d) × V
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;The killer is &lt;strong&gt;QKᵀ&lt;/strong&gt; — a matrix of shape &lt;code&gt;(sequence_length × sequence_length)&lt;/code&gt;. For a 150,000-token sequence:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;Memory = 150,000² × 4 bytes = 90 billion bytes = 84 GB
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;That's just for the attention weights. One layer. One head. An A100 has 80GB total.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;You can't fit it.&lt;/strong&gt;&lt;/p&gt;

&lt;h2&gt;
  
  
  Existing Solutions (and Their Limits)
&lt;/h2&gt;

&lt;p&gt;&lt;strong&gt;FlashAttention&lt;/strong&gt; reduces memory from O(n²) to O(n) by computing attention in tiles without materializing the full matrix. But it still requires the entire sequence on one GPU.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Ring Attention&lt;/strong&gt; (from ring-flash-attn) shards the sequence across GPUs. Each GPU holds a chunk of Q and passes K, V around in a ring. Beautiful for 1D sequences.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;The gap:&lt;/strong&gt; What about models with multiple attention patterns? &lt;/p&gt;

&lt;p&gt;Consider a tabular transformer with shape &lt;code&gt;(batch, rows, features, embed)&lt;/code&gt;:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;Attention over &lt;strong&gt;features&lt;/strong&gt; (axis 2): 5 tokens — fits easily&lt;/li&gt;
&lt;li&gt;Attention over &lt;strong&gt;rows&lt;/strong&gt; (axis 1): 150,000 tokens — needs sharding&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;No library handled this cleanly. You'd write custom code for each axis, manage different process groups, handle the tensor reshaping yourself.&lt;/p&gt;

&lt;h2&gt;
  
  
  Mosaic: Multi-Axis Attention Sharding
&lt;/h2&gt;

&lt;p&gt;Mosaic is a thin coordination layer that routes different attention axes to appropriate backends:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;mosaic&lt;/span&gt;

&lt;span class="c1"&gt;# Small axis: run locally
&lt;/span&gt;&lt;span class="n"&gt;feature_attn&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;mosaic&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;MultiAxisAttention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
    &lt;span class="n"&gt;embed_dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;96&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_heads&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="n"&gt;attention_axis&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;    &lt;span class="c1"&gt;# features dimension
&lt;/span&gt;    &lt;span class="n"&gt;backend&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;local&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;      &lt;span class="c1"&gt;# no communication needed
&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="c1"&gt;# Large axis: shard across GPUs
&lt;/span&gt;&lt;span class="n"&gt;row_attn&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;mosaic&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;MultiAxisAttention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
    &lt;span class="n"&gt;embed_dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;96&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_heads&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="n"&gt;attention_axis&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;    &lt;span class="c1"&gt;# rows dimension  
&lt;/span&gt;    &lt;span class="n"&gt;backend&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;ring&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;       &lt;span class="c1"&gt;# ring attention across GPUs
&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;That's it. Mosaic handles:&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;Permuting the attention axis to the sequence position&lt;/li&gt;
&lt;li&gt;Reshaping for QKV projection&lt;/li&gt;
&lt;li&gt;Dispatching to the right backend&lt;/li&gt;
&lt;li&gt;Restoring the original tensor shape&lt;/li&gt;
&lt;/ol&gt;

&lt;h2&gt;
  
  
  How Ring Attention Works
&lt;/h2&gt;

&lt;p&gt;The key insight: you don't need all of K and V at once. You can compute partial attention scores, accumulate them, and normalize at the end.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;4 GPUs, sequence split into 4 chunks:

Initial state:
  GPU 0: Q₀, K₀, V₀
  GPU 1: Q₁, K₁, V₁
  GPU 2: Q₂, K₂, V₂
  GPU 3: Q₃, K₃, V₃

Step 1: Each GPU computes attention with its local K, V
  GPU 0: score₀₀ = Q₀ @ K₀ᵀ
  ...

Step 2: Pass K, V to the next GPU in the ring
  GPU 0 receives K₃, V₃ from GPU 3
  GPU 0 sends K₀, V₀ to GPU 1

Step 3: Compute attention with received K, V
  GPU 0: score₀₃ = Q₀ @ K₃ᵀ
  Accumulate with score₀₀

Repeat for all chunks...

Final: Each GPU has complete attention output for its Q chunk
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;strong&gt;Memory per GPU:&lt;/strong&gt; O(n²/p) where p = number of GPUs&lt;/p&gt;

&lt;p&gt;With 8 GPUs, you've reduced memory 8×. A 150k sequence now needs ~10GB per GPU instead of 84GB.&lt;/p&gt;

&lt;h2&gt;
  
  
  Beyond 1D: Mesh2D Attention
&lt;/h2&gt;

&lt;p&gt;For very long sequences, even ring attention isn't enough. Mesh2D shards both Q and K:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;4 GPUs in 2×2 mesh:

         K₀      K₁
       ┌──────┬──────┐
    Q₀ │GPU 0 │GPU 1 │
       ├──────┼──────┤
    Q₁ │GPU 2 │GPU 3 │
       └──────┴──────┘

Each GPU computes one tile of QKᵀ
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;strong&gt;Memory per GPU:&lt;/strong&gt; O(n²/p²)&lt;/p&gt;

&lt;p&gt;With 64 GPUs in an 8×8 mesh, memory drops 64× per GPU.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;attn&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;mosaic&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;MultiAxisAttention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
    &lt;span class="n"&gt;embed_dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;128&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_heads&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;8&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="n"&gt;attention_axis&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="n"&gt;backend&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;mesh2d&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="n"&gt;mesh_shape&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;8&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;8&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;h2&gt;
  
  
  Composed Strategies
&lt;/h2&gt;

&lt;p&gt;Real clusters have topology. GPUs within a node communicate via fast NVLink (900 GB/s). GPUs across nodes use slower InfiniBand (200 GB/s).&lt;/p&gt;

&lt;p&gt;Mosaic's &lt;code&gt;ComposedAttention&lt;/code&gt; exploits this:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# 4 nodes × 8 GPUs = 32 total
&lt;/span&gt;&lt;span class="n"&gt;composed&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;mosaic&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;ComposedAttention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
    &lt;span class="n"&gt;mesh_shape&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;8&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;       &lt;span class="c1"&gt;# (nodes, gpus_per_node)
&lt;/span&gt;    &lt;span class="n"&gt;head_parallel&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;      &lt;span class="c1"&gt;# Split heads across nodes (slow link)
&lt;/span&gt;    &lt;span class="n"&gt;seq_parallel&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;ring&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;      &lt;span class="c1"&gt;# Ring within nodes (fast link)
&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Or use &lt;code&gt;HierarchicalAttention&lt;/code&gt; for explicit control:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;hier&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;mosaic&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;HierarchicalAttention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
    &lt;span class="n"&gt;intra_node_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;8&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="n"&gt;intra_node_strategy&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;local&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;  &lt;span class="c1"&gt;# Compute locally within node
&lt;/span&gt;    &lt;span class="n"&gt;inter_node_strategy&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;ring&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;    &lt;span class="c1"&gt;# Ring between node leaders
&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;h2&gt;
  
  
  The Implementation
&lt;/h2&gt;

&lt;p&gt;Mosaic is ~800 lines of Python. Here's the core pattern:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;MultiAxisAttention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="c1"&gt;# 1. Move attention axis to seq position
&lt;/span&gt;        &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;inv_perm&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;_permute_to_seq&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

        &lt;span class="c1"&gt;# 2. Flatten batch dims, project QKV
&lt;/span&gt;        &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;view&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;seq_len&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;embed_dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;qkv&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;qkv_proj&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;view&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;batch&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;seq&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;heads&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;head_dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;v&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;qkv&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;permute&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;unbind&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

        &lt;span class="c1"&gt;# 3. Dispatch to backend
&lt;/span&gt;        &lt;span class="n"&gt;out&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;_attn_fn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;v&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;  &lt;span class="c1"&gt;# local, ring, or mesh2d
&lt;/span&gt;
        &lt;span class="c1"&gt;# 4. Project output, restore shape
&lt;/span&gt;        &lt;span class="n"&gt;out&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;out_proj&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;out&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;transpose&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;reshape&lt;/span&gt;&lt;span class="p"&gt;(...))&lt;/span&gt;
        &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;out&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;permute&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;inv_perm&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;The backends wrap existing libraries:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;code&gt;local&lt;/code&gt;: &lt;code&gt;F.scaled_dot_product_attention&lt;/code&gt; (FlashAttention)&lt;/li&gt;
&lt;li&gt;
&lt;code&gt;ring&lt;/code&gt;: &lt;code&gt;ring_flash_attn_func&lt;/code&gt; from ring-flash-attn&lt;/li&gt;
&lt;li&gt;
&lt;code&gt;mesh2d&lt;/code&gt;: Custom all-gather + SDPA&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;All use FlashAttention kernels for the actual attention computation.&lt;/p&gt;

&lt;h2&gt;
  
  
  Usage
&lt;/h2&gt;



&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight shell"&gt;&lt;code&gt;pip &lt;span class="nb"&gt;install &lt;/span&gt;git+https://github.com/stprnvsh/mosaic.git

&lt;span class="c"&gt;# With ring attention support&lt;/span&gt;
pip &lt;span class="nb"&gt;install &lt;/span&gt;flash-attn ring-flash-attn
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Single node:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight shell"&gt;&lt;code&gt;torchrun &lt;span class="nt"&gt;--nproc_per_node&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;4 train.py
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Multi-node:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight shell"&gt;&lt;code&gt;&lt;span class="c"&gt;# Node 0&lt;/span&gt;
torchrun &lt;span class="nt"&gt;--nnodes&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;2 &lt;span class="nt"&gt;--nproc_per_node&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;8 &lt;span class="nt"&gt;--node_rank&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;0 &lt;span class="se"&gt;\&lt;/span&gt;
    &lt;span class="nt"&gt;--master_addr&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;192.168.1.100 &lt;span class="nt"&gt;--master_port&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;29500 train.py

&lt;span class="c"&gt;# Node 1  &lt;/span&gt;
torchrun &lt;span class="nt"&gt;--nnodes&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;2 &lt;span class="nt"&gt;--nproc_per_node&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;8 &lt;span class="nt"&gt;--node_rank&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;1 &lt;span class="se"&gt;\&lt;/span&gt;
    &lt;span class="nt"&gt;--master_addr&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;192.168.1.100 &lt;span class="nt"&gt;--master_port&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;29500 train.py
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Training script:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;mosaic&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;torch.distributed&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;dist&lt;/span&gt;

&lt;span class="n"&gt;dist&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;init_process_group&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;nccl&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ctx&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;mosaic&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;init&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;sp_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;dist&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;get_world_size&lt;/span&gt;&lt;span class="p"&gt;())&lt;/span&gt;

&lt;span class="n"&gt;model&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;MyModel&lt;/span&gt;&lt;span class="p"&gt;().&lt;/span&gt;&lt;span class="nf"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;ctx&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="c1"&gt;# Data is pre-sharded: each GPU has seq_total / world_size tokens
&lt;/span&gt;&lt;span class="n"&gt;x_local&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;load_my_shard&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;span class="n"&gt;out&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;model&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x_local&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;  &lt;span class="c1"&gt;# Communication handled by Mosaic
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;h2&gt;
  
  
  When to Use What
&lt;/h2&gt;

&lt;div class="table-wrapper-paragraph"&gt;&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;Sequence&lt;/th&gt;
&lt;th&gt;GPUs&lt;/th&gt;
&lt;th&gt;Backend&lt;/th&gt;
&lt;th&gt;Memory/GPU&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;&amp;lt; 10k&lt;/td&gt;
&lt;td&gt;1&lt;/td&gt;
&lt;td&gt;&lt;code&gt;local&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;O(n²)&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;10k–100k&lt;/td&gt;
&lt;td&gt;2–8&lt;/td&gt;
&lt;td&gt;&lt;code&gt;ring&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;O(n²/p)&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;100k–1M&lt;/td&gt;
&lt;td&gt;8–64&lt;/td&gt;
&lt;td&gt;&lt;code&gt;mesh2d&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;O(n²/p²)&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&amp;gt; 1M&lt;/td&gt;
&lt;td&gt;64+&lt;/td&gt;
&lt;td&gt;&lt;code&gt;composed&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;O(n²/(p²·h))&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;&lt;/div&gt;

&lt;h2&gt;
  
  
  Performance
&lt;/h2&gt;

&lt;p&gt;We optimized for zero overhead:&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;
&lt;strong&gt;FlashAttention everywhere&lt;/strong&gt; — All backends use &lt;code&gt;F.scaled_dot_product_attention&lt;/code&gt; for fused GEMM + softmax&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Pre-selected dispatch&lt;/strong&gt; — Backend function bound at init, no branching in forward&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;View not copy&lt;/strong&gt; — &lt;code&gt;x.view()&lt;/code&gt; instead of &lt;code&gt;x.reshape()&lt;/code&gt; when contiguous&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Pre-allocated collectives&lt;/strong&gt; — &lt;code&gt;all_gather&lt;/code&gt; into pre-sized tensors, no &lt;code&gt;torch.cat&lt;/code&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Module-level imports&lt;/strong&gt; — No import overhead per forward pass&lt;/li&gt;
&lt;/ol&gt;

&lt;h2&gt;
  
  
  What Mosaic Is Not
&lt;/h2&gt;

&lt;p&gt;Mosaic doesn't:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;Auto-parallelize your model (use nnScaler for that)&lt;/li&gt;
&lt;li&gt;Handle data parallelism (use PyTorch DDP/FSDP)&lt;/li&gt;
&lt;li&gt;Manage model sharding (use FSDP or Megatron)&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;It does one thing: &lt;strong&gt;route multi-axis attention to the right sharding backend&lt;/strong&gt;.&lt;/p&gt;

&lt;h2&gt;
  
  
  The Origin Story
&lt;/h2&gt;

&lt;p&gt;This came from profiling nanoTabPFN, a transformer for tabular data. The model has attention over both rows (150k) and features (5). Standard ring attention doesn't understand "rows" vs "features" — it just sees a sequence dimension.&lt;/p&gt;

&lt;p&gt;We needed:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;Local attention for small axes&lt;/li&gt;
&lt;li&gt;Ring attention for large axes
&lt;/li&gt;
&lt;li&gt;Clean axis routing without rewriting the model&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;Mosaic is the result.&lt;/p&gt;




&lt;p&gt;&lt;strong&gt;Code:&lt;/strong&gt; &lt;a href="https://github.com/stprnvsh/mosaic" rel="noopener noreferrer"&gt;github.com/stprnvsh/mosaic&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Dependencies:&lt;/strong&gt; PyTorch 2.0+, NCCL, optionally flash-attn + ring-flash-attn&lt;/p&gt;

</description>
      <category>architecture</category>
      <category>deeplearning</category>
      <category>llm</category>
      <category>performance</category>
    </item>
    <item>
      <title>Our GPU Was Idle 77% of the Time. Here's How We Fixed It</title>
      <dc:creator>Pranav Sateesh</dc:creator>
      <pubDate>Sat, 03 Jan 2026 19:00:16 +0000</pubDate>
      <link>https://dev.to/stprnvsh/our-gpu-was-idle-77-of-the-time-heres-how-we-fixed-it-56oj</link>
      <guid>https://dev.to/stprnvsh/our-gpu-was-idle-77-of-the-time-heres-how-we-fixed-it-56oj</guid>
      <description>&lt;p&gt;&lt;em&gt;A practical guide to eliminating data transfer bottlenecks in PyTorch — achieving 1.5x speedup with pinned memory, CUDA streams, and GPU Direct Storage.&lt;/em&gt;&lt;/p&gt;




&lt;p&gt;We assumed the GPU was our bottleneck. We were wrong.&lt;/p&gt;

&lt;p&gt;While training a transformer model, I noticed something strange in the profiler output: the CPU was spending &lt;strong&gt;77% of its time&lt;/strong&gt; on &lt;code&gt;cudaMemcpyAsync&lt;/code&gt;. Our expensive A100 GPU wasn't compute-bound. it was &lt;em&gt;starving for data&lt;/em&gt;.&lt;/p&gt;

&lt;p&gt;This post covers how we diagnosed the problem, fixed it with three increasingly aggressive optimizations, and hit the next wall. If you're training models on large datasets and haven't profiled your data pipeline, you might be leaving significant performance on the table.&lt;/p&gt;




&lt;h2&gt;
  
  
  The Setup
&lt;/h2&gt;

&lt;p&gt;We're training &lt;a href="https://github.com/stprnvsh/nanoTabPFN" rel="noopener noreferrer"&gt;nanoTabPFN&lt;/a&gt;, a transformer for tabular data. Training data lives in HDF5 files: 30,000 samples, each with 5,000 rows and 5 features. Hardware: NVIDIA A100-SXM4–80GB.&lt;/p&gt;

&lt;p&gt;The original data loading code was textbook PyTorch:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;with&lt;/span&gt; &lt;span class="n"&gt;h5py&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;File&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;filename&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;r&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;step&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;num_steps&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;from_numpy&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;X&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;][&lt;/span&gt;&lt;span class="n"&gt;ptr&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="n"&gt;end&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
        &lt;span class="n"&gt;y&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;from_numpy&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;y&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;][&lt;/span&gt;&lt;span class="n"&gt;ptr&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="n"&gt;end&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
        &lt;span class="k"&gt;yield&lt;/span&gt; &lt;span class="nf"&gt;dict&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Simple. Correct. And devastatingly slow.&lt;/p&gt;




&lt;h2&gt;
  
  
  Profile First, Optimize Later
&lt;/h2&gt;

&lt;p&gt;Before touching any code, we ran PyTorch's built-in profiler:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;torch.profiler&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;profile&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;ProfilerActivity&lt;/span&gt;

&lt;span class="k"&gt;with&lt;/span&gt; &lt;span class="nf"&gt;profile&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
    &lt;span class="n"&gt;activities&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;ProfilerActivity&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;CPU&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;ProfilerActivity&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;CUDA&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt;
    &lt;span class="n"&gt;record_shapes&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="n"&gt;profile_memory&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;prof&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="nf"&gt;train&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;prior&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;prof&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;key_averages&lt;/span&gt;&lt;span class="p"&gt;().&lt;/span&gt;&lt;span class="nf"&gt;table&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;sort_by&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;cpu_time_total&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;The results were shocking:&lt;/p&gt;

&lt;div class="table-wrapper-paragraph"&gt;&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;Operation&lt;/th&gt;
&lt;th&gt;CPU Time&lt;/th&gt;
&lt;th&gt;% of Total&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;cudaMemcpyAsync&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;44,084ms&lt;/td&gt;
&lt;td&gt;&lt;strong&gt;76.78%&lt;/strong&gt;&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;cudaMalloc&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;7,081ms&lt;/td&gt;
&lt;td&gt;12.33%&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;cudaLaunchKernel&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;645ms&lt;/td&gt;
&lt;td&gt;1.12%&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;aten::bmm&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;180ms&lt;/td&gt;
&lt;td&gt;0.31%&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;&lt;/div&gt;

&lt;p&gt;The GPU was doing matrix multiplications in &lt;em&gt;milliseconds&lt;/em&gt; while the CPU spent &lt;strong&gt;44 seconds&lt;/strong&gt; copying data.&lt;/p&gt;




&lt;h2&gt;
  
  
  Understanding the Problem
&lt;/h2&gt;

&lt;p&gt;The &lt;code&gt;.to(device)&lt;/code&gt; call in PyTorch is synchronous by default. Here's the hidden pipeline:&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;
&lt;strong&gt;h5py reads from disk&lt;/strong&gt; → CPU memory (pageable)&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;PyTorch allocates&lt;/strong&gt; → CPU staging buffer&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;cudaMemcpy&lt;/strong&gt; → GPU memory &lt;em&gt;(blocks until complete)&lt;/em&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;GPU computes&lt;/strong&gt; → while CPU waits for step 1&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;The GPU sits idle during steps 1–3. With 5,000-row samples at float32, each batch transfer is ~120MB. That's 12GB of sequential transfers over 100 steps.&lt;/p&gt;




&lt;h2&gt;
  
  
  Fix #1: Pinned Memory + Non-blocking Transfers
&lt;/h2&gt;

&lt;p&gt;The first optimization: use page-locked (pinned) memory with async transfers.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# Before: synchronous, pageable memory
&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;from_numpy&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x_np&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="c1"&gt;# After: pinned memory, async transfer
&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;from_numpy&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x_np&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;pin_memory&lt;/span&gt;&lt;span class="p"&gt;().&lt;/span&gt;&lt;span class="nf"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;non_blocking&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;strong&gt;Why this works:&lt;/strong&gt; Pinned memory is DMA-accessible — the GPU can read it directly without CPU intervention. Combined with &lt;code&gt;non_blocking=True&lt;/code&gt;, the transfer happens in the background while the CPU continues working.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Impact:&lt;/strong&gt; &lt;code&gt;cudaMemcpyAsync&lt;/code&gt; time dropped from 44s to ~4s.&lt;/p&gt;




&lt;h2&gt;
  
  
  Fix #2: CUDA Streams for True Overlap
&lt;/h2&gt;

&lt;p&gt;Non-blocking transfers alone aren't enough. By default, operations on the same CUDA stream are serialized. We need a separate stream for data transfer:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;PriorDumpDataLoader&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;...):&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;transfer_stream&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;cuda&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Stream&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;__iter__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="c1"&gt;# Pre-fill buffer with first batches
&lt;/span&gt;        &lt;span class="n"&gt;vram_buffer&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;_load_to_vram&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;_&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;prefetch&lt;/span&gt;&lt;span class="p"&gt;)]&lt;/span&gt;

        &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;step&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;num_steps&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
            &lt;span class="n"&gt;batch&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;vram_buffer&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;pop&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;  &lt;span class="c1"&gt;# Already in VRAM
&lt;/span&gt;
            &lt;span class="c1"&gt;# Prefetch next batch on separate stream
&lt;/span&gt;            &lt;span class="k"&gt;with&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;cuda&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;stream&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;transfer_stream&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
                &lt;span class="n"&gt;next_batch&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;_load_to_vram&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
            &lt;span class="n"&gt;vram_buffer&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;next_batch&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

            &lt;span class="c1"&gt;# Sync before yielding
&lt;/span&gt;            &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;cuda&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;current_stream&lt;/span&gt;&lt;span class="p"&gt;().&lt;/span&gt;&lt;span class="nf"&gt;wait_stream&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;transfer_stream&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
            &lt;span class="k"&gt;yield&lt;/span&gt; &lt;span class="n"&gt;batch&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;This is &lt;strong&gt;double buffering&lt;/strong&gt;: while the GPU processes batch N, the CPU+DMA engine load batch N+1. The GPU never waits.&lt;/p&gt;




&lt;h2&gt;
  
  
  Fix #3: GPU Direct Storage (GDS)
&lt;/h2&gt;

&lt;p&gt;The ultimate optimization: bypass the CPU entirely.&lt;/p&gt;

&lt;p&gt;NVIDIA's GPUDirect Storage reads directly from NVMe to GPU memory:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;kvikio&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;cupy&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;cp&lt;/span&gt;

&lt;span class="c1"&gt;# Allocate GPU buffer
&lt;/span&gt;&lt;span class="n"&gt;x_gpu&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;cp&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;empty&lt;/span&gt;&lt;span class="p"&gt;((&lt;/span&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;seq_len&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;features&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;dtype&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;cp&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;float32&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="c1"&gt;# Direct read: NVMe → GPU (no CPU copy)
&lt;/span&gt;&lt;span class="k"&gt;with&lt;/span&gt; &lt;span class="n"&gt;kvikio&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;CuFile&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;data.bin&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;r&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;pread&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x_gpu&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;file_offset&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;offset&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="c1"&gt;# Zero-copy to PyTorch
&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;as_tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x_gpu&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;cuda&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;strong&gt;The catch:&lt;/strong&gt; GDS requires raw binary files. HDF5 has headers that need CPU parsing. We added automatic conversion on first run:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;convert_h5_to_raw&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;h5_filename&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="k"&gt;with&lt;/span&gt; &lt;span class="n"&gt;h5py&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;File&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;h5_filename&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;r&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
        &lt;span class="n"&gt;X&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;X&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;][:].&lt;/span&gt;&lt;span class="nf"&gt;astype&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;float32&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;y&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;y&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;][:].&lt;/span&gt;&lt;span class="nf"&gt;astype&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;float32&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;X&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;tofile&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;base&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt;_X.bin&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;tofile&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;base&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt;_y.bin&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;






&lt;h2&gt;
  
  
  Results
&lt;/h2&gt;

&lt;div class="table-wrapper-paragraph"&gt;&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;Metric&lt;/th&gt;
&lt;th&gt;Baseline&lt;/th&gt;
&lt;th&gt;Optimized&lt;/th&gt;
&lt;th&gt;Speedup&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;Total time (100 steps)&lt;/td&gt;
&lt;td&gt;68.75s&lt;/td&gt;
&lt;td&gt;45.30s&lt;/td&gt;
&lt;td&gt;&lt;strong&gt;1.52x&lt;/strong&gt;&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;
&lt;code&gt;cudaMemcpyAsync&lt;/code&gt; CPU&lt;/td&gt;
&lt;td&gt;44,084ms&lt;/td&gt;
&lt;td&gt;268ms&lt;/td&gt;
&lt;td&gt;&lt;strong&gt;164x&lt;/strong&gt;&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;Steps/sec&lt;/td&gt;
&lt;td&gt;1.5&lt;/td&gt;
&lt;td&gt;2.2&lt;/td&gt;
&lt;td&gt;&lt;strong&gt;1.47x&lt;/strong&gt;&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;&lt;/div&gt;

&lt;p&gt;Memory transfer overhead dropped from 77% to &amp;lt;1% of CPU time.&lt;/p&gt;




&lt;h2&gt;
  
  
  The New Bottleneck
&lt;/h2&gt;

&lt;p&gt;With data loading solved, the profile looks completely different:&lt;/p&gt;

&lt;div class="table-wrapper-paragraph"&gt;&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;Operation&lt;/th&gt;
&lt;th&gt;CPU Time&lt;/th&gt;
&lt;th&gt;% of Total&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;Command Buffer Full&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;23,450ms&lt;/td&gt;
&lt;td&gt;&lt;strong&gt;46.91%&lt;/strong&gt;&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;cudaLaunchKernel&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;10,733ms&lt;/td&gt;
&lt;td&gt;21.47%&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;cudaMalloc&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;5,607ms&lt;/td&gt;
&lt;td&gt;11.22%&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;&lt;/div&gt;

&lt;p&gt;The GPU is now &lt;strong&gt;saturated&lt;/strong&gt;. "Command Buffer Full" means the GPU can't keep up with kernel submissions. This is exactly what we want — the GPU is the bottleneck, not data loading.&lt;/p&gt;

&lt;p&gt;The remaining compute bottleneck is attention (&lt;code&gt;aten::bmm&lt;/code&gt; at 45% CUDA time). With 5,000-row sequences, attention's O(n²) scaling dominates. Flash Attention is the next optimization.&lt;/p&gt;




&lt;h2&gt;
  
  
  Key Takeaways
&lt;/h2&gt;

&lt;p&gt;&lt;strong&gt;Async is not automatic.&lt;/strong&gt; &lt;code&gt;non_blocking=True&lt;/code&gt; does nothing without proper stream management.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Pinned memory matters.&lt;/strong&gt; 10x+ difference for large transfers.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;GDS has constraints.&lt;/strong&gt; True zero-copy requires raw binary files, GDS-compatible NVMe, and proper alignment.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Know when to stop.&lt;/strong&gt; Once you're GPU-bound, data loading optimizations won't help. Move to model architecture changes.&lt;/p&gt;




&lt;h2&gt;
  
  
  Quick Reference
&lt;/h2&gt;

&lt;div class="table-wrapper-paragraph"&gt;&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;Technique&lt;/th&gt;
&lt;th&gt;What it does&lt;/th&gt;
&lt;th&gt;When to use&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;pin_memory()&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;Page-locked CPU memory&lt;/td&gt;
&lt;td&gt;Always for GPU training&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;non_blocking=True&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;Async H2D transfer&lt;/td&gt;
&lt;td&gt;With CUDA streams&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;CUDA Streams&lt;/td&gt;
&lt;td&gt;Parallel transfer/compute&lt;/td&gt;
&lt;td&gt;Large batch sizes&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;Double buffering&lt;/td&gt;
&lt;td&gt;Prefetch next batch&lt;/td&gt;
&lt;td&gt;I/O-bound workloads&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;GDS (kvikio)&lt;/td&gt;
&lt;td&gt;Disk → GPU direct&lt;/td&gt;
&lt;td&gt;Large sequential reads&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;&lt;/div&gt;




&lt;h2&gt;
  
  
  Code
&lt;/h2&gt;

&lt;p&gt;All code is available at &lt;a href="https://github.com/stprnvsh/nanoTabPFN" rel="noopener noreferrer"&gt;github.com/stprnvsh/nanoTabPFN&lt;/a&gt;:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight shell"&gt;&lt;code&gt;&lt;span class="c"&gt;# Baseline&lt;/span&gt;
python train.py &lt;span class="nt"&gt;--profile&lt;/span&gt; &lt;span class="nt"&gt;--steps&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;100 &lt;span class="nt"&gt;--batch-size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;6

&lt;span class="c"&gt;# Optimized with GDS&lt;/span&gt;
python train_optimized.py &lt;span class="nt"&gt;--gds-bin&lt;/span&gt; &lt;span class="nt"&gt;--batch-size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;4 &lt;span class="nt"&gt;--steps&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;200

&lt;span class="c"&gt;# With Flash Attention&lt;/span&gt;
python train_optimized.py &lt;span class="nt"&gt;--flash&lt;/span&gt; &lt;span class="nt"&gt;--gds-bin&lt;/span&gt; &lt;span class="nt"&gt;--batch-size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;8 &lt;span class="nt"&gt;--steps&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;200
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



</description>
      <category>deeplearning</category>
      <category>performance</category>
      <category>python</category>
      <category>tutorial</category>
    </item>
  </channel>
</rss>
