About 1 year ago, AMD released their AI Max+ series CPUs (aka Strix Halo). It seemed that all of my youtube feed was filled with praise of the architectural decision of getting unified memory on a non-mac hardware. I finally bought a GMKTec EVO x2 in November last year with 128GB RAM. I started trying to follow different tutorials, trying to install rocm via non-official releases and trying to run local LLMs!
I wanted to write about my journey of trying to train a 1B parameter LLM on the strix halo! I will also share all the things I learned along the way. I will provide some links for all the topics that felt helpful to me.
TL;DR:
This is the repo for the AI training framework I created using Antigravity: https://github.com/genuinelucifer/onix
The repo has all the tools needed from downloading dataset, to pretraining to finetuning to running it locally.
The repo also contains this doc which outlines the full process to train and fine-tune a model using the tiny-stories dataset. It also mentions how to run it locally using the framework: https://github.com/genuinelucifer/onix/blob/main/docs/training_llama1b_on_tinystories.md
Starting phase:
I got my mini-pc with Windows 11 because I also wanted to play games on it. I didn’t know, at that time, how far proton support for steam games has come. I tried to install rocm drivers and run models on windows, it was a mistake. Support for windows is very less and none of the youtubers seems to be using windows for running AI models (atleast on strix halo).
After trying for multiple weeks and failing to get ComfyUI or Qwen-image-edit properly working on Windows, I finally went ahead and dual booted Ubuntu. This helped me unlock access to a wide set of tutorials people had created to run and train models on Linux. Most people trying to search tutorials for running models on strix halo would almost certainly stumble upon the awesome toolboxes for strix-halo created by Donato Capitella. These helped me start my journey to run models on my PC.
But, the documentation for the toolboxes asks to set the lowest amount of dedicated VRAM in the BIOS config and then edit grub config to add the following to increase the shared memory size between RAM and VRAM:
amd_iommu=off amdgpu.gttsize=126976 ttm.pages_limit=32505856
Doing this change eagerly was my second mistake. Although this might be a great hack for running models which need 100+ GB VRAM but I never ran any model which needed such large VRAM. I found that the strix halo is way more compute limited than memory limited. At a much later time, I found that setting 96 GB of dedicated VRAM and 32 GB RAM turned out to be a much much better solution for me (more on this later).
I ran a few models using ComfyUI and LM-Studio, I got about 40 tokens/second on 30B parameter models which felt pretty good. Creating a 1080p image via FLUX model on ComfyUI takes about 50 seconds. So the PC is good for local coding agents and image creation. Creating videos is still out of reach though; it took me about 6 hours to create a 4 second video using Hunyuan video model on ComfyUI!
After all the experimentation on running models, I wanted to train my own model, just running existing models wasn’t why I bought this PC.
Phase Two: Learning about LLMs and setting goals
Note that I had already learned about CNNs, RNNs etc, so I knew all the basics about how the models work. I knew the maths involved and how a basic model's architecture looks like.
What I wanted to learn was all the components of an actual LLM without too much details on the maths involved. I was also focussed on learning it from software perspective than from research perspective.
I bought the book Build a Large Language Model from Scratch by Sebastian Raschka to learn all the fundamentals of how LLMs work and what are the components inside an LLM. This was the perfect book for my use-case! It took me about 1.5 months to go through the book the old fashioned way (read the text and write every line of code manually). I have my code uploaded to my yallm repo.
After this, I wanted to speed up my workflow. I also wanted to check all the hype with the AI tools on some personal project. So, I took Gemini Pro subscription, installed Antigravity, and decided that I will give the code from my yallm repo to the agent and then ask it to write all the code now. I also wanted to create a generic framework to train an LLM from scratch on my PC, because (in my head) I was going to be training all different types of large language models.
I also had the goal to NOT be writing any code for this framework. I would (time-to-time) review some of the code and just see what I get when training, finetuning and running my models.
So, I went full steam ahead and created my Onix framework (slight pun on the ONNX model format) which I used to fully train, finetune and run a 1B parameter model on my strix-halo. I intend to keep developing the framework (maybe not 100% via AI going forward) and hopefully add support for training more types of model architectures with more optimizations. Currently it supports training a text based LLM, a VQ-VAE embedding for images and then a “multimodal” model using the trained VQ-VAE that generates images based on text.
I looked for english datasets that I could pretrain my model on, I selected the tiny-stories dataset as my first dataset. It has about 470M tokens. It seemed small enough to do some useful tests before moving on to bigger datasets (considering I was trying to train a 1B parameter model).
Phase Three: Memory footprint for pretraining the model
After this I selected llama as the base model architecture that I will start with. Llama's architecture is very well documented and easy to re-use. I stuck to the GPT2 tokenizer that I read about in the book by Sabastian Raschka, just so that I could re-use most of my code even in new model architectures. After this I started the pretraining process.
When trying to run training on a 1B Llama model, I started with with 1024 token context window. I saw that it used about 26 GB of VRAM! Which was both a relief and a shock. Shock because I didn't know why a 1B parameter model would take 26 GB of VRAM and relief because I had 110+ GB of shared memory at this point. So, I could increase both the context window and the batch size. I could even increase the model size and train much larger models. Oh, such were those times of naivety!
So, I looked though the memory usage and it turned out to be like this. I am training my model on the default fp32 accuracy. It takes about 4 bytes (32 bits) to store a fp32 model parameter.
Following are some approximate guesses on which things need how much VRAM during training (not sure how to find exact values):
1 - Total Static overhead: (~16 GB)
- ~1B model parameters
- ~1B model gradients.
- ~2B parameters for the AdamW optimizer. It has 2 values, momentum and variance, per model parameter.
2 - Total Activation Memory: (~5GB)
- Self-attention matrix (~context size^2 *n_heads * n_layers)
- Feed forward block outputs
3 - Cached parameters for the eval stage of the model since I am running eval after the first iteration of training.
4 - Pytorch/rocm Caching has some overhead as well.
So, I went and increased the batch size to be 8. Now I saw that it used about 65 GB of VRAM. This was also a shock because I expected it to take about 8x VRAM and cause OOM. But it did not!
It turned out that the static overhead remains the same, since there is only 1 instance of the model. So all the model and optimizer parameters still need only about 16GB of VRAM. But the other activations need to be replicated across batches to parallely work on different sets of data.
After this, I decided I will train on a larger context window. Because I presumed that whatever data I will finally train the model on; it would sort of work like a chatbot. And 1024 is very small context window. So I increased the context window to be 8192 tokens (8x from earlier and the least amount of context needed as per my uneducated guess).
And to my shock, even with batch size of one, I ran into OOM when trying to train a 8192 context size llama model.
In this case as well, the static memory overhead is still the same. Since self-attention is quadratic in nature, it is proportional to context_size2. If we increase the context size by 8x, the self-attention parameters increase by 64x.
So, self-attention memory needed would itself be >150GB!
Now I needed to optimize this memory usage. Self-attention didn't seem useful anymore. So I switched to flash attention mechanism. Which made all the attention parameters scale linearly instead of quadratically.
With flash attention (SDPA), my model took 58GB with batch-size of 1.
I then tried gradient checkpointing, which allows the training to discard all the forward pass values and re-calculate them during backward pass. This will save us more memory. It took only about 29 GB VRAM at this point. And every iteration took about 47 seconds.
All the data till this point is summarized in this table:
| Stage | Context size | Batch Size | Optimization | VRAM (GB) | Time s/step |
|---|---|---|---|---|---|
| 0 | 1024 | 1 | - | 26 | 3.5 |
| 1 | 1024 | 8 | - | 65 | 23.5 |
| 2 | 8192 | 1 | - | OOM | - |
| 3 | 8192 | 1 | Flash Attention | 58 | 38 |
| 4 | 8192 | 2 | Flash Attention | 95 | 80 |
| 5 | 8192 | 1 | Flash Attention + Grad Checkpointing | 29 | 47 |
| 6 | 8192 | 2 | Flash Attention + Grad Checkpointing | 34 | 98 |
Note that the data in this table was re-calculated while writing this post. So it may not align 100% with the next table. I had already enabled dedicated VRAM at this point. Also upgraded pytorch, rocm, ubuntu etc.
The main issue now started to come into light. It took about 2x the amount of time for training per step when I increase batch-size to 2x (both with and without grad checkpointing). So, it became very clear that the GPU was limited by compute and not by memory.
I stopped trying to use more VRAM at this point. I now needed to focus on improving the training time. I decided to not use grad checkpointing since it was about 25% slower and time matters more at this scale.
Now, each step of the training process needed about 80 seconds. For the tiny-stories dataset, with batch-size of 2, I still needed 24414 steps of training to complete one epoch of training.
So, it would take me about one month of continuous training to complete just 1 epoch of training (after adding evaluation steps every few 100 training steps)! It was not feasible.
And there was no point in increasing the training data size anymore, since that would be completely impractical. So, I decided that tiny-stories would be the only dataset I would train my model on.
Since I had already spent so much time optimizing the memory, I decided to continue with 8192 context window size and 2 batch size for training my model. This was obviously a mistake. Tiny stories doesn't need such huge context window. I would have gotten much faster training times if I had stuck to 512 tokens of context window size which would have been enough for this dataset. But then I wouldn't have needed to learn how to optimize training time which I got to do here. So it was a win, in hindsight.
Phase Four: Improving training speed on my hardware
At this point, I still have the shared GTT memory of 110+ GB and dedicated VRAM of 1 GB. But since both GPU and CPU share the memory, it is often fragmented and all access needs to go via a GTT (Graphics Translation Table).
So, I went ahead and removed the grub config and updated dedicated VRAM in my BIOS to 96GB. This left 32 GB for RAM. And viola, I got about 1.5x speedup just from doing this!
After this, I learned about BF16 numbers which use the exact same number of precision bits as FP32 numbers and hence give the same accuracy in machine learning algorithms where we do not need to store high absolute value numbers but need the same precision for small numbers.
Shifting to BF16 for training gave me a compounded ~5x improvement!
After this I learned about compiling models so that torch can create a fused kernel (which can optimize away certain operations). This gave further improvement to training time.
Pytorch also has optimization to calculate Flash Attention kernels ahead-of-time which can be enabled via TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL which I enabled along with the last optimization.
After all these optimization, I got about 10x total improvement in training speed. Which was good enough for me to start training the model.
All the data for speed is summarized in this table:
| Stage | Optimization | Time/step | Improvement |
|---|---|---|---|
| 0 | Baseline | 110s | 1.0x |
| 1 | 96GB Dedicated VRAM | 74s | ~1.5x |
| 2 | + BF16 Precision | 15s | ~7.3x |
| 3 | + torch.compile + AOTriton | ~11s | ~10.0x |
Note: After all 3 optimizations, the model takes about 63 GB VRAM while training which was a welcome surprise. I was not looking to reduce memory footprint at this point, so I did not capture memory footprint after every optimization.
After this, I saw that my GPU utilization graph dips every iteration and the CPU spikes at exactly the same time. It looked something like this:
This was because after every iteration, the CPU needed to load the next set of data for the next iteration. So I needed to optimize my data loader to get full 100% GPU usage at all times.
I did so by adding num-workers and prefetch-factor to my data loaders. This didn't increase the speed of training by any noticeable amount but I can atleast see my GPU being at 100% utilization throughout the training! :)
After all these optimizations, I finally ran the full pretraining on the tiny-stories dataset. It took about 3.5 days for 1 epoch of training to finish. I did have to create pause and resume mechanism so that I could resume training when my PC would restart due to electricity issues.
I did not have the heart to continue training for more epochs. And I ran it at this stage and got good enough output, IMO. Example (at this point model is auto-completing the story where I started the story with "There was a boy"):
Phase Five: Model fine-tuning
Although the model was ready to do auto-complete for tiny stories; it didn't feel natural. I wanted to do instruction fine-tuning on it. I used the TinyStoriesInstruct dataset for this. Although this dataset isn’t really what I expected it to be. It is more like “story-elements-to-story dataset”. So I wrote a python script to convert this to how I wanted my dataset to be like. I wanted the dataset to be something like:
User: Tell me a story about a king and a queen. It should have the words river, gate and knight.
Model: ….<tells a story about the same>
After this I finetuned the model on the newly created dataset. Fine tuning was exactly the same as training. Just use the earlier model and train it further. This dataset was surprising larger (in number of data points) as compared to the actual pretraining data.
I had the boon of hindsight and decided that for pretraining I will only use the first 1024 values of context window which will drastically reduce the time needed to train my dataset. I also added grad checkpointing to drastically reduce the memory it would take to run finetuning. And then I increased the batch-size to 32.
With 1024 context, grad checkpointing enabled and batch-size of 32, my finetuning took about 9 seconds per step of finetuning consuming about 53 GB VRAM.
With this, it would take about 6 days to finetune the model on the full instruct dataset, just for 1 epoch. I ran the finetune only on half the dataset and checked the results after 3 days. They looked okay to me for it being my first time trying to finetune a model.
Example:
Next Steps:
Next, I want to work on the following items (in no particular order); which, I think, will lead me to improve my framework (& knowledge of LLMs) significantly more:
- Work on the model runner to learn about what optimizations are applied for inference (for both memory and time).
- Retraining the same model with smaller context size.
- Converting model to safetensors and sharing via huggingface.
- Quantize the Bf16 models to 4bit and compare output.
- Train 8-bit or maybe even 1-bit model and compare outputs.
- Create Lora to allow fine-tuning the model with smaller set of parameters. Learn how it works.
- Train a model to create small images... Maybe galaxies & 32x32 platforming characters. Learn about image models.
Final thoughts:
If you are starting today, install Ubuntu 26.04 and go hacking. Try my framework and let me know what other things I should try to make the model training framework better!
Please share your thoughts/experience on training LLMs and what are some things I should learn next to better understand how everything works.



Top comments (0)