<?xml version="1.0" encoding="UTF-8"?>
<rss version="2.0" xmlns:atom="http://www.w3.org/2005/Atom" xmlns:dc="http://purl.org/dc/elements/1.1/">
  <channel>
    <title>DEV Community: Josef Albers</title>
    <description>The latest articles on DEV Community by Josef Albers (@josef_albers_fc59b610c5de).</description>
    <link>https://dev.to/josef_albers_fc59b610c5de</link>
    <image>
      <url>https://media2.dev.to/dynamic/image/width=90,height=90,fit=cover,gravity=auto,format=auto/https:%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Fuser%2Fprofile_image%2F1864842%2F774d37ec-856d-430b-b0d2-967179a9a445.png</url>
      <title>DEV Community: Josef Albers</title>
      <link>https://dev.to/josef_albers_fc59b610c5de</link>
    </image>
    <atom:link rel="self" type="application/rss+xml" href="https://dev.to/feed/josef_albers_fc59b610c5de"/>
    <language>en</language>
    <item>
      <title>Zig's Power in Action: C Integration and WASM Compilation for Terrain Generation</title>
      <dc:creator>Josef Albers</dc:creator>
      <pubDate>Thu, 15 Aug 2024 15:09:36 +0000</pubDate>
      <link>https://dev.to/josef_albers_fc59b610c5de/zigs-power-in-action-c-integration-and-wasm-compilation-for-terrain-generation-12pj</link>
      <guid>https://dev.to/josef_albers_fc59b610c5de/zigs-power-in-action-c-integration-and-wasm-compilation-for-terrain-generation-12pj</guid>
      <description>&lt;p&gt;Zig's ability to directly interface with C libraries and compile to WebAssembly (WASM) opens doors for diverse applications. This post showcases these capabilities through TerrainZigger, a 3D terrain generator project.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Key Points:&lt;/strong&gt;&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;Seamless C Interoperability:&lt;/strong&gt; Zig's &lt;code&gt;@cImport&lt;/code&gt; allows effortless importing and utilization of C libraries, enabling developers to tap into a rich ecosystem of existing C code. TerrainZigger demonstrates this by integrating Raylib for rendering.&lt;br&gt;
&lt;/p&gt;
&lt;pre class="highlight zig"&gt;&lt;code&gt;&lt;span class="k"&gt;const&lt;/span&gt; &lt;span class="n"&gt;ray&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nb"&gt;@cImport&lt;/span&gt;&lt;span class="p"&gt;({&lt;/span&gt; &lt;span class="nb"&gt;@cInclude&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="s"&gt;"raylib.h"&lt;/span&gt;&lt;span class="p"&gt;);&lt;/span&gt; &lt;span class="p"&gt;});&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;Effortless WASM Compilation:&lt;/strong&gt; Zig's &lt;code&gt;build-exe&lt;/code&gt; toolchain facilitates seamless compilation to WASM, making Zig code accessible from JavaScript and readily embeddable in web pages. TerrainZigger exemplifies this by providing a playable demo on itch.io.&lt;br&gt;
&lt;/p&gt;
&lt;pre class="highlight shell"&gt;&lt;code&gt;zig build-exe terrain_zigger.zig &lt;span class="nt"&gt;-target&lt;/span&gt; wasm32-freestanding &lt;span class="nt"&gt;-O&lt;/span&gt; ReleaseSmall &lt;span class="nt"&gt;-fno-entry&lt;/span&gt; &lt;span class="nt"&gt;--export&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;generate_terrain_wasm &lt;span class="nt"&gt;--export&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;get_terrain_height_wasm &lt;span class="o"&gt;&amp;amp;&amp;amp;&lt;/span&gt; python &lt;span class="nt"&gt;-m&lt;/span&gt; http.server &amp;amp; open http://localhost:8000/
&lt;span class="nb"&gt;kill&lt;/span&gt; &lt;span class="si"&gt;$(&lt;/span&gt;lsof &lt;span class="nt"&gt;-t&lt;/span&gt; &lt;span class="nt"&gt;-i&lt;/span&gt;:8000&lt;span class="si"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;Performance and Control:&lt;/strong&gt; Zig's emphasis on low-level control and performance is ideal for computationally demanding tasks such as terrain generation.&lt;br&gt;
&lt;/p&gt;
&lt;pre class="highlight shell"&gt;&lt;code&gt;zig build-exe walk.zig &lt;span class="nt"&gt;-I&lt;/span&gt;&lt;span class="nb"&gt;.&lt;/span&gt; &lt;span class="nt"&gt;-lc&lt;/span&gt; &lt;span class="si"&gt;$(&lt;/span&gt;pkg-config &lt;span class="nt"&gt;--libs&lt;/span&gt; &lt;span class="nt"&gt;--cflags&lt;/span&gt; raylib&lt;span class="si"&gt;)&lt;/span&gt; &lt;span class="nt"&gt;-O&lt;/span&gt; Debug
leaks &lt;span class="nt"&gt;-atExit&lt;/span&gt; &lt;span class="nt"&gt;--&lt;/span&gt; ./walk
&lt;/code&gt;&lt;/pre&gt;

&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;&lt;strong&gt;TerrainZigger&lt;/strong&gt;&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;strong&gt;GitHub Repo:&lt;/strong&gt; &lt;a href="https://github.com/JosefAlbers/TerrainZigger" rel="noopener noreferrer"&gt;https://github.com/JosefAlbers/TerrainZigger&lt;/a&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Playable Demo:&lt;/strong&gt; &lt;a href="https://albersj66.itch.io/terrainzigger" rel="noopener noreferrer"&gt;https://albersj66.itch.io/terrainzigger&lt;/a&gt;
&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;&lt;strong&gt;Conclusion&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;Zig's seamless interaction with C libraries and WASM compilation capability enable developers to craft high-performance applications across different platforms, including natively and within web browsers. Whether it's games, simulations, or interactive projects, Zig offers the tools to bring ideas to life.&lt;/p&gt;

</description>
      <category>zig</category>
      <category>javascript</category>
      <category>raylib</category>
      <category>discuss</category>
    </item>
    <item>
      <title>Porting Phi-3-Vision to MLX: A Python Hobbyist's Journey</title>
      <dc:creator>Josef Albers</dc:creator>
      <pubDate>Wed, 07 Aug 2024 01:38:29 +0000</pubDate>
      <link>https://dev.to/josef_albers_fc59b610c5de/porting-phi-3-vision-to-mlx-a-python-hobbyists-journey-2l2k</link>
      <guid>https://dev.to/josef_albers_fc59b610c5de/porting-phi-3-vision-to-mlx-a-python-hobbyists-journey-2l2k</guid>
      <description>&lt;p&gt;Hey fellow devs! 👋&lt;/p&gt;

&lt;p&gt;I've just published a comprehensive series on Medium detailing my journey of porting Phi-3-Vision, a powerful vision-language model, from Hugging Face to Apple's MLX framework. As a Python hobbyist, I wanted to share my experience and hopefully inspire others to dive into AI model optimization.&lt;/p&gt;

&lt;h2&gt;
  
  
  📚 Series Overview:
&lt;/h2&gt;

&lt;ol&gt;
&lt;li&gt;
&lt;strong&gt;Basic Implementation&lt;/strong&gt;: Getting Phi-3-Vision up and running in MLX.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Su-scaled Rotary Position Embeddings (SuRoPE)&lt;/strong&gt;: Implementing 128K context support.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Batching&lt;/strong&gt;: Optimizing for multiple inputs.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Caching&lt;/strong&gt;: Speeding up text generation.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Choice Selection&lt;/strong&gt;: Implementing constrained output.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Constrained Decoding&lt;/strong&gt;: Guiding the model's output structure.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;LoRA Training&lt;/strong&gt;: Fine-tuning the model efficiently.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Agent and Toolchain System&lt;/strong&gt;: Building flexible AI workflows.&lt;/li&gt;
&lt;/ol&gt;

&lt;h2&gt;
  
  
  🤔 Why This Matters:
&lt;/h2&gt;

&lt;ul&gt;
&lt;li&gt;Run advanced AI models efficiently on Apple Silicon&lt;/li&gt;
&lt;li&gt;Learn about model optimization techniques&lt;/li&gt;
&lt;li&gt;Understand the internals of vision-language models&lt;/li&gt;
&lt;li&gt;Explore the capabilities of MLX for AI development&lt;/li&gt;
&lt;/ul&gt;

&lt;h2&gt;
  
  
  🔗 Read the Full Series:
&lt;/h2&gt;

&lt;p&gt;&lt;a href="https://medium.com/@albersj66" rel="noopener noreferrer"&gt;https://medium.com/@albersj66&lt;/a&gt;&lt;/p&gt;

&lt;h2&gt;
  
  
  💻 GitHub Repository:
&lt;/h2&gt;

&lt;p&gt;I've open-sourced all the code and markdown files used in this series. You can find them in my GitHub repository:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://github.com/JosefAlbers/Phi-3-Vision-MLX" rel="noopener noreferrer"&gt;https://github.com/JosefAlbers/Phi-3-Vision-MLX&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Feel free to explore, experiment, and contribute!&lt;/p&gt;

&lt;h2&gt;
  
  
  💬 Let's Discuss:
&lt;/h2&gt;

&lt;ul&gt;
&lt;li&gt;Have you worked with MLX or other AI frameworks on Apple Silicon?&lt;/li&gt;
&lt;li&gt;What challenges have you faced in porting or optimizing AI models?&lt;/li&gt;
&lt;li&gt;Any specific parts of the series you'd like to dive deeper into?&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;I'm excited to hear your thoughts and experiences! Let's learn from each other and push the boundaries of what's possible with AI on consumer hardware.&lt;/p&gt;

</description>
      <category>discuss</category>
      <category>machinelearning</category>
      <category>applesilicon</category>
      <category>mlx</category>
    </item>
    <item>
      <title>Part 2: Implementing Su-scaled Rotary Position Embeddings (RoPE) for Phi-3-Vision</title>
      <dc:creator>Josef Albers</dc:creator>
      <pubDate>Thu, 01 Aug 2024 12:49:51 +0000</pubDate>
      <link>https://dev.to/josef_albers_fc59b610c5de/part-2-implementing-su-scaled-rotary-position-embeddings-rope-for-phi-3-vision-1oh0</link>
      <guid>https://dev.to/josef_albers_fc59b610c5de/part-2-implementing-su-scaled-rotary-position-embeddings-rope-for-phi-3-vision-1oh0</guid>
      <description>&lt;h2&gt;
  
  
  Introduction
&lt;/h2&gt;

&lt;p&gt;Welcome to Part 2 of our Phi-3-Vision porting series. In Part 1, we've created a basic implementation of the model in MLX. However, we also noted that it struggles with longer sequences. Today, we'll address this limitation by implementing Su-scaled Rotary Position Embeddings (RoPE), which will significantly enhance our model's ability to handle long contexts of up to 128K tokens.&lt;/p&gt;

&lt;p&gt;The full implementation of what we'll cover today is available at &lt;a href="https://github.com/JosefAlbers/Phi-3-Vision-MLX/tree/main/assets/tutorial_2.py" rel="noopener noreferrer"&gt;https://github.com/JosefAlbers/Phi-3-Vision-MLX/tree/main/assets/tutorial_2.py&lt;/a&gt;&lt;/p&gt;

&lt;h2&gt;
  
  
  1. Understanding Rotary Position Embeddings (RoPE)
&lt;/h2&gt;

&lt;p&gt;Before we delve into Su-scaled RoPE, let's first understand the basics of Rotary Position Embeddings.&lt;/p&gt;

&lt;p&gt;RoPE is a technique that injects positional information into the model's token representations without adding extra tokens or increasing the model's parameter count. The key idea is to apply a rotation to each token's embedding based on its position in the sequence.&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;Frequency Calculation&lt;/strong&gt;: For each dimension d in the embedding space, RoPE calculates a frequency:&lt;br&gt;
&lt;/p&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;inv_freq&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;base&lt;/span&gt; &lt;span class="o"&gt;**&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;d&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;


&lt;p&gt;&lt;a href="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fraw.githubusercontent.com%2FJosefAlbers%2FPhi-3-Vision-MLX%2Fmain%2Fassets%2Ftutorial_part2_rope_inv.png" class="article-body-image-wrapper"&gt;&lt;img src="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fraw.githubusercontent.com%2FJosefAlbers%2FPhi-3-Vision-MLX%2Fmain%2Fassets%2Ftutorial_part2_rope_inv.png"&gt;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;Position-Frequency Interaction&lt;/strong&gt;: These frequencies are then multiplied by the token positions to create unique sinusoidal patterns for each position.&lt;br&gt;
&lt;/p&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;freqs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;inv_freq&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;position_ids&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;T&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;


&lt;p&gt;&lt;a href="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fraw.githubusercontent.com%2FJosefAlbers%2FPhi-3-Vision-MLX%2Fmain%2Fassets%2Ftutorial_part2_rope_int.png" class="article-body-image-wrapper"&gt;&lt;img src="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fraw.githubusercontent.com%2FJosefAlbers%2FPhi-3-Vision-MLX%2Fmain%2Fassets%2Ftutorial_part2_rope_int.png"&gt;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;Rotation Application&lt;/strong&gt;: The resulting patterns are used to rotate the token embeddings in 2D planes.&lt;/p&gt;

&lt;p&gt;For a token at position &lt;code&gt;pos&lt;/code&gt;, RoPE applies the following rotation:&lt;br&gt;
&lt;/p&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;x_rotated&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="nf"&gt;cos&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;pos&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;freq&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;y&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="nf"&gt;sin&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;pos&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;freq&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
             &lt;span class="n"&gt;y&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="nf"&gt;cos&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;pos&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;freq&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="nf"&gt;sin&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;pos&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;freq&lt;/span&gt;&lt;span class="p"&gt;)]&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;


&lt;p&gt;&lt;a href="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fraw.githubusercontent.com%2FJosefAlbers%2FPhi-3-Vision-MLX%2Fmain%2Fassets%2Ftutorial_part2_rope_rot.png" class="article-body-image-wrapper"&gt;&lt;img src="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fraw.githubusercontent.com%2FJosefAlbers%2FPhi-3-Vision-MLX%2Fmain%2Fassets%2Ftutorial_part2_rope_rot.png"&gt;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;Now that we understand RoPE, let's explore how Su-scaled RoPE builds upon and enhances this concept.&lt;/p&gt;

&lt;h2&gt;
  
  
  2. Understanding Su-RoPE
&lt;/h2&gt;

&lt;p&gt;Su-RoPE extends RoPE by introducing scaling factors for different sequence length ranges.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;freq&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;SU_FACTOR&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;theta&lt;/span&gt; &lt;span class="o"&gt;**&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;d&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;This allows the model to better generalize to sequences longer than those seen during training.&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;Short and Long Factors&lt;/strong&gt;: Two sets of scaling factors are used, one for shorter sequences and one for longer sequences.&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fraw.githubusercontent.com%2FJosefAlbers%2FPhi-3-Vision-MLX%2Fmain%2Fassets%2Ftutorial_part2_surope_fac.png" class="article-body-image-wrapper"&gt;&lt;img src="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fraw.githubusercontent.com%2FJosefAlbers%2FPhi-3-Vision-MLX%2Fmain%2Fassets%2Ftutorial_part2_surope_fac.png"&gt;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;Adaptive Scaling&lt;/strong&gt;: The choice between short and long factors is made based on the sequence length.&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fraw.githubusercontent.com%2FJosefAlbers%2FPhi-3-Vision-MLX%2Fmain%2Fassets%2Ftutorial_part2_surope_inv.png" class="article-body-image-wrapper"&gt;&lt;img src="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fraw.githubusercontent.com%2FJosefAlbers%2FPhi-3-Vision-MLX%2Fmain%2Fassets%2Ftutorial_part2_surope_inv.png"&gt;&lt;/a&gt;&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;&lt;p&gt;&lt;strong&gt;Scaling Factor&lt;/strong&gt;: An additional scaling factor is applied to adjust for the extended maximum position embeddings.&lt;/p&gt;&lt;/li&gt;
&lt;/ol&gt;

&lt;h2&gt;
  
  
  3. Implementing Su-scaled RoPE
&lt;/h2&gt;

&lt;p&gt;Now that we understand the theory behind Su-scaled RoPE, let's implement it in code. We'll create a &lt;code&gt;SuRoPE&lt;/code&gt; class that encapsulates all the functionality we've discussed:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;mlx.core&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;mx&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;mlx.nn&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;math&lt;/span&gt;

&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;SuRoPE&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;config&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;config&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;hidden_size&lt;/span&gt; &lt;span class="o"&gt;//&lt;/span&gt; &lt;span class="n"&gt;config&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;num_attention_heads&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;original_max_position_embeddings&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;config&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;original_max_position_embeddings&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;rope_theta&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;config&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;rope_theta&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;scaling_factor&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;math&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;sqrt&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;math&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;log&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;config&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;max_position_embeddings&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;config&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;original_max_position_embeddings&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;math&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;log&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;config&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;original_max_position_embeddings&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;long_factor&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;config&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;rope_scaling&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;long_factor&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;short_factor&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;config&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;rope_scaling&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;short_factor&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;__call__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;position_ids&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;None&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="n"&gt;position_ids&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;mx&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;arange&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;dtype&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;mx&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;float32&lt;/span&gt;&lt;span class="p"&gt;)[&lt;/span&gt;&lt;span class="bp"&gt;None&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;position_ids&lt;/span&gt; &lt;span class="ow"&gt;is&lt;/span&gt; &lt;span class="bp"&gt;None&lt;/span&gt; &lt;span class="k"&gt;else&lt;/span&gt; &lt;span class="n"&gt;position_ids&lt;/span&gt;
        &lt;span class="n"&gt;cos&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sin&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;_get_cos_sin&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;position_ids&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;q&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;q&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;cos&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;_rotate_half&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;sin&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;k&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;k&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;cos&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;_rotate_half&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;sin&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;k&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;_get_cos_sin&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;position_ids&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="n"&gt;su_factor&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;long_factor&lt;/span&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;mx&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;max&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;position_ids&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;original_max_position_embeddings&lt;/span&gt; &lt;span class="k"&gt;else&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;short_factor&lt;/span&gt;
        &lt;span class="n"&gt;position_ids_expanded&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;position_ids&lt;/span&gt;&lt;span class="p"&gt;[:,&lt;/span&gt; &lt;span class="bp"&gt;None&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;:]&lt;/span&gt;
        &lt;span class="n"&gt;inv_freq&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mf"&gt;1.0&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;mx&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;array&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;su_factor&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dtype&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;mx&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;float32&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;rope_theta&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;mx&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;arange&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dtype&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;mx&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;float32&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
        &lt;span class="n"&gt;inv_freq_expanded&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;mx&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;repeat&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;inv_freq&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="bp"&gt;None&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;:,&lt;/span&gt; &lt;span class="bp"&gt;None&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;position_ids&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;freqs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;inv_freq_expanded&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;position_ids_expanded&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;transpose&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;emb&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;mx&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;concatenate&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="n"&gt;freqs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;freqs&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;cos&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;mx&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;expand_dims&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;mx&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;cos&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;emb&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;scaling_factor&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;sin&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;mx&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;expand_dims&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;mx&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;sin&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;emb&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;scaling_factor&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;cos&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sin&lt;/span&gt;

    &lt;span class="nd"&gt;@staticmethod&lt;/span&gt;
    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;_rotate_half&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="n"&gt;midpoint&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;//&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;
        &lt;span class="n"&gt;x1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x2&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;[...,&lt;/span&gt; &lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="n"&gt;midpoint&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;[...,&lt;/span&gt; &lt;span class="n"&gt;midpoint&lt;/span&gt;&lt;span class="p"&gt;:]&lt;/span&gt;
        &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;mx&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;concatenate&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="n"&gt;x2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x1&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;h2&gt;
  
  
  4. Integrating Su-scaled RoPE into Phi-3-Vision
&lt;/h2&gt;

&lt;p&gt;Integrating our Su-scaled RoPE implementation into the Phi-3-Vision model is straightforward. We only need to add two lines to our Phi3Attention module:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;Phi3Attention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;config&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="c1"&gt;# ...
&lt;/span&gt;        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;rope&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;SuRoPE&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;config&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;__call__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="c1"&gt;# ...
&lt;/span&gt;        &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;k&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;rope&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="c1"&gt;# ...
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;And now our ported model can handle up to 128K tokens!&lt;/p&gt;

&lt;h2&gt;
  
  
  Conclusion
&lt;/h2&gt;

&lt;p&gt;In this tutorial, we implemented Su-scaled Rotary Position Embeddings (RoPE), enabling our model to handle sequences up to 128K tokens. &lt;/p&gt;

&lt;p&gt;The full implementation is available at &lt;a href="https://github.com/JosefAlbers/Phi-3-Vision-MLX/tree/main/assets/tutorial_2.py" rel="noopener noreferrer"&gt;https://github.com/JosefAlbers/Phi-3-Vision-MLX/tree/main/assets/tutorial_2.py&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Next, we'll explore efficient batching techniques to further optimize our Phi-3-Vision implementation in MLX.&lt;/p&gt;

</description>
      <category>llm</category>
      <category>mlx</category>
      <category>matplotlib</category>
      <category>ai</category>
    </item>
    <item>
      <title>Part 1: Basic Implementation of Phi-3-Vision in MLX</title>
      <dc:creator>Josef Albers</dc:creator>
      <pubDate>Wed, 31 Jul 2024 11:00:40 +0000</pubDate>
      <link>https://dev.to/josef_albers_fc59b610c5de/part-1-porting-phi-3-vision-to-mlx-50gj</link>
      <guid>https://dev.to/josef_albers_fc59b610c5de/part-1-porting-phi-3-vision-to-mlx-50gj</guid>
      <description>&lt;h2&gt;
  
  
  Introduction
&lt;/h2&gt;

&lt;p&gt;Welcome to Part 1 of the tutorial series on porting Phi-3-Vision from PyTorch to Apple's MLX framework. Our goal is to create a minimal, functional implementation of Phi-3-Vision in MLX through:&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;Analyzing the original PyTorch code&lt;/li&gt;
&lt;li&gt;Translating core components to MLX&lt;/li&gt;
&lt;li&gt;Building a basic MLX implementation&lt;/li&gt;
&lt;li&gt;Loading and running the ported model&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;By the end of this tutorial, we will have a basic working model capable of generating text, setting the stage for further optimizations in subsequent parts of the series.&lt;/p&gt;

&lt;p&gt;The full implementation of what we'll cover today is available at &lt;a href="https://github.com/JosefAlbers/Phi-3-Vision-MLX/tree/main/assets/tutorial_1.py" rel="noopener noreferrer"&gt;https://github.com/JosefAlbers/Phi-3-Vision-MLX/tree/main/assets/tutorial_1.py&lt;/a&gt;&lt;/p&gt;

&lt;h2&gt;
  
  
  1. Finding and Understanding the Model Code
&lt;/h2&gt;

&lt;p&gt;Our first task is to locate the source code for the original Phi-3-Vision implementation:&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;Visit the Hugging Face model hub: &lt;a href="https://huggingface.co/models" rel="noopener noreferrer"&gt;https://huggingface.co/models&lt;/a&gt;
&lt;/li&gt;
&lt;li&gt;Search for "phi-3-vision"&lt;/li&gt;
&lt;li&gt;Click on the model to access its repository&lt;/li&gt;
&lt;li&gt;Look for a file named &lt;code&gt;modeling_phi3_v.py&lt;/code&gt;
&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;Now let's examine the code:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media.dev.to/cdn-cgi/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fjzsax6junle1e1gv7b7q.gif" class="article-body-image-wrapper"&gt;&lt;img src="https://media.dev.to/cdn-cgi/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fjzsax6junle1e1gv7b7q.gif" alt="idonteven" width="400" height="225"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Don't panic! We will break this down step by step:&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;Scroll to the bottom of the file to find the top-level model class (&lt;code&gt;Phi3VForCausalLM&lt;/code&gt; in our case)&lt;/li&gt;
&lt;li&gt;Look for the &lt;code&gt;forward&lt;/code&gt; method in this class&lt;/li&gt;
&lt;li&gt;Trace the flow of data through the model by following method calls&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;Through this process, we can identify five key components of the model:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;Main model (&lt;code&gt;Phi3VModel&lt;/code&gt;)&lt;/li&gt;
&lt;li&gt;Decoder layers (&lt;code&gt;Phi3DecoderLayer&lt;/code&gt;)&lt;/li&gt;
&lt;li&gt;Attention mechanism (&lt;code&gt;Phi3Attention&lt;/code&gt;)&lt;/li&gt;
&lt;li&gt;Feed-forward network (&lt;code&gt;Phi3MLP&lt;/code&gt;)&lt;/li&gt;
&lt;li&gt;Image embedding (&lt;code&gt;Phi3ImageEmbedding&lt;/code&gt;)&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;With these components identified, we're ready to begin the translation process to MLX.&lt;/p&gt;

&lt;h2&gt;
  
  
  2. MLX-Specific Considerations
&lt;/h2&gt;

&lt;p&gt;A few differences between PyTorch and MLX to note before we begin the porting:&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;
&lt;strong&gt;Array Creation&lt;/strong&gt;: MLX doesn't require specifying device location.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Lazy Computation&lt;/strong&gt;: Arrays in MLX are only materialized when &lt;code&gt;eval()&lt;/code&gt; is called.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Model Definition&lt;/strong&gt;: MLX uses &lt;code&gt;__call__&lt;/code&gt; instead of &lt;code&gt;forward&lt;/code&gt; for the model's forward pass.&lt;/li&gt;
&lt;/ol&gt;

&lt;h2&gt;
  
  
  3. Understanding the Model Structure
&lt;/h2&gt;

&lt;p&gt;Let's now examine each key component of Phi-3-Vision, translating them to MLX as we go:&lt;/p&gt;

&lt;h3&gt;
  
  
  3.1 Top-Level Model: Phi3VForCausalLM
&lt;/h3&gt;

&lt;p&gt;This class serves as the main entry point of the model. It encapsulates the core &lt;code&gt;Phi3VModel&lt;/code&gt; and adds a language modeling head.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;Phi3VForCausalLM&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="c1"&gt;# ...
&lt;/span&gt;    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;__call__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;input_ids&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;pixel_values&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;None&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;image_sizes&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;None&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;model&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;input_ids&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;pixel_values&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;image_sizes&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;lm_head&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;This top-level class serves two main functions:&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;It encapsulates the core model (&lt;code&gt;Phi3VModel&lt;/code&gt;), which produces contextualized representations of the input.&lt;/li&gt;
&lt;li&gt;It applies a linear transformation (the "language model head") to these representations, converting them into logits over the entire vocabulary. These logits represent the model's predictions for the next token in the sequence.&lt;/li&gt;
&lt;/ol&gt;

&lt;h3&gt;
  
  
  3.2 Core Model: Phi3VModel
&lt;/h3&gt;

&lt;p&gt;The Phi3VModel implements the main transformer architecture.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;Phi3VModel&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="c1"&gt;# ...
&lt;/span&gt;    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;__call__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;input_ids&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;pixel_values&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;image_sizes&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;embed_tokens&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;input_ids&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;vision_embed_tokens&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;pixel_values&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;image_sizes&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;l&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;layers&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
            &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;l&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;norm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;This class processes inputs through four stages:&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;
&lt;strong&gt;Text Embedding&lt;/strong&gt;: Input tokens are converted to dense vector representations.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Vision Embedding&lt;/strong&gt;: If present, image inputs are processed and integrated with the text embeddings.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Transformer Layers&lt;/strong&gt;: The combined embeddings are then passed through a series of decoder layers.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Normalization&lt;/strong&gt;: The output is normalized before being returned.&lt;/li&gt;
&lt;/ol&gt;

&lt;h3&gt;
  
  
  3.3 Image Embedding: Phi3ImageEmbedding
&lt;/h3&gt;

&lt;p&gt;This component processes image inputs and integrates them with text embeddings.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;Phi3ImageEmbedding&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="c1"&gt;# ...
&lt;/span&gt;    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;__call__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;txt_embeds&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;img_embeds&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;img_sizes&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;positions&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="c1"&gt;# Process images with CLIP
&lt;/span&gt;        &lt;span class="n"&gt;img_features&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;img_processor&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;vision_model&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;img_embeds&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

        &lt;span class="c1"&gt;# Reshape and concatenate features
&lt;/span&gt;        &lt;span class="n"&gt;global_features&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;_process_global_features&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;img_features&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;local_features&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;_process_local_features&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;img_features&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;img_sizes&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

        &lt;span class="c1"&gt;# Apply additional projections
&lt;/span&gt;        &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;mx&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;concatenate&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="n"&gt;local_features&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;global_features&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;layer&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;img_projection&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
            &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;layer&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

        &lt;span class="c1"&gt;# Integrate with text embeddings
&lt;/span&gt;        &lt;span class="n"&gt;txt_embeds&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;_integrate_features&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;txt_embeds&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;positions&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;txt_embeds&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;This class combines a CLIP (Contrastive Language-Image Pre-training) model with custom processing steps:&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;
&lt;strong&gt;CLIP Processing&lt;/strong&gt;: The model uses a pre-trained CLIP vision model to extract initial features from the input images.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Additional Processing&lt;/strong&gt;: After CLIP processing, the model applies additional processing steps:

&lt;ul&gt;
&lt;li&gt;It reshapes and concatenates the features for both global and local (sub-image) representations.&lt;/li&gt;
&lt;li&gt;It applies additional linear projections and non-linear activations (GELU) to further process these features.&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Integration with Text Embeddings&lt;/strong&gt;: Finally, the processed image features are integrated with the text embeddings at specific positions in the input sequence.&lt;/li&gt;
&lt;/ol&gt;

&lt;h3&gt;
  
  
  3.4 Decoder Layer: Phi3DecoderLayer
&lt;/h3&gt;

&lt;p&gt;Each decoder layer is a fundamental building block of the transformer architecture.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;Phi3DecoderLayer&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="c1"&gt;# ...
&lt;/span&gt;    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;__call__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="n"&gt;r&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;self_attn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;input_layernorm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
        &lt;span class="n"&gt;h&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;r&lt;/span&gt;
        &lt;span class="n"&gt;r&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;mlp&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;post_attention_layernorm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;h&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
        &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;h&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;r&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;The decoder layer performs a series of operations to its input:&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;
&lt;strong&gt;Self-Attention&lt;/strong&gt;: This mechanism allows the model to weigh the importance of different parts of the input when processing each element, enabling it to capture long-range dependencies in the sequence.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Feedforward Neural Network (MLP)&lt;/strong&gt;: This subnet processes each position independently, introducing non-linearity and increasing the model's capacity to learn complex functions.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Residual Connections&lt;/strong&gt;: After both the self-attention and MLP operations, the input is added to the output. This technique helps in mitigating the vanishing gradient problem and allows for easier training of deep networks.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Layer Normalization&lt;/strong&gt;: Applied before the self-attention and MLP operations, this normalizes the inputs to each sub-layer, stabilizing the learning process and allowing for deeper networks.&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;The combination of these components enables each layer to refine and enrich the representations passed through the model.&lt;/p&gt;

&lt;h3&gt;
  
  
  3.5 Attention Mechanism: Phi3Attention
&lt;/h3&gt;

&lt;p&gt;The attention mechanism allows the model to weigh the importance of different parts of the input when processing each element.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;Phi3Attention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="c1"&gt;# ...
&lt;/span&gt;    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;__call__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="n"&gt;B&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;L&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;
        &lt;span class="n"&gt;qkv&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;qkv_proj&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;v&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;mx&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;split&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;qkv&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;chop&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;q&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;reshape&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;B&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;L&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;n_heads&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;transpose&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;k&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;reshape&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;B&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;L&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;n_kv_heads&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;transpose&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;v&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;v&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;reshape&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;B&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;L&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;n_kv_heads&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;transpose&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;mask&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;mx&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;triu&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;mx&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;full&lt;/span&gt;&lt;span class="p"&gt;((&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]),&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="n"&gt;mx&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;inf&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;w&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;q&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;scale&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;transpose&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;w&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="n"&gt;mask&lt;/span&gt;
        &lt;span class="n"&gt;w&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;mx&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;softmax&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;w&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;o&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;w&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;v&lt;/span&gt;
        &lt;span class="n"&gt;o&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;o&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;transpose&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;reshape&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;B&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;L&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;o_proj&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;o&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;astype&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;qkv&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dtype&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Key aspects of this implementation:&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;
&lt;strong&gt;Projection and Splitting&lt;/strong&gt;: The input is first projected into query (q), key (k), and value (v) representations using a single linear projection (&lt;code&gt;qkv_proj&lt;/code&gt;), which is then split.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Multi-head Reshaping&lt;/strong&gt;: The q, k, and v tensors are reshaped to separate the heads and prepare for the attention computation.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Attention Mask&lt;/strong&gt;: A causal mask is created to ensure that each position can only attend to previous positions.&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;Scaled Dot-Product Attention&lt;/strong&gt;: The core attention computation is performed. Alternatively, you can use a faster, optimized version of this operation available in mlx.core.fast:&lt;br&gt;
&lt;/p&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# This:
&lt;/span&gt;&lt;span class="n"&gt;w&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;q&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;scale&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;transpose&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;w&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="n"&gt;mask&lt;/span&gt;
&lt;span class="n"&gt;w&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;mx&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;softmax&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;w&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;o&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;w&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;v&lt;/span&gt;

&lt;span class="c1"&gt;# Is equivalent to:
&lt;/span&gt;&lt;span class="n"&gt;o&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;mx&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;fast&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;scaled_dot_product_attention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;v&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;scale&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;scale&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;mask&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;mask&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/li&gt;
&lt;li&gt;&lt;p&gt;&lt;strong&gt;Output Projection&lt;/strong&gt;: The attention output is reshaped and projected back to the original dimensionality.&lt;/p&gt;&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;The attention mechanism supports both standard multi-head attention and grouped-query attention by allowing different numbers of heads for queries (&lt;code&gt;n_heads&lt;/code&gt;) versus keys/values (&lt;code&gt;n_kv_heads&lt;/code&gt;). &lt;/p&gt;

&lt;p&gt;&lt;a href="https://media.dev.to/cdn-cgi/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fylsr57hd8kpazjoznxac.png" class="article-body-image-wrapper"&gt;&lt;img src="https://media.dev.to/cdn-cgi/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fylsr57hd8kpazjoznxac.png" alt="https://arxiv.org/abs/2305.13245v3" width="800" height="387"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;In the current configuration, however, these are set to the same value (32), resulting in standard multi-head attention.&lt;/p&gt;

&lt;h3&gt;
  
  
  3.6 MLP Layer: Phi3MLP
&lt;/h3&gt;

&lt;p&gt;The MLP layer applies non-linear transformations to the attention outputs.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;Phi3MLP&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="c1"&gt;# ...
&lt;/span&gt;    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;__call__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;gate_up_proj&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;gate&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;mx&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;split&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;down_proj&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;silu&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;gate&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;This implements a gated feedforward network:&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;Gated Architecture:

&lt;ul&gt;
&lt;li&gt;The input is first projected into two separate spaces: one for the 'gate' and one for the 'values'.&lt;/li&gt;
&lt;li&gt;This is achieved through a single linear projection followed by a split operation.&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;Activation Function:

&lt;ul&gt;
&lt;li&gt;The gate portion uses the SiLU (Sigmoid Linear Unit) activation, also known as the swish function.&lt;/li&gt;
&lt;li&gt;SiLU is defined as f(x) = x * sigmoid(x), which has been shown to perform well in deep networks.&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;Gating Mechanism:

&lt;ul&gt;
&lt;li&gt;The activated gate is element-wise multiplied with the value portion.&lt;/li&gt;
&lt;li&gt;This allows the network to dynamically control information flow, potentially helping with gradient flow and enabling more complex functions to be learned.&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;Final Projection:

&lt;ul&gt;
&lt;li&gt;The gated output is then projected back to the model's hidden size through a final linear layer.&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;This design combines the benefits of gating mechanisms (often seen in LSTMs and GRUs) with the simplicity and effectiveness of feedforward networks, potentially allowing for more expressive computations within each transformer layer.&lt;/p&gt;

&lt;h2&gt;
  
  
  4. Loading and Using the Model
&lt;/h2&gt;

&lt;p&gt;Now that we've ported our model to MLX, let's load and use it for text generation.&lt;/p&gt;

&lt;p&gt;First, we'll download the model configuration and weights from huggingface:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;model_path&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;snapshot_download&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;microsoft/Phi-3-vision-128k-instruct&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Next, we'll load the model configuration:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;with&lt;/span&gt; &lt;span class="nf"&gt;open&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;model_path&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt;/config.json&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;r&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="n"&gt;config&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;json&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;load&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;model_config&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;SimpleNamespace&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="n"&gt;config&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Now, let's load and "sanitize" the model weights:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;model_weight&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[(&lt;/span&gt;&lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;v&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;transpose&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;patch_embedding.weight&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;k&lt;/span&gt; &lt;span class="k"&gt;else&lt;/span&gt; &lt;span class="n"&gt;v&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; 
                &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;wf&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;glob&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;glob&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;model_path&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt;/*.safetensors&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; 
                &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;v&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;mx&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;load&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;wf&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;items&lt;/span&gt;&lt;span class="p"&gt;()]&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;The line &lt;code&gt;v.transpose(0, 2, 3, 1) if "patch_embedding.weight" in k else v&lt;/code&gt; adapts the patch embedding weights to MLX's format by converting them from PyTorch's NCHW (batch, channel, height, width) to MLX's NHWC (batch, height, width, channel) format. This transposition, often called "weight sanitization", is necessary when porting the model from PyTorch to MLX.&lt;/p&gt;

&lt;p&gt;With our configuration and weights ready, we can initialize and load our model:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;model&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;Phi3VForCausalLM&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;model_config&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;load_weights&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;model_weight&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;mx&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;eval&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;parameters&lt;/span&gt;&lt;span class="p"&gt;())&lt;/span&gt;
&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;eval&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Now that our model is loaded, let's use it to generate some text. First, we'll load the pretrained processor:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;processor&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;AutoProcessor&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;from_pretrained&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;microsoft/Phi-3-vision-128k-instruct&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;trust_remote_code&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Then, we'll process our input text and generate the first token:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;inputs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;processor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Hello world!&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;return_tensors&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;np&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;input_ids&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;mx&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;array&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;inputs&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;input_ids&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;span class="n"&gt;logits&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;model&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;input_ids&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;token&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;mx&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;argmax&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;logits&lt;/span&gt;&lt;span class="p"&gt;[:,&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;:],&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;list_tokens&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;token&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;tolist&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;This code processes the input text "Hello world!" and generates the first token. We use the &lt;code&gt;AutoProcessor&lt;/code&gt; to tokenize the input, then pass it through the model to get logits. The token with the highest probability is selected as the next token.&lt;/p&gt;

&lt;p&gt;To generate more tokens, we can use a simple loop:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;5&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="n"&gt;input_ids&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;mx&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;concatenate&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="n"&gt;input_ids&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;token&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="bp"&gt;None&lt;/span&gt;&lt;span class="p"&gt;]],&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;logits&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;model&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;input_ids&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;token&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;mx&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;argmax&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;logits&lt;/span&gt;&lt;span class="p"&gt;[:,&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;:],&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;list_tokens&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="n"&gt;token&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;tolist&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;This loop generates five additional tokens by repeatedly feeding our model's output back as input.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;processor&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;tokenizer&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;decode&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;list_tokens&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;span class="c1"&gt;# Output: How are you doing?&amp;lt;|end|&amp;gt;
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;And there you have it! We've successfully ported Phi-3-Vision to MLX, loaded the model, and generated text. While this implementation is basic, it demonstrates that our port is functional and capable of generating coherent text.&lt;/p&gt;

&lt;h2&gt;
  
  
  5. Initial Results and Limitations
&lt;/h2&gt;

&lt;p&gt;Our minimal implementation works for short sequences, but you'll notice it starts producing gibberish with longer contexts. This is because we haven't yet implemented position encoding, which we'll address in the next part with RoPE (Rotary Position Embedding).&lt;/p&gt;

&lt;h2&gt;
  
  
  Conclusion:
&lt;/h2&gt;

&lt;p&gt;We've successfully created a barebones implementation of Phi-3-Vision in MLX. While it's not yet fully functional, it provides a solid foundation for the optimizations we'll explore in upcoming tutorials.&lt;/p&gt;

&lt;h2&gt;
  
  
  Next Steps:
&lt;/h2&gt;

&lt;p&gt;In Part 2, we'll implement Su-scaled Rotary Position Embeddings (RoPE) to enhance our model's ability to handle long sequences.&lt;/p&gt;

</description>
      <category>vlm</category>
      <category>llm</category>
      <category>mlx</category>
      <category>huggingface</category>
    </item>
    <item>
      <title>Porting Phi-3-Vision to MLX: A Python Hobbyist's Journey into Advanced AI on Apple Silicon</title>
      <dc:creator>Josef Albers</dc:creator>
      <pubDate>Wed, 31 Jul 2024 10:55:15 +0000</pubDate>
      <link>https://dev.to/josef_albers_fc59b610c5de/porting-phi-3-vision-to-mlx-a-python-hobbyists-journey-into-advanced-ai-on-apple-silicon-25b2</link>
      <guid>https://dev.to/josef_albers_fc59b610c5de/porting-phi-3-vision-to-mlx-a-python-hobbyists-journey-into-advanced-ai-on-apple-silicon-25b2</guid>
      <description>&lt;h2&gt;
  
  
  Introduction:
&lt;/h2&gt;

&lt;p&gt;Welcome to an exciting series on optimizing cutting-edge AI models for Apple Silicon! Over the next few weeks, we'll dive deep into the process of porting Phi-3-Vision, a powerful and compact vision-language model, from Hugging Face to MLX.&lt;/p&gt;

&lt;p&gt;This series is designed for AI enthusiasts, developers, and researchers interested in running advanced models efficiently on Mac devices. For those eager to get started, you can find the MLX ports of both Phi-3-Mini-128K and Phi-3-Vision in my GitHub repository: &lt;a href="https://github.com/JosefAlbers/Phi-3-Vision-MLX" rel="noopener noreferrer"&gt;https://github.com/JosefAlbers/Phi-3-Vision-MLX&lt;/a&gt;.&lt;/p&gt;

&lt;h2&gt;
  
  
  Why Phi-3-Vision?
&lt;/h2&gt;

&lt;p&gt;When Microsoft Research released Phi-3, I was immediately intrigued. Despite its relatively small size of 3.8 billion parameters, it was performing on par with or even surpassing models with 7 billion parameters. This efficiency was impressive and hinted at the potential for running sophisticated AI models on consumer-grade hardware.&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media.dev.to/cdn-cgi/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fg3jt1udxw4ysp51dre4u.png" class="article-body-image-wrapper"&gt;&lt;img src="https://media.dev.to/cdn-cgi/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fg3jt1udxw4ysp51dre4u.png" alt="Phi-3-*" width="800" height="341"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;The subsequent release of Phi-3-Vision further piqued my interest. As an owner of a Mac Studio and a Python hobbyist, I saw an exciting opportunity to bring this capable vision-language model to Apple Silicon. While llama.cpp was a popular option for running large language models on Mac, its C++ codebase was way beyond my skill level, so I looked for a more accessible option. This led me to MLX, Apple's machine learning framework that not only offered a Python-friendly environment but also promised better performance than llama.cpp on Apple Silicon.&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media.dev.to/cdn-cgi/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Frqlzg7q5dy5ezq8chois.png" class="article-body-image-wrapper"&gt;&lt;img src="https://media.dev.to/cdn-cgi/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Frqlzg7q5dy5ezq8chois.png" alt="Phi-3-Vision" width="800" height="362"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;What made this journey even more exciting was that it marked my first foray into contributing to open source projects. As I worked on porting Phi-3-Vision, I found myself making my first pull requests to repositories like "mlx-examples" and "mlx-vlm". This experience was an invaluable learning experience that helped me gain a better understanding of the MLX framework and how to apply it to real-world projects. This experience also connected me with the broader AI development community.&lt;/p&gt;

&lt;h2&gt;
  
  
  Useful Resources:
&lt;/h2&gt;

&lt;p&gt;Before we dive into the series, I want to highlight some excellent resources that have been invaluable in my journey:&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;
&lt;p&gt;MLX Examples (&lt;a href="https://github.com/ml-explore/mlx-examples):" rel="noopener noreferrer"&gt;https://github.com/ml-explore/mlx-examples):&lt;/a&gt; This official repository from the MLX team at Apple is a treasure trove of examples and tutorials that showcase the capabilities of the MLX framework. With a wide range of standalone examples, from basic MNIST to advanced language models and image generation, this repository is an excellent starting point for anyone looking to learn MLX. The quality and depth of the examples are truly impressive, and they demonstrate the team's commitment to making MLX accessible to developers of all levels.&lt;/p&gt;

&lt;p&gt;&lt;a href="https://res.cloudinary.com/practicaldev/image/fetch/s--fSLA7Vfa--/c_limit%2Cf_auto%2Cfl_progressive%2Cq_auto%2Cw_800/https://raw.githubusercontent.com/JosefAlbers/Phi-3-Vision-MLX/main/assets/tutorial_part0_awni.png" class="article-body-image-wrapper"&gt;&lt;img src="https://res.cloudinary.com/practicaldev/image/fetch/s--fSLA7Vfa--/c_limit%2Cf_auto%2Cfl_progressive%2Cq_auto%2Cw_800/https://raw.githubusercontent.com/JosefAlbers/Phi-3-Vision-MLX/main/assets/tutorial_part0_awni.png" width="558" height="794"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;I also want to give a special shoutout to awni, the repo owner, who was incredibly kind and patient with me when I made my first-ever pull request to this repository. Despite my lack of experience with Git and GitHub, awni guided me through the process and helped me navigate the precommit hooks and other nuances of the repository. Their patience and willingness to help a newcomer like me was truly appreciated, and I'm grateful for the opportunity to have contributed to this repository. If you're new to MLX or Git, I highly recommend checking out this repository and reaching out to awni - they're a great resource and a pleasure to work with!&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;MLX-VLM (&lt;a href="https://github.com/Blaizzy/mlx-vlm):" rel="noopener noreferrer"&gt;https://github.com/Blaizzy/mlx-vlm):&lt;/a&gt; A package specifically for running Vision Language Models on Mac using MLX. This repository was particularly helpful in understanding how to handle multimodal inputs, and I found the well-organized and well-written code to be incredibly valuable in learning not only Vision Language Models (VLMs) but also the MLX framework in general. The codebase is a great example of how to structure and implement complex AI models using MLX, making it an excellent resource for anyone looking to learn from experienced developers and improve their own MLX skills. For those interested in other models, Prince Canuma has an excellent tutorial on running Google's Gemma 2 locally on Mac using MLX:&lt;/p&gt;

&lt;p&gt;&lt;iframe width="710" height="399" src="https://www.youtube.com/embed/CKznaU1HpVQ"&gt;
&lt;/iframe&gt;
&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;&lt;p&gt;Hugging Face (&lt;a href="https://huggingface.co/):" rel="noopener noreferrer"&gt;https://huggingface.co/):&lt;/a&gt; A popular platform for natural language processing (NLP) and computer vision tasks, providing a vast range of pre-trained models, datasets, and tools. Hugging Face’s Transformers library is particularly useful for working with transformer-based models like Phi-3-Vision. Their documentation and community support are also top-notch, making it an excellent resource for anyone looking to learn more about NLP and computer vision.&lt;/p&gt;&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;These resources provide a great foundation for anyone looking to explore MLX and run advanced AI models on Apple Silicon.&lt;/p&gt;

&lt;h2&gt;
  
  
  What to Expect in This Series:
&lt;/h2&gt;

&lt;h3&gt;
  
  
  1. MLX vs. Hugging Face: A Code Comparison
&lt;/h3&gt;

&lt;p&gt;We'll start by comparing the original Hugging Face implementation with our MLX port, highlighting key differences in syntax and how MLX leverages Apple Silicon's unified memory architecture.&lt;/p&gt;

&lt;h3&gt;
  
  
  2. Implementing Su-RoPE for 128K Context
&lt;/h3&gt;

&lt;p&gt;We'll explore the Surrogate Rotary Position Embedding (Su-RoPE) implementation that enables Phi-3-Vision to handle impressive 128K token contexts.&lt;/p&gt;

&lt;h3&gt;
  
  
  3. Optimizing Text Generation in MLX: From Batching to Advanced Techniques
&lt;/h3&gt;

&lt;p&gt;Learn how to implement efficient batch text generation, a crucial feature for many real-world applications. We'll also cover custom KV-Cache implementation, streaming capabilities, and other text generation optimizations.&lt;/p&gt;

&lt;h3&gt;
  
  
  4. LoRA Fine-tuning and Evaluation on MLX
&lt;/h3&gt;

&lt;p&gt;Discover how to perform Low-Rank Adaptation (LoRA) training, enabling efficient fine-tuning of Phi-3-Vision on custom datasets. We'll also develop comprehensive evaluation techniques to ensure our LoRA-adapted model meets or exceeds the original's performance.&lt;/p&gt;

&lt;h3&gt;
  
  
  5. Building a Versatile AI Agent
&lt;/h3&gt;

&lt;p&gt;In our finale, we'll create a multi-modal AI agent showcasing Phi-3-Vision's capabilities in handling both text and visual inputs.&lt;/p&gt;

&lt;h2&gt;
  
  
  Why This Series Matters:
&lt;/h2&gt;

&lt;p&gt;Phi-3-Vision represents a significant advancement in compact, high-performing vision-language models. By porting it to MLX, we're making it more accessible and efficient for a wide range of applications on Apple Silicon devices. This project demonstrates the potential of running advanced AI models on consumer-grade hardware, specifically Apple Silicon Macs.&lt;/p&gt;

&lt;h2&gt;
  
  
  Who This Series Is For:
&lt;/h2&gt;

&lt;ul&gt;
&lt;li&gt;AI enthusiasts and hobbyists looking to dive deeper into model optimization&lt;/li&gt;
&lt;li&gt;Researchers exploring efficient AI on consumer hardware&lt;/li&gt;
&lt;li&gt;Mac users eager to leverage their devices for AI tasks&lt;/li&gt;
&lt;li&gt;Anyone curious about the intersection of AI and Apple Silicon&lt;/li&gt;
&lt;li&gt;Beginners interested in contributing to open source AI projects&lt;/li&gt;
&lt;/ul&gt;

&lt;h2&gt;
  
  
  Stay Tuned!
&lt;/h2&gt;

&lt;p&gt;Our journey into optimizing Phi-3-Vision for MLX promises to be full of insights, challenges, and exciting breakthroughs. Whether you're a fellow hobbyist looking to run advanced AI models on your Mac or simply curious about the future of AI on consumer devices, this series has something for you.&lt;/p&gt;

&lt;p&gt;Join me on this adventure in AI optimization, and let's unlock the full potential of Phi-3-Vision on Apple Silicon together!&lt;/p&gt;

</description>
      <category>python</category>
      <category>mlx</category>
      <category>vlm</category>
      <category>llm</category>
    </item>
  </channel>
</rss>
