Learn AI Series (#52) - The Transformer Architecture (Part 1)
Learn AI Series (#52) - The Transformer Architecture (Part 1)

What will I learn
- You will learn what "Attention Is All You Need" actually means and why the transformer changed everything;
- scaled dot-product attention -- the mathematical core driving every transformer;
- multi-head attention -- how the model examines different relationship types simultaneously;
- positional encoding -- giving transformers a sense of order without recurrence;
- the encoder block -- self-attention + feed-forward + residual connections + layer normalization;
- why transformers scale where RNNs could not.
Requirements
- A working modern computer running macOS, Windows or Ubuntu;
- An installed Python 3(.11+) distribution;
- The ambition to learn AI and machine learning.
Difficulty
- Beginner
Curriculum (of the Learn AI Series):
- Learn AI Series (#1) - What Machine Learning Actually Is
- Learn AI Series (#2) - Setting Up Your AI Workbench - Python and NumPy
- Learn AI Series (#3) - Your Data Is Just Numbers - How Machines See the World
- Learn AI Series (#4) - Your First Prediction - No Math, Just Intuition
- Learn AI Series (#5) - Patterns in Data - What "Learning" Actually Looks Like
- Learn AI Series (#6) - From Intuition to Math - Why We Need Formulas
- Learn AI Series (#7) - The Training Loop - See It Work Step by Step
- Learn AI Series (#8) - The Math You Actually Need (Part 1) - Linear Algebra
- Learn AI Series (#9) - The Math You Actually Need (Part 2) - Calculus and Probability
- Learn AI Series (#10) - Your First ML Model - Linear Regression From Scratch
- Learn AI Series (#11) - Making Linear Regression Real
- Learn AI Series (#12) - Classification - Logistic Regression From Scratch
- Learn AI Series (#13) - Evaluation - How to Know If Your Model Actually Works
- Learn AI Series (#14) - Data Preparation - The 80% Nobody Talks About
- Learn AI Series (#15) - Feature Engineering and Selection
- Learn AI Series (#16) - Scikit-Learn - The Standard Library of ML
- Learn AI Series (#17) - Decision Trees - How Machines Make Decisions
- Learn AI Series (#18) - Random Forests - Wisdom of Crowds
- Learn AI Series (#19) - Gradient Boosting - The Kaggle Champion
- Learn AI Series (#20) - Support Vector Machines - Drawing the Perfect Boundary
- Learn AI Series (#21) - Mini Project - Predicting Crypto Market Regimes
- Learn AI Series (#22) - K-Means Clustering - Finding Groups
- Learn AI Series (#23) - Advanced Clustering - Beyond K-Means
- Learn AI Series (#24) - Dimensionality Reduction - PCA
- Learn AI Series (#25) - Advanced Dimensionality Reduction - t-SNE and UMAP
- Learn AI Series (#26) - Anomaly Detection - Finding What Doesn't Belong
- Learn AI Series (#27) - Recommendation Systems - "Users Like You Also Liked..."
- Learn AI Series (#28) - Time Series Fundamentals - When Order Matters
- Learn AI Series (#29) - Time Series Forecasting - Predicting What Comes Next
- Learn AI Series (#30) - Natural Language Processing - Text as Data
- Learn AI Series (#31) - Word Embeddings - Meaning in Numbers
- Learn AI Series (#32) - Bayesian Methods - Thinking in Probabilities
- Learn AI Series (#33) - Ensemble Methods Deep Dive - Stacking and Blending
- Learn AI Series (#34) - ML Engineering - From Notebook to Production
- Learn AI Series (#35) - Data Ethics and Bias in ML
- Learn AI Series (#36) - Mini Project - Complete ML Pipeline
- Learn AI Series (#37) - The Perceptron - Where It All Started
- Learn AI Series (#38) - Neural Networks From Scratch - Forward Pass
- Learn AI Series (#39) - Neural Networks From Scratch - Backpropagation
- Learn AI Series (#40) - Training Neural Networks - Practical Challenges
- Learn AI Series (#41) - Optimization Algorithms - SGD, Momentum, Adam
- Learn AI Series (#42) - PyTorch Fundamentals - Tensors and Autograd
- Learn AI Series (#43) - PyTorch Data and Training
- Learn AI Series (#44) - PyTorch nn.Module - Building Real Networks
- Learn AI Series (#45) - Convolutional Neural Networks - Theory
- Learn AI Series (#46) - CNNs in Practice - Classic to Modern Architectures
- Learn AI Series (#47) - CNN Applications - Detection, Segmentation, Style Transfer
- Learn AI Series (#48) - Recurrent Neural Networks - Sequences
- Learn AI Series (#49) - LSTM and GRU - Solving the Memory Problem
- Learn AI Series (#50) - Sequence-to-Sequence Models
- Learn AI Series (#51) - Attention Mechanisms
- Learn AI Series (#52) - The Transformer Architecture (Part 1) (this post)
Learn AI Series (#52) - The Transformer Architecture (Part 1)
Solutions to Episode #51 Exercises
Exercise 1: Comparative attention decoder experiment -- vanilla seq2seq vs Bahdanau attention on sequence reversal, lengths 20 and 40.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
class Encoder(nn.Module):
def __init__(self, vocab_sz, emb_d, hid_d):
super().__init__()
self.embed = nn.Embedding(vocab_sz, emb_d)
self.lstm = nn.LSTM(emb_d, hid_d, batch_first=True)
def forward(self, src):
outputs, (h, c) = self.lstm(self.embed(src))
return outputs, h, c
class VanillaDecoder(nn.Module):
def __init__(self, vocab_sz, emb_d, hid_d):
super().__init__()
self.embed = nn.Embedding(vocab_sz, emb_d)
self.lstm = nn.LSTM(emb_d, hid_d, batch_first=True)
self.fc = nn.Linear(hid_d, vocab_sz)
def forward(self, token, h, c, enc_outputs=None):
out, (h, c) = self.lstm(self.embed(token), (h, c))
return self.fc(out.squeeze(1)), h, c
class BahdanauAttention(nn.Module):
def __init__(self, hid_d, attn_d=64):
super().__init__()
self.W_enc = nn.Linear(hid_d, attn_d, bias=False)
self.W_dec = nn.Linear(hid_d, attn_d, bias=False)
self.v = nn.Linear(attn_d, 1, bias=False)
def forward(self, dec_hidden, enc_outputs):
dec_proj = self.W_dec(dec_hidden).unsqueeze(1)
enc_proj = self.W_enc(enc_outputs)
scores = self.v(torch.tanh(dec_proj + enc_proj)).squeeze(-1)
weights = F.softmax(scores, dim=-1)
context = torch.bmm(weights.unsqueeze(1), enc_outputs).squeeze(1)
return context, weights
class AttentionDecoder(nn.Module):
def __init__(self, vocab_sz, emb_d, hid_d):
super().__init__()
self.embed = nn.Embedding(vocab_sz, emb_d)
self.attn = BahdanauAttention(hid_d)
self.lstm = nn.LSTM(emb_d + hid_d, hid_d, batch_first=True)
self.fc = nn.Linear(hid_d * 2, vocab_sz)
def forward(self, token, h, c, enc_outputs):
emb = self.embed(token)
ctx, _ = self.attn(h[-1], enc_outputs)
lstm_in = torch.cat([emb, ctx.unsqueeze(1)], dim=-1)
out, (h, c) = self.lstm(lstm_in, (h, c))
pred = self.fc(torch.cat([out.squeeze(1), ctx], dim=-1))
return pred, h, c
def run_experiment(seq_len, use_attention, epochs=30):
torch.manual_seed(42)
vocab_sz, n_train, n_test = 20, 2500, 500
src = torch.randint(2, vocab_sz, (n_train + n_test, seq_len))
tgt = src.flip(1)
X_tr, X_te = src[:n_train], src[n_train:]
y_tr, y_te = tgt[:n_train], tgt[n_train:]
loader = DataLoader(TensorDataset(X_tr, y_tr), batch_size=64, shuffle=True)
enc = Encoder(vocab_sz, 64, 128)
dec = AttentionDecoder(vocab_sz, 64, 128) if use_attention \
else VanillaDecoder(vocab_sz, 64, 128)
params = list(enc.parameters()) + list(dec.parameters())
opt = torch.optim.Adam(params, lr=1e-3)
for epoch in range(epochs):
enc.train(); dec.train()
tf = max(0.2, 1.0 - epoch * 0.03)
for sb, tb in loader:
enc_out, h, c = enc(sb)
inp = tb[:, :1]
outputs = []
for t in range(1, tb.size(1)):
pred, h, c = dec(inp, h, c, enc_out)
outputs.append(pred)
use_tf = torch.rand(1).item() < tf
inp = tb[:, t:t+1] if use_tf else pred.argmax(-1).unsqueeze(1)
logits = torch.stack(outputs, dim=1)
loss = nn.CrossEntropyLoss()(logits.reshape(-1, vocab_sz),
tb[:, 1:].reshape(-1))
opt.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(params, 5.0)
opt.step()
enc.eval(); dec.eval()
with torch.no_grad():
enc_out, h, c = enc(X_te)
inp = y_te[:, :1]
preds = []
for t in range(1, y_te.size(1)):
pred, h, c = dec(inp, h, c, enc_out)
tok = pred.argmax(-1)
preds.append(tok)
inp = tok.unsqueeze(1)
preds = torch.stack(preds, dim=1)
tok_acc = (preds == y_te[:, 1:]).float().mean().item()
seq_acc = (preds == y_te[:, 1:]).all(dim=1).float().mean().item()
return tok_acc, seq_acc
for sl in [20, 40]:
for attn, label in [(False, "Vanilla"), (True, "Attention")]:
ta, sa = run_experiment(sl, attn)
print(f"Len={sl:>2d} {label:>9s}: tok_acc={ta:.1%}, seq_acc={sa:.1%}")
At length 20, both models perform reasonably well -- the context vector can still capture enough information for a short reversal. At length 40, the gap widens significantly. The attention model maintains strong accuracy because it can look back at any encoder position directly, while the vanilla model's single context vector gets increasingly lossy as the sequence grows.
Exercise 2: Multi-head self-attention module with 4 heads.
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class MultiHeadSelfAttention(nn.Module):
def __init__(self, dim, n_heads=4):
super().__init__()
assert dim % n_heads == 0
self.n_heads = n_heads
self.d_k = dim // n_heads
self.W_q = nn.Linear(dim, dim)
self.W_k = nn.Linear(dim, dim)
self.W_v = nn.Linear(dim, dim)
self.W_out = nn.Linear(dim, dim)
def forward(self, x):
B, L, D = x.shape
Q = self.W_q(x).view(B, L, self.n_heads, self.d_k).transpose(1, 2)
K = self.W_k(x).view(B, L, self.n_heads, self.d_k).transpose(1, 2)
V = self.W_v(x).view(B, L, self.n_heads, self.d_k).transpose(1, 2)
scores = (Q @ K.transpose(-2, -1)) / math.sqrt(self.d_k)
weights = F.softmax(scores, dim=-1)
attn_out = weights @ V
concat = attn_out.transpose(1, 2).contiguous().view(B, L, D)
return self.W_out(concat), weights
mhsa = MultiHeadSelfAttention(dim=64, n_heads=4)
x = torch.randn(2, 10, 64)
out, w = mhsa(x)
print(f"Input: {x.shape}, Output: {out.shape}")
print(f"Weights shape: {w.shape}") # (2, 4, 10, 10)
# Each head's attention pattern
for h in range(4):
top_pos = w[0, h, 0].topk(3).indices.tolist()
print(f"Head {h}: token 0 attends most to positions {top_pos}")
# Parameter comparison
single_params = sum(p.numel() for p in nn.Linear(64, 64).parameters()) * 3
multi_params = sum(p.numel() for p in mhsa.parameters())
print(f"\nSingle-head Q+K+V params: {single_params:,}")
print(f"Multi-head total params: {multi_params:,}")
Even on random (untrained) data, the four heads produce different attention distributions because they have different randomly initialized weight matrices. After training, each head typically specializes -- one head might learn syntactic dependencies, another positional patterns, another semantic similarity. The total parameter count is almost identical to single-head attention with the same input/output dimensons because we're just splitting the d_model dimensions across heads, not adding new ones.
Exercise 3: Positional encoding breaks permutation equivariance.
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class SimpleSelfAttention(nn.Module):
def __init__(self, dim):
super().__init__()
self.W_q = nn.Linear(dim, dim)
self.W_k = nn.Linear(dim, dim)
self.W_v = nn.Linear(dim, dim)
self.scale = dim ** 0.5
def forward(self, x):
Q, K, V = self.W_q(x), self.W_k(x), self.W_v(x)
scores = torch.bmm(Q, K.transpose(1, 2)) / self.scale
weights = F.softmax(scores, dim=-1)
return torch.bmm(weights, V)
def sinusoidal_pe(seq_len, d_model):
pe = np.zeros((seq_len, d_model))
pos = np.arange(seq_len)[:, np.newaxis]
div = np.exp(np.arange(0, d_model, 2) * -(np.log(10000.0) / d_model))
pe[:, 0::2] = np.sin(pos * div)
pe[:, 1::2] = np.cos(pos * div)
return torch.tensor(pe, dtype=torch.float32).unsqueeze(0)
sa = SimpleSelfAttention(dim=32)
x = torch.randn(1, 10, 32)
perm = [0, 3, 7, 1, 5, 9, 2, 6, 4, 8]
inv_perm = [perm.index(i) for i in range(10)]
# (a) Without positional encoding
out_orig = sa(x)
out_shuf = sa(x[:, perm, :])
out_unshuf = out_shuf[:, inv_perm, :]
diff_no_pe = (out_orig - out_unshuf).norm().item()
print(f"Without PE - L2 diff after shuffle+unshuffle: {diff_no_pe:.8f}")
print(f"Permutation equivariant: {diff_no_pe < 1e-5}")
# (b) With positional encoding
pe = sinusoidal_pe(10, 32)
x_pe = x + pe
out_orig_pe = sa(x_pe)
x_shuf_pe = x[:, perm, :] + pe # position encoding for NEW positions
out_shuf_pe = sa(x_shuf_pe)
out_unshuf_pe = out_shuf_pe[:, inv_perm, :]
diff_with_pe = (out_orig_pe - out_unshuf_pe).norm().item()
print(f"\nWith PE - L2 diff after shuffle+unshuffle: {diff_with_pe:.6f}")
print(f"Permutation equivariant: {diff_with_pe < 1e-5}")
print(f"\nPE broke symmetry by factor: {diff_with_pe / max(diff_no_pe, 1e-10):.1f}x")
Without positional encoding, the L2 difference is essentially zero (within floating-point precision) -- self-attention is perfectly permutation-equivariant. With positional encoding added, shuffling changes the output significantly because each token now carries position-dependent information. The model "knows" that token at position 3 is different from the same token at position 7, even if the content is identical. This is exactly what we need for language -- "dog bites man" must produce a different representation than "man bites dog."
On to today's episode
Here we go! This is the one. If you've been following this series from episode #1 through episode #51, everything has been building toward this moment. The perceptron (episode #37), forward passes and backpropagation (#38-39), training tricks (#40-41), PyTorch (#42-44), CNNs for vision (#45-47), RNNs for sequences (#48), LSTMs for memory (#49), seq2seq for translation (#50), and attention for looking back at the right places (#51) -- all of it was prerequisite knowledge for the architecture we're about to build.
In June 2017, a team at Google published a paper with perhaps the most confidently titled paper in AI history: "Attention Is All You Need." The paper introduced the transformer -- an architecture built entirely from attention mechanisms, with no recurrence and no convolutions. Within two years, transformers had replaced RNNs as the default for language processing. Within five years, they had conquered vision, audio, protein folding, and virtually every other domain. Every large language model you interact with today -- GPT-4, Claude, Gemini, Llama -- is a transformer.
This is Part 1. We'll build the encoder side from scratch: scaled dot-product attention, multi-head attention, positional encoding, the feed-forward network, residual connections, layer normalization, and the full stacked encoder. Part 2 will cover the decoder side with masked attention and cross-attention, and we'll put the complete encoder-decoder transformer together.
Why drop the RNN?
RNNs have two fundamental problems that get worse as sequences get longer, and we've seen both of them first-hand in previous episodes.
Sequential processing. An RNN processes tokens one at a time, left to right. Position 50 can't be computed until positions 1 through 49 are done. This makes RNNs inherently sequential -- you can't parallelize across the sequence length, no matter how many GPUs you throw at it. For a 1000-token input, you need 1000 sequential steps. We measured this in episode #51 when comparing RNN vs self-attention inference times.
Long-range dependency degradation. Even with LSTMs and attention (as we built last episode), information from early positions gets diluted as it passes through many timesteps. The attention mechanism from episode #51 helped by letting the decoder look back at encoder states, but the encoder itself still processes sequentially -- its hidden state for position 100 is conditioned on all previous positions through a sequential chain. The cell state highway helps (episode #49), but it doesn't eliminate the problem entirely.
Self-attention solves both. Every position attends to every other position in a single parallel operation. Position 1 and position 1000 are directly connected -- no chain of sequential steps, no information decay over distance. And the entire operation can be computed as a single matrix multiplication, which is exacty what GPUs are designed for.
Scaled dot-product attention
The mathematical core of the transformer is scaled dot-product attention. We introduced it in episode #51 when building self-attention; now let's formalize it as the transformer uses it.
Given three matrices -- Queries (Q), Keys (K), and Values (V):
Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) V
Step by step:
- Compute
QK^T-- a matrix of similarity scores between every query and every key - Scale by
1/sqrt(d_k)where d_k is the key dimension -- prevents softmax saturation - Apply softmax row-wise -- each query gets a probability distribution over keys
- Multiply by V -- each query gets a weighted combination of values
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
def scaled_dot_product_attention(Q, K, V, mask=None):
d_k = Q.size(-1)
scores = torch.bmm(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
weights = F.softmax(scores, dim=-1)
output = torch.bmm(weights, V)
return output, weights
# Example: 1 batch, 4 tokens, 8 dimensions
Q = torch.randn(1, 4, 8)
K = torch.randn(1, 4, 8)
V = torch.randn(1, 4, 8)
out, w = scaled_dot_product_attention(Q, K, V)
print(f"Output: {out.shape}") # (1, 4, 8)
print(f"Weights: {w.shape}") # (1, 4, 4) - each token attends to all 4
print(f"Weight sums: {w.sum(-1)}") # all 1.0
The mask parameter is important for the decoder (covered in Part 2) -- it prevents positions from attending to future positions during training. For the encoder, no mask is needed: every position should be able to see every other position.
Why the sqrt(d_k) scaling? Without it, dot products between high-dimensional vectors grow large in magnitude. If d_k = 64, the variance of the dot product is roughly 64 times larger than if d_k = 1. Large values push softmax into saturation where virtually all the weight falls on one token and gradients vanish. Scaling by 1/sqrt(d_k) keeps the scores in a range where softmax produces meaningful, non-degenerate distributions. It's a small detail but absolutly critical for stable training ;-)
Notice something elegant about self-attention: Q, K, and V all have the same shape. In self-attention (where the sequence attends to itself), they all come from the same input sequence, projected through different learned weight matrices. The three projections let the model ask separate questions for each token: "What am I looking for?" (Q), "What do I offer?" (K), and "What information do I carry?" (V).
Multi-head attention
A single attention head computes one set of attention weights -- one "view" of the relationships between tokens. But different relationships matter for different aspects of language. In "The cat sat on the mat because it was tired," the word "it" needs to attend to "cat" for coreference resolution, but it also needs to attend to "tired" for predicting the next word. A single attention pattern can't easily capture both simultaneously.
Multi-head attention runs multiple attention operations in parallel, each with its own learned projections, then concatenates and projects the results:
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
assert d_model % n_heads == 0
self.d_k = d_model // n_heads
self.n_heads = n_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_out = nn.Linear(d_model, d_model)
def forward(self, Q, K, V, mask=None):
batch = Q.size(0)
# Project and reshape to (batch, n_heads, seq_len, d_k)
Q = self.W_q(Q).view(batch, -1, self.n_heads, self.d_k).transpose(1, 2)
K = self.W_k(K).view(batch, -1, self.n_heads, self.d_k).transpose(1, 2)
V = self.W_v(V).view(batch, -1, self.n_heads, self.d_k).transpose(1, 2)
# Scaled dot-product attention per head
scores = (Q @ K.transpose(-2, -1)) / math.sqrt(self.d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
weights = F.softmax(scores, dim=-1)
attn_out = weights @ V # (batch, n_heads, seq_len, d_k)
# Concatenate heads and project
concat = attn_out.transpose(1, 2).contiguous().view(
batch, -1, self.n_heads * self.d_k)
return self.W_out(concat)
mha = MultiHeadAttention(d_model=64, n_heads=8)
x = torch.randn(2, 10, 64) # batch=2, seq=10, dim=64
out = mha(x, x, x) # self-attention: Q=K=V=x
print(f"Input: {x.shape}, Output: {out.shape}") # same shape
With d_model=64 and n_heads=8, each head operates on 8-dimensional slices. The total computation is roughly the same as single-head attention on the full 64 dimensions -- but the model gets 8 different attention patterns in stead of 1. In practice, different heads learn to focus on different linguistic phenomena: one head might learn syntactic dependencies, another semantic relationships, another positional patterns.
The original paper used 8 heads with d_model=512 (so d_k=64 per head). Modern models use more: GPT-3 uses 96 heads with d_model=12288. The key insight is that you're not adding parameters -- you're splitting the same dimensional space into parallel subspaces that each learn their own attention pattern.
Positional encoding
Self-attention treats its input as a set, not a sequence. The attention between positions i and j depends only on the content of those positions, not their positions. "The cat chased the dog" and "The dog chased the cat" would produce identical attention patterns without positional information -- but they mean very different things. We demonstrated this permutation equivariance property in episode #51.
The transformer injects positional information by adding a positional encoding to each token's embedding before feeding it into the attention layers. The original paper used fixed sinusoidal functions:
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super().__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len).unsqueeze(1).float()
div_term = torch.exp(
torch.arange(0, d_model, 2).float() *
(-math.log(10000.0) / d_model)
)
pe[:, 0::2] = torch.sin(position * div_term) # even dims
pe[:, 1::2] = torch.cos(position * div_term) # odd dims
self.register_buffer('pe', pe.unsqueeze(0)) # (1, max_len, d_model)
def forward(self, x):
return x + self.pe[:, :x.size(1)]
pe = PositionalEncoding(d_model=64)
tokens = torch.randn(1, 10, 64)
encoded = pe(tokens)
print(f"Before PE: {tokens[0, 0, :4]}")
print(f"After PE: {encoded[0, 0, :4]}")
print(f"PE adds a position-dependent signal to each embedding")
Why sines and cosines? Two reasons. First, each position gets a unique encoding. Second, the encoding allows the model to learn relative positions: the offset between positions 5 and 8 looks the same as the offset between positions 100 and 103 in terms of the sinusoidal pattern. The model can learn "three positions apart" as a general concept, regardless of where in the sequence those positions are.
The different frequencies across dimensions create a kind of binary clock: low-frequency dimensions (the first few) change slowly across positions (good for encoding rough "beginning vs end" information), while high-frequency dimensions change rapidly (good for distinguishing nearby positions). Together they form a rich positional signal that the attention mechanism can use.
Modern transformers often replace sinusoidal encodings with learned positional embeddings -- just another embedding table indexed by position. Both approaches work. Some recent architectures (RoPE, ALiBi) modify the attention computation itself to inject relative positional information, which generalizes better to sequence lengths not seen during training. But the sinusoidal approach from the original paper remains elegant and effective.
The feed-forward network
After self-attention, each position passes through a position-wise feed-forward network: two linear transformations with a ReLU (or GELU) activation in between:
FFN(x) = W_2 * ReLU(W_1 * x + b_1) + b_2
class FeedForward(nn.Module):
def __init__(self, d_model, d_ff=None):
super().__init__()
d_ff = d_ff or d_model * 4
self.net = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Linear(d_ff, d_model)
)
def forward(self, x):
return self.net(x)
ff = FeedForward(d_model=64)
x = torch.randn(2, 10, 64)
out = ff(x)
print(f"Input: {x.shape}, Output: {out.shape}")
print(f"FFN params: {sum(p.numel() for p in ff.parameters()):,}")
The inner dimension d_ff is typically 4x d_model. This expansion-contraction pattern gives the model a "thinking space" -- it projects to a higher dimension, applies nonlinearity, then projects back. The feed-forward network is applied identically and independently to each position. Attention handles inter-token relationships; the feed-forward network handles per-token transformation.
This might seem like a trivial component compared to attention, but it accounts for about two-thirds of the parameters in a transformer. It's where much of the model's "knowledge" is stored -- factual associations, grammar patterns, and world knowledge learned during training. The attention layers decide what information to move between positions; the feed-forward layers decide what to do with that information once it arrives.
Residual connections and layer normalization
Two more ingredients hold the whole thing together, and both should look familiar if you remember our CNN episodes.
Residual connections (skip connections, from ResNet -- episode #46) add the input of each sub-layer to its output: output = sublayer(x) + x. This means the model only needs to learn the difference from the input, not the full transformation from scratch. Residual connections also provide a direct gradient path from output to input, preventing vanishing gradients in deep networks. Without them, stacking 12+ transformer layers would be extremely difficult to train -- the exact same principle as the gradient highway in LSTMs (episode #49), but applied to layer depth in stead of sequence length.
Layer normalization normalizes the activations across the feature dimension (not across the batch, like batch normalization in CNNs). It stabilizes training by keeping activations in a consistent range regardless of what the previous layer produced.
class TransformerEncoderBlock(nn.Module):
def __init__(self, d_model, n_heads, d_ff=None, dropout=0.1):
super().__init__()
self.attn = MultiHeadAttention(d_model, n_heads)
self.ff = FeedForward(d_model, d_ff)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.drop1 = nn.Dropout(dropout)
self.drop2 = nn.Dropout(dropout)
def forward(self, x, mask=None):
# Self-attention with residual + norm
attn_out = self.attn(x, x, x, mask)
x = self.norm1(x + self.drop1(attn_out))
# Feed-forward with residual + norm
ff_out = self.ff(x)
x = self.norm2(x + self.drop2(ff_out))
return x
block = TransformerEncoderBlock(d_model=64, n_heads=8)
x = torch.randn(2, 10, 64)
out = block(x)
print(f"Encoder block: {x.shape} -> {out.shape}")
print(f"Block params: {sum(p.numel() for p in block.parameters()):,}")
The original paper uses post-norm (normalize after adding the residual), which is what we've implemented here. Modern transformers almost universally use pre-norm (normalize before the sub-layer): x = x + sublayer(norm(x)). Pre-norm is more stable during training, especially for very deep models. The difference is small conceptually but matters a lot in practice -- pre-norm lets you train with larger learning rates and deeper stacks without gradient explosions.
Having said that, we're implementing post-norm here to match the original paper. If you wanted pre-norm, you'd simply move self.norm1 and self.norm2 to before the sub-layer in stead of after.
Stacking encoder blocks
The full transformer encoder is just N of these blocks stacked sequentially, preceded by embedding + positional encoding:
class TransformerEncoder(nn.Module):
def __init__(self, vocab_sz, d_model, n_heads, n_layers,
d_ff=None, max_len=5000, dropout=0.1):
super().__init__()
self.embed = nn.Embedding(vocab_sz, d_model)
self.pe = PositionalEncoding(d_model, max_len)
self.layers = nn.ModuleList([
TransformerEncoderBlock(d_model, n_heads, d_ff, dropout)
for _ in range(n_layers)
])
self.scale = math.sqrt(d_model)
self.drop = nn.Dropout(dropout)
def forward(self, src, mask=None):
x = self.drop(self.pe(self.embed(src) * self.scale))
for layer in self.layers:
x = layer(x, mask)
return x
encoder = TransformerEncoder(
vocab_sz=10000, d_model=64, n_heads=8, n_layers=6)
src = torch.randint(0, 10000, (2, 20))
enc_out = encoder(src)
print(f"Encoder output: {enc_out.shape}") # (2, 20, 64)
total_params = sum(p.numel() for p in encoder.parameters())
print(f"Parameters: {total_params:,}")
The embedding is multiplied by sqrt(d_model) before adding positional encoding -- this is a scaling trick from the original paper that keeps the embedding magnitudes balanced with the positional encoding magnitudes. Without it, the positional signal would dominate the content signal for large d_model values.
Six layers, eight heads, d_model=512: that's the original transformer encoder configuration (we're using smaller dimensions here for demonstration). Each layer refines the representation. Layer 1 might capture local word relationships. Layer 3 might capture phrase-level structure. Layer 6 might capture document-level dependencies. The representation at each position progressivly incorporates more context from the rest of the sequence.
Comparing parameter efficiency
Let's put this in perspective by comparing our transformer encoder against the LSTM-based models from previous episodes:
# Compare: transformer encoder vs bidirectional LSTM encoder
vocab_sz = 10000
d_model = 256
seq_len = 100
# Transformer: 6 layers, 8 heads
transformer = TransformerEncoder(vocab_sz, d_model, n_heads=8, n_layers=6)
t_params = sum(p.numel() for p in transformer.parameters())
# BiLSTM: 2 layers (roughly similar depth in terms of representational power)
bilstm = nn.Sequential(
nn.Embedding(vocab_sz, d_model),
nn.LSTM(d_model, d_model // 2, num_layers=2,
batch_first=True, bidirectional=True)
)
l_params = sum(p.numel() for p in bilstm.parameters())
print(f"Transformer (6 layers, 8 heads): {t_params:>10,} params")
print(f"BiLSTM (2 layers, bidirectional): {l_params:>10,} params")
# Speed comparison
import time
src = torch.randint(0, vocab_sz, (4, seq_len))
transformer.eval()
t0 = time.perf_counter()
for _ in range(20):
with torch.no_grad():
_ = transformer(src)
t_time = (time.perf_counter() - t0) / 20
bilstm.eval()
t0 = time.perf_counter()
for _ in range(20):
with torch.no_grad():
_ = bilstm(src)
b_time = (time.perf_counter() - t0) / 20
print(f"\nForward pass (batch=4, seq={seq_len}):")
print(f" Transformer: {t_time*1000:.1f}ms")
print(f" BiLSTM: {b_time*1000:.1f}ms")
On CPU the difference might not be dramatic, but on GPU the transformer's advantage is massive. The LSTM must process 100 timesteps sequentially (even the bidirectional version runs two sequential passes). The transformer processes all 100 positions simultaneously -- one big matrix multiplication per layer. The gap widens with sequence length: at 500 tokens, the LSTM is 5x slower. At 2000 tokens, it's practically unusable while the transformer barely notices (though its O(n^2) memory cost starts to bite).
Why transformers scale
The fundamental reason transformers displaced RNNs isn't just that they're better at any individual task -- it's that they're better AND they scale with hardware.
RNNs have O(n) sequential operations for sequence length n. You can't parallelize that. Doubling your GPU count doesn't make a 1000-token RNN twice as fast, because each step depends on the previous step's output.
Transformers have O(1) sequential depth per layer (assuming enough parallel compute). The self-attention computation is a batch of matrix multiplications -- exactly the operation GPUs are designed for. Doubling your GPU count roughly halves the wall-clock time.
The tradeoff: self-attention is O(n^2) in memory and computation per layer. Every token attends to every other token. For sequence length 1000, that's 1 million attention scores per head per layer. This quadratic cost is why context windows were historically limited (512, 1024, 2048 tokens) and why efficient attention variants (sparse attention, linear attention, Flash Attention) are an active research area.
But within the range of sequence lengths where the O(n^2) cost is manageable, transformers dominate because they fully exploit parallel hardware. And that parallelism enables training on unprecedanted amounts of data -- billions of tokens, then trillions -- which is what led to the LLM revolution. RNNs couldn't have scaled to GPT-3's 175 billion parameters trained on 300 billion tokens. The transformer made it possible.
What we have so far -- and what's next
Let's take stock. We've built every component of the transformer encoder:
- Scaled dot-product attention -- the core operation:
softmax(QK^T / sqrt(d_k)) V - Multi-head attention -- parallel attention heads for different relationship types
- Positional encoding -- sinusoidal signals that inject position information
- Feed-forward network -- position-wise expansion-contraction (4x inner dimension)
- Residual connections + layer normalization -- enabling deep stacking
- The full encoder -- N blocks stacked with embedding and positional encoding
This is half the transformer. The encoder takes an input sequence and produces a rich, context-aware representation where every position has "seen" every other position through 6 layers of self-attention and feed-forward processing.
But the original transformer is an encoder-decoder architecture (just like our seq2seq from episode #50, but with attention everywhere in stead of RNNs). The decoder side has two additional mechanisms we haven't covered yet: masked self-attention (preventing the decoder from looking at future positions during generation) and cross-attention (letting the decoder attend to the encoder's output). We'll build those in Part 2 and assemble the complete transformer.
The bottom line
- Transformers replace RNNs with self-attention: every position attends to every other position in parallel, eliminating sequential processing;
- Scaled dot-product attention:
softmax(QK^T / sqrt(d_k)) Vis the mathematical core -- scaling by sqrt(d_k) prevents softmax saturation; - Multi-head attention runs multiple attention patterns in parallel, each learning different relationships between tokens -- it's splitting dimensions, not adding parameters;
- Positional encoding (sinusoidal or learned) injects position information since self-attention is position-agnostic;
- The feed-forward network applies a position-wise expansion-contraction (4x inner dimension) after attention -- it holds most of the model's parameters and "knowledge";
- Residual connections and layer normalization make deep stacking possible -- the same gradient highway principle from ResNets and LSTMs;
- Transformers scale with hardware (parallel matrix multiplications) at the cost of O(n^2) memory per layer;
- Everything from episodes #37 through #51 -- perceptrons, neural nets, backprop, CNNs, RNNs, LSTMs, seq2seq, attention -- has been building to this architecture. The encoder is done. The decoder comes next ;-)
Exercises
Exercise 1: Build a parameter counting and layer analysis tool. Create a TransformerEncoder with vocab_sz=10000, d_model=512, n_heads=8, n_layers=6, d_ff=2048. Print the total parameter count, then break it down by component: how many parameters are in the embedding layer, how many in all attention layers combined (Q, K, V, output projections), how many in all feed-forward layers combined, and how many in all layer normalization layers? What percentage of the total does each component represent? You should find that feed-forward layers dominate.
Exercise 2: Implement a pre-norm vs post-norm comparison. Create two versions of TransformerEncoderBlock -- one with post-norm (as in this episode: x = norm(x + sublayer(x))) and one with pre-norm (x = x + sublayer(norm(x))). Feed the same random input through each. Then stack 12 layers of each type and measure the output norm and gradient norm at the first layer. Pre-norm should produce more stable gradients for deeper stacks. Print the output norm after 1, 6, and 12 layers for both variants.
Exercise 3: Build a complete encoder pipeline demo. Create a TransformerEncoder with vocab_sz=5000, d_model=128, n_heads=4, n_layers=4. Feed in a batch of 3 sequences of length 15 (random token IDs). Print the output shape. Then demonstrate that the encoder output for token position 7 is influenced by ALL other positions by masking out position 3 (set its embedding to zero before encoding) and measuring the L2 distance between the original output at position 7 and the modified output at position 7. Repeat by masking position 14. Both should produce non-zero differences, proving that self-attention connects all positions regardless of distance.