Learn AI Series (#53) - The Transformer Architecture (Part 2)

avatar

Learn AI Series (#53) - The Transformer Architecture (Part 2)

ai-banner.png

What will I learn

  • You will learn the decoder block -- masked self-attention plus cross-attention plus feed-forward;
  • why masking is necessary for autoregressive generation and how causal masks work;
  • cross-attention: how the decoder queries the encoder's representation;
  • the complete encoder-decoder transformer assembled end-to-end;
  • implementing a small but functional transformer in PyTorch;
  • training vs inference: why the decoder runs differently in each mode;
  • pre-norm vs post-norm and why modern models switched;
  • why the decoder-only variant (GPT) won the race for general-purpose AI.

Requirements

  • A working modern computer running macOS, Windows or Ubuntu;
  • An installed Python 3(.11+) distribution;
  • The ambition to learn AI and machine learning.

Difficulty

  • Beginner

Curriculum (of the Learn AI Series):

Learn AI Series (#53) - The Transformer Architecture (Part 2)

Solutions to Episode #52 Exercises

Exercise 1: Parameter counting and layer analysis for a full-size transformer encoder.

import torch
import torch.nn as nn
import math

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.d_k = d_model // n_heads
        self.n_heads = n_heads
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_out = nn.Linear(d_model, d_model)

    def forward(self, Q, K, V, mask=None):
        batch = Q.size(0)
        Q = self.W_q(Q).view(batch, -1, self.n_heads, self.d_k).transpose(1, 2)
        K = self.W_k(K).view(batch, -1, self.n_heads, self.d_k).transpose(1, 2)
        V = self.W_v(V).view(batch, -1, self.n_heads, self.d_k).transpose(1, 2)
        scores = (Q @ K.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        weights = torch.softmax(scores, dim=-1)
        out = (weights @ V).transpose(1, 2).contiguous().view(
            batch, -1, self.n_heads * self.d_k)
        return self.W_out(out)

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len).unsqueeze(1).float()
        div = torch.exp(torch.arange(0, d_model, 2).float() *
                        (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

class TransformerEncoderBlock(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super().__init__()
        self.attn = MultiHeadAttention(d_model, n_heads)
        self.ff = nn.Sequential(
            nn.Linear(d_model, d_ff), nn.ReLU(), nn.Linear(d_ff, d_model))
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.drop = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        x = self.norm1(x + self.drop(self.attn(x, x, x, mask)))
        x = self.norm2(x + self.drop(self.ff(x)))
        return x

class TransformerEncoder(nn.Module):
    def __init__(self, vocab_sz, d_model, n_heads, n_layers, d_ff, max_len=5000):
        super().__init__()
        self.embed = nn.Embedding(vocab_sz, d_model)
        self.pe = PositionalEncoding(d_model, max_len)
        self.layers = nn.ModuleList([
            TransformerEncoderBlock(d_model, n_heads, d_ff)
            for _ in range(n_layers)])
        self.scale = math.sqrt(d_model)

    def forward(self, src, mask=None):
        x = self.pe(self.embed(src) * self.scale)
        for layer in self.layers:
            x = layer(x, mask)
        return x

enc = TransformerEncoder(vocab_sz=10000, d_model=512, n_heads=8,
                         n_layers=6, d_ff=2048)
total = sum(p.numel() for p in enc.parameters())

emb_params = sum(p.numel() for p in enc.embed.parameters())
attn_params = sum(sum(p.numel() for p in layer.attn.parameters())
                  for layer in enc.layers)
ff_params = sum(sum(p.numel() for p in layer.ff.parameters())
                for layer in enc.layers)
norm_params = sum(sum(p.numel() for p in layer.norm1.parameters()) +
                  sum(p.numel() for p in layer.norm2.parameters())
                  for layer in enc.layers)

print(f"Total parameters: {total:,}")
print(f"  Embedding:      {emb_params:>10,} ({emb_params/total:.1%})")
print(f"  Attention (all): {attn_params:>10,} ({attn_params/total:.1%})")
print(f"  Feed-forward:   {ff_params:>10,} ({ff_params/total:.1%})")
print(f"  LayerNorm:      {norm_params:>10,} ({norm_params/total:.1%})")

The feed-forward layers dominate at roughly two-thirds of the non-embedding parameters. Each FFN does a 512->2048 expansion (1,048,576 weights + 2048 bias) and a 2048->512 contraction (1,048,576 weights + 512 bias) -- that's over 2 million parameters per layer, times 6 layers. The attention projections (Q, K, V, output) account for about one-third of the per-layer parameters. LayerNorm is negligible.

Exercise 2: Pre-norm vs post-norm gradient stability comparison.

import torch
import torch.nn as nn
import math

class PostNormBlock(nn.Module):
    def __init__(self, d_model, n_heads, d_ff=None):
        super().__init__()
        d_ff = d_ff or d_model * 4
        self.attn = MultiHeadAttention(d_model, n_heads)
        self.ff = nn.Sequential(
            nn.Linear(d_model, d_ff), nn.ReLU(), nn.Linear(d_ff, d_model))
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x):
        x = self.norm1(x + self.attn(x, x, x))
        x = self.norm2(x + self.ff(x))
        return x

class PreNormBlock(nn.Module):
    def __init__(self, d_model, n_heads, d_ff=None):
        super().__init__()
        d_ff = d_ff or d_model * 4
        self.attn = MultiHeadAttention(d_model, n_heads)
        self.ff = nn.Sequential(
            nn.Linear(d_model, d_ff), nn.ReLU(), nn.Linear(d_ff, d_model))
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x):
        x = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x))
        x = x + self.ff(self.norm2(x))
        return x

torch.manual_seed(42)
d_model, n_heads = 128, 4

for label, BlockClass in [("Post-norm", PostNormBlock),
                           ("Pre-norm", PreNormBlock)]:
    blocks = nn.ModuleList([BlockClass(d_model, n_heads) for _ in range(12)])
    x = torch.randn(2, 10, d_model, requires_grad=True)

    out = x
    norms_at = {}
    for i, block in enumerate(blocks):
        out = block(out)
        if (i + 1) in [1, 6, 12]:
            norms_at[i + 1] = out.norm().item()

    loss = out.sum()
    loss.backward()
    grad_norm = x.grad.norm().item()

    print(f"{label}:")
    for depth, norm_val in norms_at.items():
        print(f"  After {depth:>2d} layers: output norm = {norm_val:.2f}")
    print(f"  Gradient norm at input: {grad_norm:.4f}\n")

Pre-norm produces more stable output norms as depth increases because the residual path carries raw (unnormalized) values -- the signal doesn't get repeatedly squeezed by normalization. The gradient norms at the input should also be more stable for pre-norm, meaning the first layers still receive useful learning signals even in a 12-layer stack. This is why GPT-2, GPT-3, and essentially every modern transformer uses pre-norm.

Exercise 3: Demonstrating that encoder output at one position is influenced by ALL other positions.

import torch
import torch.nn as nn
import math

enc = TransformerEncoder(vocab_sz=5000, d_model=128, n_heads=4,
                         n_layers=4, d_ff=512)
enc.eval()

torch.manual_seed(42)
src = torch.randint(0, 5000, (3, 15))

with torch.no_grad():
    out_original = enc(src)
    print(f"Output shape: {out_original.shape}")  # (3, 15, 128)

    # Mask position 3: zero out its embedding
    emb_original = enc.embed(src) * enc.scale
    emb_masked_3 = emb_original.clone()
    emb_masked_3[:, 3, :] = 0.0

    x = enc.pe(emb_masked_3)
    for layer in enc.layers:
        x = layer(x)
    out_masked_3 = x

    # Mask position 14: zero out its embedding
    emb_masked_14 = emb_original.clone()
    emb_masked_14[:, 14, :] = 0.0

    x = enc.pe(emb_masked_14)
    for layer in enc.layers:
        x = layer(x)
    out_masked_14 = x

    # Measure L2 distance at position 7
    diff_3 = (out_original[:, 7] - out_masked_3[:, 7]).norm(dim=-1)
    diff_14 = (out_original[:, 7] - out_masked_14[:, 7]).norm(dim=-1)

    print(f"\nL2 distance at position 7 when masking position 3:")
    for b in range(3):
        print(f"  Batch {b}: {diff_3[b]:.4f}")

    print(f"\nL2 distance at position 7 when masking position 14:")
    for b in range(3):
        print(f"  Batch {b}: {diff_14[b]:.4f}")

    print(f"\nBoth non-zero: position 7 is influenced by ALL positions")

Both masking experiments produce non-zero L2 distances at position 7. This proves that self-attention connects all positions regardless of distance -- masking position 3 (nearby) and position 14 (far away) both change the representation at position 7. In an RNN, position 14's influence on position 7 would pass through 7 sequential steps of hidden state updates with potential information decay. In a transformer, it's a direct connection through self-attention in a single layer, and that connection is reinforced across all 4 layers.

On to today's episode

Here we go! Last episode we built the encoder half of the transformer from the ground up: embedding, positional encoding, multi-head self-attention, feed-forward networks, residual connections, layer normalization. The encoder reads the entire input sequence and produces a rich contextual representation where every position has been informed by every other position through 6 layers of parallel self-attention.

Now we build the other half: the decoder. And then we assemble the complete transformer end-to-end.

The decoder is trickier than the encoder because it has to do two things simultaneously. It needs to attend to the encoder's output (to read the source), AND it needs to attend to its own previously generated tokens (to maintain coherence) -- but it must NOT look at future tokens (because at inference time, those tokens don't exist yet). Solving this constraint elegantly is what makes the transformer decoder so clever ;-)

The causal mask problem

During training, we feed the entire target sequence to the decoder at once. For a translation from "How are you?" to "Hoe gaat het", the decoder receives all three Dutch tokens simultaneously. But at inference time, the decoder generates one token at a time: first "Hoe", then "gaat" (conditioned on "Hoe"), then "het" (conditioned on "Hoe gaat").

If self-attention in the decoder lets position 1 ("gaat") see position 2 ("het") during training, the model learns to cheat -- it looks at the answer in stead of predicting it. At inference time, when "het" doesn't exist yet, the model's predictions would be useless because it was trained with information it won't have.

The solution is masked self-attention: we mask out all positions to the right of the current position, so each position can only attend to itself and previous positions. This is called a causal mask because it enforces causality -- the output at position t depends only on positions 0 through t, never on future positions.

import torch, torch.nn as nn, torch.nn.functional as F
import math

def make_causal_mask(seq_len):
    """Lower-triangular mask: position i can attend to positions 0..i"""
    mask = torch.tril(torch.ones(seq_len, seq_len))
    return mask  # (seq_len, seq_len)

mask = make_causal_mask(5)
print(mask)
# tensor([[1., 0., 0., 0., 0.],
#         [1., 1., 0., 0., 0.],
#         [1., 1., 1., 0., 0.],
#         [1., 1., 1., 1., 0.],
#         [1., 1., 1., 1., 1.]])

Position 0 can only see itself. Position 2 can see positions 0, 1, and 2. Position 4 can see everything. The zeros become -inf before softmax, which pushes those attention weights to 0. The model literally cannot look ahead.

This is computationally elegant -- the mask is just a matrix multiplication trick. The attention computation is identical to the encoder's, except we add the mask to the scores before softmax. All positions are still computed in parallel during training; the mask just prevents information leakage from future positions. Same parallelism advantage, but with the autoregressive constraint baked in through the mask rather than through sequential processing. Pretty slick if you ask me.

The decoder block: three sub-layers

The encoder block has two sub-layers (self-attention + feed-forward). The decoder block has three:

  1. Masked self-attention: the decoder attends to itself, with the causal mask preventing future positions from being visible
  2. Cross-attention: the decoder attends to the encoder's output -- queries come from the decoder, keys and values come from the encoder
  3. Feed-forward: same position-wise expansion as the encoder

Each sub-layer has its own residual connection and layer normalization, exactly like the encoder (episode #52).

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        assert d_model % n_heads == 0
        self.d_k = d_model // n_heads
        self.n_heads = n_heads
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_out = nn.Linear(d_model, d_model)

    def forward(self, Q, K, V, mask=None):
        batch = Q.size(0)
        Q = self.W_q(Q).view(batch, -1, self.n_heads, self.d_k).transpose(1, 2)
        K = self.W_k(K).view(batch, -1, self.n_heads, self.d_k).transpose(1, 2)
        V = self.W_v(V).view(batch, -1, self.n_heads, self.d_k).transpose(1, 2)
        scores = (Q @ K.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        weights = F.softmax(scores, dim=-1)
        out = (weights @ V).transpose(1, 2).contiguous().view(
            batch, -1, self.n_heads * self.d_k)
        return self.W_out(out)

class TransformerDecoderBlock(nn.Module):
    def __init__(self, d_model, n_heads, d_ff=None, dropout=0.1):
        super().__init__()
        d_ff = d_ff or d_model * 4
        self.self_attn = MultiHeadAttention(d_model, n_heads)
        self.cross_attn = MultiHeadAttention(d_model, n_heads)
        self.ff = nn.Sequential(
            nn.Linear(d_model, d_ff), nn.ReLU(), nn.Linear(d_ff, d_model)
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.drop = nn.Dropout(dropout)

    def forward(self, x, enc_out, causal_mask=None, src_mask=None):
        # 1. Masked self-attention
        sa = self.self_attn(x, x, x, causal_mask)
        x = self.norm1(x + self.drop(sa))
        # 2. Cross-attention: queries from decoder, K/V from encoder
        ca = self.cross_attn(x, enc_out, enc_out, src_mask)
        x = self.norm2(x + self.drop(ca))
        # 3. Feed-forward
        ff = self.ff(x)
        x = self.norm3(x + self.drop(ff))
        return x

Pay close attention to the cross-attention call: self.cross_attn(x, enc_out, enc_out). The queries come from the decoder's current representation (x). The keys and values come from the encoder's output (enc_out). This is how the decoder "reads" the source: at each position, it asks "what in the source input is most relevant to what I'm currently generating?" and gets back a weighted combination of encoder representations.

This is the same mechanism as Bahdanau attention from episode #51 -- but implemented with multi-head scaled dot-product attention in stead of a feedforward scoring network. The principle is identical: the decoder queries the encoder. The implementation is more powerful (multiple heads, parallel computation) and more parallelizable (no sequential RNN processing).

Having said that, notice how the three sub-layers build on each other in a very specific order. First, the decoder figures out what it has generated so far (masked self-attention). Then, it reads the source (cross-attention). Then, it processes the combined information (feed-forward). Each step enriches the representation, and each gets its own residual connection so information can flow freely through the block.

The complete transformer

Now we assemble everything. Encoder + decoder + a final linear layer that projects decoder outputs to vocabulary logits:

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len).unsqueeze(1).float()
        div = torch.exp(torch.arange(0, d_model, 2).float() *
                        (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

class TransformerEncoderBlock(nn.Module):
    def __init__(self, d_model, n_heads, d_ff=None, dropout=0.1):
        super().__init__()
        d_ff = d_ff or d_model * 4
        self.attn = MultiHeadAttention(d_model, n_heads)
        self.ff = nn.Sequential(
            nn.Linear(d_model, d_ff), nn.ReLU(), nn.Linear(d_ff, d_model)
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.drop = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        x = self.norm1(x + self.drop(self.attn(x, x, x, mask)))
        x = self.norm2(x + self.drop(self.ff(x)))
        return x

class Transformer(nn.Module):
    def __init__(self, src_vocab, tgt_vocab, d_model=64, n_heads=8,
                 n_enc=6, n_dec=6, d_ff=256, max_len=512):
        super().__init__()
        self.d_model = d_model
        self.src_emb = nn.Embedding(src_vocab, d_model)
        self.tgt_emb = nn.Embedding(tgt_vocab, d_model)
        self.pe = PositionalEncoding(d_model, max_len)
        self.enc_layers = nn.ModuleList(
            [TransformerEncoderBlock(d_model, n_heads, d_ff)
             for _ in range(n_enc)]
        )
        self.dec_layers = nn.ModuleList(
            [TransformerDecoderBlock(d_model, n_heads, d_ff)
             for _ in range(n_dec)]
        )
        self.out_proj = nn.Linear(d_model, tgt_vocab)
        self.scale = math.sqrt(d_model)

    def encode(self, src, src_mask=None):
        x = self.pe(self.src_emb(src) * self.scale)
        for layer in self.enc_layers:
            x = layer(x, src_mask)
        return x

    def decode(self, tgt, enc_out, causal_mask, src_mask=None):
        x = self.pe(self.tgt_emb(tgt) * self.scale)
        for layer in self.dec_layers:
            x = layer(x, enc_out, causal_mask, src_mask)
        return x

    def forward(self, src, tgt):
        causal_mask = make_causal_mask(tgt.size(1)).to(tgt.device)
        enc_out = self.encode(src)
        dec_out = self.decode(tgt, enc_out, causal_mask)
        return self.out_proj(dec_out)

# Build and test
model = Transformer(src_vocab=8000, tgt_vocab=6000,
                    d_model=64, n_heads=8, n_enc=3, n_dec=3)
src = torch.randint(0, 8000, (2, 12))   # English: batch=2, len=12
tgt = torch.randint(0, 6000, (2, 15))   # Dutch: batch=2, len=15
logits = model(src, tgt)
print(f"Source: {src.shape}")
print(f"Target: {tgt.shape}")
print(f"Output logits: {logits.shape}")  # (2, 15, 6000)

n_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {n_params:,}")

That's a complete transformer. About 150 lines of PyTorch. The output logits have shape (batch, target_length, target_vocab) -- for each position in the target sequence, a probability distribution over the entire target vocabulary. During training, you compare these logits to the actual target tokens using cross-entropy loss. During inference, you take the argmax (or sample) at each position.

Look at what we've got here. Two embedding tables (source and target). One shared positional encoding. An encoder stack that reads the source in parallel. A decoder stack that reads its own output (masked) and the encoder's output (cross-attention) at every layer. And a final projection to vocab logits. That's the whole thing. Every major AI system since 2017 is a variant of this architecture -- GPT is the decoder stack alone, BERT is the encoder stack alone, T5 uses both.

Training vs inference: a crucial difference

This is important to understand, and I see people get confused about it constantly, so let's be very explicit.

During training, the decoder receives the complete target sequence shifted right by one position. If the target is [<sos>, Hoe, gaat, het, <eos>], the input to the decoder is [<sos>, Hoe, gaat, het] and we predict [Hoe, gaat, het, <eos>]. The causal mask ensures that predicting "gaat" at position 2 only sees [<sos>, Hoe] -- not "gaat" itself or anything after it. All positions are predicted in parallel. This is called teacher forcing (we covered it in episode #50).

During inference, the decoder runs autoregressively -- one token at a time:

  1. Feed [<sos>] -> predict "Hoe"
  2. Feed [<sos>, Hoe] -> predict "gaat"
  3. Feed [<sos>, Hoe, gaat] -> predict "het"
  4. Feed [<sos>, Hoe, gaat, het] -> predict <eos> -> stop
@torch.no_grad()
def greedy_decode(model, src, max_len=50, sos_id=1, eos_id=2):
    enc_out = model.encode(src)
    tgt_ids = [sos_id]
    for _ in range(max_len):
        tgt = torch.tensor([tgt_ids], device=src.device)
        mask = make_causal_mask(tgt.size(1)).to(src.device)
        dec_out = model.decode(tgt, enc_out, mask)
        logits = model.out_proj(dec_out[:, -1, :])  # last position only
        next_id = logits.argmax(-1).item()
        if next_id == eos_id:
            break
        tgt_ids.append(next_id)
    return tgt_ids[1:]  # strip <sos>

This training-inference gap matters for performance. Training is fast because everything is parallel -- the causal mask simulates the autoregressive constraint without actually doing it sequentially. Inference is sequential by nature: you can't predict token 5 until you know token 4.

The encoder only runs once per input during inference (it processes the full source in parallel). The decoder runs once per generated token. This is why inference latency scales linearly with output length, not input length. And it's one of the reasons why KV caching (storing and reusing previously computed key-value pairs in stead of recomputing them at every step) is such an important optimization in production -- without it, generating 100 tokens means re-running the decoder on sequences of length 1, 2, 3, ..., 100, which is quadratically wasteful.

Layer normalization: pre-norm vs post-norm

Our implementation uses post-norm: x = norm(x + sublayer(x)). The original 2017 paper also used post-norm. But starting around 2019, almost every major model switched to pre-norm: x = x + sublayer(norm(x)).

The difference is subtle but practically important. Pre-norm applies normalization before the sub-layer, which means the residual connection carries raw (unnormalized) values. This creates a more direct gradient path and makes training more stable, especially for deep models (24+ layers). GPT-2, GPT-3, and most modern transformers use pre-norm.

# Post-norm (original paper)
x = self.norm1(x + self.drop(self.attn(x, x, x, mask)))

# Pre-norm (modern standard)
x = x + self.drop(self.attn(self.norm1(x), self.norm1(x), self.norm1(x), mask))

If you're building a production model, use pre-norm. We implemented post-norm here to match the original paper, but the tradeoff is clear: post-norm requires more careful learning rate tuning and warmup schedules, while pre-norm is far more forgiving. The architecture is the same either way -- just a question of where you put the norm() call ;-)

What the original transformer looked like

The 2017 "Attention Is All You Need" paper tested the transformer on English-to-German and English-to-French translation with these hyperparameters:

  • d_model = 512
  • n_heads = 8 (d_k = 64 per head)
  • n_layers = 6 encoder + 6 decoder
  • d_ff = 2048 (4x d_model)
  • dropout = 0.1
  • ~65 million parameters (for the base model)
  • Trained on 8 GPUs for 3.5 days

It achieved state-of-the-art translation quality while training significantly faster than the best RNN-based models. The "big" variant (d_model=1024, n_heads=16, d_ff=4096) used ~213 million parameters and trained for 3.5 days on 8 GPUs.

For context: GPT-3 (2020) is a decoder-only transformer with 175 billion parameters, 96 layers, 96 heads, d_model=12288. The architecture is fundamentally the same -- just scaled up by a factor of 1000x. That scaling is what made the transformer the foundation of modern AI. The key insight wasn't just that attention works -- it's that this particular architecture scales predictably. Double the parameters, double the data, get measurably better results. No other architecture in the history of ML has shown this level of predictable scaling, and it's why hundreds of billions of dollars are being poured into training ever-larger transformers.

The encoder-decoder split

Not every application needs both halves of the transformer. This is important because the variants that dropped one half turned out to be more influential than the full encoder-decoder model:

Encoder-only (BERT and its descendants): for understanding tasks -- classification, named entity recognition, question answering. The encoder produces rich contextual representations; a task-specific head on top does the prediction. The input is bidirectional: every token sees every other token in both directions.

Decoder-only (GPT and its descendants): for generation tasks -- text completion, conversation, code generation. No encoder at all. The decoder attends to its own previous tokens via masked self-attention and generates the next token. Unidirectional: each token only sees tokens before it.

Encoder-decoder (the original transformer, T5, BART): for sequence-to-sequence tasks -- translation, summarization, question answering with generation. The full architecture as we've built it today.

The decoder-only variant turned out to be the winner for general-purpose AI. GPT showed that if you train a decoder-only transformer on enough data, it can do classification, translation, summarization, and generation all through next-token prediction alone. No encoder needed. No task-specific heads. Just predict the next word, over and over, trained on trillions of tokens. This simplification -- plus the scaling laws that showed predictable improvement with more compute -- is what led to the LLM revolution.

We'll explore exactly how encoder-only and decoder-only models work in upcoming episodes. They're both specializations of the architecture we've built today, and understanding the full encoder-decoder version makes understanding the variants straightforward.

Putting it in perspective

Let's step back for a moment and appreciate what we've built across episodes #52 and #53. Starting from episode #37 (the perceptron -- a single neuron), we went through:

  • Multi-layer neural networks and backpropagation (#38-39)
  • Training challenges and optimization (#40-41)
  • PyTorch as our framework (#42-44)
  • CNNs for spatial patterns (#45-47)
  • RNNs for sequential processing (#48)
  • LSTMs for long-range memory (#49)
  • Seq2seq for sequence transformation (#50)
  • Attention to fix the bottleneck (#51)
  • And now the transformer, the architecture that unifies it all (#52-53)

Every piece was necessary. The transformer uses linear projections (from the basics of neural networks), nonlinear activations (ReLU in the feed-forward layers), residual connections (from ResNet, which we covered in the CNN episodes), layer normalization (related to batch norm from CNNs), attention (from the seq2seq chapter), and positional encoding (because we dropped the sequential processing that RNNs gave us for free). None of it came from nowhere -- it all connects back.

And the complete implementation? ~150 lines. That's not because it's simple. It's because the components are composable. Attention blocks, feed-forward blocks, residual connections, normalization -- they stack cleanly because each component has a well-defined interface (tensor in, same-shape tensor out). This composability is what makes the architecture so successful. You can scale it by adding layers, adding heads, widening dimensions, or all three -- and it just works. No fundamental rewiring needed.

The bottom line

  • The decoder has three sub-layers: masked self-attention (can't see future tokens), cross-attention (reads the encoder), and feed-forward;
  • Causal masking prevents information leakage during training while allowing parallel computation -- a lower-triangular mask filled with -inf for illegal positions;
  • Cross-attention uses queries from the decoder and keys/values from the encoder -- the mechanism that connects the two halves;
  • Training processes all target positions in parallel (masked to simulate autoregression); inference generates one token at a time;
  • Pre-norm (normalize before sub-layer) is the modern standard -- more stable than the original post-norm for deep models;
  • Not all transformers need both halves: encoder-only (BERT), decoder-only (GPT), and encoder-decoder (T5) serve different purposes -- and decoder-only turned out to be the most versatile;
  • The complete transformer is remarkably compact in code (~150 lines) yet scales from 65M parameters to 175B+ by adjusting hyperparameters;
  • Every component in the transformer traces back to concepts we've built throughout this series -- nothing appeared from thin air.

Exercises

Exercise 1: Build a training loop for our complete Transformer on a simple task. Use the sequence reversal task from episode #50 (input: random token sequences, output: the same tokens in reverse order). Configure the transformer with src_vocab=tgt_vocab=20, d_model=64, n_heads=4, n_enc=2, n_dec=2, d_ff=128. Train for 30 epochs on sequences of length 10 using cross-entropy loss and Adam optimizer. Implement proper teacher forcing (feed the shifted target to the decoder). Print token accuracy and full-sequence accuracy on a held-out test set. Compare the results against the LSTM-based seq2seq from episode #50 -- the transformer should reach higher accuracy faster.

Exercise 2: Implement greedy decoding and beam search for the Transformer. Using the trained model from Exercise 1, implement both decoding strategies. For beam search, use beam width 3. Generate predictions for 20 test sequences using both strategies and compare: (a) how many sequences each gets perfectly right, and (b) the average token accuracy. Print side-by-side comparisons for 5 example sequences showing source, target, greedy prediction, and beam search prediction.

Exercise 3: Build a KV cache for faster inference. Modify the greedy_decode function so that at each step, in stead of re-encoding the full partial target sequence, the decoder reuses previously computed key-value pairs and only computes attention for the new token. Measure the wall-clock time for generating 50 tokens with and without KV caching (use a dummy trained model, accuracy doesn't matter here). Print the speedup factor. The caching version should be significantly faster because it avoids redundant computation -- each step does O(1) work per layer in stead of O(t) where t is the current sequence length.

Thanks for your time!

@scipio



0
0
0.000
0 comments