Learn AI Series (#85) - Generative Images - Diffusion Models (Part 2)

avatar

Learn AI Series (#85) - Generative Images - Diffusion Models (Part 2)

variant-b-12-green.png

What will I learn

  • You will learn latent diffusion: operating in compressed latent space instead of pixel space;
  • the Stable Diffusion architecture: VAE encoder, U-Net denoiser, text encoder;
  • text conditioning: how prompts guide image generation via cross-attention;
  • CLIP: the model that connects text and images;
  • classifier-free guidance: the trick that makes prompts actually work;
  • DDIM and advanced samplers: generating in 20-50 steps instead of 1000;
  • ControlNet: adding structural guidance like edge maps and pose skeletons;
  • the evolution from Stable Diffusion 1.5 to SDXL to SD3.

Requirements

  • A working modern computer running macOS, Windows or Ubuntu;
  • An installed Python 3(.11+) distribution;
  • The ambition to learn AI and machine learning.

Difficulty

  • Beginner

Curriculum (of the Learn AI Series):

Learn AI Series (#85) - Generative Images - Diffusion Models (Part 2)

Solutions to Episode #84 Exercises

Exercise 1: Noise schedule comparison tool.

import torch
import numpy as np


class ScheduleComparator:
    """Compare linear and cosine noise schedules
    for diffusion models."""

    def linear_schedule(self, T, beta_start=1e-4,
                        beta_end=0.02):
        betas = torch.linspace(beta_start, beta_end, T)
        alphas = 1.0 - betas
        alpha_bars = torch.cumprod(alphas, dim=0)
        return alpha_bars

    def cosine_schedule(self, T, s=0.008):
        steps = torch.arange(
            T + 1, dtype=torch.float32)
        f = torch.cos(
            (steps / T + s) / (1 + s)
            * np.pi / 2) ** 2
        alpha_bars = f / f[0]
        # Drop the extra element to get T values
        return alpha_bars[1:]

    def signal_to_noise_ratio(self, alpha_bar_t):
        """SNR in decibels. Positive = more signal,
        negative = more noise."""
        ratio = alpha_bar_t / (1.0 - alpha_bar_t)
        return 10.0 * torch.log10(ratio)

    def find_zero_crossing(self, alpha_bars):
        """Find timestep where SNR crosses 0 dB
        (alpha_bar = 0.5)."""
        for t in range(len(alpha_bars)):
            if alpha_bars[t] <= 0.5:
                return t
        return len(alpha_bars) - 1

    def compare(self, T=1000):
        lin_ab = self.linear_schedule(T)
        cos_ab = self.cosine_schedule(T)

        checkpoints = [0, 100, 250, 500,
                       750, 900, 999]

        print(f"{'t':>5}  {'Lin abar':>9} "
              f"{'Lin SNR':>9}  {'Cos abar':>9} "
              f"{'Cos SNR':>9}")
        print("-" * 48)

        for t in checkpoints:
            la = lin_ab[t].item()
            ca = cos_ab[t].item()
            ls = self.signal_to_noise_ratio(
                lin_ab[t]).item()
            cs = self.signal_to_noise_ratio(
                cos_ab[t]).item()
            print(f"{t:>5}  {la:>9.4f} {ls:>9.2f}  "
                  f"{ca:>9.4f} {cs:>9.2f}")

        lin_cross = self.find_zero_crossing(lin_ab)
        cos_cross = self.find_zero_crossing(cos_ab)
        print(f"\nSNR=0 dB crossing:")
        print(f"  Linear:  t={lin_cross}")
        print(f"  Cosine:  t={cos_cross}")
        print(f"  Cosine preserves signal "
              f"{cos_cross - lin_cross} steps longer")


comp = ScheduleComparator()
comp.compare()

The comparison shows the cosine schedule preserving signal much longer. The linear schedule's alpha_bar drops to ~0.05 at t=500, meaning 95% noise and only 5% signal. The cosine schedule is still around 0.5 at t=500 -- equal parts signal and noise. The zero-crossing confirms this: the linear schedule hits SNR=0 somewhere around t=250, while the cosine schedule doesn't cross until roughly t=500. This means the cosine schedule gives the model far more timesteps in the "interesting" regime where both signal and noise are present -- and that's exactly where the model needs to make the hardest decisions about image structure.

Exercise 2: Diffusion noise predictor evaluator.

import torch
import numpy as np


class DiffusionEvaluator:
    """Evaluate different noise prediction
    baselines for diffusion models."""

    def __init__(self, T=200, n_samples=50,
                 img_shape=(1, 32, 32), seed=42):
        torch.manual_seed(seed)
        self.T = T
        betas = torch.linspace(1e-4, 0.02, T)
        alphas = 1.0 - betas
        self.alpha_bars = torch.cumprod(
            alphas, dim=0)

        # Generate test data
        self.clean = torch.randn(
            n_samples, *img_shape)
        self.timesteps = torch.randint(
            0, T, (n_samples,))
        self.noise = torch.randn_like(self.clean)

        # Create noisy versions
        ab = self.alpha_bars[self.timesteps]
        ab = ab.view(-1, 1, 1, 1)
        self.noisy = (torch.sqrt(ab) * self.clean
                      + torch.sqrt(1 - ab)
                      * self.noise)

    def zero_predictor(self, xt, t):
        return torch.zeros_like(xt)

    def random_predictor(self, xt, t):
        return torch.randn_like(xt)

    def perfect_predictor(self, xt, t):
        # Cheating: return the actual noise
        idx = (self.timesteps == t).nonzero(
            as_tuple=True)[0]
        if len(idx) > 0:
            return self.noise[idx[0]].unsqueeze(0)
        return torch.zeros_like(xt)

    def evaluate_predictor(self, predictor, name):
        bins = [(0, 49), (50, 99),
                (100, 149), (150, 199)]
        results = {}

        for lo, hi in bins:
            mask = ((self.timesteps >= lo)
                    & (self.timesteps <= hi))
            idxs = mask.nonzero(as_tuple=True)[0]
            if len(idxs) == 0:
                results[(lo, hi)] = float('nan')
                continue
            mse_total = 0.0
            for i in idxs:
                pred = predictor(
                    self.noisy[i].unsqueeze(0),
                    self.timesteps[i])
                true = self.noise[i].unsqueeze(0)
                mse_total += torch.nn.functional.mse_loss(
                    pred, true).item()
            results[(lo, hi)] = mse_total / len(idxs)
        return results

    def run_evaluation(self):
        print(f"{'Predictor':<12} "
              + "  ".join(f"{'t=' + str(lo) + '-' + str(hi):>10}"
                          for lo, hi in [(0, 49), (50, 99),
                                          (100, 149), (150, 199)]))
        print("-" * 58)

        for name, pred in [
            ("zero", self.zero_predictor),
            ("random", self.random_predictor),
            ("perfect", self.perfect_predictor),
        ]:
            results = self.evaluate_predictor(
                pred, name)
            vals = "  ".join(
                f"{v:>10.4f}"
                for v in results.values())
            print(f"{name:<12} {vals}")


evaluator = DiffusionEvaluator()
evaluator.run_evaluation()

The zero predictor performs well at early timesteps (low t) because very little noise was added -- predicting "no noise" is nearly correct. But at high timesteps the noise is massive and predicting zero gives huge MSE. The random predictor has roughly constant MSE across all bins because random Gaussian noise has the same expected squared magnitude regardless of the actual noise pattern. The perfect predictor gets 0.0000 across every bin -- a sanity check confirming the test setup is correct.

Exercise 3: U-Net architecture analyzer.

class UNetAnalyzer:
    """Analyze parameter counts and memory usage
    for different U-Net configurations."""

    def __init__(self, in_ch, channel_config):
        self.in_ch = in_ch
        self.config = channel_config

    def conv_params(self, c_in, c_out,
                    kernel=3):
        """Conv2d parameter count."""
        return c_out * (c_in * kernel * kernel) + c_out

    def analyze(self, input_res=64):
        levels = []
        res = input_res

        # Encoder
        prev_ch = self.in_ch
        for i, ch in enumerate(self.config):
            params = self.conv_params(prev_ch, ch, 3)
            mem = ch * res * res * 4  # float32
            levels.append({
                "name": f"enc_{i + 1}",
                "channels": ch,
                "resolution": res,
                "params": params,
                "memory_bytes": mem,
            })
            prev_ch = ch
            res = res // 2

        # Bottleneck
        bot_ch = self.config[-1]
        params = self.conv_params(bot_ch, bot_ch, 3)
        mem = bot_ch * res * res * 4
        levels.append({
            "name": "bottleneck",
            "channels": bot_ch,
            "resolution": res,
            "params": params,
            "memory_bytes": mem,
        })

        # Decoder (reverse order)
        for i, ch in enumerate(
                reversed(self.config)):
            res = res * 2
            skip_ch = ch
            in_ch_dec = (self.config[-(i + 1)]
                         + skip_ch)
            out_ch = (self.config[-(i + 2)]
                      if i + 2 <= len(self.config)
                      else self.in_ch)
            up_params = self.conv_params(
                self.config[-(i + 1)], out_ch, 4)
            conv_params = self.conv_params(
                in_ch_dec, out_ch, 3)
            total_p = up_params + conv_params
            mem = out_ch * res * res * 4
            levels.append({
                "name": f"dec_{len(self.config) - i}",
                "channels": out_ch,
                "resolution": res,
                "params": total_p,
                "memory_bytes": mem,
            })

        return levels

    def print_analysis(self, input_res=64):
        levels = self.analyze(input_res)
        total_params = 0
        peak_mem = 0

        print(f"{'Level':<12} {'Ch':>5} {'Res':>5} "
              f"{'Params':>10} {'Memory':>10}")
        print("-" * 46)

        for lv in levels:
            total_params += lv["params"]
            peak_mem = max(
                peak_mem, lv["memory_bytes"])
            mem_kb = lv["memory_bytes"] / 1024
            print(
                f"{lv['name']:<12} "
                f"{lv['channels']:>5} "
                f"{lv['resolution']:>5} "
                f"{lv['params']:>10,} "
                f"{mem_kb:>8.1f} KB")

        print("-" * 46)
        print(f"{'TOTAL':<12} {'':>5} {'':>5} "
              f"{total_params:>10,} "
              f"{peak_mem / 1024:>8.1f} KB pk")


print("=== Config: [64, 128, 256] ===")
a1 = UNetAnalyzer(3, [64, 128, 256])
a1.print_analysis()

print("\n=== Config: [32, 64, 128, 256] ===")
a2 = UNetAnalyzer(3, [32, 64, 128, 256])
a2.print_analysis()

The [32, 64, 128, 256] config has more levels but starts with smaller channel counts. Adding a level doubles the number of downsampling/upsampling stages but keeps the bottleneck the same size. In practice, the four-level config often has fewer total parameters in the early layers (because 32 is smaller than 64) but similar parameter counts in the deeper layers. The memory profile also shifts -- the first level in [64, ...] at 64x64 resolution uses more memory per feature map than the first level in [32, ...] at the same resolution. Both configurations spend most of their parameters in the decoder, where skip connections double the input channel count before the convolution.

On to today's episode

Welcome back! In episode #84 we built a diffusion model that denoises in pixel space -- a 512x512 image means the U-Net processes a 512x512x3 tensor at every single step. That works fine for something like 64x64 MNIST digits, but becomes prohibitively expensive at higher resolutions. A thousand forward passes through a U-Net operating at 512x512? Your GPU memory consumption goes through the roof, and the wallclock time becomes... impractical, to put it diplomatically ;-)

The researchers at LMU Munich (Robin Rombach, Andreas Blattmann, and others) realized something important: most of the perceptually relevant information in an image lives in a much lower-dimensional space. You don't need to denoise every single pixel -- you can denoise a compact representation of the image instead. That insight gave birth to latent diffusion, and latent diffusion gave birth to Stable Diffusion.

This episode covers how we get from the "slow but mathematically beautiful" DDPM framework of episode #84 to the actual text-to-image systems people use every day. We'll walk through the full architecture piece by piece.

Latent diffusion: compressing first, denoising second

The core idea behind latent diffusion is deceptively simple: instead of running the diffusion process in pixel space (which for a 512x512 RGB image means operating on a 512x512x3 = 786,432-dimensional space), first compress the image into a much smaller latent representation using a separately trained VAE (Variational Autoencoder), then run the entire diffusion process in that compressed latent space.

A well-trained VAE can compress a 512x512x3 image down to a 64x64x4 latent -- that's a 48x reduction in dimensionality. The diffusion model never sees raw pixels. It only sees (and generates) latent vectors, which the VAE decoder then expands back to full resolution:

Text prompt --> Text Encoder --> text embeddings
                                      |
                                  (cross-attention)
                                      |
Random noise --> [Denoise in latent space] --> clean latent --> VAE Decoder --> image
                      U-Net x N steps

The VAE is trained once, on a large image dataset, then completely frozen. Here's a conceptual implementation -- real ones like the one in Stable Diffusion are much deeper with residual blocks and attention layers, but the principle is identical:

import torch
import torch.nn as nn


class SimpleVAE(nn.Module):
    """Conceptual VAE for latent diffusion.
    Real ones are much deeper with residual
    blocks and attention."""

    def __init__(self, in_ch=3, latent_ch=4,
                 base_ch=128):
        super().__init__()
        # Encoder: (3, 512, 512) -> (4, 64, 64)
        self.encoder = nn.Sequential(
            nn.Conv2d(in_ch, base_ch, 3,
                      stride=2, padding=1),
            nn.SiLU(),
            nn.Conv2d(base_ch, base_ch * 2, 3,
                      stride=2, padding=1),
            nn.SiLU(),
            nn.Conv2d(base_ch * 2, base_ch * 4,
                      3, stride=2, padding=1),
            nn.SiLU(),
        )
        # mean + logvar (double channels)
        self.to_latent = nn.Conv2d(
            base_ch * 4, latent_ch * 2, 1)

        # Decoder: (4, 64, 64) -> (3, 512, 512)
        self.from_latent = nn.Conv2d(
            latent_ch, base_ch * 4, 1)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(base_ch * 4,
                               base_ch * 2, 4,
                               stride=2, padding=1),
            nn.SiLU(),
            nn.ConvTranspose2d(base_ch * 2,
                               base_ch, 4,
                               stride=2, padding=1),
            nn.SiLU(),
            nn.ConvTranspose2d(base_ch, in_ch, 4,
                               stride=2, padding=1),
            nn.Tanh(),
        )

    def encode(self, x):
        h = self.encoder(x)
        params = self.to_latent(h)
        mean, logvar = params.chunk(2, dim=1)
        # Reparameterization trick
        std = torch.exp(0.5 * logvar)
        z = mean + std * torch.randn_like(std)
        return z

    def decode(self, z):
        return self.decoder(self.from_latent(z))


vae = SimpleVAE()
x = torch.randn(1, 3, 512, 512)
z = vae.encode(x)
recon = vae.decode(z)
print(f"Input:   {x.shape}")
print(f"Latent:  {z.shape}")
print(f"Recon:   {recon.shape}")

# Compression ratio
pixels = 512 * 512 * 3
latent_size = 64 * 64 * 4
print(f"Compression: {pixels / latent_size:.0f}x")

The reparameterization trick (episode #55 touched on this for GANs) is what makes the VAE trainable with backpropagation. Instead of sampling directly from a distribution (which is not differentiable), you sample from a standard normal and then shift/scale using the predicted mean and standard deviation. The gradient flows through the mean and std parameters cleanly.

Having said that, the VAE's reconstruction quality is not perfect. There's always some information loss -- subtle textures might get slightly smoothed, very fine details might shift. In practice the quality is excellent for photographic content but you can sometimes spot VAE artifacts if you zoom in very carefully. Stability AI's VAE for Stable Diffusion was trained with both reconstruction loss and a perceptual loss (comparing VGG features rather than raw pixels), which significantly improves visual fidelity at the cost of more complex training.

Text conditioning via cross-attention

So we have the diffusion process running efficiently in latent space. But how does a text prompt like "a sunset over the ocean, oil painting" actually guide the image generation?

Through cross-attention in the U-Net, exactly as we covered in episode #51. Remember: in cross-attention, the queries come from one source and the keys/values come from another. In latent diffusion, the queries come from the U-Net's intermediate spatial features (representing what the image currently "looks like" at this denoising step) and the keys/values come from the text embeddings (representing what the image should look like):

class CrossAttentionBlock(nn.Module):
    """Cross-attention: image features attend
    to text embeddings."""

    def __init__(self, dim, context_dim=768,
                 num_heads=8):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.to_q = nn.Linear(dim, dim)
        self.to_k = nn.Linear(context_dim, dim)
        self.to_v = nn.Linear(context_dim, dim)
        self.to_out = nn.Linear(dim, dim)
        self.num_heads = num_heads
        self.scale = (dim // num_heads) ** -0.5

    def forward(self, x, context):
        # x: (batch, seq_len, dim)
        #    -- flattened spatial features
        # context: (batch, text_len, context_dim)
        #    -- text embeddings from CLIP
        b, n, d = x.shape
        h = self.num_heads

        x = self.norm(x)
        q = self.to_q(x).view(
            b, n, h, d // h).transpose(1, 2)
        k = self.to_k(context).view(
            b, -1, h, d // h).transpose(1, 2)
        v = self.to_v(context).view(
            b, -1, h, d // h).transpose(1, 2)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        out = (attn @ v).transpose(
            1, 2).reshape(b, n, d)
        return self.to_out(out)

At each U-Net block, the spatial features are reshaped into a sequence, cross-attended with the text embeddings, then reshaped back into a spatial grid. The attention weights form a kind of "soft routing" -- different spatial positions in the image attend to different words in the prompt. When the model is denoising the top portion of the image, it might attend heavily to "sunset" and "sky". When refining the bottom, it attends more to "ocean" and "waves". The model learns these correspondences entirely from the training data -- nobody hard-codes where "sunset" should appear ;-)

CLIP: connecting words to images

The text encoder used in Stable Diffusion is CLIP (Contrastive Language-Image Pre-training, from OpenAI). CLIP was trained on 400 million image-text pairs with a contrastive objective: make the embeddings of matching image-text pairs similar, and non-matching pairs dissimilar. This produces a text encoder that understands visual concepts at a deep semantic level -- not just keyword matching, but actual understanding of composition, style, atmosphere, and spatial relationships:

from transformers import (CLIPTokenizer,
                          CLIPTextModel)

# CLIP text encoder (frozen during
# diffusion training)
tokenizer = CLIPTokenizer.from_pretrained(
    "openai/clip-vit-large-patch14")
text_encoder = CLIPTextModel.from_pretrained(
    "openai/clip-vit-large-patch14")

prompt = "a sunset over the ocean, oil painting"
tokens = tokenizer(
    prompt, return_tensors="pt",
    padding="max_length",
    max_length=77, truncation=True)

with torch.no_grad():
    text_emb = text_encoder(
        **tokens).last_hidden_state

print(f"Text embeddings: {text_emb.shape}")
# (1, 77, 768) -- 77 token positions,
# 768 dimensions each

The 77-token limit is worth noting -- CLIP's text encoder was trained with a maximum sequence length of 77 tokens. This is why very long, elaborate prompts get truncated and lose information. The SDXL variant addressed this by concatenating two different CLIP text encoders (one for short-range semantics, one for longer descriptions) to improve prompt comprehension.

Classifier-free guidance (CFG)

Training the U-Net with text conditioning is straightforward: at each step, pass the text embeddings alongside the noisy latent. But at inference time there's a remarkably clever trick that dramatically improves how well the generated image matches the prompt: classifier-free guidance.

During training, the text conditioning is randomly dropped (replaced with null/empty embeddings) some percentage of the time, typically around 10%. This means the model learns both conditional generation (with text) and unconditional generation (without text) simultaneously.

At inference, you run the model twice per step -- once with the real prompt, once without. Then you extrapolate away from the unconditional prediction toward the conditional one:

def guided_denoise(model, xt, t, text_emb,
                   null_emb, guidance_scale=7.5):
    """Classifier-free guidance: amplify the
    text's influence on the generated image."""
    # Unconditional prediction (no text)
    noise_uncond = model(xt, t, null_emb)
    # Conditional prediction (with text)
    noise_cond = model(xt, t, text_emb)

    # Guided prediction: move AWAY from
    # unconditional, TOWARD conditional
    noise_guided = (noise_uncond
                    + guidance_scale
                    * (noise_cond - noise_uncond))
    return noise_guided

The guidance_scale parameter is what you're adjusting when you set "CFG scale" in Stable Diffusion UIs. At scale=1.0, you get the raw conditional prediction -- the model's natural output given the prompt, unmodified. At scale=7.5 (the widely used default), the text has strong influence and images closely match the prompt while still looking natural. At scale=20+, the image becomes oversaturated and artifact-heavy but extremely prompt-adherent. There's a surpringly narrow sweet spot where you get both good prompt following and visual quality.

Why does this work so well? Think of it geometrically. The unconditional prediction represents "what a random image might look like at this noise level." The conditional prediction represents "what an image matching this text might look like." The difference between them is a vector pointing in the direction of "what makes this image match the text." By scaling that vector beyond 1.0, you're pushing the image further in the "matches the text" direction than the model would naturally go. You're amplifying the text's influence beyond what the model learned, which is why high CFG values can produce artifacts -- you're pushing beyond the model's training distribution.

Faster sampling: DDIM and modern schedulers

DDPM (from episode #84) requires 1000 steps to generate one image. That's 1000 forward passes through the U-Net, which even on a fast GPU takes many seconds. DDIM (Denoising Diffusion Implicit Models, Song et al., 2021) made a critical observation: you can rewrite the reverse process as a deterministic mapping that allows skipping steps. Instead of taking tiny stochastic steps through all 1000 timesteps, DDIM selects a subset (say 50 timesteps, evenly spaced) and takes larger, deterministic jumps between them:

@torch.no_grad()
def ddim_sample(model, schedule, shape,
                num_steps=50, eta=0.0):
    """DDIM sampling: fewer steps, deterministic
    when eta=0."""
    # Select a subset of timesteps
    step_size = schedule.T // num_steps
    timesteps = list(
        range(0, schedule.T, step_size))[::-1]

    x = torch.randn(shape)

    for i, t in enumerate(timesteps):
        t_batch = torch.full(
            (shape[0],), t, dtype=torch.long)
        pred_noise = model(x, t_batch)

        alpha_bar_t = schedule.alpha_bars[t]
        if i + 1 < len(timesteps):
            alpha_bar_prev = schedule.alpha_bars[
                timesteps[i + 1]]
        else:
            alpha_bar_prev = torch.tensor(1.0)

        # Predict x0 from current xt
        pred_x0 = (
            (x - torch.sqrt(1 - alpha_bar_t)
             * pred_noise)
            / torch.sqrt(alpha_bar_t))
        pred_x0 = pred_x0.clamp(-1, 1)

        # Direction pointing to xt
        dir_xt = (torch.sqrt(1 - alpha_bar_prev)
                  * pred_noise)

        # DDIM step (eta=0 = fully deterministic)
        x = (torch.sqrt(alpha_bar_prev) * pred_x0
             + dir_xt)

    return x

50 steps instead of 1000. Nearly identical quality. The eta parameter controls stochasticity: at eta=0 (DDIM default), the same initial noise always produces the same image -- fully deterministic and reproducible. At eta=1.0, you recover the original DDPM stochastic sampling.

More advanced schedulers have pushed this even further. DPM-Solver and DPM-Solver++ use higher-order ODE solvers to get good results in 15-25 steps. Euler and Euler Ancestral are popular for their simplicity and good quality at 20-30 steps. The PNDM (Pseudo Numerical Diffusion Model) scheduler is what the original Stable Diffusion demo used. Each makes different tradeoffs between speed, quality, and whether the output is deterministic:

# Using HuggingFace diffusers, switching
# schedulers is trivial:
from diffusers import (
    DDIMScheduler,
    DPMSolverMultistepScheduler,
    EulerDiscreteScheduler,
)

# DDIM: 50 steps, deterministic
# DPM-Solver++: 20-25 steps, high quality
# Euler: 25-30 steps, good balance

The practical impact is enormous. Going from 1000 steps to 25 steps means generation times drop from minutes to seconds on consumer hardware. This is what made Stable Diffusion viable for real-time interactive use.

ControlNet: structural guidance beyond text

Text prompts are expressive but imprecise. "A person standing with arms raised" will give you a person with arms raised, but you can't control the exact pose, the exact camera angle, or the precise composition. ControlNet (Zhang et al., 2023) adds structural conditioning: alongside the text prompt, you provide an edge map, depth map, pose skeleton, normal map, or other spatial guide.

ControlNet works by copying the encoder half of the U-Net and training the copy to incorporate the structural input, while keeping the original U-Net weights completely frozen. The copy's outputs are added to the original network via zero-initialized convolutions:

import copy


class ControlNet(nn.Module):
    """ControlNet: trainable copy of U-Net
    encoder with zero-init outputs."""

    def __init__(self, base_unet_encoder,
                 encoder_channels):
        super().__init__()
        # Copy of the U-Net encoder (trainable)
        self.control_encoder = copy.deepcopy(
            base_unet_encoder)

        # Zero-initialized output projections
        self.zero_convs = nn.ModuleList([
            nn.Conv2d(ch, ch, 1)
            for ch in encoder_channels
        ])
        # Initialize to output zeros at start
        for conv in self.zero_convs:
            nn.init.zeros_(conv.weight)
            nn.init.zeros_(conv.bias)

    def forward(self, control_image, t_emb):
        features = self.control_encoder(
            control_image, t_emb)
        # Project through zero convolutions
        return [
            zc(f) for zc, f
            in zip(self.zero_convs, features)
        ]
        # These get added to base U-Net's
        # skip connections

The zero-initialization trick is elegant: at the start of training, the ControlNet outputs all zeros and has zero effect on the base model. As training progresses, the zero convolution weights gradually move away from zero and the control signal gets injected more strongly. This prevents the structural conditioning from destroying the pretrained model's capabilities during the early, unstable phase of training -- a problem that plagued earlier attempts at adding conditioning to pretrained diffusion models.

In practice ControlNet is extraordinarily versatile. With a Canny edge map as input, you can redraw an existing image in a completely different style while preserving its structure. With an OpenPose skeleton, you can precisely control the pose of a generated person. With a depth map from a 3D scene, you can generate photorealistic renders of virtual environments. And because the base model stays frozen, a single Stable Diffusion checkpoint can work with many different ControlNet adapters simultaneously.

The complete Stable Diffusion pipeline

Putting all the pieces together, here's what happens when you generate an image with Stable Diffusion:

from diffusers import StableDiffusionPipeline
import torch

pipe = StableDiffusionPipeline.from_pretrained(
    "stabilityai/stable-diffusion-2-1",
    torch_dtype=torch.float16)
pipe = pipe.to("cuda")

image = pipe(
    prompt="a sunset over the ocean, "
           "oil painting style, golden light",
    negative_prompt="blurry, low quality",
    num_inference_steps=30,
    guidance_scale=7.5,
).images[0]

image.save("sunset.png")

Behind that simple API call, four things happen in sequence:

  1. CLIP encodes the prompt (and the negative prompt) into text embeddings -- 77 tokens, 768 dimensions each
  2. A random latent is sampled from a standard normal distribution -- shape (1, 4, 64, 64) for a 512x512 output
  3. The U-Net denoises for 30 steps with classifier-free guidance, running twice per step (once conditioned on the prompt, once unconditioned), blending the predictions according to guidance_scale
  4. The VAE decoder expands the clean latent from (4, 64, 64) to a full (3, 512, 512) pixel image

The negative prompt is a subtle but important addition. Instead of using a null/empty embedding for the unconditional prediction in CFG, you can provide a "negative prompt" describing what you don't want (artifacts, blur, distortion). The model then steers away from the negative prompt and toward the positive prompt, giving you finer control over the output.

The evolution: SD 1.5 to SDXL to SD3

Stable Diffusion hasn't stood still since the original 1.4/1.5 release. Each version brought significant architectural improvements:

Stable Diffusion 1.5 (2022): the model that changed everything. 860M parameter U-Net, single CLIP text encoder (ViT-L/14, 768-dim), 512x512 native resolution, VAE with 4-channel latent space. Trained on a subset of LAION-5B.

Stable Diffusion 2.1 (2022): upgraded to OpenCLIP ViT-H/14 (1024-dim text embeddings), native 768x768 resolution option, improved aesthetic quality. Somewhat controversially, it was also trained with NSFW content filtered more aggressively, which actually reduced its understanding of human anatomy in general.

SDXL (2023): major architecture upgrade. 2.6B parameter U-Net (3x larger), dual text encoders (CLIP ViT-L + OpenCLIP ViT-bigG, concatenated to 2048 dimensions), native 1024x1024 resolution, additional conditioning on image size and crop parameters. SDXL also introduced a refiner model -- a second diffusion model that takes the output of the base model and adds fine details, effectively a two-stage generation process.

Stable Diffusion 3 (2024): replaced the U-Net entirely with a DiT (Diffusion Transformer) architecture -- using transformer blocks instead of convolutional blocks for the denoising network. This aligns with the broader trend in AI where transformers are replacing convolutions everywhere. SD3 also uses three text encoders (two CLIPs plus T5-XXL for long-form text understanding) and a rectified flow formulation instead of the traditional DDPM noise schedule.

The trend is clear: bigger models, better text understanding (more and larger text encoders), higher native resolutions, and architectual shifts from convolutions toward transformers. The core latent diffusion framework -- VAE for compression, cross-attention for text conditioning, iterative denoising for generation -- remains fundamentally the same across all versions.

Samengevat

  • Latent diffusion moves the denoising process from pixel space to a 48x-smaller compressed latent space using a pretrained VAE, making high-resolution generation practical on consumer GPUs;
  • the architecture has three components: a VAE (compresses/decompresses images, frozen), a CLIP text encoder (embeds prompts into semantic vectors, frozen), and a trained U-Net denoiser that operates in latent space;
  • text conditioning works through cross-attention in the U-Net: spatial image features (queries) attend to text embeddings (keys/values), learning which words should influence which regions of the image;
  • classifier-free guidance amplifies prompt adherence by running the model twice per step (with and without text) and extrapolating toward the conditioned prediction; the guidance_scale parameter controls strength;
  • DDIM and modern schedulers (DPM-Solver, Euler) reduce sampling from 1000 steps down to 20-50 steps with minimal quality loss, making real-time generation feasible;
  • ControlNet adds structural guidance (edges, depth, pose) through a trainable copy of the U-Net encoder with zero-initialized outputs, preserving the base model's capabilities while adding precise spatial control;
  • the evolution from SD 1.5 (860M params, 512x512) to SDXL (2.6B, 1024x1024) to SD3 (DiT architecture, triple text encoders) shows the field moving toward bigger models, better text understanding, and transformer-based denoisers -- but the core latent diffusion framework stays satiesfyingly consistent.

Exercises

Exercise 1: Build a VAE compression analyzer. Create a class VAECompressionAnalyzer that: (a) takes a list of image resolutions (e.g., [256, 512, 768, 1024]) and a fixed latent channel count (default 4) with a downsampling factor (default 8), (b) for each resolution, computes the pixel-space size (H x W x 3 floats), the latent-space size (H/f x W/f x latent_ch floats), the compression ratio, and the memory saved in MB (assuming float32), (c) prints a comparison table showing resolution, pixel size, latent size, compression ratio, and memory savings, (d) also computes and prints how many DDIM steps at 50 steps/image you could run on the latent-space representation in the same memory budget as a single pixel-space forward pass (assume memory per step scales linearly with tensor size). Verify that the compression ratio is constant across resolutions (since the spatial downsampling factor is fixed) but the absolute memory savings grow quadratically.

Exercise 2: Build a CFG scale simulator. Create a class CFGSimulator that: (a) generates synthetic 1D "image features" (length 100) representing a denoising step: unconditional = np.random.randn(100) and conditional = unconditional + signal where signal is a known offset vector, (b) implements apply_cfg(unconditional, conditional, scale) that returns the guided prediction, (c) for guidance scales [1.0, 3.0, 5.0, 7.5, 10.0, 15.0, 20.0], computes and prints: the cosine similarity between the guided output and the pure signal (measures how well CFG recovers the text direction), the L2 norm of the guided output (measures magnitude/saturation), and the "signal amplification factor" (ratio of the guided signal component to the raw conditional signal component), (d) identifies the scale at which the L2 norm of the guided output exceeds 3x the L2 norm of the conditional output (this approximates where artifacts start appearing in real systems). Verify that similarity to the signal increases with scale but the norm grows without bound.

Exercise 3: Build a DDIM step quality estimator. Create a class DDIMStepEstimator that: (a) implements a toy 1D diffusion process: x0 is a known signal, xt = sqrt(alpha_bar_t) * x0 + sqrt(1-alpha_bar_t) * noise, and the "model" is a perfect noise predictor that returns the exact noise, (b) implements ddim_reconstruct(x0, num_steps) that adds noise to x0 at t=T then runs DDIM backwards for num_steps steps (using the perfect predictor) to recover x0, (c) for step counts [5, 10, 20, 50, 100, 200, 500, 1000], computes the reconstruction MSE between the original x0 and the DDIM-recovered version, (d) prints a table showing step count, MSE, and relative quality (MSE normalized by the 1000-step MSE), (e) identifies the "knee" -- the step count beyond which doubling the steps reduces MSE by less than 10%. Use T=1000 with a linear schedule and a signal of length 256. Verify that the perfect predictor achieves near-zero MSE at all step counts (the error comes only from discretization of the ODE, not from prediction quality).

Thanks for reading!

@scipio



0
0
0.000
0 comments