DEV Community

Cover image for AWS re:Invent 2025 - Performance engineering on Neuron: How to optimize your LLM with NKI (AIM414)
Kazuya
Kazuya

Posted on

AWS re:Invent 2025 - Performance engineering on Neuron: How to optimize your LLM with NKI (AIM414)

🦄 Making great presentations more accessible.
This project aims to enhances multilingual accessibility and discoverability while maintaining the integrity of original content. Detailed transcriptions and keyframes preserve the nuances and technical insights that make each session compelling.

Overview

📖 AWS re:Invent 2025 - Performance engineering on Neuron: How to optimize your LLM with NKI (AIM414)

In this video, AWS Solutions Architects Scott Perry and Sadaf Rasool demonstrate performance optimization for LLMs on AWS Trainium using the Neuron Kernel Interface (NKI). They benchmark Qwen3 0.6B with naive attention achieving 0.35 prompts/second and ~3 second latency, then replace the attention module with an NKI-optimized flash attention kernel. Through simple code integration involving tensor reshaping and calling pre-built kernels, they achieve 2.99 prompts/second—a 6-8x performance improvement with latency reduced to ~0.25 seconds. The session showcases NKI's Python-based domain-specific language for writing hardware-aware kernels that leverage Trainium's compute engines, memory hierarchy, and optimization techniques like pipelining and fusion. They announce the newly launched open-source NKI Library containing production-ready, pre-optimized kernels for various model components.


; This article is entirely auto-generated while preserving the original presentation content as much as possible. Please note that there may be typos or inaccuracies.

Main Part

Thumbnail 0

Introduction to Performance Engineering on Neuron with NKI

Welcome to this afternoon's session. This is AIM 414: Performance Engineering on Neuron—How to Optimize Your LLM with NKI. My name is Scott Perry. I'm a Principal Solutions Architect for AI/ML Performance on the Annapurna Labs team at AWS, where we work directly with customers on our AI/ML accelerators: Trainium and Inferentia. Joining me today is my colleague Sadaf. Hey everyone, this is Sadaf Rasool. I'm also a Solutions Architect for AI/ML Performance at AWS Trainium. I'm primarily focused on AI/ML performance on Neuron, and we're both super excited to have you here today for this session.

Thanks for coming out. This is a code talk, so we're going to spend most of the session digging into a terminal editor and code editor, running some scripts, and doing all of this live. Hopefully it goes well. Before we do that, we need to go through a few background slides just to give everybody an idea of what Neuron and NKI are all about. Bear with us for just a few slides, and then we'll get right into it.

Thumbnail 60

To give you some context, for many years customers have been doing machine learning workloads on AWS, and some of the early feedback we got from customers is that they wanted choice and they wanted to see improved price performance for running their ML workloads on AWS. Back in about 2019, we invested in designing a purpose-built machine learning accelerator that we would offer via EC2 instances in AWS. That first generation was called AWS Inferentia.

Thumbnail 80

This is Inferentia 1, if you're familiar with the EC2 instances. It was a smaller accelerator designed for the deep learning models of that time, like BERT and YOLO. Since then, we've released the second generation, Inferentia 2, and more interestingly, we've also released three versions of the AWS Trainium chip, with the third generation launching just yesterday. AWS Trainium is actually a more powerful chip capable of being used for distributed training as well as inference.

Thumbnail 110

Understanding NeuronCore Architecture and Memory Hierarchy

At the heart of the Trainium and Inferentia accelerator chips are the NeuronCores. In a given instance, you might have multiple chips, and within each chip, you can have multiple NeuronCores. Today we're going to be working with a Trn2 instance, which actually has eight physical cores within the chip. We're going to be using one of them for the examples here today. Within that core, this is a high-level schematic of what's actually involved.

What's really interesting about this architecture is that it's purpose-built for machine learning workloads. We have four specific compute engines within the NeuronCore. Starting with the Tensor Engine, which is focused on things like matrix multiplications and transposes. We have the Vector Engine that you can use for max pool and average pool type layers. There's the Scalar Engine for things like activations, and we also have a general-purpose SIMD engine with general-purpose cores that can be reprogrammed for layers that aren't covered in the other three engines. So we have four compute engines that you can actually use in parallel on this chip to maximize compute utilization.

Also worth noting here is we have a couple of on-chip SRAMs: the SBUF and the PSUM. These are obviously on-chip, so they're very close to the compute, offering high bandwidth but maybe lower capacity. We have a set of DMA engines for moving data, and we also have high bandwidth memory as part of the device. Because these are offered via EC2, we also have access to the host memory, which in this case would be DRAM associated with the CPU. So there are three levels of the memory hierarchy that we care about.

Thumbnail 210

There's the on-chip SRAM, which is lower capacity but very high bandwidth memory that we would love all of our data to live on all the time if it could. At the bottom of the tiers, we have the host memory, which is high capacity but lower bandwidth. In between, we have a bit of a compromise with the accelerator HBM: moderate capacity and pretty good bandwidth. One of the major struggles when you're trying to optimize your machine learning workloads is how do you place data, get it as close to the compute exactly when you need it, and how do you juggle moving between the tiers.

The Roofline Model and Performance Optimization Strategies

When we talk about optimizing machine learning workloads, one way of representing a given workload is the roofline model. Has anybody seen a roofline model before? Let me take a quick refresher then. As part of the roofline model, basically for any given algorithm or machine learning model, there's going to be an arithmetic intensity intrinsic to that model—an algorithmic arithmetic intensity. Whatever the mathematical operations are for that given model, it's going to consume a certain amount of operations that you have to do per byte of memory that's read. But for a given accelerator chip, you're going to have a finite memory bandwidth and a finite maximum compute throughput that you can achieve.

If it turns out that your achieved arithmetic intensity for a given workload falls on the left-hand side, meaning there's low ops per byte, you're going to be memory bound. We don't really want to be there because it means that we're limited by how fast we can read from memory for the workload that we're actually operating. What we would like to do is try to push things to the right-hand side of the graph and be compute-bound, meaning that we're doing many operations per byte of memory read so we can take full advantage of the compute accelerators.

The difference between algorithmic, theoretical arithmetic intensity and achieved intensity is basically the implementation. So how can we improve performance and shift the graph to the right?

We can do things like pipelining operations. If you have a major workload, sometimes you can chunk it up into smaller bits and actually run multiple parts of the workload at any given time and keep the engines busier than they would be otherwise. There are things like minimizing data movement. If you're doing a lot of DMAs, moving a lot of data unnecessarily at the wrong time, that's going to reduce your performance. So we want to try to avoid that.

Thumbnail 370

We want to maximize data throughput. In some cases, you might be sending small messages or moving small bits of data around. Is there some way maybe we can coalesce that into larger chunks and improve data throughput? And if we're in a distributed setting, we might have multiple Trainium chips that we need to communicate across sometimes. For example, synchronizing gradients during training. When that happens, it would be nice if we could overlap the collectives with some type of compute or data movement so we're not just sitting there waiting for the collectives to take place.

Thumbnail 400

Introducing the Neuron Kernel Interface (NKI)

So this is why and how we would try to improve performance, but what does it mean for Trainium and Neuron? Like how do we get there? To work with the Trainium and Inferentia chips, we have the Neuron SDK, which is the software development kit that includes all of the various layers of the stack. We have a compiler, runtime, driver, and user tools that allow customers to work with training and inference.

Thumbnail 430

We typically look at our customers as three personas: ML developers, data scientists, and performance engineers. Lots of folks kind of fall between the lines, but for today, we're focused more on the performance engineer persona. Within the Neuron SDK, we have a certain set of tools that are going to be applicable to performance engineers. The one that we want to focus the most on today is called the Neuron Kernel Interface, or NKI. Has anybody worked with NKI before or heard about it?

So what is NKI? NKI is essentially a Python-based domain-specific language for writing kernels for Trainium and Inferentia. Before we had NKI, basically you could write your machine learning model maybe in PyTorch and JAX, run it through the Neuron compiler, which takes the compute graphs from your model, optimizes them to run on Trainium or Inferentia, and then you could run your model that way. But you're putting a lot of faith in the general purpose Neuron compiler to perfectly optimize your model.

Thumbnail 510

As we see with our accelerators and other accelerators, oftentimes customers need lower level access. They want to be able to really fine tune the model to take full advantage of the underlying hardware. That's what NKI offers you. If you're familiar with OpenAI Triton, this is not quite Triton, it's a little bit different, but this is our flavor of writing low-level kernels for our accelerators.

NKI integrates directly with PyTorch, JAX, and NumPy. It directly emits Neuron ISA instructions, so the actual hardware instructions that exist on the chips get emitted from this platform. We have a couple of namespaces within NKI. We have the nki.lang, which is higher-level constructs to help you out with helper functions and that kind of thing. But we also have direct Neuron ISA instructions as part of the nki.isa namespace. On the left-hand side, you do see a kernel. We're going to actually hop into that in a minute and go through line by line and show you what it's like to build a basic kernel before we get into the actual benchmarking.

Thumbnail 540

Exploring a Basic NKI Tensor Add Kernel

For today's session, what are we looking to achieve here in record time? We're going to start off by benchmarking and profiling an LLM. In this case, it's going to be Qwen3. We're going to run that using Hugging Face transformers and the Neuron SDK. We're going to compile the basic Qwen implementation to run on Trainium as is, using naive attention. That's just going to be our baseline. So we'll look at the performance and look at a basic profile.

Thumbnail 580

Then we're going to walk through an actual NKI kernel implementation of attention. We're going to add that into the Qwen3 model that we had previously run, and we're going to redo the benchmarking and the profiling to see what the performance gains hopefully are at the end of the session. And then we'll obviously give you guys some time for questions at the end. The technology stack we're using specifically here today is AWS Trainium2, the second generation of our chip.

On a smaller instance, we're going to be using one of the NeuronCores out of eight that exist on the chip today, just to keep things a little bit more consumable. We're going to be using the Neuron SDK to both compile and run our models. In our case today, we're using Qwen3 0.6B, an embedding model, and this is available via the Hugging Face transformers library.

Thumbnail 620

With that, let's take a quick look at the example kernel that I put on the slide in a code editor. We'll go through it, and then we'll get right into the benchmarking and show you the vanilla implementation.

Here we can see a very basic kernel called the NKI Tensor Add kernel. This is just a basic kernel that takes two similarly shaped tensors and does element-wise addition of them. The first thing I want you to notice is that it's defined as a basic Python function. The only thing that's really special about this function is that we decorate it with @nki.jit() to let the Neuron compiler know that this is intended to be an NKI kernel that we want to compile.

There's a specific constraint around NKI kernels today. The inputs and outputs—you can see the A input and B input, which are the two tensors that will be added, as well as the output—these need to exist on HBM. We talked about the three tiers of memory that we're going to focus on. The inputs and outputs actually have to exist on device memory, but not on the on-chip SBUF. The framework is going to handle that for us. We'll be using PyTorch today.

Let me run through this line by line. We pass in two inputs that we want to add together. There's a basic check to make sure they're the same shape, otherwise the element-wise addition won't work. Here we're going to see a similar flow that exists in all NKI kernels. The first step is we need to allocate some space on the on-chip SBUF. Here we're allocating some space for both the A input tensor and the B input tensor using sbuf.view, and SBUF is one of our on-chip SRAMs.

The next step is we actually use an ISA instruction—this is a DMA copy where we specify that we want to copy from the A input and B input HBM tensors to the on-chip tensors. Similarly, we allocate some space for the result, which we call C tile, and this is on SBUF. We use sbuf.view to allocate the space for the tensor, and then we use the ISA tensor_tensor instruction to actually do the addition for us.

Here you can specify that this tensor_tensor operation is using the addition operator. The two inputs are the two tiles that we had passed in and copied to SBUF, and then we're outputting the results to the destination, which is the C tile. Lastly, because we need to get the outputs back to HBM to return back to the framework, we use hbm.view to allocate an HBM hosted tensor, and then again we use a DMA copy to get the results back to HBM before we return the result.

This is a bit of a trivial kernel just to give you the high-level flow, but you can see that this is fairly low level if you've been used to working with PyTorch. However, if you're really looking to squeeze the best bang for buck out of your model and make the best use of the hardware, there are going to be cases where you might need to get this deep into the space.

Thumbnail 790

Thumbnail 800

Benchmarking and Profiling Qwen3 with Naive Attention

With that, let's take a quick look at the benchmarking script here. I didn't want to just blindly run a script, so I'm going to run through the code a little bit with you here just to show you what we're running. This is just a basic Python script that uses Hugging Face Transformers and the Qwen3 model. Essentially, we load the Qwen3 model on CPU first, we're going to compile it for Neuron, and then we're going to run a bunch of inferences and time the latency and measure the throughput that way. We'll also do a sanity check against CPU to make sure that the model results are the same as CPU.

Thumbnail 820

Thumbnail 840

If we skip over the imports, there's a little bit of environment setup that we use for Neuron for the runtime. In this case, we want to specify that we're just using a single core, for example. We're also, because we're going to be doing some profiling, enabling some profiling environment variables here just to make sure that we capture the metrics that we need.

Thumbnail 850

Thumbnail 860

We do wrap the model in a basic class here. This is just to make it a little bit easier to get the last hidden state from this embedding model, just to make things a little bit easier for the demo today. The class doesn't really do anything other than do a forward pass with outputting only the last hidden state.

Thumbnail 880

We also have a helper function that encodes some example text and returns a repeating tensor based on the batch size that we specify. It's just going to be the same text repeated over and over based on the batch size that we need just to help us out. We have a few arguments that we use to control the behavior of the benchmark script. Obviously, controlling things like batch size and max length is nice to be able to tweak as you're doing the executions. But we also have an NKI flag here that we use to specify whether or not the benchmarking script should use the NKI implementation that Sadaf is going to add for us a little bit later.

Thumbnail 900

Here you can see this is where we actually load the model for the first time on the CPU. We do truncate the model to four layers here just to speed things up and get this done in 45 minutes here today.

Thumbnail 920

Essentially the performance for a handful of layers should be similar to the performance for the full model. You just have to scale it up accordingly. If we skip down a little bit here, this is where we actually get into the meat of it. If we've previously run the script for this configuration, we're going to cache the neuron compiled model on disk just to save time in case we want to run it again.

Thumbnail 960

But if it isn't compiled, we use this torch_neuron.trace call. This is what actually triggers the neuron compiler to take the code. In this case, it's just the vanilla implementation of Qwen3 with naive attention. It's going to pass that into the neuron compiler. The neuron compiler will extract the graphs, compile them to run on neuron using the example inputs that we've provided, and it returns a PyTorch loadable model that we can then save to disk and cache out to disk.

Thumbnail 970

Thumbnail 980

If we haven't compiled it, we compile it, and if it's already on disk, we just load it in. We do a quick warm-up inference, we run through an iteration of 5 inferences, measuring the inference time, and then we calculate the accuracy compared to CPU and put the results. This should be pretty straightforward if you've worked with Hugging Face before, but I wanted to show you the code so it wasn't a mystery. Why don't we actually run this and get into the heart of it here.

Thumbnail 1020

If you open up the terminal, I'm logged into a Trn2.3xlarge instance, and we have a neuron SDK environment already available to us. We can run commands like neuron-ls, for example, to get a quick breakdown of what's available on this instance type. Let's just run the benchmarking script here with Python. What we should see is that it detects that the previous CPU outputs are saved on disk to save time and it loads those, and it also loads the pre-compiled neuron model for us again, just to save a couple of minutes.

Thumbnail 1040

Thumbnail 1050

Within a few seconds, we should start to see this executing the model on the neuron core. This is Qwen3 0.6B, a four-layer version of that model using a 16K sequence length. It's just about done running, and now we get the accuracy just to prove it to you. The MSE loss and both the cosine similarity are what we would expect. The outputs of this neuron compiled model that ran on the accelerator match the CPU outputs 100 percent, which is great.

On the performance side, you can see that this is not a very performant model out of the box. We're only getting 0.35 prompts a second with latency around 2.88 seconds. I think there's definitely room for improvement there. One of the areas where we often look to improve performance for transformer models is the attention block. We'll take a look at that shortly.

Before we do that, I did want to show you that we can actually look at the profiling results that we captured here as well. Sometimes what you'll see is that maybe you're actually getting great device time, like the model's actually running well on Trainium, but there's something else going on at the PyTorch layer that's slowing things down. If we get a system level view with a system profile here, we'll be able to see if this model actually ran as best we could expect for right now before we add optimizations.

Thumbnail 1140

As part of the Neuron SDK, we do have neuron-profile, which is the profiling tool. There are different interfaces for this. Today we're just going to output the system level profile in a Perfetto compatible format so that we can quickly load it up in Perfetto, which is an open source visualization tool that you're probably familiar with. We'll do a neuron-profile view, specify the directory that contains the profiling data, and tell it to use the output format for Perfetto.

Thumbnail 1150

Thumbnail 1160

Thumbnail 1170

Now if we look in the directory, it's created the system profile for us, so we'll just quickly download that. Let's hop into the browser, and this is Perfetto. We'll just open up the trace file and take a quick look here. If you're not familiar, this is just showing the execution timeline starting on the left, kind of scrolling to the right. You can see the first thing that happened is nrt_load. This is actually the neuron runtime loading the weights of the model into HBM.

Beyond that, you see all these nrt_execute calls. What's nice is we can actually see exactly what's happening within the neuron runtime. Here it's executing the model. You can see there's basically 7 executions in the model here. Our benchmarking loop was 5, but we also did 1 execution for the accuracy check and 1 is a quick warm-up, so that's why there's 7.

Thumbnail 1210

Thumbnail 1220

If you click on any of these executions, you can actually see metadata associated with the execution. For this model it may not be that interesting, but for a larger model or an execution where you had multiple graphs executing on different cores, this would be very helpful. Here we can see the exact name of the graph, which is the compiled version of the model for Neuron, and we can see which NeuronCore it's running on. On the left-hand side you can see the duration of the execution is pretty much what we saw from our benchmarking script. It's just under 3 seconds execution latency, which I'd say is not great. Hopefully we can improve on that. This is the baseline performance. We took that original Qwen3 model from Hugging Face, compiled it to run on our own, executed it on Trainium, and got those baseline results. Now let's see if Sadaf can improve on that using NKI.

Integrating NKI Flash Attention into the Qwen3 Model

Thanks, Scott. Thanks for setting up the stage. Hey everyone, this is Sadaf again. I know it's not a good time to talk about code after lunch, but I promise we will try to make it as exciting and promising as possible. What we're trying to do here essentially is look at the performance we've seen for this Qwen3 model. We got the performance numbers and the timeline as well, and we'll try to see if NKI, which is the Neuron Kernel Interface as Scott mentioned, can help us improve this performance.

In this part of the talk, we are going to focus on three major objectives. First, I would like to show you how quick and convenient it is to integrate NKI kernels in our existing model code. Second, what are the performance impacts of the same. Third, which is more food for thought from the extensibility and applications point of view, is there a way we can use those NKI kernels in our respective workloads when we go back home. These are the three major objectives we are going to address in this part of the talk.

Thumbnail 1370

We're going to do a little bit of code, as I promised. I'll try not to make it very boring for you, but we'll do our best. There are a couple of things worth noting from the profiling we have done so far. For each iteration, it's taking approximately 2 seconds and 812 milliseconds, or roughly 3 seconds. Second, as Scott was trying to show, the throughput we are seeing here is 0.35 prompts per second. Technically, it means that it takes approximately 3 seconds to just process one prompt.

When we are talking about NKI and how it helps us improve the performance, the reason it can do that is because it works at a very low level, which is very close to the hardware, and it takes care of all the optimizations available so that we can get the maximum performance from the underlying hardware, which is Trainium in this case. We are very proud and excited that yesterday we launched the NKI library, which Scott is going to talk about in a minute. We are exposing some predefined, pre-implemented kernels which can be plugged into your models just like we're going to show in this Qwen3 model, and we can take advantage of the performance on Trainium.

Thumbnail 1470

Thumbnail 1480

Thumbnail 1490

Thumbnail 1500

People who have worked with Hugging Face Transformers library before would know that for each model directory we have different model-specific files. In the case of Qwen3, we have a configuration file, a modeling file, and so on. The very first file that I'm super interested in right now is this modeling file. The reason I'm interested in this Qwen3 modeling file is because this file, which is given by the Hugging Face Transformers library, tells us the complete implementation of the Qwen3 model. If I want to understand what the RMS norm looked like for Qwen3, it is here. If I want to understand what MLP looks like for Qwen3, it is here. And how the rotary embedding looks or how it is implemented, everything is here.

Thumbnail 1520

Thumbnail 1540

Thumbnail 1550

Thumbnail 1560

Thumbnail 1570

Thumbnail 1590

The part we are particularly interested in today is attention. Everybody has heard about transformer-based models and attention mechanisms, right? That's what we're going to explore today. We have this eager attention forward method that has been provided by the Hugging Face Transformers library itself for this model, and we're going to see if there's a way we can replace this with NKI attention. We will examine what the performance implications are. That is one module we are going to focus on today. Let me first take this eager attention forward here and paste it. Everybody can see the code, right? Do I need to maximize it? Are we good? Awesome. I've copied the eager attention and let's give it a meaningful name. Let's call it NKI attention forward. I won't touch the input arguments, but I'm going to remove this implementation because we are going to write our own implementation here. We will return the attention output, and because it's inference, we don't need to return the weights. So we're good.

Thumbnail 1600

Thumbnail 1620

Step number one of my objective is to showcase how easy and convenient it is to integrate NKI kernels in our existing models. That's what we're going to look at right now. We have this attention_forward.py file here, and what it contains is the implementation, the NKI implementation for different flavors of attention. As we know, we have SDP attention, sliding window attention, flash attention, and many other variations. It has implementations for the majority of those attention mechanisms. We are going to use one of them. In this case, we are going to leverage flash attention for Qwen3.

Thumbnail 1650

Thumbnail 1690

What this implementation file does is provide an adapter, or you can think of it as an interface function. It will take all the inputs from us and ensure basic hygiene like verifying we're giving the right shapes, formats, and data types, and then it will invoke the right kernel for us. We will definitely look into that kernel as well, but before that, let's invoke this kernel from our modeling file. My attention output will come from here, and I'm going to call this one. As we were looking at it, and we can go through the provided documentation, because it's an NKI implementation of the kernel, it expects us to provide the Q, K, V tensors in a specific format or specific shapes. Because it's very close to the hardware, it expects us to do that little bit of massaging and reshaping before giving those tensors to NKI so that it can perform its operations.

Thumbnail 1720

For that, we need to make a little bit of reshaping for our Q, K, V tensors right now. For the query, the first thing I'm going to do is expose the shape of my query, key, or value. In this case, I will get the batch size, number of heads, sequence length, and head dimension. That's what it will give me. This kernel, the attention kernel, expects the query in a different format. I'm going to write it as a comment here so that we keep track of it. It says you are in this format right now, but I expect you to give me the query in this format. It's just a little bit of reshaping, permute, and all these classical tensor operations that we are going to perform here.

Thumbnail 1800

So what we are going to do is transform from B, H, S, D to B times H, D, S. Essentially, we are contracting two dimensions into one dimension and then changing the sequence of S, D to D, S. That's it. What we can do is first perform a permute to switch S and D, or swap the sequence length and the head dimensions. So we'll make it first B, H, D, S, and then I think it will be easy from there to make it B times H, D, S.

Thumbnail 1810

Thumbnail 1840

Thumbnail 1850

Thumbnail 1860

Let's do that. Query equals to Query.clone() dot permute. The reason we are doing this is because we need to change the shapes from BHSD to BHDS. So 0 will remain 0, 1 will remain 1, 2 will become 3, and 3 will become 2. So far we're good. We have achieved BHDS and then we will reshape it to B times HDS, as simple as that, and we will ask to give us a contiguous memory location so that it's optimized. We have just reshaped our query tensors. Now, similarly, we can do the same for key tensors as well. Our keys have exactly the same shapes, so we just need to make the change to the variable names here: key_states, and everything remains the same.

For values, it's even simpler. For values, it says give me B times HSD. So we don't need to do the permute here. We can simply just reshape. Since we don't have to permute, we can keep the S and D at their original positions. Once we have this, it means I am ready to provide these tensors to the attention kernel adapter, and I'm going to call this so we can provide our query, key_states, value_states, and then scaling. We just need to change the variable names so that it's not key, it's values.

Thumbnail 1930

Thumbnail 1960

Thumbnail 1970

What we have done so far is we have changed our QKV shapes. We did not make any changes to the tensor values, we did not make any other changes, just reshaped them in the format that this kernel expects. Once we do that, we would be able to call this attention through the NKI kernel adapter, passing our parameters and we'll get the attention output. Once we get the attention output, the beauty of it is that it's going to give us the attention output in this format, and we would like to have it in the final outcome in this format: BSHD. So we need to do the exactly same transformation here as well.

Thumbnail 1990

Thumbnail 2000

Thumbnail 2010

Thumbnail 2030

We can do this: first of all, we are going to reshape it so that it becomes BHSD. And then we just need to swap S and H, right? Right now we have HS, then we're going to make it SH. So we'll just call our favorite function, which is transpose. And then we will transpose the 1st index with the 2nd index. Contiguous. And then we return the attention output. If we look at the implementation of this, we have just reshaped query, key, and values. Once we have reshaped them, we have called the attention kernel, getting the output, and again reshaping it in the expected format that the function output requires, and then using it.

Thumbnail 2040

Thumbnail 2050

Another thing that we have to do is we have to switch the attention implementation. There is a place where we are calling the attention here. What I have done here is I have already added this: if my config has the NKI as an attention implementation, call nki_attention_forward. So we are switching from eager attention to NKI attention so that we are 100% sure that we are calling the right attention.

Thumbnail 2080

Thumbnail 2090

Let's quickly have a look. I'm going to remove the previous profiles. As we have seen when Scott was running, he was able to get the profiles. I'm going to delete it so that we don't have any cache data or something. And then we are going to try to make it run as well. Let's have our terminal open. And this time, I'm going to call the same script which Scott has run, but this time, I'm going to pass this argument. Right? So, hope it runs the first time. Let's do it.

Dramatic Performance Gains: 6-8X Improvement with NKI Attention

What is going to happen now is that because it has found a new attention implementation, it is going to submit the model again for the Neuron compiler. It will say, "I got this new attention. Just compile it for me so that I can run on Trainium, right?" This compilation will take maybe a couple of minutes. Meanwhile, let's have a quick look at the actual attention kernel that I promised to you. I'm going to the attention forward by looking at the adapter kernel. I will save all these details for you so you can check it out later. What essentially I can see is that it is calling this attention kernel implementation without swapping.

I am more interested in finding out where the definition of this kernel is. Here it is at line number 657. Just to warn you, it is a pretty decent implementation with about 1000 lines of code, so I do not want to bore you with that. But what I want to share with you is that at the end of the day, attention is nothing but a function of QKVs essentially, right? We are doing QK transpose, Softmax, and then multiplying this with V. This is exactly what it is doing, but it is doing it in a very hardware-aware fashion.

What it means is that with NKI, we have amazing capabilities that allow us to take maximum advantage of the underlying hardware, which is Trainium. What it does is employ various memory as well as compute optimization techniques. When I am saying memory optimization techniques, it uses smart DMAs to ensure that we are efficiently using the DMA engines and efficiently using the DMA packet sizes. It uses intelligent tiling concepts, efficient memory layout, and allocations. When it comes to compute, it takes advantage of pipelining so that, as Scott was mentioning in one of the slides where we have different compute engines, it makes sure that it is trying to keep all the compute engines as busy as possible in parallel so that we can utilize whatever available compute we have on the table and take the best out of that.

It also uses the concept of fusion, where we are combining multiple operations together so that we do not have to do a lot of DMAs or a lot of memory accesses. We can keep our compute engines more and more busy. So it employs all these mechanisms to get the best performance from Trainium. Once you employ all these mechanisms, what you get is a super-optimized implementation of your model or a specific module that you are talking about.

After a couple of minutes, we see the compiler status is passed. Now it is going to run the inference with this new attention implementation. Remember, it is not running the naive attention anymore. It is running the NKI attention this time. The performance that we see is 2.99 prompts per second. Can anybody remind me what was the performance that we have seen with the naive attention? Exactly, 0.352, almost 3 prompts per second. So we are talking about approximately 6 to 8 times better performance. Not 6 to 8 percent, but 6 to 8X performance boost here. Is not that amazing, right?

So let's have a look at the profile as well. Just like Scott has generated the profile, we are going to generate the profile again. We are going to download this system profile and save it with, let's say, NKI. We go to the Profetto again. Remember this time as well, Scott was mentioning 2 seconds, 812 milliseconds, approximately 3 seconds, right? Now let's open this NKI trace that we have just got and let's expand it a little bit. Let's see the time. Are we ready for that? Now, if we go there, the duration goes from almost 3 seconds to almost a quarter of a second, right?

Thumbnail 2410

With this NKI kernel, we are getting almost 6 to 8x performance boost right in front of us. That was my second objective. The first was to showcase how easy it is to integrate an NKI kernel, which is what we have done. The second was to understand the performance implications and performance boost that we get out of it.

The third objective that I was trying to drive is food for thought. Think about it. This is just a small model with just one attention module or attention part of the model that we have replaced with NKI. Now this model has much more than that. It has MLP also, which is super compute bound. It has QKV projections, output projections, and so on. Scott will drive us in a moment and help us understand that we have some of these kernels exposed to the outside world for the first time ever. These are production-scale kernels that we can use in our own models and leverage this amazing performance boost.

Thumbnail 2490

Going back to the performance gain, from throughput for the naive attention, we got this one-story building to almost a ten-story building now. From the latency point of view, which we want to be as small as possible, we have come down from there to here. We do not have profiling based on power consumption as of today. With that, I will leave the floor for Scott. He will drive us home from here. Thank you.

Thumbnail 2540

Launching the NKI Library and Closing Remarks

Thank you, Sada. As you saw, we are able to go from not so great performance to pretty impressive performance gains basically just by adding a few lines of code and taking advantage of a kernel that has been built essentially by the Annapurna Labs team. What we are really excited to talk about though is that just yesterday we launched the NKI Library. This is an open source project that contains a number of pre-optimized NKI kernels that are written by our team, curated, tested, and maintained by our team and provided on GitHub.

Starting off, there are a number of kernels available there today. You can grab the QR code and take a look. There is a lot of dense model support in the beginning. We are going to add additional kernels around mixture of experts and other use cases and workloads as well. This is super exciting because up until now, NKI is still in beta, but it gives customers access to the lower level compute engines that you just did not have before. You saw the naive implementation of attention when we compiled it. It did run and gave accurate results, but the performance was really lackluster out of the box.

With NKI, we were able to drastically close the gap there and really drive amazing performance with this model. We focused on attention today, but there are a lot of other aspects of different models and different layers that you might be able to apply NKI to. It is a little bit like onion peeling. You are going to get into a workload, do a device level profile, figure out where the bottlenecks are, and maybe tackle some of those with NKI. It moves the bottleneck somewhere else, and at some point, you can almost chase it forever. But at least with NKI, it gives the capabilities to the customer's hands so you can actually do this performance engineering that just was not possible from the PyTorch level before.

With that, I want to thank everybody for coming out today. We really appreciate you coming in for an afternoon session to learn about Trainium and NKI. We have a number of other Neuron and Trainium sessions today and tomorrow if you are interested in checking them out, and we are open for questions. Thank you.


; This article is entirely auto-generated using Amazon Bedrock.

Top comments (0)