Learn AI Series (#44) - PyTorch nn.Module - Building Real Networks

avatar

Learn AI Series (#44) - PyTorch nn.Module - Building Real Networks

ai-banner.png

What will I learn

  • You will learn nn.Module -- PyTorch's building block for all neural network architectures;
  • how to write custom modules with learnable parameters and full forward-pass control;
  • Sequential, ModuleList, ModuleDict -- organizing layers in complex architectures;
  • skip connections (residual blocks) -- why adding the input to the output changes everything;
  • hooks -- inspecting activations and gradients without touching the model code;
  • parameter groups -- different learning rates for different parts of a model;
  • model surgery -- modifying pretrained architectures for new tasks.

Requirements

  • A working modern computer running macOS, Windows or Ubuntu;
  • An installed Python 3(.11+) distribution;
  • The ambition to learn AI and machine learning.

Difficulty

  • Beginner

Curriculum (of the Learn AI Series):

Learn AI Series (#44) - PyTorch nn.Module - Building Real Networks

Solutions to Episode #43 Exercises

Exercise 1: Create a RangeDataset that labels samples 1 if between 3 and 7, 0 otherwise.

import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np

class RangeDataset(Dataset):
    def __init__(self, n_samples=500):
        np.random.seed(42)
        self.X = np.random.uniform(0, 10, size=(n_samples, 1)).astype(np.float32)
        self.y = ((self.X[:, 0] >= 3) & (self.X[:, 0] <= 7)).astype(np.float32)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return torch.tensor(self.X[idx]), torch.tensor(self.y[idx])

dataset = RangeDataset(500)
loader = DataLoader(dataset, batch_size=16, shuffle=True)

for i, (X_batch, y_batch) in enumerate(loader):
    print(f"Batch {i}: X shape={X_batch.shape}, y shape={y_batch.shape}")
    if i >= 2:
        break
# Batch 0: X shape=torch.Size([16, 1]), y shape=torch.Size([16])
# Batch 1: X shape=torch.Size([16, 1]), y shape=torch.Size([16])
# Batch 2: X shape=torch.Size([16, 1]), y shape=torch.Size([16])

Straightforward extension of the CircleDataset pattern from episode #43. The only difference is the labeling function -- here it's a range check in stead of a circle boundary. One input feature, one binary label, same Dataset interface.

Exercise 2: Full training loop with 80/20 split, AdamW + cosine annealing.

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

class CircleDataset(Dataset):
    def __init__(self, n_samples=1000):
        np.random.seed(42)
        self.X = np.random.randn(n_samples, 2).astype(np.float32)
        self.y = ((self.X[:, 0]**2 + self.X[:, 1]**2) < 1.5).astype(np.float32)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return torch.tensor(self.X[idx]), torch.tensor(self.y[idx])

train_data = CircleDataset(800)
val_data = CircleDataset(200)
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
val_loader = DataLoader(val_data, batch_size=64, shuffle=False)

model = nn.Sequential(
    nn.Linear(2, 32), nn.ReLU(), nn.Dropout(0.2),
    nn.Linear(32, 16), nn.ReLU(),
    nn.Linear(16, 1), nn.Sigmoid()
)

loss_fn = nn.BCELoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=30)

for epoch in range(30):
    model.train()
    train_loss = 0
    for X_b, y_b in train_loader:
        pred = model(X_b).squeeze()
        loss = loss_fn(pred, y_b)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    scheduler.step()

    model.eval()
    val_correct, val_total = 0, 0
    val_loss = 0
    with torch.no_grad():
        for X_v, y_v in val_loader:
            vp = model(X_v).squeeze()
            val_loss += loss_fn(vp, y_v).item()
            val_correct += ((vp > 0.5) == y_v).sum().item()
            val_total += len(y_v)

    if epoch % 5 == 0:
        print(f"Epoch {epoch:>2d}: train_loss={train_loss/len(train_loader):.4f}, "
              f"val_loss={val_loss/len(val_loader):.4f}, "
              f"val_acc={val_correct/val_total:.1%}")

# Final accuracies
model.eval()
for name, loader in [("Train", train_loader), ("Val", val_loader)]:
    correct, total = 0, 0
    with torch.no_grad():
        for X_b, y_b in loader:
            p = model(X_b).squeeze()
            correct += ((p > 0.5) == y_b).sum().item()
            total += len(y_b)
    print(f"{name} accuracy: {correct/total:.1%}")

The cosine annealing scheduler gradually reduces the learning rate from 0.001 down to near-zero over the 30 epochs. Combined with AdamW weight decay, this is the modern default for training PyTorch models (as we discussed in episodes #40 and #41).

Exercise 3: Checkpoint every 10 epochs, compare epoch-10 vs final accuracy.

import os

# Re-train with checkpointing
model2 = nn.Sequential(
    nn.Linear(2, 32), nn.ReLU(), nn.Dropout(0.2),
    nn.Linear(32, 16), nn.ReLU(),
    nn.Linear(16, 1), nn.Sigmoid()
)
optimizer2 = torch.optim.AdamW(model2.parameters(), lr=0.001, weight_decay=0.01)
scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer2, T_max=30)

for epoch in range(30):
    model2.train()
    for X_b, y_b in train_loader:
        pred = model2(X_b).squeeze()
        loss = loss_fn(pred, y_b)
        optimizer2.zero_grad()
        loss.backward()
        optimizer2.step()
    scheduler2.step()

    if (epoch + 1) % 10 == 0:
        torch.save({
            'epoch': epoch,
            'model_state_dict': model2.state_dict(),
            'optimizer_state_dict': optimizer2.state_dict(),
        }, f'/tmp/checkpoint_epoch_{epoch+1}.pt')
        print(f"Saved checkpoint at epoch {epoch+1}")

# Load epoch-10 checkpoint into fresh model
model_ep10 = nn.Sequential(
    nn.Linear(2, 32), nn.ReLU(), nn.Dropout(0.2),
    nn.Linear(32, 16), nn.ReLU(),
    nn.Linear(16, 1), nn.Sigmoid()
)
ckpt = torch.load('/tmp/checkpoint_epoch_10.pt', weights_only=False)
model_ep10.load_state_dict(ckpt['model_state_dict'])

# Compare
for label, m in [("Epoch 10", model_ep10), ("Final (30)", model2)]:
    m.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for X_v, y_v in val_loader:
            p = m(X_v).squeeze()
            correct += ((p > 0.5) == y_v).sum().item()
            total += len(y_v)
    print(f"{label} val accuracy: {correct/total:.1%}")
# Epoch 10 will typically be slightly lower than the final model

Typically the final model outperforms the epoch-10 checkpoint since training had another 20 epochs of refinement. The key point: checkpointing lets you go back to any earlier state of training, which is invaluable when experimenting with longer training schedules or debugging overfitting.

On to today's episode

In episodes #42 and #43, we built models using nn.Sequential -- stack a few layers in a list, pass data through them top to bottom. That approach works great for simple architectures. But what happens when you need a network with skip connections (where the input of a block gets added to its output)? Or multiple inputs feeding into different branches? Or shared parameters between two different parts of the network? nn.Sequential can't express any of these because it only knows how to pipe data linearly from one layer to the next.

Every serious PyTorch architecture -- from ResNets to Transformers to GANs -- is built by subclassing nn.Module. It's the base class for everything in PyTorch: individual layers (nn.Linear, nn.Conv2d), activation functions (nn.ReLU), regularization layers (nn.Dropout, nn.BatchNorm1d), and complete models. nn.Sequential itself is just a convenience wrapper around nn.Module. Understanding nn.Module deeply is the key to building, modifying, and debugging any architecture you'll encounter in practice. Here we go!

Your first custom Module

The recipe is simple: subclass nn.Module, define your layers in __init__, and implement the forward method. If you've been following the Learn Python Series (especially the episodes on classes and inheritance), this will feel completely natural -- it's just Python OOP:

import torch
import torch.nn as nn

class SimpleClassifier(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.layer1 = nn.Linear(input_dim, hidden_dim)
        self.bn1 = nn.BatchNorm1d(hidden_dim)
        self.layer2 = nn.Linear(hidden_dim, hidden_dim)
        self.bn2 = nn.BatchNorm1d(hidden_dim)
        self.output = nn.Linear(hidden_dim, output_dim)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.3)

    def forward(self, x):
        x = self.dropout(self.relu(self.bn1(self.layer1(x))))
        x = self.dropout(self.relu(self.bn2(self.layer2(x))))
        return self.output(x)

model = SimpleClassifier(20, 64, 3)
print(model)
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")

There are two rules you absolutely must follow. First, register all layers as attributes in __init__. PyTorch scans your Module object for any attributes that are themselves nn.Module instances and automatically registers their parameters. If you create a layer inside forward() in stead of in __init__, PyTorch has no way to discover it -- its parameters won't show up in model.parameters(), the optimizer won't update them, and model.to(device) won't move them to GPU. Your model will appear to train but one or more layers will have frozen random weights. Nasty bug to track down.

Second, always call super().__init__() at the start of your __init__. This initializes the internal Module machinery that makes parameter registration, hooks, serialization and everything else work. Skip it and you'll get an immediate error the first time you try to assign a sub-module attribute.

The forward method defines what happens when you call the model. It's just Python -- use if-statements, loops, whatever you need. PyTorch builds the autograd graph dynamically as the code executes (the same dynamic graph behavior we saw in episode #42).

Sequential vs custom Module: when to use what

So when do you actually need a custom Module in stead of nn.Sequential? A few common situations:

# nn.Sequential is fine for this:
simple = nn.Sequential(
    nn.Linear(20, 64), nn.ReLU(), nn.Dropout(0.3),
    nn.Linear(64, 32), nn.ReLU(), nn.Dropout(0.3),
    nn.Linear(32, 3)
)

# But this NEEDS a custom Module (skip connection):
class ResidualBlock(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.block = nn.Sequential(
            nn.Linear(dim, dim), nn.BatchNorm1d(dim), nn.ReLU(),
            nn.Linear(dim, dim), nn.BatchNorm1d(dim))

    def forward(self, x):
        return torch.relu(self.block(x) + x)  # skip connection!

# And this NEEDS a custom Module (conditional logic):
class ConditionalNet(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.shared = nn.Linear(input_dim, hidden_dim)
        self.branch_a = nn.Linear(hidden_dim, output_dim)
        self.branch_b = nn.Linear(hidden_dim, output_dim)

    def forward(self, x, use_branch_a=True):
        h = torch.relu(self.shared(x))
        if use_branch_a:
            return self.branch_a(h)
        return self.branch_b(h)

print(f"Simple (Sequential): {sum(p.numel() for p in simple.parameters()):,} params")
resblock = ResidualBlock(64)
print(f"ResidualBlock:       {sum(p.numel() for p in resblock.parameters()):,} params")
cond = ConditionalNet(20, 64, 3)
print(f"ConditionalNet:      {sum(p.numel() for p in cond.parameters()):,} params")

The rule of thumb: if your data flows in one straight line from input to output, nn.Sequential is cleaner. The moment you need addition (+ x), branching (if), multiple inputs, or anything that breaks the linear chain, you need a custom Module. Having said that, I tend to default to custom Modules even for simple architectures once a project gets beyond the prototyping stage -- the extra few lines of code buy you flexibility to add features later without restructuring.

ModuleList and ModuleDict

When you have a variable number of layers, you might be tempted to put them in a regular Python list. Don't. PyTorch can't find parameters inside plain Python lists or dicts:

class FlexibleMLP(nn.Module):
    def __init__(self, layer_sizes):
        super().__init__()
        # CORRECT: nn.ModuleList -- PyTorch tracks these parameters
        self.layers = nn.ModuleList([
            nn.Linear(layer_sizes[i], layer_sizes[i+1])
            for i in range(len(layer_sizes) - 1)
        ])

    def forward(self, x):
        for layer in self.layers[:-1]:
            x = torch.relu(layer(x))
        return self.layers[-1](x)    # no activation on output layer

# Compare correct vs incorrect
flex = FlexibleMLP([20, 64, 32, 3])
print(f"ModuleList layers: {len(flex.layers)}")
print(f"Parameters found:  {sum(p.numel() for p in flex.parameters()):,}")

# THE BUG: plain Python list hides parameters
class BrokenMLP(nn.Module):
    def __init__(self, layer_sizes):
        super().__init__()
        self.layers = [    # plain list -- BAD
            nn.Linear(layer_sizes[i], layer_sizes[i+1])
            for i in range(len(layer_sizes) - 1)
        ]

    def forward(self, x):
        for layer in self.layers[:-1]:
            x = torch.relu(layer(x))
        return self.layers[-1](x)

broken = BrokenMLP([20, 64, 32, 3])
print(f"\nBroken (plain list):")
print(f"Layers exist:     {len(broken.layers)}")
print(f"Parameters found: {sum(p.numel() for p in broken.parameters()):,}")  # 0!

The broken version has functioning layers -- the forward pass works, the outputs look correct. But model.parameters() returns nothing, so the optimizer has nothing to update. The model trains for 100 epochs with zero improvement and you spend an hour wondering why your loss isn't decreasing. I've seen this exact bug in production code from people who should know better ;-)

ModuleDict does the same thing for named layer collections. It's particularly useful for multi-task architectures where you want to route data through different heads depending on the task:

class MultiTaskNet(nn.Module):
    def __init__(self, input_dim, hidden_dim, task_outputs):
        super().__init__()
        self.backbone = nn.Sequential(
            nn.Linear(input_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim), nn.ReLU()
        )
        self.heads = nn.ModuleDict({
            name: nn.Linear(hidden_dim, out_dim)
            for name, out_dim in task_outputs.items()
        })

    def forward(self, x, task_name):
        features = self.backbone(x)
        return self.heads[task_name](features)

multi = MultiTaskNet(20, 64, {'classify': 3, 'regress': 1, 'embed': 16})
x = torch.randn(8, 20)

for task in multi.heads:
    out = multi(x, task)
    print(f"Task '{task}': output shape = {out.shape}")

print(f"\nTotal params: {sum(p.numel() for p in multi.parameters()):,}")

The backbone parameters are shared across all tasks -- trained by gradients from every task head. The individual heads specialize for their own task. This is a common pattern in real-world architectures where you want to leverage shared feature extraction across related problems.

Skip connections (residual blocks)

The ResNet architecture (He et al., 2015) introduced one of the most important ideas in deep learning: skip connections. The concept is remarkably simple -- add the input of a block directly to its output:

class ResidualBlock(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.block = nn.Sequential(
            nn.Linear(dim, dim),
            nn.BatchNorm1d(dim),
            nn.ReLU(),
            nn.Linear(dim, dim),
            nn.BatchNorm1d(dim)
        )

    def forward(self, x):
        return torch.relu(self.block(x) + x)  # THE skip connection

class ResNet(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, n_blocks=3):
        super().__init__()
        self.input_layer = nn.Linear(input_dim, hidden_dim)
        self.blocks = nn.Sequential(*[ResidualBlock(hidden_dim) for _ in range(n_blocks)])
        self.output_layer = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = torch.relu(self.input_layer(x))
        x = self.blocks(x)
        return self.output_layer(x)

resnet = ResNet(20, 64, 3, n_blocks=5)
print(f"ResNet with 5 blocks: {sum(p.numel() for p in resnet.parameters()):,} params")

# Test it
x = torch.randn(32, 20)
out = resnet(x)
print(f"Input: {x.shape} -> Output: {out.shape}")

Notice something elegant: ResidualBlock is itself an nn.Module, and we compose five of them inside another nn.Module (ResNet). Modules nest seamlessly because they all share the same interface. This composability is one of the most powerful aspects of the nn.Module design -- you build complex architectures from simple, reusable pieces.

But why do skip connections matter? Remember the vanishing gradient problem from episode #40: in deep networks, gradients shrink as they travel backward through many layers, eventually becoming too small to update the early layers. The skip connection provides a gradient highway -- during backpropagation, gradients can flow directly through the addition operation (the derivative of f(x) + x with respect to x always includes a 1 term), bypassing the block's layers entirely. This means even the earliest layers in a 100-layer network still receive meaningful gradient signal.

Without skip connections, plain networks start degrading in performance beyond roughly 20 layers -- adding more layers actually makes accuracy worse (not just due to overfitting but because training fails). With skip connections, ResNets have been trained with 152 layers and beyond. That's a prettyy dramatic improvement from a single + x ;-)

Hooks: inspecting the internals

Hooks let you observe (or even modify) what happens inside a model without touching the model's code. This is incredibly useful for debugging, visualization, and feature extraction. Forward hooks fire after a layer computes its output; backward hooks fire during backpropagation:

activation_stats = {}

def make_hook(name):
    def hook(module, input, output):
        activation_stats[name] = {
            'mean': output.detach().mean().item(),
            'std': output.detach().std().item(),
            'min': output.detach().min().item(),
            'max': output.detach().max().item(),
            'dead_frac': (output.detach() == 0).float().mean().item()
        }
    return hook

# Attach hooks to specific layers
model = SimpleClassifier(20, 64, 3)
handles = []
handles.append(model.layer1.register_forward_hook(make_hook('layer1')))
handles.append(model.layer2.register_forward_hook(make_hook('layer2')))
handles.append(model.output.register_forward_hook(make_hook('output')))

# Forward pass triggers all hooks
x = torch.randn(32, 20)
_ = model(x)

print("Activation statistics per layer:")
for name, stats in activation_stats.items():
    print(f"  {name:>8s}: mean={stats['mean']:+.3f}, std={stats['std']:.3f}, "
          f"range=[{stats['min']:.3f}, {stats['max']:.3f}], "
          f"dead={stats['dead_frac']:.1%}")

# Clean up hooks when done
for h in handles:
    h.remove()

The dead fraction tells you what percentage of outputs are exactly zero after ReLU. If a layer shows 70-80% dead neurons, something is wrong -- those neurons aren't contributing to the network's predictions at all. Possible fixes: lower the learning rate (large updates can push neurons into the permanently-dead zone), use a different initialization, or switch from ReLU to Leaky ReLU (which has a small non-zero gradient for negative inputs, preventing neurons from dying entirely -- we covered this in episode #40).

Hooks are non-invasive. You attach them from outside the model and remove them when you're done. The model's forward method doesn't change at all. This makes hooks ideal for diagnostics on models you didn't write -- you can inspect any pretrained model's internals without modifying a single line of its source code.

A practical example: extracting intermediate representations for transfer learning. You hook into an early layer of a pretrained image model, run your data through it, and use the activations as features for a completely different task. We'll do exactly this when we start working with pretrained CNNs.

# Gradient hook example -- monitor gradient magnitudes
gradient_stats = {}

def grad_hook(name):
    def hook(module, grad_input, grad_output):
        if grad_output[0] is not None:
            gradient_stats[name] = {
                'grad_mean': grad_output[0].detach().mean().item(),
                'grad_std': grad_output[0].detach().std().item(),
                'grad_norm': grad_output[0].detach().norm().item()
            }
    return hook

model2 = SimpleClassifier(20, 64, 3)
model2.layer1.register_full_backward_hook(grad_hook('layer1'))
model2.layer2.register_full_backward_hook(grad_hook('layer2'))
model2.output.register_full_backward_hook(grad_hook('output'))

# Forward + backward to trigger gradient hooks
x = torch.randn(32, 20)
y = torch.randint(0, 3, (32,))
loss = nn.CrossEntropyLoss()(model2(x), y)
loss.backward()

print("\nGradient statistics per layer:")
for name, stats in gradient_stats.items():
    print(f"  {name:>8s}: grad_mean={stats['grad_mean']:+.6f}, "
          f"grad_std={stats['grad_std']:.6f}, "
          f"grad_norm={stats['grad_norm']:.4f}")

If you see gradient norms that are orders of magnitude different between layers (output layer at 0.1, layer1 at 0.00001), you've got a vanishing gradient problem. This is the exact problem we diagnosed theoretically in episode #40 -- hooks give you the empirical measurement.

Parameter groups: fine-grained optimizer control

Different parts of a model often need different learning rates. The classic scenario: you take a pretrained model (trained on millions of images over days of GPU time), add a new classification head, and want to train the head aggressively while barely touching the pretrained weights. Parameter groups make this possible:

resnet = ResNet(20, 64, 3, n_blocks=3)

# Different learning rates per component
optimizer = torch.optim.AdamW([
    {'params': resnet.input_layer.parameters(), 'lr': 1e-4},   # low LR
    {'params': resnet.blocks.parameters(), 'lr': 1e-4},        # low LR
    {'params': resnet.output_layer.parameters(), 'lr': 1e-3},  # high LR (10x)
], weight_decay=0.01)

for i, group in enumerate(optimizer.param_groups):
    n_params = sum(p.numel() for p in group['params'])
    print(f"Group {i}: lr={group['lr']}, params={n_params:,}, "
          f"weight_decay={group['weight_decay']}")

Parameter groups are passed to the optimizer as a list of dictionaries. Each dictionary must have a 'params' key (an iterable of parameters) and can optionally override any optimizer setting -- learning rate, weight decay, momentum, whatever the specific optimizer supports. Parameters not explicitly assigned to a group use the optimizer's default settings.

This is the mechanism behind fine-tuning pretrained models. The standard recipe: freeze the backbone (set requires_grad = False on all backbone parameters) or use a very low learning rate, train the new head at a much higher rate, then optionally "unfreeze" layers gradually from top to bottom with progressively lower learning rates. This technique -- called gradual unfreezing -- lets the model adapt its pretrained representations without catastrophically forgetting what it already knows.

# Freezing parameters -- exclude from optimization entirely
for param in resnet.input_layer.parameters():
    param.requires_grad = False

# Only unfrozen parameters get updated
trainable = sum(p.numel() for p in resnet.parameters() if p.requires_grad)
frozen = sum(p.numel() for p in resnet.parameters() if not p.requires_grad)
print(f"Trainable:  {trainable:,}")
print(f"Frozen:     {frozen:,}")
print(f"Total:      {trainable + frozen:,}")

# Unfreeze later when you want to fine-tune everything
for param in resnet.input_layer.parameters():
    param.requires_grad = True
print(f"\nAfter unfreezing: {sum(p.numel() for p in resnet.parameters() if p.requires_grad):,} trainable")

Model surgery: modifying pretrained architectures

One of the most common patterns in modern deep learning: take a model pretrained on a massive dataset, chop off the last layer, and bolt on a new one that fits your specific task. This is transfer learning in its simplest form:

# Simulate a pretrained model (normally you'd load from torchvision)
class PretrainedFeatureExtractor(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
            nn.Linear(100, 256), nn.ReLU(),
            nn.Linear(256, 256), nn.ReLU(),
            nn.Linear(256, 128), nn.ReLU()
        )
        self.classifier = nn.Linear(128, 1000)  # ImageNet: 1000 classes

    def forward(self, x):
        feat = self.features(x)
        return self.classifier(feat)

# "Download" the pretrained model
pretrained = PretrainedFeatureExtractor()
print(f"Original classifier: {pretrained.classifier}")
print(f"Original output classes: 1000")

# Surgery: replace the classifier for our 5-class task
pretrained.classifier = nn.Sequential(
    nn.Dropout(0.3),
    nn.Linear(128, 5)    # our task has 5 classes
)

# Freeze the feature extractor
for param in pretrained.features.parameters():
    param.requires_grad = False

print(f"\nNew classifier: {pretrained.classifier}")
print(f"New output classes: 5")

trainable = sum(p.numel() for p in pretrained.parameters() if p.requires_grad)
total = sum(p.numel() for p in pretrained.parameters())
print(f"Trainable: {trainable:,} / {total:,} ({trainable/total:.1%})")

In practice, you'd use torchvision.models.resnet50(pretrained=True), replace its .fc attribute with nn.Linear(2048, num_classes), and freeze everything else. We'll do exactly this when we build CNNs in upcoming episodes -- the pattern will be identical, just with real image data and real pretrained weights. The key insight: you're adapting a model trained on 1.4 million ImageNet images by training only a few thousand new parameters. The feature extractor already knows how to recognize edges, textures, shapes, and objects. Your new classifier just needs to learn the mapping from those rich features to your specific categories.

# Verify the model still works end-to-end after surgery
x = torch.randn(16, 100)
output = pretrained(x)
print(f"\nForward pass: input {x.shape} -> output {output.shape}")

# Train ONLY the new classifier
optimizer = torch.optim.Adam(
    filter(lambda p: p.requires_grad, pretrained.parameters()),
    lr=0.001
)

# Quick training loop to show it works
y = torch.randint(0, 5, (16,))
loss_fn = nn.CrossEntropyLoss()

for epoch in range(10):
    pred = pretrained(x)
    loss = loss_fn(pred, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if epoch % 3 == 0:
        acc = (pred.argmax(dim=1) == y).float().mean()
        print(f"Epoch {epoch}: loss={loss.item():.4f}, acc={acc.item():.1%}")

Notice the filter(lambda p: p.requires_grad, ...) trick -- this passes only the unfrozen parameters to the optimizer, saving memory (the optimizer doesn't need to maintain Adam state for frozen parameters) and making the optimizer step faster.

Custom parameters with nn.Parameter

Sometimes you need learnable parameters that aren't part of a standard layer. nn.Parameter wraps a tensor and tells PyTorch "this should be optimized":

class ScaledDotProduct(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.scale = nn.Parameter(torch.ones(dim))    # learnable scale
        self.bias = nn.Parameter(torch.zeros(dim))     # learnable bias
        self.linear = nn.Linear(dim, dim)

    def forward(self, x):
        return self.linear(x) * self.scale + self.bias

sdp = ScaledDotProduct(32)
print("Named parameters:")
for name, param in sdp.named_parameters():
    print(f"  {name}: shape={list(param.shape)}")

# scale and bias show up alongside linear's weight and bias

Any tensor wrapped in nn.Parameter and assigned as an attribute of an nn.Module is automatically included in model.parameters(). The optimizer treats it just like any other weight or bias. This is how attention mechanisms implement their query/key/value projections and how layer normalization implements its learnable affine parameters -- the building blocks we'll need for Transformers.

Putting it all together: a complete example

Let's combine everything -- custom Module, residual connections, parameter groups, hooks -- into a single training example:

import numpy as np
from torch.utils.data import DataLoader, TensorDataset

# Generate synthetic data (10-class classification)
np.random.seed(42)
X_train = torch.randn(2000, 20)
y_train = torch.randint(0, 10, (2000,))
X_val = torch.randn(500, 20)
y_val = torch.randint(0, 10, (500,))

train_loader = DataLoader(TensorDataset(X_train, y_train), batch_size=64, shuffle=True)
val_loader = DataLoader(TensorDataset(X_val, y_val), batch_size=128)

# Build the model
model = ResNet(input_dim=20, hidden_dim=64, output_dim=10, n_blocks=3)
print(f"Model: {sum(p.numel() for p in model.parameters()):,} parameters")

# Parameter groups
optimizer = torch.optim.AdamW([
    {'params': model.input_layer.parameters(), 'lr': 1e-3},
    {'params': model.blocks.parameters(), 'lr': 1e-3},
    {'params': model.output_layer.parameters(), 'lr': 1e-3},
], weight_decay=0.01)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=30)
loss_fn = nn.CrossEntropyLoss()

# Training
for epoch in range(30):
    model.train()
    total_loss = 0
    for X_b, y_b in train_loader:
        pred = model(X_b)
        loss = loss_fn(pred, y_b)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    scheduler.step()

    if epoch % 5 == 0:
        model.eval()
        correct, total = 0, 0
        with torch.no_grad():
            for X_v, y_v in val_loader:
                correct += (model(X_v).argmax(1) == y_v).sum().item()
                total += len(y_v)
        print(f"Epoch {epoch:>2d}: train_loss={total_loss/len(train_loader):.4f}, "
              f"val_acc={correct/total:.1%}")

# Final
model.eval()
correct, total = 0, 0
with torch.no_grad():
    for X_v, y_v in val_loader:
        correct += (model(X_v).argmax(1) == y_v).sum().item()
        total += len(y_v)
print(f"\nFinal val accuracy: {correct/total:.1%}")

Random data with 10 classes means ~10% baseline accuracy. The model should do meaningfully better than that -- but don't expect 90%+ on random data. The point here is that all the pieces fit together: custom Module with residual blocks, parameter groups, cosine annealing, proper train/eval toggling.

The nn.Module cheat sheet

Here's the complete reference for everything nn.Module gives you:

# Everything you get from nn.Module:
demo = ResNet(20, 64, 3)

# 1. List all parameters (for optimizer)
print("parameters():", sum(p.numel() for p in demo.parameters()))

# 2. Named parameters (for debugging)
for n, p in demo.named_parameters():
    print(f"  {n}: {list(p.shape)}")

# 3. Move to device
# demo.to('cuda')  # or .to('mps') on Apple Silicon

# 4. Save / load
torch.save(demo.state_dict(), '/tmp/demo.pth')
demo.load_state_dict(torch.load('/tmp/demo.pth', weights_only=True))

# 5. Train / eval mode
demo.train()   # dropout ON, batchnorm uses batch stats
demo.eval()    # dropout OFF, batchnorm uses running stats

# 6. List sub-modules
print("\nSub-modules:")
for name, module in demo.named_modules():
    if name:
        print(f"  {name}: {module.__class__.__name__}")

What to remember from this one

  • nn.Module is the base class for all PyTorch networks -- subclass it, define layers in __init__, implement forward. Use nn.Sequential for simple linear stacks, custom Module for everything else;
  • Register all layers as attributes in __init__ -- PyTorch automatically tracks their parameters. A layer hidden in a plain Python list or created inside forward() is invisible to the optimizer;
  • ModuleList and ModuleDict are the correct containers for variable-count and named layer collections -- they register parameters properly, plain lists and dicts don't;
  • Skip connections (output + input) enable very deep networks by providing gradient shortcuts. This is the core idea behind ResNets and one of the most impactful architecture innovations in deep learning;
  • Hooks (forward and backward) let you inspect activations and gradients non-invasively -- attach from outside the model, no code changes needed. Essential for debugging, visualization, and feature extraction;
  • Parameter groups give per-layer control over learning rate and weight decay. Freezing parameters with requires_grad = False excludes them from optimization entirely -- both mechanisms are critical for fine-tuning pretrained models;
  • Model surgery (replacing layers, freezing the backbone) enables transfer learning: take a model trained on millions of samples, replace the head, freeze the features, and train on your specific task with a fraction of the data and compute.

We now have the full PyTorch toolkit: tensors and autograd (#42), data pipelines and training loops (#43), and custom architectures with nn.Module (this episode). With these three building blocks, you can construct and train any neural network architecture. The next step is to start building architectures that actually exploit the structure of specific data types -- starting with images, where spatial structure matters enormously, and where a specialized architecture called the Convolutional Neural Network has been dominating since 2012 ;-)

Exercises

Exercise 1: Create a custom nn.Module called GatedMLP with 3 hidden layers. In the forward pass, each hidden layer's output should be multiplied element-wise by a learned "gate" (a separate nn.Linear followed by nn.Sigmoid). This means each hidden layer has a companion gate layer. Train it on 2D circle classification data (same data as episode #42) and print the final accuracy.

Exercise 2: Build a DeepResNet with a configurable number of residual blocks (using nn.ModuleList). Add forward hooks to every residual block that record the mean activation magnitude. Train on the 10-class random data from this episode for 20 epochs, then print the activation statistics for each block. Are the activations staying stable through the depth of the network?

Exercise 3: Implement model surgery: create a "pretrained" 3-block ResNet with 10 output classes, train it for 30 epochs, then replace the output layer with one that has 5 classes. Freeze the residual blocks, train only the new output layer for 15 epochs. Compare the training speed (epochs to reach peak accuracy) of the fine-tuned model vs training a fresh 5-class ResNet from scratch for 30 epochs.

Thanks for reading, tot de volgende!

@scipio



0
0
0.000
0 comments