Learn AI Series (#55) - Generative Adversarial Networks

avatar

Learn AI Series (#55) - Generative Adversarial Networks

ai-banner.png

What will I learn

  • You will learn the adversarial setup -- how a generator and discriminator compete to produce realistic outputs;
  • the minimax game -- the mathematical formulation of GAN training;
  • training dynamics -- why GANs are notoriously hard to train;
  • mode collapse -- when the generator gets lazy and repeats itself;
  • DCGAN -- deep convolutional GANs that actually work on images;
  • progressive growing and StyleGAN -- the path to photorealistic generation;
  • when to use GANs vs diffusion models in practice;
  • ethical considerations of synthetic media.

Requirements

  • A working modern computer running macOS, Windows or Ubuntu;
  • An installed Python 3(.11+) distribution;
  • The ambition to learn AI and machine learning.

Difficulty

  • Beginner

Curriculum (of the Learn AI Series):

Learn AI Series (#55) - Generative Adversarial Networks

Solutions to Episode #54 Exercises

Exercise 1: Complete ViT for CIFAR-10 vs CNN baseline.

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import time

class PatchEmbedding(nn.Module):
    def __init__(self, img_size=32, patch_size=4, in_channels=3, d_model=128):
        super().__init__()
        self.n_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(in_channels, d_model,
                              kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x)
        return x.flatten(2).transpose(1, 2)

class TransformerBlock(nn.Module):
    def __init__(self, d_model, n_heads, d_ff):
        super().__init__()
        self.norm1 = nn.LayerNorm(d_model)
        self.attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
        self.norm2 = nn.LayerNorm(d_model)
        self.ff = nn.Sequential(
            nn.Linear(d_model, d_ff), nn.GELU(), nn.Linear(d_ff, d_model))

    def forward(self, x):
        x = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0]
        x = x + self.ff(self.norm2(x))
        return x

class SmallViT(nn.Module):
    def __init__(self, img_size=32, patch_size=4, d_model=128,
                 n_heads=4, n_layers=4, n_classes=10, d_ff=512):
        super().__init__()
        self.patch_embed = PatchEmbedding(img_size, patch_size, 3, d_model)
        n_patches = self.patch_embed.n_patches
        self.cls_token = nn.Parameter(torch.randn(1, 1, d_model) * 0.02)
        self.pos_embed = nn.Parameter(
            torch.randn(1, n_patches + 1, d_model) * 0.02)
        self.layers = nn.ModuleList([
            TransformerBlock(d_model, n_heads, d_ff) for _ in range(n_layers)])
        self.norm = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, n_classes)

    def forward(self, x):
        batch = x.size(0)
        patches = self.patch_embed(x)
        cls = self.cls_token.expand(batch, -1, -1)
        x = torch.cat([cls, patches], dim=1) + self.pos_embed
        for layer in self.layers:
            x = layer(x)
        return self.head(self.norm(x[:, 0]))

# CNN baseline
cnn = nn.Sequential(
    nn.Conv2d(3, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
    nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
    nn.Conv2d(128, 256, 3, padding=1), nn.ReLU(), nn.AdaptiveAvgPool2d(1),
    nn.Flatten(), nn.Linear(256, 10)
)

transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_data = datasets.CIFAR10('.', train=True, download=True, transform=transform)
test_data = datasets.CIFAR10('.', train=False, transform=test_transform)
train_loader = DataLoader(train_data, batch_size=128, shuffle=True)
test_loader = DataLoader(test_data, batch_size=256)

def train_model(model, name, epochs=20):
    opt = torch.optim.Adam(model.parameters(), lr=1e-3)
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for imgs, labels in train_loader:
            logits = model(imgs)
            loss = nn.CrossEntropyLoss()(logits, labels)
            opt.zero_grad()
            loss.backward()
            opt.step()
            total_loss += loss.item()
        model.eval()
        correct, total = 0, 0
        with torch.no_grad():
            for imgs, labels in test_loader:
                preds = model(imgs).argmax(-1)
                correct += (preds == labels).sum().item()
                total += labels.size(0)
        acc = correct / total
        if epoch % 5 == 0 or epoch == epochs - 1:
            print(f"{name} epoch {epoch:>2d}: loss={total_loss/len(train_loader):.3f}, "
                  f"test_acc={acc:.1%}")

vit = SmallViT()
print(f"ViT params: {sum(p.numel() for p in vit.parameters()):,}")
print(f"CNN params: {sum(p.numel() for p in cnn.parameters()):,}\n")
train_model(cnn, "CNN", epochs=20)
print()
train_model(vit, "ViT", epochs=20)

The CNN trains faster in the first few epochs because of its built-in spatial biases -- convolutions are a strong prior for image data. With data augmentation (random flip + random crop), the ViT closes the gap by epoch 15-20, but on CIFAR-10's modest 50K images it typically doesn't surpass the CNN. On larger datasets, the crossover point favors the ViT.

Exercise 2: Patch embedding visualization and position similarity.

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms

transform = transforms.Compose([transforms.ToTensor()])
cifar = datasets.CIFAR10('.', train=True, download=True, transform=transform)
img, label = cifar[0]  # single image: (3, 32, 32)
print(f"Image shape: {img.shape}, label: {label}")

pe = PatchEmbedding(img_size=32, patch_size=4, d_model=128)
patches = pe(img.unsqueeze(0))
print(f"Patch embeddings: {patches.shape}")  # (1, 64, 128)
print(f"Number of patches: {patches.shape[1]}")
print(f"Each patch: 4x4x3 = {4*4*3} pixels -> {patches.shape[2]}-dim vector")

# Verify patch extraction matches raw pixels
n_per_side = 32 // 4  # 8
for row in range(2):
    for col in range(2):
        patch_idx = row * n_per_side + col
        raw_patch = img[:, row*4:(row+1)*4, col*4:(col+1)*4]
        print(f"Patch ({row},{col}) idx={patch_idx}: "
              f"raw pixel range [{raw_patch.min():.3f}, {raw_patch.max():.3f}]")

# Position embedding similarity (randomly initialized ViT)
vit = SmallViT(img_size=32, patch_size=4, d_model=128,
               n_heads=4, n_layers=4, n_classes=10, d_ff=512)
pos = vit.pos_embed.detach().squeeze(0)  # (65, 128) -- 64 patches + CLS
sim = F.cosine_similarity(pos.unsqueeze(0), pos.unsqueeze(1), dim=-1)
print(f"\nPosition similarity matrix: {sim.shape}")  # (65, 65)

# Top similar pairs (excluding self-similarity)
sim_no_diag = sim.clone()
sim_no_diag.fill_diagonal_(float('-inf'))
flat_idx = sim_no_diag.view(-1).topk(6).indices
for idx in flat_idx:
    i, j = idx.item() // 65, idx.item() % 65
    print(f"  Positions {i} and {j}: similarity = {sim[i, j]:.4f}")

print("\nBefore training, similarities are roughly random.")
print("After training, nearby patches would show high similarity.")

Before training the position embeddings are randomly initialized, so the similarity matrix shows no spatial pattern. After training on real image data, the model discovers that neighboring patches should have similar position embeddings -- it learns the 2D grid structure purely from the optimization objective. This is what we discussed in the episode about 1D position embeddings learning 2D spatial structure.

Exercise 3: Window attention vs full attention -- cost and speed comparison.

import torch
import torch.nn as nn
import math
import time

class FullAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        self.qkv = nn.Linear(d_model, d_model * 3)
        self.proj = nn.Linear(d_model, d_model)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.n_heads, self.d_k)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        attn = (q @ k.transpose(-2, -1)) / math.sqrt(self.d_k)
        attn = attn.softmax(dim=-1)
        out = (attn @ v).transpose(1, 2).reshape(B, N, C)
        return self.proj(out)

class WindowAttention(nn.Module):
    def __init__(self, d_model, n_heads, window_size=7):
        super().__init__()
        self.window_size = window_size
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        self.qkv = nn.Linear(d_model, d_model * 3)
        self.proj = nn.Linear(d_model, d_model)

    def forward(self, x, H, W):
        B, N, C = x.shape
        ws = self.window_size
        x = x.view(B, H, W, C)
        pad_h = (ws - H % ws) % ws
        pad_w = (ws - W % ws) % ws
        if pad_h > 0 or pad_w > 0:
            x = nn.functional.pad(x, (0, 0, 0, pad_w, 0, pad_h))
        Hp, Wp = x.shape[1], x.shape[2]
        nH, nW = Hp // ws, Wp // ws
        windows = x.view(B, nH, ws, nW, ws, C)
        windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous()
        windows = windows.view(-1, ws * ws, C)
        qkv = self.qkv(windows).reshape(-1, ws*ws, 3, self.n_heads, self.d_k)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        attn = (q @ k.transpose(-2, -1)) / math.sqrt(self.d_k)
        attn = attn.softmax(dim=-1)
        out = (attn @ v).transpose(1, 2).reshape(-1, ws*ws, C)
        out = self.proj(out)
        out = out.view(B, nH, nW, ws, ws, C)
        out = out.permute(0, 1, 3, 2, 4, 5).contiguous()
        out = out.view(B, Hp, Wp, C)
        if pad_h > 0 or pad_w > 0:
            out = out[:, :H, :W, :]
        return out.view(B, H * W, C)

d_model = 96
full_attn = FullAttention(d_model, n_heads=3)
win_attn = WindowAttention(d_model, n_heads=3, window_size=7)
full_attn.eval()
win_attn.eval()

H, W = 56, 56
x = torch.randn(1, H * W, d_model)

# Theoretical cost
N = H * W  # 3136
w = 7 * 7  # 49
n_windows = (H // 7) * (W // 7)  # 64
full_flops = N * N * d_model
win_flops = n_windows * w * w * d_model
print(f"Full attention: O({N}^2 * {d_model}) = {full_flops:,} ops")
print(f"Window attention: O({n_windows} * {w}^2 * {d_model}) = {win_flops:,} ops")
print(f"Theoretical speedup: {full_flops / win_flops:.1f}x\n")

# Wall-clock timing
with torch.no_grad():
    t0 = time.perf_counter()
    for _ in range(10):
        _ = full_attn(x)
    t_full = (time.perf_counter() - t0) / 10

    t0 = time.perf_counter()
    for _ in range(10):
        _ = win_attn(x, H, W)
    t_win = (time.perf_counter() - t0) / 10

print(f"Full attention:   {t_full*1000:.1f}ms")
print(f"Window attention: {t_win*1000:.1f}ms")
print(f"Measured speedup: {t_full / t_win:.1f}x")

The theoretical speedup is dramatic -- full attention on a 56x56 feature map is O(3136^2) while 7x7 window attention is O(64 * 49^2), roughly 64x cheaper. The measured speedup will be lower than theoretical (memory overhead, window reshaping operations, CUDA kernel launch costs on GPU), but still substantial. This is why the Swin Transformer can handle high-resolution inputs that would be impractical with standard ViT's global attention.

On to today's episode

Here we go! For the past several episodes, everything we've built has been discriminative. We've trained models that take an input and predict a label or a value. A CNN takes an image and says "cat" or "dog." A transformer takes a sequence and predicts the next token. Even our ViT from last episode was a classifier -- it chops an image into patches, processes them with self-attention, and outputs a class prediction. All of these are asking the question: given this input, what category does it belong to?

Today we flip the script entirely. Instead of classifying existing data, we're going to create new data. Generative models learn the underlying distribution of a training set and then sample from it -- producing new images, new music, new text that look like they could have been in the training set but never were. And the most fascinating way to train a generative model? Make two neural networks fight each other.

This is the Generative Adversarial Network -- the GAN -- and it was introduced by Ian Goodfellow in 2014. The idea is almost philosophically elegant: one network tries to create convincing fakes, and another network tries to catch them. As the detector gets better, the forger gets better. When the game reaches equilibrium, the forger produces samples indistinguishable from real data.

GANs were the dominant generative approach for images from roughly 2014 to 2022, and understanding them is essential background for the diffusion models and other generative techniques we'll cover later in the series. The adversarial training principle also shows up in robustness training, domain adaptation, and quite some other areas of modern ML ;-)

The adversarial setup

The analogy everyone uses (and honestly, it's a good one): a counterfeiter and a detective. The counterfeiter creates fake banknotes. The detective examines banknotes and decides whether each is real or fake. If the detective catches a fake, the counterfeiter learns what gave it away and improves. If the detective is fooled, the counterfeiter knows that technique works.

Over time, both get better. The counterfeiter produces increasingly realistic fakes. The detective develops increasingly subtle detection methods. If training goes well, the counterfeiter eventually produces fakes so good that the detective can't do better than random guessing -- 50/50.

In neural network terms:

  • Generator G: takes random noise z (sampled from a simple distribution like standard normal) and transforms it into a synthetic sample G(z)
  • Discriminator D: takes a sample (real or generated) and outputs a probability that it's real
import torch
import torch.nn as nn

class Generator(nn.Module):
    def __init__(self, z_dim=64, out_dim=784):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(z_dim, 128), nn.ReLU(),
            nn.Linear(128, 256), nn.ReLU(),
            nn.Linear(256, out_dim), nn.Tanh()
        )

    def forward(self, z):
        return self.net(z)

class Discriminator(nn.Module):
    def __init__(self, in_dim=784):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, 256), nn.LeakyReLU(0.2),
            nn.Linear(256, 128), nn.LeakyReLU(0.2),
            nn.Linear(128, 1), nn.Sigmoid()
        )

    def forward(self, x):
        return self.net(x)

G = Generator(z_dim=64, out_dim=784)
D = Discriminator(in_dim=784)
z = torch.randn(4, 64)
fake = G(z)
score = D(fake)
print(f"Noise: {z.shape}")
print(f"Generated: {fake.shape}")    # (4, 784) -- flattened 28x28 images
print(f"D scores: {score.shape}")    # (4, 1) -- probabilities
print(f"D thinks real: {score.detach().mean():.4f}")

Two details worth noting. The generator uses Tanh to output values in [-1, 1] (matching normalized image data). The discriminator uses LeakyReLU in stead of ReLU -- this prevents "dead neurons" that can cause the discriminator to stop providing useful gradients to the generator. We covered the dead neuron problem back in episode #40 when discussing training challenges; LeakyReLU allows a small negative slope (0.2) so gradients always flow, even for negative inputs.

The minimax game

The training objective is a minimax game. The discriminator wants to maximize its ability to classify real vs fake. The generator wants to minimize the discriminator's ability -- it wants the discriminator to think generated samples are real.

Mathematically: min_G max_D E[log D(x)] + E[log(1 - D(G(z)))]

The first term: for real data x, the discriminator wants D(x) close to 1 (high log probability). The second term: for fake data G(z), the discriminator wants D(G(z)) close to 0, so log(1 - 0) = 0 (maximum). But the generator wants D(G(z)) close to 1, so log(1 - 1) approaches negative infinity (minimum).

In practice, training alternates between two steps:

import torch.optim as optim

z_dim = 64
G = Generator(z_dim, 784)
D = Discriminator(784)
opt_G = optim.Adam(G.parameters(), lr=2e-4, betas=(0.5, 0.999))
opt_D = optim.Adam(D.parameters(), lr=2e-4, betas=(0.5, 0.999))
criterion = nn.BCELoss()

# Simulated training step (real data would come from DataLoader)
batch_size = 64
real_data = torch.randn(batch_size, 784) * 0.5  # placeholder for real images

# Step 1: Train Discriminator
z = torch.randn(batch_size, z_dim)
fake_data = G(z).detach()  # detach: don't backprop through G
real_labels = torch.ones(batch_size, 1)
fake_labels = torch.zeros(batch_size, 1)

d_loss_real = criterion(D(real_data), real_labels)
d_loss_fake = criterion(D(fake_data), fake_labels)
d_loss = d_loss_real + d_loss_fake
opt_D.zero_grad()
d_loss.backward()
opt_D.step()

# Step 2: Train Generator
z = torch.randn(batch_size, z_dim)
fake_data = G(z)
g_loss = criterion(D(fake_data), real_labels)  # G wants D to say "real"
opt_G.zero_grad()
g_loss.backward()
opt_G.step()

print(f"D loss: {d_loss.item():.4f}, G loss: {g_loss.item():.4f}")

The .detach() on the generator's output in step 1 is critical. When training the discriminator, we don't want gradients flowing back into the generator -- we're only updating D's parameters. In step 2, we do want gradients flowing through D and into G (but we don't update D's parameters -- only G's).

Notice the generator's loss uses real_labels as the target. The generator is trying to fool the discriminator -- it wants D to output 1 (real) for generated samples. This is sometimes called the "non-saturating" loss because log(D(G(z))) provides stronger gradients when D(G(z)) is small (early in training) compared to the original minimax log(1 - D(G(z))) formulation. Stronger gradients early on means the generator actually learns something before the discriminator becomes too dominant.

Having said that, let me show you a more complete training loop with actual MNIST data to make this concrete:

from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# MNIST setup
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # scale to [-1, 1]
])
mnist = datasets.MNIST('.', train=True, download=True, transform=transform)
loader = DataLoader(mnist, batch_size=64, shuffle=True)

z_dim = 64
G = Generator(z_dim, 784)
D = Discriminator(784)
opt_G = optim.Adam(G.parameters(), lr=2e-4, betas=(0.5, 0.999))
opt_D = optim.Adam(D.parameters(), lr=2e-4, betas=(0.5, 0.999))
criterion = nn.BCELoss()

for epoch in range(5):
    d_total, g_total, n = 0, 0, 0
    for real_imgs, _ in loader:
        batch = real_imgs.size(0)
        real_flat = real_imgs.view(batch, -1)
        real_lbl = torch.ones(batch, 1) * 0.9   # label smoothing
        fake_lbl = torch.zeros(batch, 1)

        # Train D
        z = torch.randn(batch, z_dim)
        fake = G(z).detach()
        d_loss = criterion(D(real_flat), real_lbl) + criterion(D(fake), fake_lbl)
        opt_D.zero_grad()
        d_loss.backward()
        opt_D.step()

        # Train G
        z = torch.randn(batch, z_dim)
        fake = G(z)
        g_loss = criterion(D(fake), torch.ones(batch, 1))
        opt_G.zero_grad()
        g_loss.backward()
        opt_G.step()

        d_total += d_loss.item()
        g_total += g_loss.item()
        n += 1

    print(f"Epoch {epoch}: D_loss={d_total/n:.3f}, G_loss={g_total/n:.3f}")

After a few epochs, the generator starts producing things that vaguely resemble digits. After many more epochs with a proper DCGAN architecture (coming up next), the results get dramatically better. But even with this simple MLP-based GAN on MNIST, the principle works -- the generator learns to produce 28x28 images from pure noise, guided entirely by the discriminator's feedback.

Why GANs are hard to train

GAN training is notoriously unstable. The generator and discriminator are in a dynamic equilibrium, and quite some things can go wrong:

Oscillation: D gets too strong, G can't make progress. D's loss drops to near zero (it classifies everything correctly), which means G gets no useful gradient signal. Then G overfits to a specific trick, D adjusts, and the cycle repeats without convergence.

Mode collapse: the generator discovers that producing one specific output (or a few specific outputs) consistently fools the discriminator. In stead of generating diverse faces, it generates the same face over and over. The generator has "collapsed" to a few modes of the data distribution in stead of covering the full distribution. Think about why this happens -- if the generator finds one image that D consistently rates as "real," the gradient signal says "keep doing this." The generator has no incentive to explore other outputs because exploration risks worse scores.

Vanishing gradients: if D is too confident (outputs very close to 0 or 1), the gradients through the sigmoid are tiny, and G learns very slowly. This is the same saturation problem we saw in episode #40 with sigmoid activations.

Non-convergence: unlike supervised learning where the loss monotonically decreases (at least on the training set), GAN training is a game -- the losses oscillate as G and D take turns improving. There's no simple metric that tells you "training is going well" until you visually inspect the generated samples. This is genuinly one of the most frustrating aspects of working with GANs in practice.

Practical tricks that help:

# Tricks that make GAN training more stable

# 1. Label smoothing: 0.9 instead of 1.0 for "real"
real_labels = torch.ones(batch_size, 1) * 0.9

# 2. Adam with beta1=0.5 (standard GAN recipe)
opt = optim.Adam(params, lr=2e-4, betas=(0.5, 0.999))

# 3. LeakyReLU in discriminator
nn.LeakyReLU(0.2)  # NOT nn.ReLU()

# 4. Batch normalization (but NOT in D's output layer)
nn.BatchNorm1d(256)

# 5. Train D more than G when needed
for _ in range(2):  # 2 D steps per G step
    # ... train discriminator ...
# ... train generator once ...

# 6. Spectral normalization on D's weights
nn.utils.spectral_norm(nn.Linear(256, 128))

These tricks were discovered through years of painful experimentation by the community. They're not derived from theory -- they're empirical recipes that make the training game more stable. The GAN literature is full of papers proposing new loss functions, normalization schemes, and architectural constraints specifically to tame training instability.

DCGAN: making it work on images

The original GAN used fully connected layers, which doesn't scale well to images. In 2016, Radford et al. proposed DCGAN (Deep Convolutional GAN) with specific architectural guidelines that made GAN training much more stable on image data:

class DCGANGenerator(nn.Module):
    def __init__(self, z_dim=100, channels=1):
        super().__init__()
        self.net = nn.Sequential(
            # z_dim -> 256 x 7 x 7
            nn.ConvTranspose2d(z_dim, 256, 7, 1, 0, bias=False),
            nn.BatchNorm2d(256), nn.ReLU(),
            # 256x7x7 -> 128x14x14
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128), nn.ReLU(),
            # 128x14x14 -> channels x 28x28
            nn.ConvTranspose2d(128, channels, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, z):
        return self.net(z.view(-1, z.size(1), 1, 1))

class DCGANDiscriminator(nn.Module):
    def __init__(self, channels=1):
        super().__init__()
        self.net = nn.Sequential(
            # channels x 28x28 -> 64x14x14
            nn.Conv2d(channels, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2),
            # 64x14x14 -> 128x7x7
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128), nn.LeakyReLU(0.2),
            # 128x7x7 -> 1
            nn.Conv2d(128, 1, 7, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.net(x).view(-1, 1)

G = DCGANGenerator(z_dim=100)
D = DCGANDiscriminator()
z = torch.randn(4, 100)
fake_imgs = G(z)
scores = D(fake_imgs)
print(f"Generated images: {fake_imgs.shape}")   # (4, 1, 28, 28)
print(f"Discriminator scores: {scores.shape}")   # (4, 1)
print(f"G params: {sum(p.numel() for p in G.parameters()):,}")
print(f"D params: {sum(p.numel() for p in D.parameters()):,}")

The DCGAN guidelines that became standard practice:

  • Replace pooling with strided convolutions (in D) and transposed convolutions (in G)
  • Use batch normalization in both networks (except D's input layer and G's output layer)
  • Remove fully connected hidden layers -- go all-convolutional
  • Use ReLU in the generator (except output: Tanh) and LeakyReLU in the discriminator

ConvTranspose2d (transposed convolution, sometimes misleadingly called "deconvolution") is the generator's key operation. If you remember from episodes #45-46, a regular convolution with stride 2 downsamples -- it takes a 14x14 feature map and produces a 7x7 one. A transposed convolution does the opposite: it upsamples -- a 7x7 feature map becomes 14x14, then 28x28. The generator starts from a random vector and progressively "paints" an image at increasing resolution, layer by layer. The discriminator does the reverse -- it takes a full-resolution image and progressively compresses it down to a single real/fake probability.

Progressive growing and StyleGAN

Generating tiny 28x28 MNIST digits is one thing. Generating photorealistic 1024x1024 faces is something else entirely. The problem: training a GAN on high-resolution images from the start is extremely unstable. The discriminator easily distinguishes blurry early-stage generator outputs from crisp real images, providing no useful gradient signal.

Progressive growing (Karras et al., 2017, NVIDIA) solved this by starting training at low resolution (4x4) and gradually adding layers to both G and D during training:

  1. Train at 4x4 until stable
  2. Add layers for 8x8, blend in gradually (using a lerp between old and new output)
  3. Continue growing: 16x16 -> 32x32 -> ... -> 1024x1024

Each resolution stage starts easy -- the generator just needs to get the rough structure right -- and refines progressivly. By the time the model reaches 1024x1024, it has already learned faces at all lower resolutions, so the high-resolution details are refinements on an already solid foundation.

StyleGAN (Karras et al., 2019, also NVIDIA) built on progressive growing with a radically different generator architecture. In stead of feeding the noise vector directly into the first layer, StyleGAN introduces a mapping network that transforms the noise z into an intermediate latent vector w through 8 fully-connected layers, and then injects w into each layer through adaptive instance normalization (AdaIN). Different layers control different aspects of the image:

class MappingNetwork(nn.Module):
    """StyleGAN's mapping network: z -> w through 8 FC layers."""
    def __init__(self, z_dim=512, w_dim=512, n_layers=8):
        super().__init__()
        layers = []
        for i in range(n_layers):
            in_d = z_dim if i == 0 else w_dim
            layers.extend([nn.Linear(in_d, w_dim), nn.LeakyReLU(0.2)])
        self.net = nn.Sequential(*layers)

    def forward(self, z):
        return self.net(z)

class AdaIN(nn.Module):
    """Adaptive Instance Normalization -- injects style at each layer."""
    def __init__(self, channels, w_dim=512):
        super().__init__()
        self.norm = nn.InstanceNorm2d(channels)
        self.style = nn.Linear(w_dim, channels * 2)  # scale + shift

    def forward(self, x, w):
        style = self.style(w).unsqueeze(2).unsqueeze(3)
        gamma, beta = style.chunk(2, dim=1)
        return gamma * self.norm(x) + beta

mapping = MappingNetwork(z_dim=128, w_dim=128)
adain = AdaIN(channels=64, w_dim=128)

z = torch.randn(2, 128)
w = mapping(z)
feature_map = torch.randn(2, 64, 16, 16)
styled = adain(feature_map, w)
print(f"z: {z.shape} -> w: {w.shape}")
print(f"Feature map: {feature_map.shape} -> Styled: {styled.shape}")

The key insight: different layers of the generator control different levels of detail:

  • Early layers (4x4, 8x8): coarse features -- face shape, pose, hair style
  • Middle layers (16x16, 32x32): medium features -- eye shape, nose, wrinkles
  • Late layers (64x64+): fine details -- skin texture, hair strands, pores

This separation means you can mix styles from different images: take the coarse structure from one face and the fine details from another. Feed one w vector to the early layers and a different w vector to the late layers, and you get a controllable blend. StyleGAN2 (2020) and StyleGAN3 (2021) further improved quality and eliminated artifacts. By 2021, StyleGAN-generated faces were virtually indistinguishable from real photographs to the human eye.

Wasserstein GAN: a better loss function

One of the most important theoretical improvements to GAN training came from Arjovsky et al. in 2017: the Wasserstein GAN (WGAN). The idea: replace the binary cross-entropy loss with the Wasserstein distance (also called Earth Mover's Distance) between the real and generated distributions.

The original GAN minimizes the Jensen-Shannon divergence between distributions. The problem: when the real and generated distributions don't overlap (which is common early in training -- the generated images are obviously fake), the JS divergence is constant and provides zero gradient. The Wasserstein distance, on the other hand, always provides a meaningful gradient even when distributions don't overlap.

class WGANCritic(nn.Module):
    """WGAN uses a 'critic' (no sigmoid) instead of a discriminator."""
    def __init__(self, in_dim=784):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, 256), nn.LeakyReLU(0.2),
            nn.Linear(256, 128), nn.LeakyReLU(0.2),
            nn.Linear(128, 1)  # NO sigmoid -- raw score
        )

    def forward(self, x):
        return self.net(x)

# WGAN training step (simplified)
critic = WGANCritic()
generator = Generator(z_dim=64)
opt_C = optim.RMSprop(critic.parameters(), lr=5e-5)
opt_G = optim.RMSprop(generator.parameters(), lr=5e-5)

batch_size = 64
real_data = torch.randn(batch_size, 784) * 0.5

# Train critic (5 steps per G step is typical for WGAN)
for _ in range(5):
    z = torch.randn(batch_size, 64)
    fake = generator(z).detach()
    c_loss = -(critic(real_data).mean() - critic(fake).mean())
    opt_C.zero_grad()
    c_loss.backward()
    opt_C.step()
    # Weight clipping (original WGAN -- WGAN-GP uses gradient penalty instead)
    for p in critic.parameters():
        p.data.clamp_(-0.01, 0.01)

# Train generator
z = torch.randn(batch_size, 64)
fake = generator(z)
g_loss = -critic(fake).mean()
opt_G.zero_grad()
g_loss.backward()
opt_G.step()

print(f"Critic loss: {c_loss.item():.4f}, G loss: {g_loss.item():.4f}")

Notice three differences from the standard GAN: (1) no sigmoid on the critic (it outputs an unbounded score, not a probability), (2) the critic is trained multiple times per generator step (typically 5), and (3) weight clipping enforces the Lipschitz constraint that the Wasserstein distance requires. WGAN-GP (Gradient Penalty, Gulrajani et al.) later replaced weight clipping with a gradient penalty term, which works much better in practice.

WGANs are significantly more stable to train -- the critic loss actually correlates with sample quality (unlike standard GAN losses), so you can monitor training progress numerically in stead of just staring at generated images ;-)

GANs vs diffusion models

Since roughly 2022, diffusion models (which we'll cover later in this series) have largely replaced GANs as the dominant generative approach for images. Stable Diffusion, DALL-E 2, Midjourney -- all diffusion-based.

Why the shift? Diffusion models are easier to train (no adversarial instability), produce higher diversity (less mode collapse), and handle text-conditioning more naturally. GANs can be faster at inference (one forward pass vs. many denoising steps) and are still used in specific applications -- real-time video synthesis, super-resolution, image editing, face manipulation -- but for general-purpose image generation, diffusion won.

Understanding GANs still matters though. The adversarial principle appears throughout modern AI: adversarial training for robustness (training a classifier on adversarial examples to make it harder to fool), discriminator heads in other architectures (used for domain adaptation, style transfer), and the game-theoretic thinking that GANs introduced. And many production systems still run GAN-based components -- particularly where inference speed matters, since a GAN generator produces an image in a single forward pass while a diffusion model needs 20-50 denoising steps.

Ethical considerations

Generative models that produce photorealistic human faces raise serious ethical questions, and I think it's important we address them directly.

Deepfakes: GAN-generated face-swaps can put anyone's face into any video. This enables fraud, harassment, non-consensual content, and political manipulation. The technology for creating convincing deepfakes has become accessible to anyone with a decent GPU and an afternoon to spare -- and that's a real problem.

Identity fraud: synthetic faces can be used to create fake social media profiles, bypass identity verification systems, and manufacture fake personas at scale. Every time you see "This Person Does Not Exist" generating a perfect face, remember that the same technology can generate thousands of fake LinkedIn profiles.

Consent: StyleGAN was trained on datasets of real people's faces (FFHQ -- Flickr-Faces-HQ, scraped from Flickr). Those people didn't consent to having their likeness used to train a model that generates synthetic faces. This raises fundamental questions about data rights that the field is still grappling with.

Detection: the arms race between generation and detection is ongoing. Techniques exist (frequency analysis, inconsistency detection, forensic classifiers), but they're always playing catch-up with improving generators. No detection method has remained reliable for more than a year or two before generators learned to avoid the artifacts being detected.

The technology itself is neutral -- the same GAN that enables deepfakes also enables medical image synthesis (generating rare pathology examples for training diagnostic models), privacy-preserving data sharing (generating synthetic patient records that preserve statistical properties without exposing real individuals), and creative tools. But responsible development requires acknowledging the dual-use potential and investing in detection and attribution alongside generation.

The bottom line

  • GANs train two competing networks: a generator that creates synthetic data from random noise and a discriminator that tries to detect fakes;
  • Training is a minimax game: the discriminator maximizes classification accuracy while the generator minimizes it -- when it converges, the generator produces samples indistinguishable from real data;
  • GAN training is unstable -- mode collapse (generator repeats itself), oscillation (D and G take turns dominating), and vanishing gradients (D is too confident) are common failure modes;
  • DCGAN established architectural guidelines (strided convolutions, batch norm, LeakyReLU) that made image GANs practical -- these guidelines came from empirical experimentation, not theory;
  • Progressive growing trains from low to high resolution, enabling photorealistic 1024x1024 generation -- each resolution stage builds on the previous one;
  • StyleGAN separates coarse and fine control through a mapping network and adaptive instance normalization at each layer -- enabling style mixing and unprecedented image quality;
  • WGAN replaced the JS divergence with Wasserstein distance for more stable training -- the critic loss actually correlates with sample quality, which is a huge practical advantage;
  • Diffusion models have largely replaced GANs for general image generation since ~2022, but GANs remain relevant for fast inference and as a foundational concept in generative modeling;
  • Synthetic media raises serious ethical questions around consent, identity fraud, and deepfakes -- the detection vs generation arms race continues.

Exercises

Exercise 1: Build a complete GAN training pipeline on MNIST and visualize the results. Create a Generator (z_dim=64, layers: 64->256->512->784, ReLU + Tanh output) and Discriminator (784->512->256->1, LeakyReLU 0.2 + Sigmoid). Train for 50 epochs on MNIST with Adam (lr=2e-4, betas=(0.5, 0.999)), batch size 64, and label smoothing (0.9 for real labels). Every 10 epochs, generate 16 samples from fixed noise vectors and save the 784-dim outputs reshaped to 28x28. Print the discriminator's average score on real data and fake data at each epoch -- track how these evolve. By epoch 50, the generated images should look recognizably like digits (if blurry).

Exercise 2: Implement mode collapse detection. Train the same GAN from Exercise 1, but this time after every epoch, generate 100 samples and compute the pairwise cosine similarity between all generated images (flattened to vectors). The average pairwise similarity tells you about diversity: if all images look the same (mode collapse), similarity is very high (close to 1.0). If images are diverse, similarity is lower. Plot (or print) this metric across 50 training epochs. Also compute the standard deviation of pixel values across the 100 generated images -- mode collapse should show low variance. Compare two training runs: one with the standard training recipe, and one where you intentionally cause mode collapse by training G for 5 steps per D step (the opposite of what WGAN recommends).

Exercise 3: Build a DCGAN for MNIST and compare against the MLP-based GAN. Implement the DCGANGenerator and DCGANDiscriminator from this episode. Train both the MLP GAN (Exercise 1) and the DCGAN for 20 epochs on MNIST with identical hyperparameters (Adam, lr=2e-4, betas=(0.5, 0.999), batch size 64). After training, generate 100 samples from each model. For each model, compute: (a) the average discriminator score on generated samples (higher = more realistic to D), (b) the pixel-level mean and standard deviation of generated images (should roughly match MNIST statistics: mean ~0.13, std ~0.31 for raw [0,1] values), and (c) the diversity metric from Exercise 2. The DCGAN should produce sharper, more structured images because the convolutional architecture has the right inductive biases for spatial data.

Thanks for reading!

@scipio



0
0
0.000
0 comments