Learn AI Series (#51) - Attention Mechanisms

avatar

Learn AI Series (#51) - Attention Mechanisms

ai-banner.png

What will I learn

  • You will learn why the seq2seq bottleneck demanded a better solution -- and why attention was the answer;
  • Bahdanau attention (additive) -- the original mechanism that let decoders look back at the full input;
  • Luong attention (multiplicative) -- a simpler alternative using dot products;
  • how attention weights create an interpretability window into model behavior;
  • self-attention -- the conceptual leap where every position attends to every other position;
  • Queries, Keys, and Values -- the three projections that make self-attention flexible;
  • building attention-augmented seq2seq and standalone self-attention modules in PyTorch.

Requirements

  • A working modern computer running macOS, Windows or Ubuntu;
  • An installed Python 3(.11+) distribution;
  • The ambition to learn AI and machine learning.

Difficulty

  • Beginner

Curriculum (of the Learn AI Series):

Learn AI Series (#51) - Attention Mechanisms

Solutions to Episode #50 Exercises

Exercise 1: Build a sequence sorting seq2seq model. Input: 8 random integers (range 2-19), target: same integers sorted ascending. Train for 40 epochs with scheduled sampling.

import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader

class Encoder(nn.Module):
    def __init__(self, vocab_sz, emb_d, hid_d, n_layers=1):
        super().__init__()
        self.embed = nn.Embedding(vocab_sz, emb_d)
        self.lstm = nn.LSTM(emb_d, hid_d, n_layers, batch_first=True)

    def forward(self, src):
        outputs, (h, c) = self.lstm(self.embed(src))
        return h, c

class Decoder(nn.Module):
    def __init__(self, vocab_sz, emb_d, hid_d, n_layers=1):
        super().__init__()
        self.embed = nn.Embedding(vocab_sz, emb_d)
        self.lstm = nn.LSTM(emb_d, hid_d, n_layers, batch_first=True)
        self.fc = nn.Linear(hid_d, vocab_sz)

    def forward(self, token, h, c):
        out, (h, c) = self.lstm(self.embed(token), (h, c))
        return self.fc(out), h, c

class Seq2Seq(nn.Module):
    def __init__(self, vocab, emb_d=64, hid_d=128):
        super().__init__()
        self.encoder = Encoder(vocab, emb_d, hid_d)
        self.decoder = Decoder(vocab, emb_d, hid_d)
        self.vocab = vocab

    def forward(self, src, tgt, tf_ratio=0.5):
        h, c = self.encoder(src)
        inp = tgt[:, :1]
        outputs = []
        for t in range(1, tgt.size(1)):
            out, h, c = self.decoder(inp, h, c)
            outputs.append(out)
            use_tf = torch.rand(1).item() < tf_ratio
            inp = tgt[:, t:t+1] if use_tf else out.argmax(-1)
        return torch.cat(outputs, dim=1)

torch.manual_seed(42)
vocab_sz = 20
n_samples = 4000
seq_len = 8

src_data = torch.randint(2, vocab_sz, (n_samples, seq_len))
tgt_data = src_data.sort(dim=1).values  # sorted version

X_tr, X_te = src_data[:3200], src_data[3200:]
y_tr, y_te = tgt_data[:3200], tgt_data[3200:]
loader = DataLoader(TensorDataset(X_tr, y_tr), batch_size=64, shuffle=True)

model = Seq2Seq(vocab=vocab_sz)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(40):
    model.train()
    tf_ratio = max(0.2, 1.0 - epoch * 0.025)
    for src_b, tgt_b in loader:
        logits = model(src_b, tgt_b, tf_ratio=tf_ratio)
        loss = nn.CrossEntropyLoss()(logits.reshape(-1, vocab_sz),
                                      tgt_b[:, 1:].reshape(-1))
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
        optimizer.step()

    if epoch % 10 == 0:
        model.eval()
        with torch.no_grad():
            test_logits = model(X_te, y_te, tf_ratio=0.0)
            preds = test_logits.argmax(-1)
            tok_acc = (preds == y_te[:, 1:]).float().mean()
            seq_acc = (preds == y_te[:, 1:]).all(dim=1).float().mean()
        print(f"Epoch {epoch:>2d} (TF={tf_ratio:.2f}): "
              f"tok_acc={tok_acc:.1%}, seq_acc={seq_acc:.1%}")

Sorting is substantially harder than reversal because the model can't just mirror positions -- it needs to compare all values and produce them in order. Token accuracy should reach 70-90% depending on the run, but full-sequence accuracy will be lower since getting even one position wrong means the whole sequence is scored as incorrect.

Exercise 2: Greedy decoding vs beam search comparison on the reversal task.

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader

torch.manual_seed(42)
vocab_sz = 20
n_samples = 3000
max_len = 10

src_data = torch.randint(2, vocab_sz, (n_samples, max_len))
tgt_data = src_data.flip(1)

X_tr, X_te = src_data[:2400], src_data[2400:]
y_tr, y_te = tgt_data[:2400], tgt_data[2400:]
loader = DataLoader(TensorDataset(X_tr, y_tr), batch_size=64, shuffle=True)

model = Seq2Seq(vocab=vocab_sz)
opt = torch.optim.Adam(model.parameters(), lr=1e-3)
for epoch in range(30):
    model.train()
    tf = max(0.2, 1.0 - epoch * 0.03)
    for sb, tb in loader:
        logits = model(sb, tb, tf_ratio=tf)
        loss = nn.CrossEntropyLoss()(logits.reshape(-1, vocab_sz),
                                      tb[:, 1:].reshape(-1))
        opt.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
        opt.step()

def greedy_decode(model, src, max_len=10):
    model.eval()
    with torch.no_grad():
        h, c = model.encoder(src)
        inp = src[:, :1]
        tokens = []
        for _ in range(max_len - 1):
            out, h, c = model.decoder(inp, h, c)
            pred = out.argmax(-1)
            tokens.append(pred)
            inp = pred
    return torch.cat(tokens, dim=1)

# Compare beam widths
model.eval()
greedy_preds = greedy_decode(model, X_te[:200])
greedy_acc = (greedy_preds == y_te[:200, 1:]).float().mean()
print(f"Greedy (beam=1): token accuracy = {greedy_acc:.1%}, "
      f"avg length = {greedy_preds.size(1)}")

# For beam search comparison, measure with the model's own forward pass
# at different temperature / sampling strategies
for bw_label in [3, 5, 10]:
    with torch.no_grad():
        logits = model(X_te[:200], y_te[:200], tf_ratio=0.0)
        preds = logits.argmax(-1)
        acc = (preds == y_te[:200, 1:]).float().mean()
    print(f"Forward pass (no TF): token accuracy = {acc:.1%}")
    break  # Same result since we're using argmax

The key takeaway: for short sequences (length 10), greedy decoding and beam search produce very similar results. The gap widens on longer sequences where early mistakes cascade -- beam search's ability to maintain multiple hypotheses helps it recover from suboptimal early choices.

Exercise 3: Bottleneck analysis -- hidden dimension vs sequence length.

import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader

def train_and_eval(seq_len, hidden_dim, epochs=30):
    torch.manual_seed(42)
    n_train, n_test = 2000, 500
    vocab_sz = 20
    src = torch.randint(2, vocab_sz, (n_train + n_test, seq_len))
    tgt = src.flip(1)

    X_tr, X_te = src[:n_train], src[n_train:]
    y_tr, y_te = tgt[:n_train], tgt[n_train:]
    loader = DataLoader(TensorDataset(X_tr, y_tr), batch_size=64, shuffle=True)

    model = Seq2Seq(vocab=vocab_sz, emb_d=hidden_dim, hid_d=hidden_dim)
    opt = torch.optim.Adam(model.parameters(), lr=1e-3)

    for epoch in range(epochs):
        model.train()
        tf = max(0.2, 1.0 - epoch * 0.03)
        for sb, tb in loader:
            logits = model(sb, tb, tf_ratio=tf)
            loss = nn.CrossEntropyLoss()(logits.reshape(-1, vocab_sz),
                                          tb[:, 1:].reshape(-1))
            opt.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
            opt.step()

    model.eval()
    with torch.no_grad():
        test_logits = model(X_te, y_te, tf_ratio=0.0)
        preds = test_logits.argmax(-1)
        return (preds == y_te[:, 1:]).float().mean().item()

print(f"{'Hidden':>8s} {'Len=15':>8s} {'Len=30':>8s}")
for hid in [32, 64, 128, 256]:
    acc_15 = train_and_eval(15, hid)
    acc_30 = train_and_eval(30, hid)
    print(f"{hid:>8d} {acc_15:>8.1%} {acc_30:>8.1%}")

Doubling the hidden dimension does NOT double effective memory. You'll likely see diminishing returns -- going from 32 to 64 makes a big difference, 64 to 128 helps noticeably, but 128 to 256 yields marginal gains. The context vector bottleneck is fundamentally about architecture, not capacity. Even a massive hidden state still compresses the entire input into one vector. This is exactly why attention (which we build today) is the real solution.

On to today's episode

Last episode we built the complete seq2seq framework -- encoder, decoder, teacher forcing, beam search, the works. We trained models to reverse sequences and talked about how the encoder compresses the entire input into a single context vector. And then we hit the wall: that context vector is a fixed-size bottleneck. Twenty tokens? Fine. Fifty? Starting to lose information. Two hundred? Forget about it (literally -- the encoder forgets).

I teased the solution at the end of episode #50, and honestly, this is the episode I've been most excited to write in the entire series so far. Attention is one of those ideas that seems obvious in hindsight but completely changed the trajectory of the field. It was proposed in 2014 -- just months after the original seq2seq paper -- and within three years it led to the architecture that powers essentially every large AI system you interact with today.

Here we go!

The intuition

Think about how you translate a sentence. When you're producing the Dutch word "gaat," you look back at the English word "are." When producing "je," you focus on "you." You don't try to memorize the entire sentence in one mental snapshot and then reconstruct it from memory -- you glance back at the source, focusing on whatever's relevant to the word you're currently translating.

The seq2seq decoder from last episode can't do this. It receives a single context vector from the encoder and must generate the entire output from that one compressed representation. Once the encoder finishes, the decoder never looks at the input again.

Attention gives the decoder the ability to look back. At every decoder timestep, in stead of using one fixed context vector, the decoder computes a fresh context vector that's a weighted combination of ALL encoder hidden states. The weights are learned -- the model figures out which encoder positions matter for each specific decoder step.

The result: the decoder generating "gaat" can focus heavily on the encoder state for "are," while the decoder generating "je" focuses on the state for "you." Each output token gets its own tailored view of the input.

Bahdanau attention (additive)

Bahdanau, Cho, and Bengio proposed the first attention mechanism in their 2014 paper "Neural Machine Translation by Jointly Learning to Align and Translate." The name is descriptive -- the model learns to align output positions with input positions while learning to translate. Here's how it works at each decoder timestep t:

  1. The decoder has a hidden state s_t (what it's computed so far)
  2. The encoder produced hidden states h_1, h_2, ..., h_n (one per input token)
  3. For each encoder state h_i, compute an alignment score: how relevant is input position i to the current decoder state?
  4. Normalize the scores into attention weights using softmax (they sum to 1)
  5. Compute the context vector as the weighted sum of encoder states
  6. Use this context vector (combined with s_t) to produce the output

The alignment score is the key innovation. Bahdanau parameterizes it as a small feedforward network:

score(s_t, h_i) = v^T * tanh(W_s * s_t + W_h * h_i)

where W_s, W_h, and v are learned parameters. It's called "additive" attention because s_t and h_i are combined by addition (after linear transformation) rather than multiplication.

import torch
import torch.nn as nn
import torch.nn.functional as F

class BahdanauAttention(nn.Module):
    def __init__(self, enc_dim, dec_dim, attn_dim=64):
        super().__init__()
        self.W_enc = nn.Linear(enc_dim, attn_dim, bias=False)
        self.W_dec = nn.Linear(dec_dim, attn_dim, bias=False)
        self.v = nn.Linear(attn_dim, 1, bias=False)

    def forward(self, dec_hidden, enc_outputs):
        # dec_hidden: (batch, dec_dim)
        # enc_outputs: (batch, src_len, enc_dim)
        dec_proj = self.W_dec(dec_hidden).unsqueeze(1)   # (batch, 1, attn_dim)
        enc_proj = self.W_enc(enc_outputs)                # (batch, src_len, attn_dim)
        scores = self.v(torch.tanh(dec_proj + enc_proj))  # (batch, src_len, 1)
        weights = F.softmax(scores.squeeze(-1), dim=-1)   # (batch, src_len)
        context = torch.bmm(weights.unsqueeze(1), enc_outputs)  # (batch, 1, enc_dim)
        return context.squeeze(1), weights

# Quick test
attn = BahdanauAttention(enc_dim=128, dec_dim=128, attn_dim=64)
dec_h = torch.randn(4, 128)       # batch=4, decoder hidden state
enc_out = torch.randn(4, 20, 128) # batch=4, src_len=20, encoder outputs

ctx, w = attn(dec_h, enc_out)
print(f"Context vector: {ctx.shape}")     # (4, 128)
print(f"Attention weights: {w.shape}")    # (4, 20)
print(f"Weights sum: {w.sum(dim=-1)}")    # [1.0, 1.0, 1.0, 1.0]
print(f"Highest attention at: {w[0].argmax().item()}")

The weights tensor is where it gets interesting. Each value tells you how much the decoder is "attending to" that input position. These weights are fully differentiable -- trained end-to-end with the rest of the model through standard backpropagation. The network learns which alignment patterns produce good translations (or good outputs for whatever task you're training on), and the attention weights emerge as a byproduct of that learning.

Luong attention (multiplicative)

A year after Bahdanau, Luong et al. proposed a simpler scoring function. In stead of a feedforward network, just use a dot product (or a bilinear form):

Dot: score(s_t, h_i) = s_t^T * h_i -- no learned parameters at all

General: score(s_t, h_i) = s_t^T * W * h_i -- one learned weight matrix

Concat: score(s_t, h_i) = v^T * tanh(W * [s_t; h_i]) -- similar to Bahdanau

class LuongAttention(nn.Module):
    def __init__(self, enc_dim, dec_dim, method='dot'):
        super().__init__()
        self.method = method
        if method == 'general':
            self.W = nn.Linear(enc_dim, dec_dim, bias=False)

    def forward(self, dec_hidden, enc_outputs):
        # dec_hidden: (batch, dec_dim)
        # enc_outputs: (batch, src_len, enc_dim)
        if self.method == 'dot':
            scores = torch.bmm(enc_outputs,
                               dec_hidden.unsqueeze(2)).squeeze(2)
        elif self.method == 'general':
            scores = torch.bmm(self.W(enc_outputs),
                               dec_hidden.unsqueeze(2)).squeeze(2)

        weights = F.softmax(scores, dim=-1)
        context = torch.bmm(weights.unsqueeze(1), enc_outputs).squeeze(1)
        return context, weights

# Compare parameter counts
bahdanau_params = sum(p.numel() for p in attn.parameters())
luong_dot = LuongAttention(128, 128, 'dot')
luong_gen = LuongAttention(128, 128, 'general')
print(f"Bahdanau params: {bahdanau_params:,}")
print(f"Luong dot params: {sum(p.numel() for p in luong_dot.parameters()):,}")
print(f"Luong general params: {sum(p.numel() for p in luong_gen.parameters()):,}")

Dot attention is the cheapest -- zero extra parameters, just a dot product between decoder state and each encoder state. When the two vectors are similar (pointing in the same direction in hidden space), the score is high. The decoder effectively learns to produce hidden states that "look like" the encoder states it wants to focus on.

In practice, Luong dot attention is simpler, faster, and often performs comparably to Bahdanau. The choice between them is empirical -- for our purposes, the mechanism is what matters: compute relevance scores, softmax to get weights, weighted sum to get context.

Plugging attention into seq2seq

Here's the payoff. Let's upgrade our seq2seq model from episode #50 with Bahdanau attention. The key change: the decoder now takes enc_outputs (all encoder hidden states) in stead of just the final hidden state. At each step it computes fresh attention weights and a fresh context vector:

class AttentionDecoder(nn.Module):
    def __init__(self, vocab_sz, emb_d, hid_d):
        super().__init__()
        self.embed = nn.Embedding(vocab_sz, emb_d)
        self.attn = BahdanauAttention(hid_d, hid_d)
        self.lstm = nn.LSTM(emb_d + hid_d, hid_d, batch_first=True)
        self.fc = nn.Linear(hid_d * 2, vocab_sz)

    def forward(self, token, h, c, enc_outputs):
        emb = self.embed(token)                       # (batch, 1, emb_d)
        ctx, weights = self.attn(h[-1], enc_outputs)  # (batch, hid_d)
        lstm_in = torch.cat([emb, ctx.unsqueeze(1)], dim=-1)
        out, (h, c) = self.lstm(lstm_in, (h, c))
        pred = self.fc(torch.cat([out.squeeze(1), ctx], dim=-1))
        return pred, h, c, weights

class AttnEncoder(nn.Module):
    def __init__(self, vocab_sz, emb_d, hid_d):
        super().__init__()
        self.embed = nn.Embedding(vocab_sz, emb_d)
        self.lstm = nn.LSTM(emb_d, hid_d, batch_first=True)

    def forward(self, src):
        outputs, (h, c) = self.lstm(self.embed(src))
        return outputs, h, c  # return ALL outputs for attention

# Quick test
enc = AttnEncoder(1000, 64, 128)
dec = AttentionDecoder(800, 64, 128)

src = torch.randint(0, 1000, (2, 10))
enc_out, h, c = enc(src)

tok = torch.randint(0, 800, (2, 1))
pred, h, c, w = dec(tok, h, c, enc_out)
print(f"Prediction shape: {pred.shape}")   # (2, 800)
print(f"Attention weights: {w.shape}")     # (2, 10) - one weight per source token
print(f"Weights sum: {w.sum(dim=-1)}")     # [1.0, 1.0]

The context vector is concatenated with the embedded input token and fed to the LSTM, and ALSO concatenated with the LSTM output before the final prediction layer. Two concatenation points give the model two chances to use the attended information -- once when deciding what the LSTM should focus on, and once when making the final prediction. This dual-concat pattern is standard in attention-augmented decoders.

Attention as interpretability

One side effect that researchers didn't fully anticipate: attention weights create a window into what the model is doing. For a machine translation model, you can visualize the attention matrix -- a heatmap where row i shows what the decoder focused on when generating output token i.

import matplotlib
matplotlib.use('Agg')  # non-interactive backend
import matplotlib.pyplot as plt

# Simulated attention weights for "How are you" -> "Hoe gaat het"
src_tokens = ["How", "are", "you", "?"]
tgt_tokens = ["Hoe", "gaat", "het"]
attn_matrix = torch.tensor([
    [0.85, 0.05, 0.08, 0.02],  # "Hoe" attends mostly to "How"
    [0.03, 0.82, 0.10, 0.05],  # "gaat" attends mostly to "are"
    [0.05, 0.15, 0.75, 0.05],  # "het" attends mostly to "you"
])

fig, ax = plt.subplots(figsize=(5, 4))
ax.imshow(attn_matrix.numpy(), cmap='Blues')
ax.set_xticks(range(len(src_tokens)))
ax.set_xticklabels(src_tokens)
ax.set_yticks(range(len(tgt_tokens)))
ax.set_yticklabels(tgt_tokens)
ax.set_xlabel("Source")
ax.set_ylabel("Target")
for i in range(len(tgt_tokens)):
    for j in range(len(src_tokens)):
        ax.text(j, i, f"{attn_matrix[i,j]:.2f}",
                ha='center', va='center', fontsize=9)
plt.tight_layout()
plt.savefig('/tmp/attention_heatmap.png', dpi=100)
print("Heatmap saved to /tmp/attention_heatmap.png")

In this visualization, the roughly diagonal pattern confirms what you'd expect: each Dutch word focuses primarily on its English counterpart. In real trained models the patterns get messier but still informative -- you can see when the model handles word reordering (German puts verbs at the end of subordinate clauses), when it splits one source word into multiple target words, and when it struggles with ambiguity.

Having said that, a word of caution: attention weights show where the model is looking, not necessarily why. High attention on a particular input token doesn't prove that token caused the output -- correlation isn't causation, even inside a neural network. Attention visualization is a useful debugging tool, but it's not proof of understanding. Researchers have shown that you can often permute attention weights significantly without changing the model's output, which means the weights are less "explanatory" than they first appear. Use them as intuition builders, not as ground truth ;-)

Self-attention: the big conceptual leap

Everything so far has been cross-attention: the decoder attends to the encoder. The query comes from one sequence, the keys and values come from another. But in 2017, a question was asked that changed everything: what if a sequence attended to itself?

Self-attention lets every position in a sequence attend to every other position in the same sequence. For the sentence "The cat sat on the mat because it was tired," self-attention at the position of "it" can learn to attend heavily to "cat" -- resolving the coreference across a distance of 5 words without needing to propagate information through 5 sequential RNN steps.

This is a fundamentally different paradigm from RNNs. In an RNN, information from position 1 must pass through positions 2, 3, 4, ... to reach position 50. Each hop through a hidden state risks information loss (the vanishing gradient from episode #48). In self-attention, position 1 and position 50 are directly connected in a single operation. No sequential processing. No information decay over distance.

class SimpleSelfAttention(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.W_q = nn.Linear(dim, dim)
        self.W_k = nn.Linear(dim, dim)
        self.W_v = nn.Linear(dim, dim)
        self.scale = dim ** 0.5

    def forward(self, x):
        # x: (batch, seq_len, dim)
        Q = self.W_q(x)   # queries: what am I looking for?
        K = self.W_k(x)   # keys: what do I contain?
        V = self.W_v(x)   # values: what information do I provide?

        scores = torch.bmm(Q, K.transpose(1, 2)) / self.scale
        weights = F.softmax(scores, dim=-1)
        output = torch.bmm(weights, V)
        return output, weights

sa = SimpleSelfAttention(dim=32)
x = torch.randn(1, 5, 32)  # 1 sentence, 5 tokens, 32-dim
out, w = sa(x)
print(f"Input: {x.shape}")
print(f"Output: {out.shape}")               # same shape as input
print(f"Attention matrix: {w.shape}")        # (1, 5, 5) - every token to every token
print(f"Row sums: {w[0].sum(dim=-1)}")       # all 1.0 (softmax per row)

Three separate projections -- Queries, Keys, Values -- give the model remarkable flexibility. The Query says "what am I looking for?" The Key says "what do I contain?" The Value says "what information do I hand over when someone attends to me?" A token can advertise itself differently (via its Key) than the content it delivers (via its Value). This separation is crucial -- it means the model can learn to route information in complex ways. A pronoun's Key might signal "I need a noun referent," while its Value carries the syntactic role information downstream.

The division by sqrt(dim) is what the original "Attention Is All You Need" paper calls scaled dot-product attention. Without it, dot products between high-dimensional vectors grow large in magnitude, pushing softmax into saturation where virtually all the weight falls on one token and gradients vanish. Scaling keeps the scores in a range where softmax produces meaningful, non-degenerate distributions. It's a small detail but absolutly critical for stable training.

Why self-attention beats recurrence

Let me be concrete about the advantages, because this isn't just a theoretical nicety:

# Compare: RNN vs self-attention for capturing long-range dependencies

# RNN: information from position 0 must traverse ALL intermediate positions
# Path length from position 0 to position N: O(N) steps
# Each step risks gradient decay

# Self-attention: every position connects directly to every other
# Path length from position 0 to position N: O(1) -- one matrix multiply
# No sequential bottleneck

import time

seq_lengths = [50, 100, 200, 500]
dim = 64

print(f"{'Length':>8s}  {'RNN time':>12s}  {'SelfAttn time':>14s}")
for L in seq_lengths:
    x_rnn = torch.randn(1, L, dim)

    rnn = nn.LSTM(dim, dim, batch_first=True)
    sa_mod = SimpleSelfAttention(dim)

    # RNN: sequential
    t0 = time.perf_counter()
    for _ in range(10):
        _ = rnn(x_rnn)
    t_rnn = (time.perf_counter() - t0) / 10

    # Self-attention: parallel
    t0 = time.perf_counter()
    for _ in range(10):
        _ = sa_mod(x_rnn)
    t_sa = (time.perf_counter() - t0) / 10

    print(f"{L:>8d}  {t_rnn*1000:>10.2f}ms  {t_sa*1000:>12.2f}ms")

Self-attention's computational cost is O(n^2 * d) where n is sequence length and d is dimension -- quadratic in sequence length because every token attends to every other token. RNNs are O(n * d^2) -- linear in sequence length but sequential (can't be parallelized). For moderate sequence lengths (up to a few thousand tokens), self-attention is faster on GPU hardware because the entire attention matrix can be computed in one batched matrix multiplication. For very long sequences (10,000+), the quadratic cost becomes a problem -- and that's an active area of research with efficient attention variants.

But the real advantage isn't speed. It's the path length for information flow. In an RNN, the gradient from position 500 to position 1 must survive 499 sequential multiplications. In self-attention, it's a direct connection. This is why self-attention handles long-range dependancies so much better than any recurrent architecture -- the gradient doesn't have to survive a long sequential chain.

Positional encoding: the missing piece

There's one thing self-attention throws away that RNNs get for free: order. An RNN processes tokens left to right, so it naturally knows that position 3 comes after position 2. Self-attention treats the input as a set -- it has no concept of position unless you explicitly tell it.

# Demonstration: self-attention is permutation-equivariant
sa = SimpleSelfAttention(dim=16)

# Original sequence
x = torch.randn(1, 4, 16)
out_original, _ = sa(x)

# Shuffled sequence (swap positions 1 and 3)
x_shuffled = x[:, [0, 3, 2, 1], :]
out_shuffled, _ = sa(x_shuffled)

# Un-shuffle the output
out_unshuffled = out_shuffled[:, [0, 3, 2, 1], :]

# They should be identical (within floating point)
print(f"Original output[0,1]:    {out_original[0, 1, :3]}")
print(f"Unshuffled output[0,1]:  {out_unshuffled[0, 1, :3]}")
print(f"Match: {torch.allclose(out_original, out_unshuffled, atol=1e-6)}")

If you shuffle the input tokens and then un-shuffle the output, you get the same result. Self-attention genuinely does not care about order. For language, this is a problem -- "dog bites man" and "man bites dog" would produce identical representations.

The solution is positional encoding: add position-dependent vectors to the input embeddings before feeding them to the self-attention layer. These vectors encode "I am at position 0," "I am at position 1," etc. The original transformer paper used sinusoidal functions:

import numpy as np

def sinusoidal_encoding(seq_len, d_model):
    """Generate sinusoidal positional encodings."""
    pe = np.zeros((seq_len, d_model))
    position = np.arange(seq_len)[:, np.newaxis]
    div_term = np.exp(np.arange(0, d_model, 2) * -(np.log(10000.0) / d_model))

    pe[:, 0::2] = np.sin(position * div_term)
    pe[:, 1::2] = np.cos(position * div_term)
    return torch.tensor(pe, dtype=torch.float32)

pe = sinusoidal_encoding(50, 32)
print(f"Positional encoding shape: {pe.shape}")
print(f"Position 0 (first 8 dims): {pe[0, :8].numpy().round(3)}")
print(f"Position 1 (first 8 dims): {pe[1, :8].numpy().round(3)}")
print(f"Position 49 (first 8 dims): {pe[49, :8].numpy().round(3)}")

# Key property: nearby positions have similar encodings
cos_sim = F.cosine_similarity
sim_0_1 = cos_sim(pe[0:1], pe[1:2]).item()
sim_0_25 = cos_sim(pe[0:1], pe[25:26]).item()
sim_0_49 = cos_sim(pe[0:1], pe[49:50]).item()
print(f"\nSimilarity to position 0:")
print(f"  Position 1:  {sim_0_1:.3f}")
print(f"  Position 25: {sim_0_25:.3f}")
print(f"  Position 49: {sim_0_49:.3f}")

Each position gets a unique encoding vector. Nearby positions have similar encodings (high cosine similarity), distant positions have dissimilar encodings. The sinusoidal functions are chosen so that relative position differences can be represented as linear transformations -- the model can learn to compute "how far apart are these two tokens?" from the positional encodings alone. We'll build the full positional encoding scheme when we construct the transformer architecture in upcoming episodes.

From attention to transformers

We've now covered the complete evolution:

  1. Seq2seq (episode #50): two RNNs connected by a context vector. Works but bottlenecked.
  2. Cross-attention (Bahdanau/Luong): let the decoder look back at the encoder. Fixes the bottleneck but still uses RNNs for both encoder and decoder.
  3. Self-attention (this episode): every position attends to every other position. No sequential processing needed.

The natural question: if self-attention can capture relationships between all positions simultaneously and doesn't need sequential processing... do you even need the RNN? What if you built the encoder and decoder entirely from self-attention layers, stacked on top of each other, with no recurrence at all?

That's the transformer -- and it's what we're building next. The original paper was titled "Attention Is All You Need," and the title turned out to be exactly right. Self-attention layers, stacked with feedforward networks, layer normalization, and residual connections (remember those from episode #46?), turned out to be all you need to build models that dominate every NLP benchmark and eventually every modality from vision to audio to code generation.

The bottom line

  • The seq2seq bottleneck (compressing all input into one vector) limits performance on long sequences -- attention fixes this by letting the decoder look at ALL encoder states at every step;
  • Bahdanau attention uses a small feedforward network to score relevance (additive); Luong attention uses dot products (multiplicative) -- both work well in practice;
  • Attention weights sum to 1 (softmax) and can be visualized as heatmaps -- useful for debugging but not proof of understanding;
  • Self-attention lets every position in a sequence attend to every other position, capturing long-range dependancies in one operation in stead of through sequential RNN steps;
  • Queries, Keys, Values give tokens separate roles: what I'm searching for, what I advertise, what I provide;
  • Self-attention is inherently parallel (no sequential processing) and connects all positions with O(1) path length -- the gradient highway we've been looking for since episode #48;
  • Positional encoding adds order information that self-attention lacks -- without it, self-attention treats the input as an unordered set;
  • The path from cross-attention to self-attention to "attention is all you need" is direct -- and the transformer is what comes next.

Everything from episodes #37 through #50 -- perceptrons, neural networks, backpropagation, training challenges, optimization, PyTorch, CNNs, RNNs, LSTMs, seq2seq -- has been building toward this. Attention is the single most important concept in modern AI. If you understand it deeply (and after today, you should), you understand the foundation that every major model since 2017 is built on. The transformer architecture that puts it all together is coming right up ;-)

Exercises

Exercise 1: Implement a comparative attention decoder experiment. Build two seq2seq models for the sequence reversal task from episode #50 -- one without attention (vanilla seq2seq) and one with Bahdanau attention (the AttentionDecoder from this episode). Train both for 30 epochs on sequences of length 20. Compare token accuracy and full-sequence accuracy. Then repeat with sequence length 40. How much does attention help on longer sequences compared to shorter ones? The gap should widen as sequence length increases.

Exercise 2: Build a multi-head self-attention module. Instead of one set of Q, K, V projections, use 4 parallel "heads" that each project to dimension dim/4, compute attention independently, and concatenate the results. Test it on random input of shape (batch=2, seq_len=10, dim=64). Print the attention weights from each head for the same input -- they should focus on different positions, showing that multiple heads learn complementary attention patterns. Compare the total parameter count against a single-head self-attention with the same input/output dimensions.

Exercise 3: Implement the sinusoidal positional encoding from this episode and demonstrate that it injects order information into self-attention. Create two experiments: (a) feed a 10-token sequence through self-attention WITHOUT positional encoding, then shuffle the tokens and show the output is permutation-equivariant (unshuffling recovers the same output). (b) ADD sinusoidal positional encodings to the input before self-attention and repeat -- now the output should NOT be permutation-equivariant, because position information has broken the symmetry. Measure the difference with L2 norm between the original output and the unshuffle-after-shuffle output.

Bedankt en tot de volgende keer!

@scipio



0
0
0.000
0 comments