DEV Community

Cover image for Key Insights Gained from Building and Training LLM
Lewis Won
Lewis Won

Posted on

Key Insights Gained from Building and Training LLM

Table of Contents

Introduction

This article explains how to build and train a Large Language Model (LLM) from scratch. I also hope to use this as an opportunity for me to document my understanding of the purpose and trade offs for each component of the LLM.

This article is possible thanks to resources such as Build a Large Language Model (From Scratch) by Sebastian Raschka, and Stanford CS336 Language Modeling from Scratch, Assignment 1. For this article, I will not be sharing any code so as to avoid violating any Honor Code, but I will instead share my conceptual takeaways from my experience building a LLM.


Architecture of the LLM

In my first article, I drew the following diagram to represent a LLM architecture which is more closely aligned with the original "Attention is all you need" paper:

Original architecture

In this article, I instead implement the following architecture, which is more closely aligned with more modern LLMs such as Qwen3 and Llama 3.2.

New architecture

The three most significant differences are the placement of layer normalization, the method of positional encoding, and the overall architectural pattern.

Feature Design 1 (Classic Flowchart) Design 2 (Modern Block Diagram) Significance
Layer Normalization Post-Layer Normalization (Post-LN). The add (residual connection) happens before the layer normalization. (x + SubLayer(x)) -> Norm Pre-Layer Normalization (Pre-LN). The layer normalization is applied before the sub-layer (Attention or FFN). x + SubLayer(Norm(x)) This is the most critical difference for training stability. Pre-LN is much more stable for very deep networks.
Positional Encoding Absolute Positional Encoding. Explicitly adds a separate positional encoding vector to the token embeddings at the very beginning. Rotary Positional Embedding (RoPE). Injects position information directly into the Query and Key matrices within the attention mechanism. RoPE is a form of relative positional encoding that has better performance and can better generalize to sequence lengths not seen during training.
Final Normalization No normalization layer after the final transformer block and before the final linear layer. An extra Norm layer is present after the final transformer block, just before the final linear (output) layer. This is a common practice in modern architectures like Llama to stabilize the final output projection.

This article will touch on all the components of the LLM except for the soft-max self-attention mechanism. If you are interested in understanding how the self-attention mechanism works and how to implement it, check out my first article.


Byte pair encoding (BPE) tokenizer

What are some other alternatives to BPE tokenizer and why they are not used:
i. Character-based tokenization can result in a very large vocabulary, which places additional pressure on memory. Also, such a large vocabulary is also inefficient because many characters are infrequently used, e.g. how often do we use "孖" in text?
ii. Byte-based tokenization, i.e. representing representing any string as a sequence of integers between 0 and 255. While simplest, it leads to very long context length (compared to other methods of tokenization), and because the softmax self-attention mechanism time and space complexity is quadratic, this tokenization method leads to both performance loss and increased memory stress.
iii. Word-based tokenization. It could also lead to a very large vocabulary, many words are rare, e.g. "supercalifragilisticexpialidocious ", the vocabulary size is not fixed, and we may get new words not seen in training.

Why is BPE tokenizer popular: it is more efficient than byte-based tokenization, and because it is trained on raw text to represent vocabulary that are more frequently found, its vocabulary are more frequently seen than character-based tokenization, and it will not have new tokens not seen in training.

The core idea of BPE is to start with a small vocabulary of basic units and iteratively merge the most frequent adjacent pair of units to create a new, longer unit. This process is repeated for a fixed number of merges.

A. How is a string broken down in this implementation?

i. Representation as bytes: We represent strings as bytes to make the tokenizer robust across languages and text formats. The initial vocabulary is therefore composed of all 256 possible byte values (from 0 to 255).

# The initial vocabulary contains one entry for each byte value.
vocab: dict[int, bytes] = {i: bytes([i]) for i in range(256)}
Enter fullscreen mode Exit fullscreen mode

ii. Reading as raw bytes: The input file is opened in binary mode ("rb"), meaning it is read as a sequence of raw bytes, not text characters. This avoids any issues with text encoding (like UTF-8, ASCII, etc.) at the initial stage.

iii. Pre-tokenization (Word Splitting): While BPE can run on the entire byte stream, it is more efficient to first split the text into "words" and perform merges within these words. This prevents merges from crossing word boundaries (e.g., merging the "e" from "nice" and the "d" from "dog").

The code decodes a chunk of bytes to a string and then uses a regular expression to split the text into a list of "pre-tokens" or "words".

iv. Mapping Bytes to "Characters" for Processing: A raw byte sequence is hard to work with in Python strings. To solve this, the algorithm maps each of the 256 raw byte values to a unique, printable Unicode character.

For example, byte 35 (the # symbol) might map to the character '#'. A non-printable byte like a space might be mapped to a Unicode character like 'ā'.

A "word" like "hello" is encoded to bytes (b'hello'), and then each byte is mapped to its corresponding character, resulting in a tuple of characters: ('h', 'e', 'l', 'l', 'o').

What is the difference between breaking down into bytes vs. characters?

i. Character-based BPE: If you split a string like "résumé" into characters, you get ('r', 'é', 's', 'u', 'm', 'é'). The algorithm would see 'é' as a single, atomic unit. This works well for a single language but can lead to a massive vocabulary if the text contains many different languages and symbols (e.g., English, Chinese, and Emojis).

ii. Byte-based BPE (this implementation): A string like "résumé" is first encoded to its UTF-8 byte representation: b'r\xc3\xa9\x73\x75\x6d\xc3\xa9'. The algorithm then splits this into its constituent bytes. The initial sequence would be (b'r', b'\xc3', b'\xa9', b's', b'u', b'm', b'\xc3', b'\xa9').
* Advantage: The base vocabulary is always fixed at 256 bytes. The algorithm learns to represent common multi-byte characters (like 'é' -> b'\xc3\xa9') and common words (like 'the' -> b'the') as single tokens. This is efficient and language-agnostic. The "character" é is not lost; it is simply learned as a merge of the bytes \xc3 and \xa9.

B. Design Considerations for Very Large Streams of Text

Handling files that are too large to fit in RAM (e.g., 100GB) requires a streaming or chunking approach. To handle large input files, we implement a "Map-Reduce" pattern.

i. Chunking the File Safely: It is non-trivial to split a large file at arbitrary byte positions (e.g., every 1GB), because you might slice a multi-byte UTF-8 character in half, causing a decoding error.
* We can implement a function that seeks to an approximate split point and then scans forward until it finds a safe delimiter (like a newline \n or a special token like <|endoftext|>). This ensures that each chunk is a self-contained, valid piece of text.

ii. Parallel Processing (Map Phase): Instead of one process reading the file sequentially, the code spawns multiple worker processes, one for each CPU core (multiprocessing.Pool(num_processes)).
* Each worker is assigned a single chunk (defined by start and end boundaries).
* The worker function _process_chunk reads only its assigned chunk from the disk, performs pre-tokenization, and counts the frequencies of the initial byte sequences ("words").
* Each worker only holds its small chunk and its resulting Counter object in memory.

iii. Aggregation (Reduce Phase): The main process waits for all workers to finish. It then collects the Counter object from each worker and merges them into a single master word_counts dictionary.

This design ensures that the entire massive text file is never loaded into memory at once. The only large object held in memory is the final word_counts dictionary, which is significantly smaller than the source text itself.

C. Design for Reducing Processing Time

The code employs several key optimizations to reduce the training time, which can be very long for large vocabularies.

i. Parallelism: As described above, the most time-consuming initial step—reading the disk and counting initial word frequencies—is parallelized across all available CPU cores. This provides a near-linear speedup for this "Map" phase.

ii. In-Memory Operation after Initial Scan: Once the word_counts dictionary is built, the code never reads the input file again. The entire iterative merging loop operates on this in-memory word_counts dictionary. This avoids slow disk I/O in the main loop.

iii. Incremental Updates (The Core Optimization): A naive BPE implementation would, after every single merge, have to rescan the entire corpus to find the next most frequent pair. This would be incredibly slow. This code uses a much smarter approach:

  • After finding the best pair to merge (e.g., ('t', 'h') -> 'th'), do not rescan the raw text.

  • Instead, update the existing word_counts dictionary. Find only the words that are affected by the merge (i.e., those containing both 't' and 'h' in sequence).

  • Then iterate through this much smaller subset (words_to_update), creates the new merged word (e.g., ('s', 't', 'r', 'e', 'n', 'g', 't', 'h') becomes ('s', 't', 'r', 'e', 'n', 'g', 'th')), and updates the word_counts dictionary by removing the old word's count and adding the new word's count.

This incremental update of the word frequency map makes the training efficient. While the code still rebuilds the pair_stats from word_counts in each loop, word_counts itself is updated efficiently, preventing a full rescan of the original data.


Token embedder

The token embedder is essentially a learnable lookup table. Its job is to convert a sequence of integer token IDs into a sequence of vectors (embeddings).

A. The Process

i. Initialization (__init__):
* A large matrix, self.weight, is created with the shape (num_embeddings, embedding_dim).
* num_embeddings: This is the size of your vocabulary. For example, if your model knows 50,000 unique words/sub-words (tokens), num_embeddings would be 50,000.
* embedding_dim: This is the size of the vector you want to represent each token with. Common values are 768, 1024, or 4096. Each of the 50,000 tokens will get its own unique vector of this size.
* The self.weight matrix is initialized with small random numbers drawn from a truncated normal distribution. This specific initialization (related to Xavier/Glorot initialization) is a best practice that helps the model train more stably by preventing the gradients from becoming too large or too small early in the training process. This is achieved by setting the variance of the weights such that the variance of the outputs of a layer is equal to the variance of its inputs.

ii. Forward Pass (forward):
* The module receives an input tensor called token_ids. This is a tensor of integers.
* The key operation is self.weight[token_ids]. This is an efficient indexing operation. For each integer ID in the token_ids tensor, it looks up the corresponding row in the self.weight matrix.
* The output is a new tensor where each integer ID has been replaced by its corresponding embedding vector.

A concrete example:

Imagine a dictionary where the keys are integers (from 0 to num_embeddings - 1) and the values are long lists of numbers (the embedding vectors). The forward function takes a list of keys and returns the corresponding list of values.

Example:

  • num_embeddings = 50,000
  • embedding_dim = 768
  • self.weight is a [50,000, 768] matrix.
  • Input token_ids = torch.tensor([42, 801, 19]).
  • The module returns a [3, 768] tensor containing the 42nd row, the 801st row, and the 19th row from the self.weight matrix.

B. Design Considerations for Large Streams of Text

The module itself does not handle streaming directly; the architecture that uses this module does. In order to process large amounts of text:

i. Batching and Fixed Context Windows: LLMs do not process an entire book at once. Text is broken down into smaller, manageable chunks of a fixed length (e.g., 512, 2048, or 4096 tokens). This fixed size is often called the "context window" or sequence_length. Note that this length directly influences the maximum context window a model can effectively handle. While it is possible to go beyond this length during inference, performance of the LLM may degrade (source).
ii. Statelessness: The Embedding module has no memory of previous inputs. It performs the exact same lookup for token 2054 every time it sees it. The memory of the sequence is handled later in the LLM by mechanisms like the attention layers.
iii. Predictable Memory Usage: The memory consumed by the embedding parameters (self.weight) is fixed: num_embeddings * embedding_dim * bytes_per_float. It does not depend on the length of the input text. The memory used during the forward pass (the activations) depends on batch_size * sequence_length * embedding_dim, which is predictable and manageable due to the fixed context window.

C. Design for Reducing Processing Time

The efficiency comes from avoiding a computationally expensive operation.

A naive way to implement this would be:
i. Take an integer token ID, say 42.
ii. Convert it to a "one-hot" vector. This is a vector of size num_embeddings (e.g., 50,000) that is all zeros except for a 1 at the 42nd position.
iii. Perform a matrix multiplication between this one-hot vector (1, 50000) and the weight matrix (50000, 768).

This matrix multiplication is mathematically correct, but incredibly inefficient because it involves thousands of multiplications by zero.

Instead, use self.weight[token_ids] to implement an indexing lookup. This is vastly faster for two main reasons:

i. No Useless Computations: It directly accesses the required memory location for the row vector. There are no multiplications by zero. It's a memory-bound operation, not a compute-bound one.
ii. Parallelism: This lookup operation is highly parallelizable on GPUs. A GPU can fetch the embedding vectors for all tokens in a batch simultaneously, leading to massive speedups compared to processing them one by one.

D. How the Embedder is Trained

Unlike older models like Word2Vec, the embedder is trained as an integral part of the LLM.

The training process is described below:

i. Not Pre-trained and Frozen: Unlike Word2Vec, where you might train embeddings on a large corpus and then use them as static features in a downstream task, the self.weight matrix in this module is a nn.Parameter. This tells the PyTorch framework that its values are learnable and should be updated during training.

ii. Joint Training: The embedding layer is just the first layer of the much larger LLM. The embeddings it produces are fed through the subsequent transformer layers. The model makes a prediction (e.g., predicting the next token in a sentence), and a loss is calculated based on how wrong that prediction was.

iii. Backpropagation: This loss is then backpropagated through the entire network. The gradients flow backward through the transformer layers and all the way back to the Embedding module.

iv. Updating the Vectors: The optimizer (e.g., AdamW) uses these gradients to update the values in the self.weight matrix. The embedding for a token is "nudged" in a direction that would have made the final prediction more accurate.

In essence, the model learns the "meaning" of a token by learning what vector representation for that token is most useful for the rest of the network to solve its task (predicting the next word). The embeddings are not learned in isolation; they are continuously refined and contextualized by the very task the LLM is being trained on. This makes them much more powerful and nuanced than static, pre-trained embeddings.


RMS Norm

RMSNorm (Root Mean Square Normalization) is a simplification of the more common Layer Normalization. Its primary goal is to re-scale the activations of a layer based on their magnitude, without changing their mean.

The core idea is to normalize the input vector x by its Root Mean Square (RMS) value. I included below the step-by-step mechanism:

i. Input Vector: Take an input vector x (which corresponds to the activations for a single token in the sequence). Let's say this vector has d dimensions.
ii. Calculate the Mean of Squares: Square every element in the vector x and then compute the average (mean) of these squared values.
- Formula: Mean(x2)=1di=1dxi2\text{Mean}(x^2) = \frac{1}{d} \sum_{i=1}^{d} x_i^2
iii. Calculate the Root Mean Square (RMS): Take the square root of the value from the previous step. This is the RMS value of the vector.
- Formula: RMS(x)=Mean(x2)+ε\text{RMS}(x) = \sqrt{\text{Mean}(x^2) + \varepsilon}
- An epsilon (ε), a very small number, is added inside the square root for numerical stability to prevent division by zero if the vector x is all zeros.
iv. Normalize the Vector: Divide each element of the original input vector x by its RMS value. This scales the vector so that its new RMS value is 1.
- Formula: xnormalized=xRMS(x)x_{\text{normalized}} = \frac{x}{\text{RMS}(x)}
v. Apply a Learnable Gain: Multiply the normalized vector by a learnable parameter g (gain). This allows the network to learn the optimal scale for the activations in that layer. Without this gain, the network's expressive capacity could be limited.
- Formula: y = xnormalizedgx_{\text{normalized}} \cdot g

The final formula is:

RMSNorm(x)=x1di=1dxi2+ϵg\text{RMSNorm}(x) = \frac{x}{\sqrt{\frac{1}{d}\sum_{i=1}^{d} x_i^2 + \epsilon}} \cdot g

A. Design Considerations for RMS Norm vs. Other Norms

Normalization layers are designed to control the distribution of activations flowing between layers, which helps stabilize training and speed up convergence. The key design differences lie in what is being normalized and how.

Normalization Type Key Idea Design Consideration
Batch Norm Normalize each feature across the batch dimension. Assumes that the statistics (mean/variance) of a mini-batch are a good estimate for the entire dataset. It is dependent on the batch size and works best with larger batches.
Layer Norm Normalize each data point (e.g., a token's embedding) across the feature dimension. Assumes that all features for a single data point should be re-centered and re-scaled together. It is independent of the batch size, making it suitable for NLP and variable-length sequences. It performs both re-centering (subtracting the mean) and re-scaling.
RMS Norm A simplification of Layer Norm. Normalize each data point across the feature dimension. The core hypothesis is that re-scaling invariance is the main benefit of Layer Norm, and re-centering invariance provides little to no benefit while adding computational cost. RMSNorm deliberately drops the mean-centering step to improve efficiency.

The design of RMSNorm is a direct trade-off: it sacrifices the re-centering property of LayerNorm for a significant gain in computational speed. The bet is that for large language models, this trade-off is beneficial.

B. Pros and Cons of Different Norms

Normalization Type Pros Cons
RMS Norm - Computationally Efficient: Up to 25-40% faster than LayerNorm because it avoids calculating the mean.
- Simple: Fewer operations and parameters (no bias term b).
- Effective: Empirically shown to be as effective as, or sometimes better than, LayerNorm in modern LLMs (e.g., Llama, PaLM).
- No Re-centering: Lacks invariance to shifts in the mean of activations, which could theoretically be a disadvantage in some (uncommon) scenarios.
Layer Norm - Very Stable: Normalizes each example independently of the batch, making it robust to batch size and effective for RNNs/Transformers.
- Invariant to shift and scale: Full invariance to affine transformations of the input.
- More Computationally Expensive: Requires calculating both the mean and the variance, which is slower than RMSNorm's single calculation.
- More Parameters: Has both a gain g and a bias b parameter.
Batch Norm - Excellent for CNNs: Speeds up convergence and acts as a regularizer in computer vision tasks.
- Well-established: Deeply studied and understood.
- Batch Dependent: Performance degrades significantly with small batch sizes.
- Poor for RNNs/Transformers: Statistics are unstable across variable-length sequences.
- Complex behavior at inference: Requires tracking running population statistics.

Source: Batch Normalization, Layer Normalization and Root Mean Square Layer Normalization: A Comprehensive Guide with Python Implementations

RoPE positional embedding

At its core, Rotary Positional Embedding (RoPE) encodes absolute positional information into query and key vectors in a way that allows the self-attention mechanism to naturally handle relative positions.

Instead of adding a positional vector to the token embeddings (like traditional absolute positional embeddings), RoPE rotates the query and key vectors. The angle of rotation depends on the token's absolute position in the sequence.

A. The Intuition

Imagine we have a token of 6 dimensions, x, it is the i-th token in the prompt, and for these 6 dimensions, we group it into 3 consecutive pairs. We represent each pair of dimension as k.

  • Dimensions (0,1) from the first 2D vector: [xi,0,xi,1][x_{i,0}, x_{i,1}] . For this pair, k = 1.
  • Dimensions (2,3) from the second 2D vector: [xi,2,xi,3][x_{i,2}, x_{i,3}] . For this pair, k = 2
  • Dimensions (4,5) from the third 2D vector: [xi,4,xi,5][x_{i,4}, x_{i,5}] . For this pair, k = 3

Next, we have a set of frequencies θi\theta_i predefined, where i corresponds to each pair of dimensions. A common way to define these is a geometric progression, where:

θi,k=ibase2kd\theta_{i,k} = \frac{i}{\text{base}^{\frac{2k}{d}}}
  • i is the dimension of the dimension pair (from 0 to 5 in our case)
  • k refers to each pair of dimension
  • d is the total number of dimensions (6)
  • base is an arbitrarily large number, we can set it as 10,000

Assuming i = 5, i.e. the 5th token, we can calculate a unique angle for each pair of k:

  • Angle for Pair 1 (k=1):

    • θ_{5,1} = 5 / 10000^(2*1 / 6) = 5 / 10000^(1/3) ≈ 5 / 21.54 ≈ 0.232 radians.
    • This is a relatively high-frequency rotation (a larger angle).
  • Angle for Pair 2 (k=2):

    • θ_{5,2} = 5 / 10000^(2*2 / 6) = 5 / 10000^(2/3) ≈ 5 / 464.16 ≈ 0.011 radians.
  • Angle for Pair 3 (k=3):

    • θ_{5,3} = 5 / 10000^(2*3 / 6) = 5 / 10000¹ = 0.0005 radians.
    • This is a very low-frequency rotation (a tiny angle).

Next, we rotate each pair by its corresponding angle:

  1. Rotate Pair 1 ([x₀, x₁]) using θ_{5,1}:

    • x'₀ = x₀ * cos(0.232) - x₁ * sin(0.232)
    • x'₁ = x₀ * sin(0.232) + x₁ * cos(0.232)
  2. Rotate Pair 2 ([x₂, x₃]) using θ_{5,2}:

    • x'₂ = x₂ * cos(0.011) - x₃ * sin(0.011)
    • x'₃ = x₂ * sin(0.011) + x₃ * cos(0.011)
  3. Rotate Pair 3 ([x₄, x₅]) using θ_{5,3}:

    • x'₄ = x₄ * cos(0.0005) - x₅ * sin(0.0005)
    • x'₅ = x₄ * sin(0.0005) + x₅ * cos(0.0005)

The final, positionally-encoded query vector x' is formed by putting the newly rotated components back together:

x' = [x'₀, x'₁, x'₂, x'₃, x'₄, x'₅]

Finally, we use the transformed vectors to measure relative position. Imagine a key token at position j = 7. We would do this exact same process for its key vector k, but we would use j=7 in the angle calculations (θ_{7,1}, θ_{7,2}, etc.). This would produce a rotated key vector k'.

When the attention mechanism computes the dot product q' · k', the mathematical properties of these rotations ensure that the resulting score is sensitive to the relative distance i - j = 5 - 7 = -2.

This means the attention score between two tokens is inherently sensitive to how far apart they are, which is exactly what we want for capturing contextual relationships.

B. Design Considerations for RoPE

RoPE was designed to address shortcomings in other positional embedding methods.

Design Consideration Absolute Positional Embedding (APE) Relative Positional Embedding (e.g., T5) Rotary Positional Embedding (RoPE)
How it's applied Additive: A positional vector is added to the token embedding. Additive (in attention): A learned bias is added to the attention score based on relative distance. Multiplicative: The query/key vectors are rotated (a multiplicative operation).
Absolute vs. Relative Encodes absolute position. The model must learn to interpret this as relative. Directly encodes relative position. Encodes absolute position in a way that becomes relative during the attention calculation.
Sequence Length Extrapolation Poor. Learned embeddings for positions 0-511 don't generalize well to position 512. Good. The bias for a distance of 10 is the same regardless of absolute position. Good. The rotation is a continuous function. It can be extrapolated, although performance can degrade if stretched too far without fine-tuning.
Parameter Efficiency Requires a learned matrix of (max_seq_len, d_model). Requires a smaller learned matrix for relative distance biases. No learned parameters. The rotation values are fixed and calculated on the fly or pre-computed.

Source: Comparison- Sinusoidal vs. Learned Embeddings

RoPE's design is a clever "best of both worlds" approach. It applies a transformation based on absolute position (m) but results in a relative-aware attention mechanism. The lack of learned parameters makes it elegant and robust.

C. Pros and Cons of Different Positional Embeddings

Type Pros Cons
Learned Absolute (APE) - Simple to implement and understand.
- Was the standard for a long time (BERT, GPT-2).
- Poor extrapolation to longer sequences.
- Fixed max_seq_len.
- Learned embeddings can "interfere" with token embeddings.
Relative (T5-style) - Excellent at modeling relative positions.
- Good extrapolation capabilities.
- More complex to implement within the attention mechanism.
- Adds a learned parameter for each relative distance, which can be memory-intensive.
Rotary (RoPE) - Excellent performance, now standard in many top models (Llama, PaLM).
- Good extrapolation properties.
- No learned parameters for position, making it very efficient.
- Decouples positional information from the token embedding's norm.
- Conceptually more complex than APE.
- Requires the feature dimension (d_k) to be even, as it operates on pairs of dimensions.

Source: Comparison- Sinusoidal vs. Learned Embeddings

D. Optimization tricks

As the calculation of (cos, sin) can be computationally more expensive than simpl arithmetic, instead of calculating them for each token in every single forward pass, we can instead pre-calculate all the cos and sin once at model initialization and cache them. The forward pass becomes a simple, memory-efficient table lookup.

Avoid the use of Python loops in the forward pass. Instead, rely on operations such as:
* Slicing (x[..., ::2]): This is a highly optimized memory-access pattern.
* Stacking and Rearranging: The torch.stack and einops.rearrange operations are compiled down to highly efficient GPU kernels for moving data.


SwiGLU feed-forward network (FFN)

The SwiGLU FFN is an enhancement over the standard feed-forward network used in the original Transformer architecture. Its core innovation is a gating mechanism that dynamically controls the flow of information through the network.

The formula:

FFNSwiGLU(x)=Wdown(SiLU(xWgate)(xWup))\text{FFN}{\text{SwiGLU}}(x) = W{\text{down}}(\text{SiLU}(x W_{\text{gate}}) \odot (x W_{\text{up}}))

i. Input (x): This is the tensor representing a token, coming from the preceding self-attention layer. Its shape is (..., d_model).

ii. Parallel Linear Projections: The input x is fed into two separate linear layers simultaneously:
* gate_proj: This projection, x @ W_gate, creates an intermediate representation that will become the "gate".
* up_proj: This projection, x @ W_up, creates the main content or "value" that we want to process.
Both projections typically expand the dimension from d_model to a larger intermediate dimension, d_ff.

iii. The Gating Mechanism: This is the heart of SwiGLU.
* The output of the gate_proj is passed through an activation function, in this case, SiLU (Sigmoid-weighted Linear Unit), also known as Swish.
* The SiLU function (f(x) = x * sigmoid(x)) is smooth and non-monotonic. Its output is not bounded between 0 and 1.
* The activated gate, SiLU(gate_output), is then multiplied element-wise (*) with the output of the up_proj.

iv. Selective Information Flow: This element-wise multiplication is crucial. The SiLU(gate_output) acts as a filter.
* If a value in the activated gate is close to zero, the corresponding value in up_output is suppressed (multiplied by zero), effectively "closing the gate" for that piece of information.
* If a value in the activated gate is large, the corresponding value in up_output is allowed to pass through, "opening the gate".
Because the gate's values depend on the input x, this filtering is dynamic and data-dependent. The network learns which information is important for a given token and context.

iv. Down Projection: Finally, the filtered result (fused_output) is passed through a final linear layer, down_proj (W_down), which projects the dimension back down from d_ff to the original d_model. This allows the output to be added back into the residual stream of the Transformer block.

Analogy: Imagine the input data as a fleet of cars approaching a major, intelligent traffic interchange.

The up_output is a wide, multi-lane expressway carrying all the cars (the core information). In parallel, the gate_output is a service road with advanced sensors that analyze the traffic flow—what types of cars are in each lane and where they are headed.
This sensor data is fed into a central control system (the SiLU function) which doesn't just use simple red or green lights. Instead, it creates a dynamic set of ramp meters and lane-specific speed controls for the main expressway.

The element-wise multiplication is the act of these meters and controls engaging with the traffic. A meter might completely stop a lane (a value of zero), let it flow freely (a high value), or just let a few cars trickle through (a fractional value), all depending on the real-time sensor readings. This ensures only the most relevant traffic proceeds smoothly.

Finally, the down_proj is where all the lanes, having been filtered and controlled, merge back together into a single, streamlined highway, with the flow of traffic now optimized and ready for the next leg of its journey.

A. Design Considerations for SwiGLU vs. Other FFNs

The primary alternative is the Standard ReLU-based FFN from the original "Attention Is All You Need" paper. Its formula is: FFNReLU(x)=max(0,xW1+b1)W2+b2\text{FFN}_{\text{ReLU}}(x) = \max(0, x W_1 + b_1) W_2 + b_2 .

Here are the key design considerations when comparing them:

Consideration Standard ReLU FFN SwiGLU FFN
Non-linearity A single, static non-linearity (ReLU). It zeroes out negative values. A dynamic, gated non-linearity. The activation (SiLU) is combined with a data-dependent gate to selectively filter information.
Parameter Count Two weight matrices (W1, W2) and two bias vectors. Three weight matrices (W_gate, W_up, W_down) and typically no biases (as in the LLaMA implementation).
Parameter Efficiency To achieve similar performance to SwiGLU, d_ff is often set to 4 * d_model. Because it has three matrices, the d_ff is often reduced to keep the total parameter count comparable. A common choice in LLaMA is d_ff = (2/3) * 4 * d_model.
Computational Cost Involves two matrix multiplications and one activation. Involves three matrix multiplications, one activation, and one element-wise multiplication. It is more computationally intensive.
Expressiveness Less expressive. The filtering is a simple on/off switch based on the sign of the value. More expressive. The gate can learn complex relationships in the input to decide what information to pass on, leading to better model quality.

B. Pros and Cons of Each FFN

Feature Standard ReLU FFN SwiGLU FFN
Pros Faster: Computationally cheaper due to fewer matrix multiplications.
Simpler: Easier to implement and understand.
Fewer Parameters: Requires only two weight matrices for the same d_ff.
Higher Quality: Consistently improves model performance; a key factor in models like LLaMA, PaLM, and Mixtral.
More Expressive: The dynamic gating mechanism allows for learning more complex functions and better information flow control.
Helps with Gradients: The smooth nature of SiLU and multiplicative interaction can help alleviate vanishing gradient problems.
Cons Lower Quality: Empirically leads to worse model performance.
"Dying ReLU" Problem: Neurons can become stuck outputting zero.
Static Filtering: The non-linearity is fixed and does not adapt to the input.
Slower: Requires an additional matrix multiplication, increasing computational expense.
More Parameters: Uses a third weight matrix, which can increase model size if d_ff is not adjusted.

Why is SwiGLU preferred?

Why? Let me quote the original paper by Shazeer et. al:

We offer no explanation as to why these architectures seem to work; we attribute their success, as all else, to divine benevolence.

Linear layer

The Linear module performs a linear transformation, which is one of the most common operations in a neural network. It essentially remaps a vector from one dimension to another using a matrix multiplication.

It contains the following:

i. Dimensions (in_features, out_features): These define the transformation. The module will take input vectors of size in_features and produce output vectors of size out_features. For example, in an LLM, this could be transforming a 768-dimension token embedding into a 3072-dimension vector for an internal "feed-forward" layer.

ii. Weight Matrix Creation (self.W):
* The core of the linear layer is the weight matrix, W. The transformation is mathematically defined as y = x @ W.T (input x matrix-multiplied by the transpose of W).
* If an input x has shape (..., in_features), to get an output of (..., out_features), the matrix we multiply by must have the shape (in_features, out_features).
* This means W.T must be (in_features, out_features). Consequently, the code correctly creates W with the shape (out_features, in_features).
* torch.empty(...) allocates memory for this matrix but doesn't initialize it with meaningful values yet. This is slightly more efficient than creating a tensor of zeros.

iii. Weight Initialization:
* This is a critical step for stable and effective training. Initializing weights to zero or purely random values can cause problems like vanishing or exploding gradients.
* We can use a Truncated Normal Distribution. This means it draws random numbers from a normal (Gaussian) distribution but discards and re-samples any values that fall outside a specific range.
* The range is determined by a standard deviation (std_dev) calculated with the formula sqrt(2 / (in_features + out_features)). This is a variant of the "Kaiming" or "He" initialization, designed to keep the variance of the outputs roughly equal to the variance of the inputs, which helps learning.
* Truncating at [-3*std_dev, 3*std_dev] prevents any single initial weight from being excessively large, which further promotes stability.

iv. Parameter Registration (nn.Parameter(weight)):
* The created weight tensor is wrapped in nn.Parameter. This is how you tell PyTorch: "This tensor is a trainable parameter of the model."
* When you train the model, PyTorch knows to calculate gradients for this parameter during backpropagation and update it using an optimizer (like AdamW).

A. forward method

The forward method essentially contains:

i. Input (x: torch.Tensor): It accepts an input tensor x. The shape is (N, ..., in_features), where N is the batch size and ... represents any other dimensions, like the sequence length of text. This flexibility is powerful.

ii. The Transformation:
* This is the heart of the module's operation. We can use torch.einsum (Einstein summation notation) to perform the matrix multiplication.
* "...i, oi -> ...o" is a concise way to define the operation:
* ...i: Represents the input tensor x, where i is the last dimension (in_features).
* oi: Represents the weight tensor self.W, where o is out_features and i is in_features.
* The i is the "summation index." The operation multiplies elements along this shared dimension and sums them up—this is the definition of matrix multiplication.
* ...o: Specifies the output shape. The ... dimensions from the input are preserved, and the new final dimension is o (out_features).
* This operation is functionally identical to torch.matmul(x, self.W.T), but einsum can be more explicit and sometimes more efficient for complex tensor contractions.

A note on vanilla PyTorch vs einops

The einops library helps write tensor operations in a more readable, intuitive, and less error-prone way. It uses a string-based mini-language to describe operations like transposing, reshaping, and splitting tensors.

Let's consider a common operation in a multi-head attention mechanism: combining the head dimension with the feature dimension.

Scenario: You have a tensor from a multi-head attention layer with shape (batch_size, sequence_length, num_heads, head_dim). To pass this to the next layer (like the FFN), you need to combine num_heads and head_dim back into d_model (where d_model = num_heads * head_dim), resulting in a shape of (batch_size, sequence_length, d_model).

Example: Plain PyTorch Code

import torch

batch_size = 32
seq_len = 128
num_heads = 12
head_dim = 64
d_model = num_heads * head_dim # 768

# Input tensor with shape (batch, seq, heads, head_dim)
x = torch.randn(batch_size, seq_len, num_heads, head_dim)

# To combine the last two dimensions, we first need to make them contiguous in memory
# by transposing 'num_heads' and 'seq_len'
# Then we can use .view() to reshape
# This is a common pattern, but it's not immediately obvious what it's doing.
x_transposed = x.transpose(1, 2) # Shape: (batch, heads, seq, head_dim)

# Now check if it's contiguous
if not x_transposed.is_contiguous():
    x_contiguous = x_transposed.contiguous()
else:
    x_contiguous = x_transposed

# Finally, reshape to the target shape
# The -1 for d_model relies on you knowing the other dimensions are correct.
output = x_contiguous.view(batch_size, seq_len, d_model)

print(f"PyTorch Output Shape: {output.shape}")
Enter fullscreen mode Exit fullscreen mode

Critique:

  • transpose(1, 2): What do 1 and 2 represent? You have to mentally map them to sequence_length and num_heads. This is a frequent source of bugs.
  • .contiguous(): This is a memory layout detail that leaks into the high-level logic. You have to remember to call it after certain transpositions before you can view it.
  • .view(..., d_model): The reshape operation is disconnected from the transpose. It is hard to see the full transformation at a glance.

Example: Code with einops

import torch
from einops import rearrange

batch_size = 32
seq_len = 128
num_heads = 12
head_dim = 64

# Input tensor
x = torch.randn(batch_size, seq_len, num_heads, head_dim)

# Use rearrange to describe the entire transformation in one go.
# 'b s h d -> b s (h d)'
# This reads as: "rearrange from (batch, seq, head, dim) to (batch, seq, (head_dim combined))"
output = rearrange(x, 'b s h d -> b s (h d)')

print(f"einops Output Shape: {output.shape}")
Enter fullscreen mode Exit fullscreen mode

Critique & Benefits:

  • Self-Documenting: The string 'b s h d -> b s (h d)' is clearer. It names the dimensions and explicitly shows how they are being transformed. b and s stay put, while h and d are combined.
  • Concise and Robust: A complex, multi-step, error-prone operation is reduced to a single, readable line. einops handles the low-level details like contiguity checks for you.
  • Error Prevention: It is much harder to make a mistake. If you misspell a dimension name or the number of dimensions doesn't match the input tensor, einops will raise an informative error. With transpose and view, you might just get a tensor with the wrong shape and not notice until much later.

Results and findings

After building my LLM, I trained it on a dataset called tinystories_10k_valid locally on my Nvidia RTX 4070 Super. I followed the parameters specified in the assignment:

  • vocab_size: 10,000
  • context_length: 256
  • d_model: 256 (note that I reduced it from 512)
  • d_ff: 1344
  • RoPE_theta_parameter: 10,000
  • number_of_layers: 4
  • number_of_heads: 4 (note I reduced head from 16 to 4)

After 1000 iterations, my results are saved here. I have also reproduced a screenshot below for easier reference.

training result

We can see from the graph val_loss, it started high, dropped to a minimum at step 3, and then steadily increased for the rest of the training. This meant that after step 3, the model is no longer learning general patterns from the data, and instead has started to memorise the specific examples and noise in the training data set.

Nonetheless, even with this model, we were able to generate text from it. For example, using the prompt "Once upon a time, there was a pretty girl named Lily", my LLM was able to generate the following result, which was rather satisfying.

LLM generation

Conclusion

Overall, it has been a fun experience completing assignment 1 of CS336, and I have learnt a lot of engineering lessons through the process of building each component of the LLM from scratch.

Also, I would also like to give a shout-out to the use of Google Gemini as a learning tool. Many times when I see an equation that I find it difficult to understand, or certain notations that was not properly documented, I found Gemini to be really good as answering my questions and even give me concrete examples to illustrate how we can implement the equation by hand. I must say it is not easy to find the right balance between acceleration and learning, and I am still trying to find this balance. So if you have any suggestions on how to use AI to accelerate learning while still ensuring that you actually learn, please let me know!

For my future articles, I am keen to explore generating synthetic datasets and fine-tuning using such datasets. Stay tuned!

Top comments (0)