Learn AI Series (#84) - Generative Images - Diffusion Models (Part 1)

avatar

Learn AI Series (#84) - Generative Images - Diffusion Models (Part 1)

variant-a-01-orange.png

What will I learn

  • You will learn the forward diffusion process: systematically adding Gaussian noise to images over T timesteps;
  • the reverse process: training a neural network to predict and remove the noise added at each step;
  • DDPM (Denoising Diffusion Probabilistic Models) -- the foundational architecture that started the generative revolution;
  • the variance schedule and the alpha-bar shortcut for jumping directly to any timestep;
  • the U-Net architecture adapted for diffusion with sinusoidal timestep conditioning;
  • how to train a diffusion model from scratch using simple MSE loss on noise prediction;
  • the iterative sampling procedure: going from pure Gaussian noise to a clean image in T denoising steps;
  • why diffusion models beat GANs on training stability, mode coverage, and scalability.

Requirements

  • A working modern computer running macOS, Windows or Ubuntu;
  • An installed Python 3(.11+) distribution;
  • The ambition to learn AI and machine learning.

Difficulty

  • Beginner

Curriculum (of the Learn AI Series):

Learn AI Series (#84) - Generative Images - Diffusion Models (Part 1)

Solutions to Episode #83 Exercises

Exercise 1: 3D convolution parameter analyzer.

import numpy as np


class Conv3DAnalyzer:
    """Compare parameter counts and FLOPs for
    full 3D vs factored (2+1)D convolutions."""

    def params_full_3d(self, in_ch, out_ch,
                       kt, kh, kw):
        """Full 3D conv: out * (in * kt * kh * kw)
        + out bias."""
        return out_ch * (in_ch * kt * kh * kw) + out_ch

    def params_factored(self, in_ch, out_ch,
                        kt, kh, kw):
        """(2+1)D: spatial 1xHxW then temporal
        Tx1x1, using out_ch as intermediate."""
        mid = out_ch
        spatial = mid * (in_ch * 1 * kh * kw) + mid
        temporal = out_ch * (mid * kt * 1 * 1) + out_ch
        return spatial + temporal

    def flops_full_3d(self, in_ch, out_ch,
                      kt, kh, kw,
                      t_in, h_in, w_in):
        """MACs for full 3D conv with stride=1,
        same padding."""
        return (out_ch * t_in * h_in * w_in
                * in_ch * kt * kh * kw)

    def compare(self, specs, input_dims):
        """Print comparison table."""
        print(f"{'Layer':<12} {'Full3D':>10} "
              f"{'Factor':>10} {'Savings':>8} "
              f"{'FLOPs(M)':>10}")
        print("-" * 54)
        for i, (spec, dims) in enumerate(
                zip(specs, input_dims)):
            ic, oc, kt, kh, kw = spec
            ti, hi, wi = dims
            p_full = self.params_full_3d(
                ic, oc, kt, kh, kw)
            p_fact = self.params_factored(
                ic, oc, kt, kh, kw)
            ratio = p_fact / p_full
            flops = self.flops_full_3d(
                ic, oc, kt, kh, kw,
                ti, hi, wi) / 1e6
            print(f"Layer {i + 1:<5} {p_full:>10,} "
                  f"{p_fact:>10,} {ratio:>7.2f}x "
                  f"{flops:>10,.1f}")


analyzer = Conv3DAnalyzer()
specs = [
    (3, 64, 3, 7, 7),
    (64, 128, 3, 3, 3),
    (128, 256, 3, 3, 3),
]
input_dims = [
    (16, 224, 224),
    (8, 56, 56),
    (4, 28, 28),
]
analyzer.compare(specs, input_dims)

Layer 1 shows the biggest savings from factorization because the 7x7 spatial kernel is large -- splitting it from the temporal kernel reduces the product of kernel dimensions substantially. Layer 2 and 3 both use 3x3x3 kernels where the savings are smaller but still meaningful. The FLOPs column shows how expensive the first layer is even with "only" 64 output channels -- the 224x224 spatial resolution at 16 temporal frames creates an enormous output volume.

Exercise 2: Temporal action proposal generator.

import numpy as np


class ActionProposalGenerator:
    """Generate and filter temporal action
    proposals from per-frame scores."""

    def generate_proposals(self, scores,
                           threshold=0.5,
                           min_duration=3):
        above = scores >= threshold
        segments = []
        start = None
        for i in range(len(above)):
            if above[i] and start is None:
                start = i
            elif not above[i] and start is not None:
                segments.append((start, i - 1))
                start = None
        if start is not None:
            segments.append((start, len(scores) - 1))

        # Merge segments separated by < 2 frames
        merged = []
        for seg in segments:
            if (merged
                    and seg[0] - merged[-1][1] < 2):
                merged[-1] = (merged[-1][0], seg[1])
            else:
                merged.append(seg)

        # Filter by duration and compute stats
        proposals = []
        for s, e in merged:
            dur = e - s + 1
            if dur >= min_duration:
                seg_scores = scores[s:e + 1]
                proposals.append({
                    "start": s,
                    "end": e,
                    "duration": dur,
                    "mean_score": float(
                        seg_scores.mean()),
                    "peak_score": float(
                        seg_scores.max()),
                })
        return proposals

    def nms_temporal(self, proposals,
                     overlap_threshold=0.3):
        if not proposals:
            return []
        props = sorted(
            proposals,
            key=lambda p: p["mean_score"],
            reverse=True)
        keep = []
        for p in props:
            overlaps = False
            for k in keep:
                o_start = max(p["start"], k["start"])
                o_end = min(p["end"], k["end"])
                overlap = max(0, o_end - o_start + 1)
                shorter = min(
                    p["duration"], k["duration"])
                if overlap / shorter > overlap_threshold:
                    overlaps = True
                    break
            if not overlaps:
                keep.append(p)
        return keep


rng = np.random.RandomState(42)
scores = np.zeros(200)

# 3 real action peaks
for s, e in [(20, 40), (80, 110), (150, 170)]:
    scores[s:e + 1] = rng.uniform(0.7, 0.95,
                                   e - s + 1)

# 2 brief noise spikes (should be filtered)
scores[55:58] = rng.uniform(0.6, 0.8, 3)
scores[130:133] = rng.uniform(0.6, 0.8, 3)

# Add baseline noise
scores += rng.randn(200) * 0.05
scores = np.clip(scores, 0, 1)

gen = ActionProposalGenerator()
raw = gen.generate_proposals(scores)
final = gen.nms_temporal(raw)

print(f"Raw proposals: {len(raw)}")
print(f"After NMS:     {len(final)}")
print(f"\n{'Start':>6} {'End':>6} {'Dur':>5} "
      f"{'Mean':>6} {'Peak':>6}")
print("-" * 33)
for p in final:
    print(f"{p['start']:>6} {p['end']:>6} "
          f"{p['duration']:>5} "
          f"{p['mean_score']:>6.3f} "
          f"{p['peak_score']:>6.3f}")

The min_duration=3 filter catches the 2-3 frame noise spikes. The three genuine action regions are 21, 31, and 21 frames long respectively, so they pass easily. Temporal NMS removes any overlapping proposals that might arise from the merging step -- in this synthetic case there's no overlap between the three actions, so NMS keeps all three.

Exercise 3: Video clip feature comparison tool.

import numpy as np


class VideoFeatureAnalyzer:
    """Compare clip similarity under different
    temporal pooling strategies."""

    def __init__(self, seed=42):
        self.rng = np.random.RandomState(seed)

    def make_clip(self, C=64, T=8, H=7, W=7):
        return self.rng.randn(C, T, H, W).astype(
            np.float32)

    def make_similar(self, anchor, noise=0.1):
        return anchor + self.rng.randn(
            *anchor.shape) * noise

    def temporal_pooling(self, features, method):
        # features: (C, T, H, W)
        if method == "average":
            return features.mean(axis=1)
        elif method == "max":
            return features.max(axis=1)
        elif method == "attention":
            T = features.shape[1]
            weights = self.rng.randn(T)
            weights = np.exp(weights) / np.exp(
                weights).sum()
            pooled = np.zeros_like(features[:, 0])
            for t in range(T):
                pooled += weights[t] * features[:, t]
            return pooled
        raise ValueError(f"Unknown method: {method}")

    def clip_similarity(self, feat1, feat2):
        f1 = feat1.flatten()
        f2 = feat2.flatten()
        dot = np.dot(f1, f2)
        n1 = np.linalg.norm(f1)
        n2 = np.linalg.norm(f2)
        return dot / max(n1 * n2, 1e-8)


analyzer = VideoFeatureAnalyzer()

# Generate anchors, similar, and different clips
anchors = [analyzer.make_clip() for _ in range(5)]
similars = [analyzer.make_similar(a)
            for a in anchors]
differents = [analyzer.make_clip() for _ in range(5)]

methods = ["average", "max", "attention"]
print(f"{'Method':<12} {'Sim-pair':>10} "
      f"{'Diff-pair':>10} {'Gap':>8}")
print("-" * 42)

for method in methods:
    sim_scores = []
    diff_scores = []
    for i in range(5):
        fa = analyzer.temporal_pooling(
            anchors[i], method)
        fs = analyzer.temporal_pooling(
            similars[i], method)
        fd = analyzer.temporal_pooling(
            differents[i], method)
        sim_scores.append(
            analyzer.clip_similarity(fa, fs))
        diff_scores.append(
            analyzer.clip_similarity(fa, fd))
    ms = np.mean(sim_scores)
    md = np.mean(diff_scores)
    print(f"{method:<12} {ms:>10.4f} "
          f"{md:>10.4f} {ms - md:>8.4f}")

All three methods produce higher similarity for similar pairs than different pairs, which is the basic sanity check. Average pooling tends to show the largest gap because averaging preserves the overall signal structure best when the perturbation is small and random -- the noise gets smoothed out. Max pooling can be noisier because small perturbations can change which timestep has the max value at each spatial position. Attention pooling falls somewhere in between, depending on the random weight initialization.

On to today's episode

Here we go! Eighty-three episodes in, and we're about to enter what is arguably the most impactful area of modern AI: generative image models. Everything we've done in the vision section so far has been about understanding images -- classifying them, detecting objects, segmenting pixels, reading text, analyzing video. All of that is perception. Now we flip the script. We're going to teach neural networks to create images from scratch.

Back in episode #55 we explored GANs (Generative Adversarial Networks) -- two networks locked in a minimax game, generator vs discriminator, competing until the generator produces convincing images. GANs were the dominant generative paradigm from 2014 to roughly 2020 and produced some spectacular results. But they were (and frankly still are) a nightmare to train: mode collapse where the generator only produces a few types of images, training instability where the loss oscillates wildly, and no meaningful training metric to tell you whether things are actually improving ;-)

Then diffusion models arrived and changed the entire landscape.

The idea is almost embarassingly simple when you hear it: take a clean image, gradually add random noise to it over many steps until it becomes pure static, then train a neural network to reverse that process -- to take a noisy image and predict what the noise looks like so you can subtract it. If you can learn to denoise step by step, you can start from pure random noise and iteratively denoise it into a realistic image. No adversarial training. No mode collapse. Just a single network learning to remove noise. That's it.

So why did nobody try this before GANs? People did, actually. The mathematical foundations of diffusion processes go back decades in thermodynamics and statistical physics. Sohl-Dickstein et al. formalized the idea for generative modeling in 2015. But it wasn't until the DDPM paper (Ho et al., 2020) -- "Denoising Diffusion Probabilistic Models" -- that the approach was made practical enough to compete with (and then completely surpass) GANs. DDPM is our focus today.

The forward process: destroying information on purpose

The forward diffusion process takes a clean image x_0 and adds Gaussian noise over T timesteps, producing increasingly noisy versions x_1, x_2, ..., x_T. At each step, a small amount of noise is added according to a variance schedule beta_1, beta_2, ..., beta_T:

x_0 (clean) -> x_1 (barely noisy) -> x_2 -> ... -> x_T (pure noise)

Mathematically, each step takes the previous image, scales it down slightly, and adds fresh Gaussian noise. The scaling factor at step t is sqrt(1 - beta_t) and the noise variance is beta_t. The scaling ensures the overall variance doesn't explode as you keep adding noise -- without it, the pixel values would grow unbounded.

A critical property of Gaussian distributions makes this process tractable: you don't need to simulate all T steps sequentially to get from x_0 to any intermediate x_t. You can jump directly in one operation. Define alpha_t = 1 - beta_t and alpha_bar_t = alpha_1 * alpha_2 * ... * alpha_t (the cumulative product). Then:

x_t = sqrt(alpha_bar_t) * x_0 + sqrt(1 - alpha_bar_t) * noise

This is a weighted combination of the original image and pure noise, where alpha_bar_t controls the mixing ratio. At t=0, alpha_bar is close to 1 (almost all signal). At t=T, alpha_bar is close to 0 (almost all noise). Let's implement this:

import torch
import numpy as np


class DiffusionSchedule:
    """Manages the noise schedule for the forward
    diffusion process."""

    def __init__(self, T=1000,
                 beta_start=1e-4, beta_end=0.02):
        self.T = T
        # Linear schedule: beta goes from small
        # to large over T steps
        self.betas = torch.linspace(
            beta_start, beta_end, T)
        self.alphas = 1.0 - self.betas
        # Cumulative product: how much original
        # signal survives at each timestep
        self.alpha_bars = torch.cumprod(
            self.alphas, dim=0)

    def add_noise(self, x0, t, noise=None):
        """Jump directly from clean x0 to noisy xt.
        No need to simulate intermediate steps."""
        if noise is None:
            noise = torch.randn_like(x0)
        alpha_bar = self.alpha_bars[t]
        alpha_bar = alpha_bar.view(-1, 1, 1, 1)

        xt = (torch.sqrt(alpha_bar) * x0
              + torch.sqrt(1 - alpha_bar) * noise)
        return xt


schedule = DiffusionSchedule(T=1000)

# Simulate a "clean" image (random for demo)
x0 = torch.randn(1, 3, 64, 64)

# At t=10: almost identical to original
x_early = schedule.add_noise(
    x0, t=torch.tensor([10]))
# At t=500: heavily corrupted
x_mid = schedule.add_noise(
    x0, t=torch.tensor([500]))
# At t=999: essentially pure random noise
x_late = schedule.add_noise(
    x0, t=torch.tensor([999]))

print(f"alpha_bar at t=10:  "
      f"{schedule.alpha_bars[10]:.4f}")
print(f"alpha_bar at t=500: "
      f"{schedule.alpha_bars[500]:.4f}")
print(f"alpha_bar at t=999: "
      f"{schedule.alpha_bars[999]:.4f}")

The alpha_bar values tell you exactly how much of the original image survives. At t=10, about 99% of the signal remains -- you'd barely notice the noise. At t=500, only about 5% survives. At t=999, the original image is completely buried under noise. The forward process is fixed -- no learnable parameters, just a deterministic schedule that progressively destroys information.

Having said that, the choice of schedule matters A LOT. The linear schedule above (beta going from 0.0001 to 0.02) was used in the original DDPM paper. Later work found that a cosine schedule (where alpha_bar follows a cosine curve from 1 to 0) gives better results because it destroys information more gradually in the early steps, where the model needs to learn fine details:

def cosine_schedule(T, s=0.008):
    """Cosine schedule from Nichol & Dhariwal."""
    steps = torch.arange(T + 1, dtype=torch.float32)
    f = torch.cos(
        (steps / T + s) / (1 + s) * np.pi / 2
    ) ** 2
    alpha_bars = f / f[0]
    betas = 1 - alpha_bars[1:] / alpha_bars[:-1]
    return torch.clamp(betas, max=0.999)


cos_betas = cosine_schedule(1000)
cos_alphas = 1.0 - cos_betas
cos_alpha_bars = torch.cumprod(cos_alphas, dim=0)

print(f"Cosine alpha_bar at t=10:  "
      f"{cos_alpha_bars[10]:.4f}")
print(f"Cosine alpha_bar at t=500: "
      f"{cos_alpha_bars[500]:.4f}")
print(f"Cosine alpha_bar at t=999: "
      f"{cos_alpha_bars[999]:.4f}")

With the cosine schedule, alpha_bar at t=500 is typically around 0.5 (50% signal remaining) instead of 0.05 with the linear schedule. This means the model spends more of its capacity learning to handle moderately noisy images -- which is where the important structural decisions happen -- rather than wasting capacity on the trivially easy (barely noisy) and trivially hard (pure noise) extremes.

The reverse process: learning to denoise

The forward process is the easy part -- no learning involved. The reverse process is where the actual ML happens. We train a neural network to predict the noise that was added at each step:

Given a noisy image x_t and the timestep t, predict the noise that was added.

Why predict the noise rather than the clean image directly? Both are mathematically equivalent -- if you know the noise and the mixing formula, you can recover x_0, and vice versa. But noise prediction tends to produce more stable training empirically. The loss landscape is smoother because the target (Gaussian noise) has consistent statistics regardless of the input image, whereas clean image targets vary wildly between training examples.

def training_step(model, x0, schedule):
    """One training step for a diffusion model.
    Surprisingly simple."""
    batch_size = x0.shape[0]

    # 1. Sample random timesteps
    t = torch.randint(0, schedule.T, (batch_size,))

    # 2. Sample random noise
    noise = torch.randn_like(x0)

    # 3. Create noisy images
    xt = schedule.add_noise(x0, t, noise)

    # 4. Network predicts the noise
    predicted_noise = model(xt, t)

    # 5. Loss = how wrong was the prediction?
    loss = torch.nn.functional.mse_loss(
        predicted_noise, noise)
    return loss

That's the ENTIRE training loop (a part from the optimizer step, obviously). Sample a random timestep, add noise, predict the noise, compute MSE. No adversarial loss, no feature matching, no perceptual loss, no progressive growing. Just "can you tell me what noise I added?" The simplicity is genuinely remarkable when you compare it to the elaborate training procedures GANs required -- and it works better ;-)

The denoising U-Net

The architecture used to predict noise is a U-Net -- the same encoder-decoder structure with skip connections we saw in image segmentation (episode #80), but adapted for diffusion. The key adaptation: the model needs to know which timestep it's denoising, because removing noise at t=900 (mostly noise, need to recover large-scale structure) requires a very different operation than at t=50 (mostly clean, need to refine fine details like textures and edges).

Timestep information is injected via sinusoidal positional embeddings -- borrowed from transformers (episode #52) -- that are projected through an MLP and added to the intermediate features at every residual block:

import torch
import torch.nn as nn
import math


class TimestepEmbedding(nn.Module):
    """Encode the integer timestep t into a
    continuous vector using sinusoidal
    frequencies, then project through an MLP."""

    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        self.mlp = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.SiLU(),
            nn.Linear(dim * 4, dim),
        )

    def forward(self, t):
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(
            torch.arange(half_dim,
                         device=t.device) * -emb)
        emb = (t.float().unsqueeze(1)
               * emb.unsqueeze(0))
        emb = torch.cat(
            [torch.sin(emb), torch.cos(emb)],
            dim=1)
        return self.mlp(emb)


class ResBlock(nn.Module):
    """Residual block with timestep conditioning.
    The timestep embedding is ADDED to the
    intermediate features after the first conv."""

    def __init__(self, in_ch, out_ch, time_dim):
        super().__init__()
        self.conv1 = nn.Sequential(
            nn.GroupNorm(8, in_ch),
            nn.SiLU(),
            nn.Conv2d(in_ch, out_ch, 3, padding=1))
        self.time_proj = nn.Linear(time_dim, out_ch)
        self.conv2 = nn.Sequential(
            nn.GroupNorm(8, out_ch),
            nn.SiLU(),
            nn.Conv2d(out_ch, out_ch, 3, padding=1))
        self.skip = (nn.Conv2d(in_ch, out_ch, 1)
                     if in_ch != out_ch
                     else nn.Identity())

    def forward(self, x, t_emb):
        h = self.conv1(x)
        # Add timestep info -- broadcast across
        # spatial dimensions
        h = h + self.time_proj(t_emb).unsqueeze(
            -1).unsqueeze(-1)
        h = self.conv2(h)
        return h + self.skip(x)


class SimpleUNet(nn.Module):
    """Minimal diffusion U-Net for noise
    prediction."""

    def __init__(self, in_ch=3, base_ch=64,
                 time_dim=256):
        super().__init__()
        self.time_embed = TimestepEmbedding(time_dim)

        # Encoder (downsampling path)
        self.enc1 = ResBlock(
            in_ch, base_ch, time_dim)
        self.enc2 = ResBlock(
            base_ch, base_ch * 2, time_dim)
        self.down1 = nn.Conv2d(
            base_ch, base_ch, 3,
            stride=2, padding=1)
        self.down2 = nn.Conv2d(
            base_ch * 2, base_ch * 2, 3,
            stride=2, padding=1)

        # Bottleneck
        self.mid = ResBlock(
            base_ch * 2, base_ch * 2, time_dim)

        # Decoder (upsampling path)
        self.up2 = nn.ConvTranspose2d(
            base_ch * 2, base_ch * 2, 4,
            stride=2, padding=1)
        self.dec2 = ResBlock(
            base_ch * 4, base_ch, time_dim)
        self.up1 = nn.ConvTranspose2d(
            base_ch, base_ch, 4,
            stride=2, padding=1)
        self.dec1 = ResBlock(
            base_ch * 2, base_ch, time_dim)

        self.out = nn.Conv2d(base_ch, in_ch, 1)

    def forward(self, x, t):
        t_emb = self.time_embed(t)

        # Down
        e1 = self.enc1(x, t_emb)
        e2 = self.enc2(self.down1(e1), t_emb)

        # Bottom
        m = self.mid(self.down2(e2), t_emb)

        # Up with skip connections
        d2 = self.dec2(
            torch.cat([self.up2(m), e2], dim=1),
            t_emb)
        d1 = self.dec1(
            torch.cat([self.up1(d2), e1], dim=1),
            t_emb)

        return self.out(d1)


model = SimpleUNet()
x = torch.randn(2, 3, 64, 64)
t = torch.randint(0, 1000, (2,))
pred_noise = model(x, t)
print(f"Input shape:  {x.shape}")
print(f"Output shape: {pred_noise.shape}")
# Both (2, 3, 64, 64) -- same shape

Notice how the timestep embedding is injected at every ResBlock. This allows the model to behave very differently depending on how noisy the input is. At high t values (early in the reverse process, lots of noise), the model focuses on large-scale structure -- roughly where objects should be, what the overall composition looks like. At low t values (late in the reverse process, mostly clean), it refines fine details -- textures, edges, precise colors. The same network handles both tasks, guided by the timestep conditioning.

The skip connections between encoder and decoder are essential (just as they were for segmentation in episode #80). Without them, the decoder would have to reconstruct spatial detail from the compressed bottleneck alone. With skip connections, the decoder gets direct access to the high-resolution features from the encoder, allowing it to focus on refinement rather than reconstruction.

GroupNorm instead of BatchNorm is standard in diffusion U-Nets. BatchNorm (episode #40) computes statistics across the batch dimension, which becomes problematic when batch sizes are small (common with high-resolution images that eat GPU memory). GroupNorm computes statistics within groups of channels for each sample independently -- no batch dependency, stable regardless of batch size.

Sampling: from noise to image

Once the model is trained, generating an image means starting from pure random noise and iteratively denoising through all T timesteps:

@torch.no_grad()
def sample_ddpm(model, schedule, shape,
                device='cpu'):
    """Generate images by iterative denoising.
    Walks backwards from t=T-1 to t=0."""
    x = torch.randn(shape, device=device)

    for t in reversed(range(schedule.T)):
        t_batch = torch.full(
            (shape[0],), t,
            device=device, dtype=torch.long)

        # Predict noise in the current image
        pred_noise = model(x, t_batch)

        # DDPM reverse step
        alpha = schedule.alphas[t]
        alpha_bar = schedule.alpha_bars[t]
        beta = schedule.betas[t]

        # Compute the mean of p(x_{t-1} | x_t)
        coeff = beta / torch.sqrt(1 - alpha_bar)
        mean = (1 / torch.sqrt(alpha)) * (
            x - coeff * pred_noise)

        # Add noise (except at the very last step)
        if t > 0:
            noise = torch.randn_like(x)
            sigma = torch.sqrt(beta)
            x = mean + sigma * noise
        else:
            x = mean

    return x


# Generate 4 images (untrained model = garbage,
# but the process is correct)
images = sample_ddpm(
    model, schedule, shape=(4, 3, 64, 64))
print(f"Generated: {images.shape}")

This is 1000 forward passes through the U-Net to generate a single batch of images. That's the main drawback of diffusion models compared to GANs (which generate in a single forward pass). For a reasonably-sized U-Net on a modern GPU, each forward pass takes maybe 10-50ms, so generating one image takes 10-50 seconds. That's... not great for interactive applications.

The speed problem has driven a huge amount of follow-up research. DDIM (Denoising Diffusion Implicit Models, Song et al., 2021) showed you can skip steps -- instead of denoising through all 1000 timesteps, you can take larger jumps (say, 50 or 100 steps) with a deterministic sampling formula, producing similar quality in a fraction of the time. We'll explore DDIM and other acceleration techniques in the next episode.

Putting it all together: a complete training script

Let's combine everything into a complete (if minimal) training loop. We'll use MNIST digits for simplicity -- the same dataset we've seen in earlier episodes -- so you can actually run this and see results:

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms


def train_diffusion(epochs=5, batch_size=64,
                    T=200, lr=1e-3):
    """Train a small diffusion model on MNIST."""
    device = torch.device(
        "cuda" if torch.cuda.is_available()
        else "cpu")

    # MNIST: 28x28 grayscale -> scale to [-1, 1]
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,)),
    ])
    dataset = datasets.MNIST(
        "./data", train=True,
        download=True, transform=transform)
    loader = DataLoader(
        dataset, batch_size=batch_size,
        shuffle=True, drop_last=True)

    # Smaller T for MNIST (28x28 needs less
    # steps than 256x256)
    schedule = DiffusionSchedule(
        T=T, beta_start=1e-4, beta_end=0.02)
    schedule.betas = schedule.betas.to(device)
    schedule.alphas = schedule.alphas.to(device)
    schedule.alpha_bars = (
        schedule.alpha_bars.to(device))

    model = SimpleUNet(
        in_ch=1, base_ch=32, time_dim=128
    ).to(device)
    optimizer = torch.optim.Adam(
        model.parameters(), lr=lr)

    for epoch in range(epochs):
        total_loss = 0
        n_batches = 0
        for images, _ in loader:
            images = images.to(device)
            loss = training_step(
                model, images, schedule)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            n_batches += 1

        avg = total_loss / n_batches
        print(f"Epoch {epoch + 1}/{epochs}, "
              f"loss: {avg:.4f}")

    return model, schedule


# Uncomment to actually train:
# model, schedule = train_diffusion(epochs=5)
# samples = sample_ddpm(
#     model, schedule,
#     shape=(16, 1, 28, 28),
#     device=next(model.parameters()).device)
# print(f"Samples: {samples.shape}")

On MNIST with just 5 epochs you'll start seeing vaguely digit-shaped blobs. With 20-50 epochs the generated digits become recognizable. It won't match state-of-the-art quality (which requires much larger models, more training time, and various architectural improvements) but it demonstrates the core mechanism. The loss decreases smoothly -- no oscillation, no collapse, just steady improvement. Compare that to GAN training where you're constantly watching two losses hoping neither one runs away from the other.

Why diffusion models won

Several properties make diffusion models superior to GANs for image generation:

Training stability: the loss is plain MSE. No adversarial dynamics, no balancing two networks against each other, no mode collapse. The loss decreases monotonically and (unlike GAN discriminator/generator losses) actually correlates with output quality. You can look at the loss curve and tell whether training is going well -- a luxury you simply don't have with GANs.

Mode coverage: GANs can "forget" parts of the data distribution -- the generator finds a few modes that fool the discriminator and ignores everything else (this is mode collapse, and it's the single most frustrating failure mode of GANs). Diffusion models are trained on the entire dataset uniformly -- every image contributes equally to the noise prediction loss, so the model learns the full diversity of the training data.

Scalability: the training objective is clean enough that scaling up (bigger model, more data, more compute) reliably produces better results. This is the same scaling property that made transformers successful for language (episodes #52-53) -- a simple, well-defined objective that rewards scale without hitting diminishing returns too quickly.

Controllability: the iterative sampling process provides natural hooks for control. You can guide the denoising with text prompts, reference images, class labels, or other conditioning signals at each step. This turns out to be far easier to implement and more reliable than controlling GAN generation, which is partly why text-to-image systems (Stable Diffusion, DALL-E, Midjourney) are all built on diffusion rather than GANs.

The cost is speed: ~1000 network evaluations per sample vs 1 for GANs. There are ways to bring that number down dramatically, and the next episode in this series will cover exactly that -- latent diffusion (moving the process into a compressed latent space), advanced samplers like DDIM that skip timesteps, and how text conditioning turns a noise predictor into a full text-to-image system.

Samengevat

  • The forward diffusion process adds Gaussian noise over T timesteps, turning any image into pure noise; the alpha_bar shortcut lets you jump directly to any timestep without simulating intermediate steps;
  • the reverse process trains a neural network to predict the noise added at each step; the training loss is simple MSE between predicted and true noise -- no adversarial loss, no complex training dynamics;
  • the denoising network is a U-Net with timestep conditioning injected via sinusoidal embeddings at every residual block, using GroupNorm instead of BatchNorm for stability with small batches;
  • sampling starts from pure Gaussian noise and iteratively denoises over T steps, each time predicting and removing noise; DDPM uses 1000 steps by default, which is slow but produces high-quality results;
  • diffusion models offer stable training, full mode coverage, and smooth scaling compared to GANs, at the cost of slow sampling (1000 forward passes per image vs 1 for GANs);
  • the DDPM paper (Ho et al., 2020) established these foundations and kicked off the generative revolution that led to Stable Diffusion, DALL-E, and Midjourney;
  • the variance schedule (linear vs cosine) controls how information is destroyed during the forward process -- cosine schedules spend more capacity on moderately noisy images where the important structural decisions happen.

We've now covered the foundational mechanics of diffusion models: how noise is added, how a network learns to remove it, and how iterative sampling generates images from scratch. The core ideas are elegant and (for deep learning) surprisingly simple. But 1000 sampling steps is too slow for practical use, and we haven't yet seen how to control what the model generates. The connection between this noise-prediction framework and systems that can generate a photorealistic image from a text description like "a cat wearing a top hat on the moon" involves several more pieces -- latent spaces, text encoders, classifier-free guidance -- that build directly on what we covered today.

Exercises

Exercise 1: Build a noise schedule comparison tool. Create a class ScheduleComparator that: (a) implements both linear_schedule(T, beta_start, beta_end) and cosine_schedule(T, s=0.008) returning the full alpha_bar sequence for each, (b) implements signal_to_noise_ratio(alpha_bar_t) that computes 10 * log10(alpha_bar_t / (1 - alpha_bar_t)) for any timestep (this is the SNR in decibels -- positive means more signal than noise, negative means more noise than signal), (c) for T=1000, computes and prints a comparison table at timesteps [0, 100, 250, 500, 750, 900, 999] showing alpha_bar and SNR(dB) for both schedules side by side, (d) identifies the timestep where each schedule crosses SNR=0 dB (equal signal and noise) and prints the crossing point for each. Verify that the cosine schedule crosses later (preserves signal longer) than the linear schedule.

Exercise 2: Implement a diffusion noise predictor evaluator. Create a class DiffusionEvaluator that: (a) generates synthetic test data: 50 random "clean images" of shape (1, 32, 32), adds noise at known timesteps using the linear schedule with T=200, (b) implements three "noise predictor" baselines: zero_predictor (always predicts zero noise -- equivalent to saying "the image is clean"), random_predictor (predicts random Gaussian noise), and perfect_predictor (returns the exact noise that was added), (c) implements evaluate_predictor(predictor, test_data) that computes MSE between predicted and true noise averaged across all test samples, (d) groups the test samples into 4 bins by timestep range (t=0-49, 50-99, 100-149, 150-199) and reports per-bin MSE for each predictor, (e) prints a table showing that the zero predictor gets low MSE at early timesteps (where little noise was added) but high MSE at late timesteps, while the random predictor has roughly constant MSE everywhere. Verify that the perfect predictor gets MSE=0.0 across all bins.

Exercise 3: Build a U-Net architecture analyzer. Create a class UNetAnalyzer that: (a) takes a list of channel configurations for the encoder path (e.g. [64, 128, 256]) and the input channels (e.g. 3), (b) computes the total parameter count for each encoder block (Conv2d with 3x3 kernel + bias), each decoder block (ConvTranspose2d 4x4 for upsampling + Conv2d 3x3 for the skip-concatenated input), and the bottleneck (Conv2d 3x3), (c) computes the feature map size at each level assuming input resolution 64x64 with stride-2 downsampling at each level, (d) computes the memory footprint at each level in bytes (assuming float32) for a batch size of 1, (e) prints a table showing: level name, channels, spatial resolution, parameters, and memory per level. Add a "total" row summing parameters and peak memory. Test with channel configs [64, 128, 256] and [32, 64, 128, 256] and compare total parameter counts.

That's a wrap. Thanks for your time!

@scipio



0
0
0.000
0 comments