Attention & transformers

Tokenization, embeddings, scaled dot-product attention, multi-head attention, the transformer block, and how the same architecture powers modern NLP, vision (ViT), and even graph and audio models.

Why this matters. Transformers are the substrate of nearly every modern AI system — LLMs, image generators, speech recognition, code assistants. Understanding attention is no longer optional.

Tokenization

Models don't see text — they see integer token IDs. The tokenizer converts strings to IDs and back.

# conceptual sketch — using a BPE-style tokenizer
text = "Transformers tokenize text into integer IDs."
tokens = tokenizer.encode(text)   # e.g. [1532, 18, 9241, ...]
back   = tokenizer.decode(tokens) # "Transformers tokenize ..."

Embeddings

Each token ID is looked up in an embedding table to produce a dense vector (typically 256–4096 dimensions).

import torch.nn as nn

vocab_size = 32_000
d_model    = 512

token_embed = nn.Embedding(vocab_size, d_model)
ids   = torch.tensor([[1, 5, 9, 4]])   # batch=1, seq_len=4
x     = token_embed(ids)               # shape (1, 4, 512)

Embeddings are learned during training — semantically similar tokens end up nearby in the vector space.

Positional embeddings. Self-attention is permutation-equivariant — it doesn't know token order. So position is injected explicitly, either as learned positional embeddings, sinusoidal embeddings, or rotary (RoPE) embeddings added or applied to the input vectors.

Scaled dot-product attention

The core mechanism. Given three matrices — Queries Q, Keys K, Values V — attention computes:

Attention(Q, K, V) = softmax( Q Kᵀ / √dₖ ) · V
import torch
import torch.nn.functional as F

def attention(Q, K, V, mask=None):
    # Q, K, V: (batch, n_heads, seq_len, d_head)
    d_k = Q.size(-1)
    scores = (Q @ K.transpose(-2, -1)) / (d_k ** 0.5)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))
    weights = F.softmax(scores, dim=-1)
    return weights @ V

Multi-head attention

Instead of one attention with full d_model dimensions, split into h heads of size d_model / h. Each head learns to look at a different relationship; their outputs are concatenated and linearly projected.

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        assert d_model % n_heads == 0
        self.n_heads = n_heads
        self.d_head  = d_model // n_heads
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

    def forward(self, x, mask=None):
        B, T, D = x.shape
        # project and reshape to (B, n_heads, T, d_head)
        Q = self.W_q(x).view(B, T, self.n_heads, self.d_head).transpose(1, 2)
        K = self.W_k(x).view(B, T, self.n_heads, self.d_head).transpose(1, 2)
        V = self.W_v(x).view(B, T, self.n_heads, self.d_head).transpose(1, 2)
        out = attention(Q, K, V, mask)            # (B, n_heads, T, d_head)
        out = out.transpose(1, 2).contiguous().view(B, T, D)
        return self.W_o(out)

The transformer block

A single transformer block stacks: self-attention → residual + LayerNorm → feed-forward MLP → residual + LayerNorm.

class TransformerBlock(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super().__init__()
        self.attn  = MultiHeadAttention(d_model, n_heads)
        self.norm1 = nn.LayerNorm(d_model)
        self.ff = nn.Sequential(
            nn.Linear(d_model, d_ff), nn.GELU(),
            nn.Linear(d_ff, d_model),
        )
        self.norm2 = nn.LayerNorm(d_model)
        self.drop  = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        x = x + self.drop(self.attn(self.norm1(x), mask))
        x = x + self.drop(self.ff(self.norm2(x)))
        return x

Stack N of these (typically 6–96) and you have a transformer encoder. Add a causal mask to the attention and you have a decoder (used in GPT-style language models).

Pre-training & fine-tuning

NLP applications

Computer vision

Common pitfalls