Learn AI Series (#55) - Generative Adversarial Networks
Learn AI Series (#55) - Generative Adversarial Networks

What will I learn
- You will learn the adversarial setup -- how a generator and discriminator compete to produce realistic outputs;
- the minimax game -- the mathematical formulation of GAN training;
- training dynamics -- why GANs are notoriously hard to train;
- mode collapse -- when the generator gets lazy and repeats itself;
- DCGAN -- deep convolutional GANs that actually work on images;
- progressive growing and StyleGAN -- the path to photorealistic generation;
- when to use GANs vs diffusion models in practice;
- ethical considerations of synthetic media.
Requirements
- A working modern computer running macOS, Windows or Ubuntu;
- An installed Python 3(.11+) distribution;
- The ambition to learn AI and machine learning.
Difficulty
- Beginner
Curriculum (of the Learn AI Series):
- Learn AI Series (#1) - What Machine Learning Actually Is
- Learn AI Series (#2) - Setting Up Your AI Workbench - Python and NumPy
- Learn AI Series (#3) - Your Data Is Just Numbers - How Machines See the World
- Learn AI Series (#4) - Your First Prediction - No Math, Just Intuition
- Learn AI Series (#5) - Patterns in Data - What "Learning" Actually Looks Like
- Learn AI Series (#6) - From Intuition to Math - Why We Need Formulas
- Learn AI Series (#7) - The Training Loop - See It Work Step by Step
- Learn AI Series (#8) - The Math You Actually Need (Part 1) - Linear Algebra
- Learn AI Series (#9) - The Math You Actually Need (Part 2) - Calculus and Probability
- Learn AI Series (#10) - Your First ML Model - Linear Regression From Scratch
- Learn AI Series (#11) - Making Linear Regression Real
- Learn AI Series (#12) - Classification - Logistic Regression From Scratch
- Learn AI Series (#13) - Evaluation - How to Know If Your Model Actually Works
- Learn AI Series (#14) - Data Preparation - The 80% Nobody Talks About
- Learn AI Series (#15) - Feature Engineering and Selection
- Learn AI Series (#16) - Scikit-Learn - The Standard Library of ML
- Learn AI Series (#17) - Decision Trees - How Machines Make Decisions
- Learn AI Series (#18) - Random Forests - Wisdom of Crowds
- Learn AI Series (#19) - Gradient Boosting - The Kaggle Champion
- Learn AI Series (#20) - Support Vector Machines - Drawing the Perfect Boundary
- Learn AI Series (#21) - Mini Project - Predicting Crypto Market Regimes
- Learn AI Series (#22) - K-Means Clustering - Finding Groups
- Learn AI Series (#23) - Advanced Clustering - Beyond K-Means
- Learn AI Series (#24) - Dimensionality Reduction - PCA
- Learn AI Series (#25) - Advanced Dimensionality Reduction - t-SNE and UMAP
- Learn AI Series (#26) - Anomaly Detection - Finding What Doesn't Belong
- Learn AI Series (#27) - Recommendation Systems - "Users Like You Also Liked..."
- Learn AI Series (#28) - Time Series Fundamentals - When Order Matters
- Learn AI Series (#29) - Time Series Forecasting - Predicting What Comes Next
- Learn AI Series (#30) - Natural Language Processing - Text as Data
- Learn AI Series (#31) - Word Embeddings - Meaning in Numbers
- Learn AI Series (#32) - Bayesian Methods - Thinking in Probabilities
- Learn AI Series (#33) - Ensemble Methods Deep Dive - Stacking and Blending
- Learn AI Series (#34) - ML Engineering - From Notebook to Production
- Learn AI Series (#35) - Data Ethics and Bias in ML
- Learn AI Series (#36) - Mini Project - Complete ML Pipeline
- Learn AI Series (#37) - The Perceptron - Where It All Started
- Learn AI Series (#38) - Neural Networks From Scratch - Forward Pass
- Learn AI Series (#39) - Neural Networks From Scratch - Backpropagation
- Learn AI Series (#40) - Training Neural Networks - Practical Challenges
- Learn AI Series (#41) - Optimization Algorithms - SGD, Momentum, Adam
- Learn AI Series (#42) - PyTorch Fundamentals - Tensors and Autograd
- Learn AI Series (#43) - PyTorch Data and Training
- Learn AI Series (#44) - PyTorch nn.Module - Building Real Networks
- Learn AI Series (#45) - Convolutional Neural Networks - Theory
- Learn AI Series (#46) - CNNs in Practice - Classic to Modern Architectures
- Learn AI Series (#47) - CNN Applications - Detection, Segmentation, Style Transfer
- Learn AI Series (#48) - Recurrent Neural Networks - Sequences
- Learn AI Series (#49) - LSTM and GRU - Solving the Memory Problem
- Learn AI Series (#50) - Sequence-to-Sequence Models
- Learn AI Series (#51) - Attention Mechanisms
- Learn AI Series (#52) - The Transformer Architecture (Part 1)
- Learn AI Series (#53) - The Transformer Architecture (Part 2)
- Learn AI Series (#54) - Vision Transformers
- Learn AI Series (#55) - Generative Adversarial Networks (this post)
Learn AI Series (#55) - Generative Adversarial Networks
Solutions to Episode #54 Exercises
Exercise 1: Complete ViT for CIFAR-10 vs CNN baseline.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import time
class PatchEmbedding(nn.Module):
def __init__(self, img_size=32, patch_size=4, in_channels=3, d_model=128):
super().__init__()
self.n_patches = (img_size // patch_size) ** 2
self.proj = nn.Conv2d(in_channels, d_model,
kernel_size=patch_size, stride=patch_size)
def forward(self, x):
x = self.proj(x)
return x.flatten(2).transpose(1, 2)
class TransformerBlock(nn.Module):
def __init__(self, d_model, n_heads, d_ff):
super().__init__()
self.norm1 = nn.LayerNorm(d_model)
self.attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
self.norm2 = nn.LayerNorm(d_model)
self.ff = nn.Sequential(
nn.Linear(d_model, d_ff), nn.GELU(), nn.Linear(d_ff, d_model))
def forward(self, x):
x = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0]
x = x + self.ff(self.norm2(x))
return x
class SmallViT(nn.Module):
def __init__(self, img_size=32, patch_size=4, d_model=128,
n_heads=4, n_layers=4, n_classes=10, d_ff=512):
super().__init__()
self.patch_embed = PatchEmbedding(img_size, patch_size, 3, d_model)
n_patches = self.patch_embed.n_patches
self.cls_token = nn.Parameter(torch.randn(1, 1, d_model) * 0.02)
self.pos_embed = nn.Parameter(
torch.randn(1, n_patches + 1, d_model) * 0.02)
self.layers = nn.ModuleList([
TransformerBlock(d_model, n_heads, d_ff) for _ in range(n_layers)])
self.norm = nn.LayerNorm(d_model)
self.head = nn.Linear(d_model, n_classes)
def forward(self, x):
batch = x.size(0)
patches = self.patch_embed(x)
cls = self.cls_token.expand(batch, -1, -1)
x = torch.cat([cls, patches], dim=1) + self.pos_embed
for layer in self.layers:
x = layer(x)
return self.head(self.norm(x[:, 0]))
# CNN baseline
cnn = nn.Sequential(
nn.Conv2d(3, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
nn.Conv2d(128, 256, 3, padding=1), nn.ReLU(), nn.AdaptiveAvgPool2d(1),
nn.Flatten(), nn.Linear(256, 10)
)
transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, padding=4),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
test_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_data = datasets.CIFAR10('.', train=True, download=True, transform=transform)
test_data = datasets.CIFAR10('.', train=False, transform=test_transform)
train_loader = DataLoader(train_data, batch_size=128, shuffle=True)
test_loader = DataLoader(test_data, batch_size=256)
def train_model(model, name, epochs=20):
opt = torch.optim.Adam(model.parameters(), lr=1e-3)
for epoch in range(epochs):
model.train()
total_loss = 0
for imgs, labels in train_loader:
logits = model(imgs)
loss = nn.CrossEntropyLoss()(logits, labels)
opt.zero_grad()
loss.backward()
opt.step()
total_loss += loss.item()
model.eval()
correct, total = 0, 0
with torch.no_grad():
for imgs, labels in test_loader:
preds = model(imgs).argmax(-1)
correct += (preds == labels).sum().item()
total += labels.size(0)
acc = correct / total
if epoch % 5 == 0 or epoch == epochs - 1:
print(f"{name} epoch {epoch:>2d}: loss={total_loss/len(train_loader):.3f}, "
f"test_acc={acc:.1%}")
vit = SmallViT()
print(f"ViT params: {sum(p.numel() for p in vit.parameters()):,}")
print(f"CNN params: {sum(p.numel() for p in cnn.parameters()):,}\n")
train_model(cnn, "CNN", epochs=20)
print()
train_model(vit, "ViT", epochs=20)
The CNN trains faster in the first few epochs because of its built-in spatial biases -- convolutions are a strong prior for image data. With data augmentation (random flip + random crop), the ViT closes the gap by epoch 15-20, but on CIFAR-10's modest 50K images it typically doesn't surpass the CNN. On larger datasets, the crossover point favors the ViT.
Exercise 2: Patch embedding visualization and position similarity.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
transform = transforms.Compose([transforms.ToTensor()])
cifar = datasets.CIFAR10('.', train=True, download=True, transform=transform)
img, label = cifar[0] # single image: (3, 32, 32)
print(f"Image shape: {img.shape}, label: {label}")
pe = PatchEmbedding(img_size=32, patch_size=4, d_model=128)
patches = pe(img.unsqueeze(0))
print(f"Patch embeddings: {patches.shape}") # (1, 64, 128)
print(f"Number of patches: {patches.shape[1]}")
print(f"Each patch: 4x4x3 = {4*4*3} pixels -> {patches.shape[2]}-dim vector")
# Verify patch extraction matches raw pixels
n_per_side = 32 // 4 # 8
for row in range(2):
for col in range(2):
patch_idx = row * n_per_side + col
raw_patch = img[:, row*4:(row+1)*4, col*4:(col+1)*4]
print(f"Patch ({row},{col}) idx={patch_idx}: "
f"raw pixel range [{raw_patch.min():.3f}, {raw_patch.max():.3f}]")
# Position embedding similarity (randomly initialized ViT)
vit = SmallViT(img_size=32, patch_size=4, d_model=128,
n_heads=4, n_layers=4, n_classes=10, d_ff=512)
pos = vit.pos_embed.detach().squeeze(0) # (65, 128) -- 64 patches + CLS
sim = F.cosine_similarity(pos.unsqueeze(0), pos.unsqueeze(1), dim=-1)
print(f"\nPosition similarity matrix: {sim.shape}") # (65, 65)
# Top similar pairs (excluding self-similarity)
sim_no_diag = sim.clone()
sim_no_diag.fill_diagonal_(float('-inf'))
flat_idx = sim_no_diag.view(-1).topk(6).indices
for idx in flat_idx:
i, j = idx.item() // 65, idx.item() % 65
print(f" Positions {i} and {j}: similarity = {sim[i, j]:.4f}")
print("\nBefore training, similarities are roughly random.")
print("After training, nearby patches would show high similarity.")
Before training the position embeddings are randomly initialized, so the similarity matrix shows no spatial pattern. After training on real image data, the model discovers that neighboring patches should have similar position embeddings -- it learns the 2D grid structure purely from the optimization objective. This is what we discussed in the episode about 1D position embeddings learning 2D spatial structure.
Exercise 3: Window attention vs full attention -- cost and speed comparison.
import torch
import torch.nn as nn
import math
import time
class FullAttention(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.qkv = nn.Linear(d_model, d_model * 3)
self.proj = nn.Linear(d_model, d_model)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.n_heads, self.d_k)
qkv = qkv.permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @ k.transpose(-2, -1)) / math.sqrt(self.d_k)
attn = attn.softmax(dim=-1)
out = (attn @ v).transpose(1, 2).reshape(B, N, C)
return self.proj(out)
class WindowAttention(nn.Module):
def __init__(self, d_model, n_heads, window_size=7):
super().__init__()
self.window_size = window_size
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.qkv = nn.Linear(d_model, d_model * 3)
self.proj = nn.Linear(d_model, d_model)
def forward(self, x, H, W):
B, N, C = x.shape
ws = self.window_size
x = x.view(B, H, W, C)
pad_h = (ws - H % ws) % ws
pad_w = (ws - W % ws) % ws
if pad_h > 0 or pad_w > 0:
x = nn.functional.pad(x, (0, 0, 0, pad_w, 0, pad_h))
Hp, Wp = x.shape[1], x.shape[2]
nH, nW = Hp // ws, Wp // ws
windows = x.view(B, nH, ws, nW, ws, C)
windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous()
windows = windows.view(-1, ws * ws, C)
qkv = self.qkv(windows).reshape(-1, ws*ws, 3, self.n_heads, self.d_k)
qkv = qkv.permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @ k.transpose(-2, -1)) / math.sqrt(self.d_k)
attn = attn.softmax(dim=-1)
out = (attn @ v).transpose(1, 2).reshape(-1, ws*ws, C)
out = self.proj(out)
out = out.view(B, nH, nW, ws, ws, C)
out = out.permute(0, 1, 3, 2, 4, 5).contiguous()
out = out.view(B, Hp, Wp, C)
if pad_h > 0 or pad_w > 0:
out = out[:, :H, :W, :]
return out.view(B, H * W, C)
d_model = 96
full_attn = FullAttention(d_model, n_heads=3)
win_attn = WindowAttention(d_model, n_heads=3, window_size=7)
full_attn.eval()
win_attn.eval()
H, W = 56, 56
x = torch.randn(1, H * W, d_model)
# Theoretical cost
N = H * W # 3136
w = 7 * 7 # 49
n_windows = (H // 7) * (W // 7) # 64
full_flops = N * N * d_model
win_flops = n_windows * w * w * d_model
print(f"Full attention: O({N}^2 * {d_model}) = {full_flops:,} ops")
print(f"Window attention: O({n_windows} * {w}^2 * {d_model}) = {win_flops:,} ops")
print(f"Theoretical speedup: {full_flops / win_flops:.1f}x\n")
# Wall-clock timing
with torch.no_grad():
t0 = time.perf_counter()
for _ in range(10):
_ = full_attn(x)
t_full = (time.perf_counter() - t0) / 10
t0 = time.perf_counter()
for _ in range(10):
_ = win_attn(x, H, W)
t_win = (time.perf_counter() - t0) / 10
print(f"Full attention: {t_full*1000:.1f}ms")
print(f"Window attention: {t_win*1000:.1f}ms")
print(f"Measured speedup: {t_full / t_win:.1f}x")
The theoretical speedup is dramatic -- full attention on a 56x56 feature map is O(3136^2) while 7x7 window attention is O(64 * 49^2), roughly 64x cheaper. The measured speedup will be lower than theoretical (memory overhead, window reshaping operations, CUDA kernel launch costs on GPU), but still substantial. This is why the Swin Transformer can handle high-resolution inputs that would be impractical with standard ViT's global attention.
On to today's episode
Here we go! For the past several episodes, everything we've built has been discriminative. We've trained models that take an input and predict a label or a value. A CNN takes an image and says "cat" or "dog." A transformer takes a sequence and predicts the next token. Even our ViT from last episode was a classifier -- it chops an image into patches, processes them with self-attention, and outputs a class prediction. All of these are asking the question: given this input, what category does it belong to?
Today we flip the script entirely. Instead of classifying existing data, we're going to create new data. Generative models learn the underlying distribution of a training set and then sample from it -- producing new images, new music, new text that look like they could have been in the training set but never were. And the most fascinating way to train a generative model? Make two neural networks fight each other.
This is the Generative Adversarial Network -- the GAN -- and it was introduced by Ian Goodfellow in 2014. The idea is almost philosophically elegant: one network tries to create convincing fakes, and another network tries to catch them. As the detector gets better, the forger gets better. When the game reaches equilibrium, the forger produces samples indistinguishable from real data.
GANs were the dominant generative approach for images from roughly 2014 to 2022, and understanding them is essential background for the diffusion models and other generative techniques we'll cover later in the series. The adversarial training principle also shows up in robustness training, domain adaptation, and quite some other areas of modern ML ;-)
The adversarial setup
The analogy everyone uses (and honestly, it's a good one): a counterfeiter and a detective. The counterfeiter creates fake banknotes. The detective examines banknotes and decides whether each is real or fake. If the detective catches a fake, the counterfeiter learns what gave it away and improves. If the detective is fooled, the counterfeiter knows that technique works.
Over time, both get better. The counterfeiter produces increasingly realistic fakes. The detective develops increasingly subtle detection methods. If training goes well, the counterfeiter eventually produces fakes so good that the detective can't do better than random guessing -- 50/50.
In neural network terms:
- Generator G: takes random noise z (sampled from a simple distribution like standard normal) and transforms it into a synthetic sample G(z)
- Discriminator D: takes a sample (real or generated) and outputs a probability that it's real
import torch
import torch.nn as nn
class Generator(nn.Module):
def __init__(self, z_dim=64, out_dim=784):
super().__init__()
self.net = nn.Sequential(
nn.Linear(z_dim, 128), nn.ReLU(),
nn.Linear(128, 256), nn.ReLU(),
nn.Linear(256, out_dim), nn.Tanh()
)
def forward(self, z):
return self.net(z)
class Discriminator(nn.Module):
def __init__(self, in_dim=784):
super().__init__()
self.net = nn.Sequential(
nn.Linear(in_dim, 256), nn.LeakyReLU(0.2),
nn.Linear(256, 128), nn.LeakyReLU(0.2),
nn.Linear(128, 1), nn.Sigmoid()
)
def forward(self, x):
return self.net(x)
G = Generator(z_dim=64, out_dim=784)
D = Discriminator(in_dim=784)
z = torch.randn(4, 64)
fake = G(z)
score = D(fake)
print(f"Noise: {z.shape}")
print(f"Generated: {fake.shape}") # (4, 784) -- flattened 28x28 images
print(f"D scores: {score.shape}") # (4, 1) -- probabilities
print(f"D thinks real: {score.detach().mean():.4f}")
Two details worth noting. The generator uses Tanh to output values in [-1, 1] (matching normalized image data). The discriminator uses LeakyReLU in stead of ReLU -- this prevents "dead neurons" that can cause the discriminator to stop providing useful gradients to the generator. We covered the dead neuron problem back in episode #40 when discussing training challenges; LeakyReLU allows a small negative slope (0.2) so gradients always flow, even for negative inputs.
The minimax game
The training objective is a minimax game. The discriminator wants to maximize its ability to classify real vs fake. The generator wants to minimize the discriminator's ability -- it wants the discriminator to think generated samples are real.
Mathematically: min_G max_D E[log D(x)] + E[log(1 - D(G(z)))]
The first term: for real data x, the discriminator wants D(x) close to 1 (high log probability). The second term: for fake data G(z), the discriminator wants D(G(z)) close to 0, so log(1 - 0) = 0 (maximum). But the generator wants D(G(z)) close to 1, so log(1 - 1) approaches negative infinity (minimum).
In practice, training alternates between two steps:
import torch.optim as optim
z_dim = 64
G = Generator(z_dim, 784)
D = Discriminator(784)
opt_G = optim.Adam(G.parameters(), lr=2e-4, betas=(0.5, 0.999))
opt_D = optim.Adam(D.parameters(), lr=2e-4, betas=(0.5, 0.999))
criterion = nn.BCELoss()
# Simulated training step (real data would come from DataLoader)
batch_size = 64
real_data = torch.randn(batch_size, 784) * 0.5 # placeholder for real images
# Step 1: Train Discriminator
z = torch.randn(batch_size, z_dim)
fake_data = G(z).detach() # detach: don't backprop through G
real_labels = torch.ones(batch_size, 1)
fake_labels = torch.zeros(batch_size, 1)
d_loss_real = criterion(D(real_data), real_labels)
d_loss_fake = criterion(D(fake_data), fake_labels)
d_loss = d_loss_real + d_loss_fake
opt_D.zero_grad()
d_loss.backward()
opt_D.step()
# Step 2: Train Generator
z = torch.randn(batch_size, z_dim)
fake_data = G(z)
g_loss = criterion(D(fake_data), real_labels) # G wants D to say "real"
opt_G.zero_grad()
g_loss.backward()
opt_G.step()
print(f"D loss: {d_loss.item():.4f}, G loss: {g_loss.item():.4f}")
The .detach() on the generator's output in step 1 is critical. When training the discriminator, we don't want gradients flowing back into the generator -- we're only updating D's parameters. In step 2, we do want gradients flowing through D and into G (but we don't update D's parameters -- only G's).
Notice the generator's loss uses real_labels as the target. The generator is trying to fool the discriminator -- it wants D to output 1 (real) for generated samples. This is sometimes called the "non-saturating" loss because log(D(G(z))) provides stronger gradients when D(G(z)) is small (early in training) compared to the original minimax log(1 - D(G(z))) formulation. Stronger gradients early on means the generator actually learns something before the discriminator becomes too dominant.
Having said that, let me show you a more complete training loop with actual MNIST data to make this concrete:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# MNIST setup
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)) # scale to [-1, 1]
])
mnist = datasets.MNIST('.', train=True, download=True, transform=transform)
loader = DataLoader(mnist, batch_size=64, shuffle=True)
z_dim = 64
G = Generator(z_dim, 784)
D = Discriminator(784)
opt_G = optim.Adam(G.parameters(), lr=2e-4, betas=(0.5, 0.999))
opt_D = optim.Adam(D.parameters(), lr=2e-4, betas=(0.5, 0.999))
criterion = nn.BCELoss()
for epoch in range(5):
d_total, g_total, n = 0, 0, 0
for real_imgs, _ in loader:
batch = real_imgs.size(0)
real_flat = real_imgs.view(batch, -1)
real_lbl = torch.ones(batch, 1) * 0.9 # label smoothing
fake_lbl = torch.zeros(batch, 1)
# Train D
z = torch.randn(batch, z_dim)
fake = G(z).detach()
d_loss = criterion(D(real_flat), real_lbl) + criterion(D(fake), fake_lbl)
opt_D.zero_grad()
d_loss.backward()
opt_D.step()
# Train G
z = torch.randn(batch, z_dim)
fake = G(z)
g_loss = criterion(D(fake), torch.ones(batch, 1))
opt_G.zero_grad()
g_loss.backward()
opt_G.step()
d_total += d_loss.item()
g_total += g_loss.item()
n += 1
print(f"Epoch {epoch}: D_loss={d_total/n:.3f}, G_loss={g_total/n:.3f}")
After a few epochs, the generator starts producing things that vaguely resemble digits. After many more epochs with a proper DCGAN architecture (coming up next), the results get dramatically better. But even with this simple MLP-based GAN on MNIST, the principle works -- the generator learns to produce 28x28 images from pure noise, guided entirely by the discriminator's feedback.
Why GANs are hard to train
GAN training is notoriously unstable. The generator and discriminator are in a dynamic equilibrium, and quite some things can go wrong:
Oscillation: D gets too strong, G can't make progress. D's loss drops to near zero (it classifies everything correctly), which means G gets no useful gradient signal. Then G overfits to a specific trick, D adjusts, and the cycle repeats without convergence.
Mode collapse: the generator discovers that producing one specific output (or a few specific outputs) consistently fools the discriminator. In stead of generating diverse faces, it generates the same face over and over. The generator has "collapsed" to a few modes of the data distribution in stead of covering the full distribution. Think about why this happens -- if the generator finds one image that D consistently rates as "real," the gradient signal says "keep doing this." The generator has no incentive to explore other outputs because exploration risks worse scores.
Vanishing gradients: if D is too confident (outputs very close to 0 or 1), the gradients through the sigmoid are tiny, and G learns very slowly. This is the same saturation problem we saw in episode #40 with sigmoid activations.
Non-convergence: unlike supervised learning where the loss monotonically decreases (at least on the training set), GAN training is a game -- the losses oscillate as G and D take turns improving. There's no simple metric that tells you "training is going well" until you visually inspect the generated samples. This is genuinly one of the most frustrating aspects of working with GANs in practice.
Practical tricks that help:
# Tricks that make GAN training more stable
# 1. Label smoothing: 0.9 instead of 1.0 for "real"
real_labels = torch.ones(batch_size, 1) * 0.9
# 2. Adam with beta1=0.5 (standard GAN recipe)
opt = optim.Adam(params, lr=2e-4, betas=(0.5, 0.999))
# 3. LeakyReLU in discriminator
nn.LeakyReLU(0.2) # NOT nn.ReLU()
# 4. Batch normalization (but NOT in D's output layer)
nn.BatchNorm1d(256)
# 5. Train D more than G when needed
for _ in range(2): # 2 D steps per G step
# ... train discriminator ...
# ... train generator once ...
# 6. Spectral normalization on D's weights
nn.utils.spectral_norm(nn.Linear(256, 128))
These tricks were discovered through years of painful experimentation by the community. They're not derived from theory -- they're empirical recipes that make the training game more stable. The GAN literature is full of papers proposing new loss functions, normalization schemes, and architectural constraints specifically to tame training instability.
DCGAN: making it work on images
The original GAN used fully connected layers, which doesn't scale well to images. In 2016, Radford et al. proposed DCGAN (Deep Convolutional GAN) with specific architectural guidelines that made GAN training much more stable on image data:
class DCGANGenerator(nn.Module):
def __init__(self, z_dim=100, channels=1):
super().__init__()
self.net = nn.Sequential(
# z_dim -> 256 x 7 x 7
nn.ConvTranspose2d(z_dim, 256, 7, 1, 0, bias=False),
nn.BatchNorm2d(256), nn.ReLU(),
# 256x7x7 -> 128x14x14
nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
nn.BatchNorm2d(128), nn.ReLU(),
# 128x14x14 -> channels x 28x28
nn.ConvTranspose2d(128, channels, 4, 2, 1, bias=False),
nn.Tanh()
)
def forward(self, z):
return self.net(z.view(-1, z.size(1), 1, 1))
class DCGANDiscriminator(nn.Module):
def __init__(self, channels=1):
super().__init__()
self.net = nn.Sequential(
# channels x 28x28 -> 64x14x14
nn.Conv2d(channels, 64, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2),
# 64x14x14 -> 128x7x7
nn.Conv2d(64, 128, 4, 2, 1, bias=False),
nn.BatchNorm2d(128), nn.LeakyReLU(0.2),
# 128x7x7 -> 1
nn.Conv2d(128, 1, 7, 1, 0, bias=False),
nn.Sigmoid()
)
def forward(self, x):
return self.net(x).view(-1, 1)
G = DCGANGenerator(z_dim=100)
D = DCGANDiscriminator()
z = torch.randn(4, 100)
fake_imgs = G(z)
scores = D(fake_imgs)
print(f"Generated images: {fake_imgs.shape}") # (4, 1, 28, 28)
print(f"Discriminator scores: {scores.shape}") # (4, 1)
print(f"G params: {sum(p.numel() for p in G.parameters()):,}")
print(f"D params: {sum(p.numel() for p in D.parameters()):,}")
The DCGAN guidelines that became standard practice:
- Replace pooling with strided convolutions (in D) and transposed convolutions (in G)
- Use batch normalization in both networks (except D's input layer and G's output layer)
- Remove fully connected hidden layers -- go all-convolutional
- Use ReLU in the generator (except output: Tanh) and LeakyReLU in the discriminator
ConvTranspose2d (transposed convolution, sometimes misleadingly called "deconvolution") is the generator's key operation. If you remember from episodes #45-46, a regular convolution with stride 2 downsamples -- it takes a 14x14 feature map and produces a 7x7 one. A transposed convolution does the opposite: it upsamples -- a 7x7 feature map becomes 14x14, then 28x28. The generator starts from a random vector and progressively "paints" an image at increasing resolution, layer by layer. The discriminator does the reverse -- it takes a full-resolution image and progressively compresses it down to a single real/fake probability.
Progressive growing and StyleGAN
Generating tiny 28x28 MNIST digits is one thing. Generating photorealistic 1024x1024 faces is something else entirely. The problem: training a GAN on high-resolution images from the start is extremely unstable. The discriminator easily distinguishes blurry early-stage generator outputs from crisp real images, providing no useful gradient signal.
Progressive growing (Karras et al., 2017, NVIDIA) solved this by starting training at low resolution (4x4) and gradually adding layers to both G and D during training:
- Train at 4x4 until stable
- Add layers for 8x8, blend in gradually (using a lerp between old and new output)
- Continue growing: 16x16 -> 32x32 -> ... -> 1024x1024
Each resolution stage starts easy -- the generator just needs to get the rough structure right -- and refines progressivly. By the time the model reaches 1024x1024, it has already learned faces at all lower resolutions, so the high-resolution details are refinements on an already solid foundation.
StyleGAN (Karras et al., 2019, also NVIDIA) built on progressive growing with a radically different generator architecture. In stead of feeding the noise vector directly into the first layer, StyleGAN introduces a mapping network that transforms the noise z into an intermediate latent vector w through 8 fully-connected layers, and then injects w into each layer through adaptive instance normalization (AdaIN). Different layers control different aspects of the image:
class MappingNetwork(nn.Module):
"""StyleGAN's mapping network: z -> w through 8 FC layers."""
def __init__(self, z_dim=512, w_dim=512, n_layers=8):
super().__init__()
layers = []
for i in range(n_layers):
in_d = z_dim if i == 0 else w_dim
layers.extend([nn.Linear(in_d, w_dim), nn.LeakyReLU(0.2)])
self.net = nn.Sequential(*layers)
def forward(self, z):
return self.net(z)
class AdaIN(nn.Module):
"""Adaptive Instance Normalization -- injects style at each layer."""
def __init__(self, channels, w_dim=512):
super().__init__()
self.norm = nn.InstanceNorm2d(channels)
self.style = nn.Linear(w_dim, channels * 2) # scale + shift
def forward(self, x, w):
style = self.style(w).unsqueeze(2).unsqueeze(3)
gamma, beta = style.chunk(2, dim=1)
return gamma * self.norm(x) + beta
mapping = MappingNetwork(z_dim=128, w_dim=128)
adain = AdaIN(channels=64, w_dim=128)
z = torch.randn(2, 128)
w = mapping(z)
feature_map = torch.randn(2, 64, 16, 16)
styled = adain(feature_map, w)
print(f"z: {z.shape} -> w: {w.shape}")
print(f"Feature map: {feature_map.shape} -> Styled: {styled.shape}")
The key insight: different layers of the generator control different levels of detail:
- Early layers (4x4, 8x8): coarse features -- face shape, pose, hair style
- Middle layers (16x16, 32x32): medium features -- eye shape, nose, wrinkles
- Late layers (64x64+): fine details -- skin texture, hair strands, pores
This separation means you can mix styles from different images: take the coarse structure from one face and the fine details from another. Feed one w vector to the early layers and a different w vector to the late layers, and you get a controllable blend. StyleGAN2 (2020) and StyleGAN3 (2021) further improved quality and eliminated artifacts. By 2021, StyleGAN-generated faces were virtually indistinguishable from real photographs to the human eye.
Wasserstein GAN: a better loss function
One of the most important theoretical improvements to GAN training came from Arjovsky et al. in 2017: the Wasserstein GAN (WGAN). The idea: replace the binary cross-entropy loss with the Wasserstein distance (also called Earth Mover's Distance) between the real and generated distributions.
The original GAN minimizes the Jensen-Shannon divergence between distributions. The problem: when the real and generated distributions don't overlap (which is common early in training -- the generated images are obviously fake), the JS divergence is constant and provides zero gradient. The Wasserstein distance, on the other hand, always provides a meaningful gradient even when distributions don't overlap.
class WGANCritic(nn.Module):
"""WGAN uses a 'critic' (no sigmoid) instead of a discriminator."""
def __init__(self, in_dim=784):
super().__init__()
self.net = nn.Sequential(
nn.Linear(in_dim, 256), nn.LeakyReLU(0.2),
nn.Linear(256, 128), nn.LeakyReLU(0.2),
nn.Linear(128, 1) # NO sigmoid -- raw score
)
def forward(self, x):
return self.net(x)
# WGAN training step (simplified)
critic = WGANCritic()
generator = Generator(z_dim=64)
opt_C = optim.RMSprop(critic.parameters(), lr=5e-5)
opt_G = optim.RMSprop(generator.parameters(), lr=5e-5)
batch_size = 64
real_data = torch.randn(batch_size, 784) * 0.5
# Train critic (5 steps per G step is typical for WGAN)
for _ in range(5):
z = torch.randn(batch_size, 64)
fake = generator(z).detach()
c_loss = -(critic(real_data).mean() - critic(fake).mean())
opt_C.zero_grad()
c_loss.backward()
opt_C.step()
# Weight clipping (original WGAN -- WGAN-GP uses gradient penalty instead)
for p in critic.parameters():
p.data.clamp_(-0.01, 0.01)
# Train generator
z = torch.randn(batch_size, 64)
fake = generator(z)
g_loss = -critic(fake).mean()
opt_G.zero_grad()
g_loss.backward()
opt_G.step()
print(f"Critic loss: {c_loss.item():.4f}, G loss: {g_loss.item():.4f}")
Notice three differences from the standard GAN: (1) no sigmoid on the critic (it outputs an unbounded score, not a probability), (2) the critic is trained multiple times per generator step (typically 5), and (3) weight clipping enforces the Lipschitz constraint that the Wasserstein distance requires. WGAN-GP (Gradient Penalty, Gulrajani et al.) later replaced weight clipping with a gradient penalty term, which works much better in practice.
WGANs are significantly more stable to train -- the critic loss actually correlates with sample quality (unlike standard GAN losses), so you can monitor training progress numerically in stead of just staring at generated images ;-)
GANs vs diffusion models
Since roughly 2022, diffusion models (which we'll cover later in this series) have largely replaced GANs as the dominant generative approach for images. Stable Diffusion, DALL-E 2, Midjourney -- all diffusion-based.
Why the shift? Diffusion models are easier to train (no adversarial instability), produce higher diversity (less mode collapse), and handle text-conditioning more naturally. GANs can be faster at inference (one forward pass vs. many denoising steps) and are still used in specific applications -- real-time video synthesis, super-resolution, image editing, face manipulation -- but for general-purpose image generation, diffusion won.
Understanding GANs still matters though. The adversarial principle appears throughout modern AI: adversarial training for robustness (training a classifier on adversarial examples to make it harder to fool), discriminator heads in other architectures (used for domain adaptation, style transfer), and the game-theoretic thinking that GANs introduced. And many production systems still run GAN-based components -- particularly where inference speed matters, since a GAN generator produces an image in a single forward pass while a diffusion model needs 20-50 denoising steps.
Ethical considerations
Generative models that produce photorealistic human faces raise serious ethical questions, and I think it's important we address them directly.
Deepfakes: GAN-generated face-swaps can put anyone's face into any video. This enables fraud, harassment, non-consensual content, and political manipulation. The technology for creating convincing deepfakes has become accessible to anyone with a decent GPU and an afternoon to spare -- and that's a real problem.
Identity fraud: synthetic faces can be used to create fake social media profiles, bypass identity verification systems, and manufacture fake personas at scale. Every time you see "This Person Does Not Exist" generating a perfect face, remember that the same technology can generate thousands of fake LinkedIn profiles.
Consent: StyleGAN was trained on datasets of real people's faces (FFHQ -- Flickr-Faces-HQ, scraped from Flickr). Those people didn't consent to having their likeness used to train a model that generates synthetic faces. This raises fundamental questions about data rights that the field is still grappling with.
Detection: the arms race between generation and detection is ongoing. Techniques exist (frequency analysis, inconsistency detection, forensic classifiers), but they're always playing catch-up with improving generators. No detection method has remained reliable for more than a year or two before generators learned to avoid the artifacts being detected.
The technology itself is neutral -- the same GAN that enables deepfakes also enables medical image synthesis (generating rare pathology examples for training diagnostic models), privacy-preserving data sharing (generating synthetic patient records that preserve statistical properties without exposing real individuals), and creative tools. But responsible development requires acknowledging the dual-use potential and investing in detection and attribution alongside generation.
The bottom line
- GANs train two competing networks: a generator that creates synthetic data from random noise and a discriminator that tries to detect fakes;
- Training is a minimax game: the discriminator maximizes classification accuracy while the generator minimizes it -- when it converges, the generator produces samples indistinguishable from real data;
- GAN training is unstable -- mode collapse (generator repeats itself), oscillation (D and G take turns dominating), and vanishing gradients (D is too confident) are common failure modes;
- DCGAN established architectural guidelines (strided convolutions, batch norm, LeakyReLU) that made image GANs practical -- these guidelines came from empirical experimentation, not theory;
- Progressive growing trains from low to high resolution, enabling photorealistic 1024x1024 generation -- each resolution stage builds on the previous one;
- StyleGAN separates coarse and fine control through a mapping network and adaptive instance normalization at each layer -- enabling style mixing and unprecedented image quality;
- WGAN replaced the JS divergence with Wasserstein distance for more stable training -- the critic loss actually correlates with sample quality, which is a huge practical advantage;
- Diffusion models have largely replaced GANs for general image generation since ~2022, but GANs remain relevant for fast inference and as a foundational concept in generative modeling;
- Synthetic media raises serious ethical questions around consent, identity fraud, and deepfakes -- the detection vs generation arms race continues.
Exercises
Exercise 1: Build a complete GAN training pipeline on MNIST and visualize the results. Create a Generator (z_dim=64, layers: 64->256->512->784, ReLU + Tanh output) and Discriminator (784->512->256->1, LeakyReLU 0.2 + Sigmoid). Train for 50 epochs on MNIST with Adam (lr=2e-4, betas=(0.5, 0.999)), batch size 64, and label smoothing (0.9 for real labels). Every 10 epochs, generate 16 samples from fixed noise vectors and save the 784-dim outputs reshaped to 28x28. Print the discriminator's average score on real data and fake data at each epoch -- track how these evolve. By epoch 50, the generated images should look recognizably like digits (if blurry).
Exercise 2: Implement mode collapse detection. Train the same GAN from Exercise 1, but this time after every epoch, generate 100 samples and compute the pairwise cosine similarity between all generated images (flattened to vectors). The average pairwise similarity tells you about diversity: if all images look the same (mode collapse), similarity is very high (close to 1.0). If images are diverse, similarity is lower. Plot (or print) this metric across 50 training epochs. Also compute the standard deviation of pixel values across the 100 generated images -- mode collapse should show low variance. Compare two training runs: one with the standard training recipe, and one where you intentionally cause mode collapse by training G for 5 steps per D step (the opposite of what WGAN recommends).
Exercise 3: Build a DCGAN for MNIST and compare against the MLP-based GAN. Implement the DCGANGenerator and DCGANDiscriminator from this episode. Train both the MLP GAN (Exercise 1) and the DCGAN for 20 epochs on MNIST with identical hyperparameters (Adam, lr=2e-4, betas=(0.5, 0.999), batch size 64). After training, generate 100 samples from each model. For each model, compute: (a) the average discriminator score on generated samples (higher = more realistic to D), (b) the pixel-level mean and standard deviation of generated images (should roughly match MNIST statistics: mean ~0.13, std ~0.31 for raw [0,1] values), and (c) the diversity metric from Exercise 2. The DCGAN should produce sharper, more structured images because the convolutional architecture has the right inductive biases for spatial data.