DEV Community

Cover image for Inside the Transformer Architecture: The Core of Modern AI
Aun Raza
Aun Raza

Posted on

Inside the Transformer Architecture: The Core of Modern AI

Inside the Transformer Architecture: The Core of Modern AI

The Transformer architecture has revolutionized the field of Artificial Intelligence, becoming the foundation for state-of-the-art models in Natural Language Processing (NLP), Computer Vision, and beyond. This article delves into the core of this powerful architecture, exploring its purpose, key features, and providing a practical code example.

Purpose:

The primary purpose of the Transformer is to process sequences of data, such as text or images, while effectively capturing long-range dependencies. Unlike Recurrent Neural Networks (RNNs) which process data sequentially, Transformers utilize parallel processing, significantly improving training speed and scalability. This allows them to understand context and relationships between elements within a sequence, leading to superior performance on tasks like machine translation, text generation, and image recognition.

Features:

  • Self-Attention: The heart of the Transformer lies in its self-attention mechanism. This allows the model to weigh the importance of different parts of the input sequence when processing a particular element. Instead of relying on the order of the input, self-attention dynamically learns relationships between all elements simultaneously.
  • Parallel Processing: Unlike sequential models, Transformers can process the entire input sequence in parallel, leveraging the power of modern GPUs. This drastically reduces training time, especially for long sequences.
  • Encoder-Decoder Structure: Many Transformer models employ an encoder-decoder structure. The encoder processes the input sequence and generates a contextualized representation. The decoder then uses this representation to generate the output sequence.
  • Multi-Head Attention: To capture different aspects of the relationships within the input sequence, Transformers utilize multiple attention heads. Each head learns a different set of attention weights, providing a richer representation of the input.
  • Positional Encoding: Since Transformers process data in parallel, they need a mechanism to understand the order of elements in the sequence. Positional encoding adds information about the position of each element to the input embedding.

Code Example (PyTorch):

This simplified example demonstrates a single self-attention layer using PyTorch:

import torch
import torch.nn as nn

class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert (self.head_dim * heads == embed_size), "Embed size needs to be divisible by heads"

        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

    def forward(self, values, keys, query, mask):
        N = query.shape[0]  # Number of examples
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

        # Split embedding into self.heads pieces
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        query = query.reshape(N, query_len, self.heads, self.head_dim)

        values = self.values(values)  # (N, value_len, heads, head_dim)
        keys = self.keys(keys)  # (N, key_len, heads, head_dim)
        query = self.queries(query)  # (N, query_len, heads, head_dim)

        # Scaled dot-product attention
        energy = torch.einsum("nqhd,nkhd->nhqk", [query, keys])
        # query shape: (N, query_len, heads, head_dim)
        # keys shape: (N, key_len, heads, head_dim)
        # energy shape: (N, heads, query_len, key_len)

        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))

        attention = torch.softmax(energy / (self.embed_size ** (1/2)), dim=3)

        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads * self.head_dim
        )
        # attention shape: (N, heads, query_len, key_len)
        # values shape: (N, value_len, heads, head_dim)
        # out shape: (N, query_len, heads, head_dim) then flatten last two dim

        out = self.fc_out(out)
        return out

# Example usage
embed_size = 512
heads = 8
seq_len = 32
N = 4  # Batch size

values = torch.randn((N, seq_len, embed_size))
keys = torch.randn((N, seq_len, embed_size))
query = torch.randn((N, seq_len, embed_size))
attention = SelfAttention(embed_size, heads)
output = attention(values, keys, query, mask=None)
print(output.shape) # Output shape: torch.Size([4, 32, 512])
Enter fullscreen mode Exit fullscreen mode

This code defines a SelfAttention class that performs multi-head self-attention. It takes values, keys, and query as input, representing the embedded representations of the input sequence. The forward method calculates the attention weights and produces the output.

Installation:

To run the example above, you need to install PyTorch. You can install it using pip:

pip install torch
Enter fullscreen mode Exit fullscreen mode

This example provides a glimpse into the core of the Transformer architecture. By understanding its fundamental components, developers can leverage its power to build innovative AI solutions. Further exploration of more complex Transformer models, such as BERT and GPT, will reveal the full potential of this groundbreaking architecture.

Top comments (0)