Learn AI Series (#90) - Self-Supervised Learning for Vision
Learn AI Series (#90) - Self-Supervised Learning for Vision

What will I learn
- You will learn why supervised learning's dependence on labeled data creates a fundamental bottleneck for vision;
- contrastive learning with SimCLR and MoCo: training visual representations by comparing augmented image views;
- BYOL and DINO: learning without negative pairs through momentum encoders and self-distillation;
- masked image modeling with MAE: the BERT of vision, masking 75% of image patches;
- vision foundation models like DINOv2 and how they produce general-purpose features;
- the practical recipe for using self-supervised pre-training in your own projects.
Requirements
- A working modern computer running macOS, Windows or Ubuntu;
- An installed Python 3(.11+) distribution;
- The ambition to learn AI and machine learning.
Difficulty
- Beginner
Curriculum (of the Learn AI Series):
- Learn AI Series (#1) - What Machine Learning Actually Is
- Learn AI Series (#2) - Setting Up Your AI Workbench - Python and NumPy
- Learn AI Series (#3) - Your Data Is Just Numbers - How Machines See the World
- Learn AI Series (#4) - Your First Prediction - No Math, Just Intuition
- Learn AI Series (#5) - Patterns in Data - What "Learning" Actually Looks Like
- Learn AI Series (#6) - From Intuition to Math - Why We Need Formulas
- Learn AI Series (#7) - The Training Loop - See It Work Step by Step
- Learn AI Series (#8) - The Math You Actually Need (Part 1) - Linear Algebra
- Learn AI Series (#9) - The Math You Actually Need (Part 2) - Calculus and Probability
- Learn AI Series (#10) - Your First ML Model - Linear Regression From Scratch
- Learn AI Series (#11) - Making Linear Regression Real
- Learn AI Series (#12) - Classification - Logistic Regression From Scratch
- Learn AI Series (#13) - Evaluation - How to Know If Your Model Actually Works
- Learn AI Series (#14) - Data Preparation - The 80% Nobody Talks About
- Learn AI Series (#15) - Feature Engineering and Selection
- Learn AI Series (#16) - Scikit-Learn - The Standard Library of ML
- Learn AI Series (#17) - Decision Trees - How Machines Make Decisions
- Learn AI Series (#18) - Random Forests - Wisdom of Crowds
- Learn AI Series (#19) - Gradient Boosting - The Kaggle Champion
- Learn AI Series (#20) - Support Vector Machines - Drawing the Perfect Boundary
- Learn AI Series (#21) - Mini Project - Predicting Crypto Market Regimes
- Learn AI Series (#22) - K-Means Clustering - Finding Groups
- Learn AI Series (#23) - Advanced Clustering - Beyond K-Means
- Learn AI Series (#24) - Dimensionality Reduction - PCA
- Learn AI Series (#25) - Advanced Dimensionality Reduction - t-SNE and UMAP
- Learn AI Series (#26) - Anomaly Detection - Finding What Doesn't Belong
- Learn AI Series (#27) - Recommendation Systems - "Users Like You Also Liked..."
- Learn AI Series (#28) - Time Series Fundamentals - When Order Matters
- Learn AI Series (#29) - Time Series Forecasting - Predicting What Comes Next
- Learn AI Series (#30) - Natural Language Processing - Text as Data
- Learn AI Series (#31) - Word Embeddings - Meaning in Numbers
- Learn AI Series (#32) - Bayesian Methods - Thinking in Probabilities
- Learn AI Series (#33) - Ensemble Methods Deep Dive - Stacking and Blending
- Learn AI Series (#34) - ML Engineering - From Notebook to Production
- Learn AI Series (#35) - Data Ethics and Bias in ML
- Learn AI Series (#36) - Mini Project - Complete ML Pipeline
- Learn AI Series (#37) - The Perceptron - Where It All Started
- Learn AI Series (#38) - Neural Networks From Scratch - Forward Pass
- Learn AI Series (#39) - Neural Networks From Scratch - Backpropagation
- Learn AI Series (#40) - Training Neural Networks - Practical Challenges
- Learn AI Series (#41) - Optimization Algorithms - SGD, Momentum, Adam
- Learn AI Series (#42) - PyTorch Fundamentals - Tensors and Autograd
- Learn AI Series (#43) - PyTorch Data and Training
- Learn AI Series (#44) - PyTorch nn.Module - Building Real Networks
- Learn AI Series (#45) - Convolutional Neural Networks - Theory
- Learn AI Series (#46) - CNNs in Practice - Classic to Modern Architectures
- Learn AI Series (#47) - CNN Applications - Detection, Segmentation, Style Transfer
- Learn AI Series (#48) - Recurrent Neural Networks - Sequences
- Learn AI Series (#49) - LSTM and GRU - Solving the Memory Problem
- Learn AI Series (#50) - Sequence-to-Sequence Models
- Learn AI Series (#51) - Attention Mechanisms
- Learn AI Series (#52) - The Transformer Architecture (Part 1)
- Learn AI Series (#53) - The Transformer Architecture (Part 2)
- Learn AI Series (#54) - Vision Transformers
- Learn AI Series (#55) - Generative Adversarial Networks
- Learn AI Series (#56) - Mini Project - Building a Transformer From Scratch
- Learn AI Series (#57) - Language Modeling - Predicting the Next Word
- Learn AI Series (#58) - GPT Architecture - Decoder-Only Transformers
- Learn AI Series (#59) - BERT and Encoder Models
- Learn AI Series (#60) - Training Large Language Models
- Learn AI Series (#61) - Instruction Tuning and Alignment
- Learn AI Series (#62) - Prompt Engineering - Getting the Most from LLMs
- Learn AI Series (#63) - Embeddings and Vector Search
- Learn AI Series (#64) - Retrieval-Augmented Generation (RAG) - Basics
- Learn AI Series (#65) - RAG - Advanced Techniques
- Learn AI Series (#66) - Working with LLM APIs
- Learn AI Series (#67) - Building AI Agents (Part 1) - Foundations
- Learn AI Series (#68) - Building AI Agents (Part 2) - Advanced Patterns
- Learn AI Series (#69) - Fine-Tuning Language Models
- Learn AI Series (#70) - Running Local Models
- Learn AI Series (#71) - Text Generation Techniques
- Learn AI Series (#72) - Tokenization Deep Dive
- Learn AI Series (#73) - LLM Evaluation
- Learn AI Series (#74) - The Hugging Face Ecosystem
- Learn AI Series (#75) - Multimodal Models - Text Meets Vision
- Learn AI Series (#76) - Mini Project - Your Own AI Assistant
- Learn AI Series (#77) - Image Processing Fundamentals
- Learn AI Series (#78) - Object Detection (Part 1) - Foundations
- Learn AI Series (#79) - Object Detection (Part 2) - Modern Approaches
- Learn AI Series (#80) - Image Segmentation
- Learn AI Series (#81) - Pose Estimation and Tracking
- Learn AI Series (#82) - Optical Character Recognition
- Learn AI Series (#83) - Video Understanding
- Learn AI Series (#84) - Generative Images - Diffusion Models (Part 1)
- Learn AI Series (#85) - Generative Images - Diffusion Models (Part 2)
- Learn AI Series (#86) - Image-to-Image and Editing
- Learn AI Series (#87) - 3D Vision
- Learn AI Series (#88) - Face Analysis
- Learn AI Series (#89) - Medical and Scientific Imaging
- Learn AI Series (#90) - Self-Supervised Learning for Vision (this post)
Learn AI Series (#90) - Self-Supervised Learning for Vision
Solutions to Episode #89 Exercises
Exercise 1: Medical dataset augmentation validator.
import numpy as np
from scipy.ndimage import gaussian_filter
from scipy.ndimage import map_coordinates
from scipy.ndimage import rotate as ndimage_rotate
class AugmentationValidator:
"""Validate augmentations for medical images
by checking anatomical consistency and
signal preservation."""
def __init__(self, size=64, seed=42):
self.size = size
self.rng = np.random.RandomState(seed)
self.original = self._make_xray()
def _make_xray(self):
img = self.rng.normal(
128, 15, (self.size, self.size)
).astype(np.float64)
y, x = np.ogrid[:self.size, :self.size]
# Heart: left half, brighter
hx, hy, hr = 20, 32, 10
heart = ((x - hx) ** 2
+ (y - hy) ** 2) <= hr ** 2
img[heart] = 200
# Nodule: right half, smaller
nx, ny, nr = 45, 30, 5
nodule = ((x - nx) ** 2
+ (y - ny) ** 2) <= nr ** 2
img[nodule] = 180
return np.clip(img, 0, 255)
def check_anatomical_consistency(
self, original, augmented):
mid = original.shape[1] // 2
orig_left = original[:, :mid].sum()
orig_right = original[:, mid:].sum()
aug_left = augmented[:, :mid].sum()
aug_right = augmented[:, mid:].sum()
orig_brighter = (
"left" if orig_left > orig_right
else "right")
aug_brighter = (
"left" if aug_left > aug_right
else "right")
return orig_brighter == aug_brighter
def signal_preservation(self, original,
augmented):
a = original.flatten()
b = augmented.flatten()
a = a - a.mean()
b = b - b.mean()
denom = (np.sqrt((a ** 2).sum()
* (b ** 2).sum()))
if denom < 1e-12:
return 0.0
return float(np.dot(a, b) / denom)
def augment_rotate(self, img, angle=10):
return ndimage_rotate(
img, angle, reshape=False,
mode='nearest')
def augment_hflip(self, img):
return img[:, ::-1].copy()
def augment_elastic(self, img,
alpha=30, sigma=4):
shape = img.shape
dx = gaussian_filter(
self.rng.randn(*shape),
sigma) * alpha
dy = gaussian_filter(
self.rng.randn(*shape),
sigma) * alpha
yy, xx = np.mgrid[
0:shape[0], 0:shape[1]]
indices = [
np.clip(yy + dy, 0, shape[0] - 1),
np.clip(xx + dx, 0, shape[1] - 1)]
return map_coordinates(
img, indices, order=1)
def augment_brightness(self, img, factor=1.2):
return np.clip(img * factor, 0, 255)
def augment_noise(self, img, sigma=0.05):
noise = self.rng.randn(
*img.shape) * sigma * 255
return np.clip(img + noise, 0, 255)
def run(self):
augmentations = [
("Rotation 10deg",
self.augment_rotate),
("Horizontal flip",
self.augment_hflip),
("Elastic deform",
self.augment_elastic),
("Brightness +20%",
self.augment_brightness),
("Gaussian noise",
self.augment_noise),
]
print(f"{'Augmentation':<20} "
f"{'Anatomy':>8} {'Corr':>7}")
print("-" * 38)
for name, fn in augmentations:
aug = fn(self.original)
consistent = (
self.check_anatomical_consistency(
self.original, aug))
corr = self.signal_preservation(
self.original, aug)
status = "PASS" if consistent else "FAIL"
print(f"{name:<20} "
f"{status:>8} {corr:>7.4f}")
validator = AugmentationValidator()
validator.run()
Horizontal flip is the only augmentation that fails the anatomical consistency check -- flipping swaps left and right, moving the "heart" circle to the right half of the image. All other augmentations preserve the left-right intensity distribution. Brightness shift has the highest correlation (it's a linear scaling of pixel values, preserving all spatial structure perfectly), while elastic deformation has a lower correlation because it physically warps pixel positions. Gaussian noise also drops correlation slightly since it adds random variation at every pixel. The key takeaway: horizontal flipping is genuinely dangerous for medical images because it violates anatomical assumptions that any downstream classifier will rely on.
Exercise 2: Class imbalance strategy comparator.
import numpy as np
from sklearn.linear_model import (
LogisticRegression)
from sklearn.model_selection import (
train_test_split)
from sklearn.metrics import (
roc_auc_score, f1_score,
confusion_matrix)
class ImbalanceComparator:
"""Compare strategies for handling class
imbalance in binary classification."""
def __init__(self, n_total=1000,
n_positive=30, n_dims=10,
seed=42):
rng = np.random.RandomState(seed)
n_neg = n_total - n_positive
X_neg = rng.normal(
0.0, 1.0, (n_neg, n_dims))
X_pos = rng.normal(
1.0, 1.5, (n_positive, n_dims))
self.X = np.vstack([X_neg, X_pos])
self.y = np.array(
[0] * n_neg + [1] * n_positive)
def random_oversample(self, X, y, seed=42):
rng = np.random.RandomState(seed)
pos_idx = np.where(y == 1)[0]
neg_idx = np.where(y == 0)[0]
n_needed = len(neg_idx) - len(pos_idx)
extra = rng.choice(
pos_idx, size=n_needed, replace=True)
X_out = np.vstack([X, X[extra]])
y_out = np.concatenate([y, y[extra]])
return X_out, y_out
def evaluate(self, model, X_test, y_test):
y_prob = model.predict_proba(
X_test)[:, 1]
y_pred = model.predict(X_test)
cm = confusion_matrix(y_test, y_pred)
tn, fp, fn, tp = cm.ravel()
sens = tp / max(tp + fn, 1)
spec = tn / max(tn + fp, 1)
auc = roc_auc_score(y_test, y_prob)
f1 = f1_score(y_test, y_pred)
# Find threshold for 90% sensitivity
thresholds = np.linspace(0, 1, 500)
best_t, best_spec = 0.5, spec
for t in thresholds:
preds = (y_prob >= t).astype(int)
s = preds[y_test == 1].sum() / max(
(y_test == 1).sum(), 1)
if s >= 0.90:
sp = (1 - preds[
y_test == 0]).sum() / max(
(y_test == 0).sum(), 1)
best_t = t
best_spec = sp
break
return {
"sens": sens, "spec": spec,
"auc": auc, "f1": f1,
"t90_thresh": best_t,
"t90_spec": best_spec}
def run(self):
X_tr, X_te, y_tr, y_te = (
train_test_split(
self.X, self.y, test_size=0.2,
random_state=42, stratify=self.y))
# Strategy 1: Default
m1 = LogisticRegression(
max_iter=1000, random_state=42)
m1.fit(X_tr, y_tr)
# Strategy 2: Balanced weights
m2 = LogisticRegression(
class_weight='balanced',
max_iter=1000, random_state=42)
m2.fit(X_tr, y_tr)
# Strategy 3: Random oversampling
X_os, y_os = self.random_oversample(
X_tr, y_tr)
m3 = LogisticRegression(
max_iter=1000, random_state=42)
m3.fit(X_os, y_os)
results = {
"Default": self.evaluate(
m1, X_te, y_te),
"Balanced": self.evaluate(
m2, X_te, y_te),
"Oversample": self.evaluate(
m3, X_te, y_te),
}
print(f"{'Strategy':<12} {'Sens':>6} "
f"{'Spec':>6} {'AUC':>6} "
f"{'F1':>6}")
print("-" * 40)
for name, r in results.items():
print(f"{name:<12} "
f"{r['sens']:>6.3f} "
f"{r['spec']:>6.3f} "
f"{r['auc']:>6.3f} "
f"{r['f1']:>6.3f}")
print(f"\nAt 90% sensitivity:")
print(f"{'Strategy':<12} {'Thresh':>7} "
f"{'Spec':>6}")
print("-" * 28)
for name, r in results.items():
print(f"{name:<12} "
f"{r['t90_thresh']:>7.3f} "
f"{r['t90_spec']:>6.3f}")
comp = ImbalanceComparator()
comp.run()
The default (unweighted) classifier achieves high specificity but poor sensitivity -- it rarely predicts positive because the training data is 97% negative, so predicting "negative" is almost always correct. Balanced weighting and oversampling both improve sensitivity substantially by forcing the model to pay more attention to the minority class, at the cost of some specificity. The AUC values are similar across strategies because AUC measures ranking quality (which doesn't depend on the threshold), while sensitivity, specificity, and F1 all depend on the decision threshold. The 90%-sensitivity analysis shows that all strategies can reach 90% sensitivity, but they require different thresholds and sacrifice different amounts of specificity to get there.
Exercise 3: Multi-channel scientific image normalizer and analyzer.
import numpy as np
class ScientificImageAnalyzer:
"""Normalize and analyze multi-channel
scientific images (satellite, microscopy,
etc.)."""
CHANNEL_NAMES = [
"Red", "Green", "Blue",
"NIR-1", "NIR-2", "Thermal",
"Radar-1", "Radar-2"]
def __init__(self, size=128, seed=42):
rng = np.random.RandomState(seed)
self.image = np.zeros((8, size, size))
# RGB: 0-255
for c in range(3):
self.image[c] = rng.uniform(
20, 235, (size, size))
# NIR: 500-3000
self.image[3] = rng.uniform(
500, 3000, (size, size))
self.image[4] = rng.uniform(
500, 3000, (size, size))
# Thermal (Kelvin): 250-320
self.image[5] = rng.uniform(
250, 320, (size, size))
# Radar (dB): -25 to 5
self.image[6] = rng.uniform(
-25, 5, (size, size))
self.image[7] = rng.uniform(
-25, 5, (size, size))
def per_channel_normalize(self, image):
result = np.zeros_like(
image, dtype=np.float64)
for c in range(image.shape[0]):
mu = image[c].mean()
std = image[c].std()
if std < 1e-8:
std = 1.0
result[c] = (image[c] - mu) / std
return result
def percentile_clip(self, image,
low_pct=2, high_pct=98):
result = np.zeros_like(
image, dtype=np.float64)
for c in range(image.shape[0]):
lo = np.percentile(
image[c], low_pct)
hi = np.percentile(
image[c], high_pct)
clipped = np.clip(
image[c], lo, hi)
rng = hi - lo
if rng > 1e-8:
result[c] = (
(clipped - lo) / rng)
return result
def compute_ndvi(self, image):
red = image[0]
nir = image[3]
return (nir - red) / (
nir + red + 1e-8)
def channel_correlation(self, image,
c1, c2):
a = image[c1].flatten()
b = image[c2].flatten()
a = a - a.mean()
b = b - b.mean()
denom = np.sqrt(
(a ** 2).sum() * (b ** 2).sum())
if denom < 1e-12:
return 0.0
return float(np.dot(a, b) / denom)
def run(self):
normed = self.per_channel_normalize(
self.image)
clipped = self.percentile_clip(
self.image)
print("Per-channel statistics:")
print(f"{'Ch':<8} {'RawMean':>10} "
f"{'RawStd':>10} {'NormMean':>10} "
f"{'NormStd':>10}")
print("-" * 52)
for c in range(8):
print(f"{self.CHANNEL_NAMES[c]:<8} "
f"{self.image[c].mean():>10.2f} "
f"{self.image[c].std():>10.2f} "
f"{normed[c].mean():>10.6f} "
f"{normed[c].std():>10.6f}")
# NDVI analysis
ndvi_raw = self.compute_ndvi(self.image)
ndvi_clip = self.compute_ndvi(clipped)
veg_raw = (ndvi_raw > 0.3).mean()
veg_clip = (ndvi_clip > 0.3).mean()
print(f"\nNDVI (raw): mean={ndvi_raw.mean():.4f}"
f", std={ndvi_raw.std():.4f}"
f", veg fraction={veg_raw:.4f}")
print(f"NDVI (clip): mean="
f"{ndvi_clip.mean():.4f}"
f", std={ndvi_clip.std():.4f}"
f", veg fraction={veg_clip:.4f}")
# Correlation preservation
pairs = [(0, 1), (0, 3), (3, 4),
(5, 6)]
print(f"\nCorrelation preservation:")
print(f"{'Pair':<14} {'Raw':>8} "
f"{'Normed':>8}")
print("-" * 32)
for c1, c2 in pairs:
cr = self.channel_correlation(
self.image, c1, c2)
cn = self.channel_correlation(
normed, c1, c2)
name = (f"{self.CHANNEL_NAMES[c1]}-"
f"{self.CHANNEL_NAMES[c2]}")
print(f"{name:<14} {cr:>8.5f} "
f"{cn:>8.5f}")
analyzer = ScientificImageAnalyzer()
analyzer.run()
Per-channel normalization brings every channel to mean approximately 0 and standard deviation approximately 1, regardless of whether the original values ranged from 0-255 (RGB), 500-3000 (NIR), 250-320 (thermal), or -25 to 5 (radar). The NDVI values computed from the raw image are dominated by the NIR channel (which has values in the hundreds to thousands) relative to the Red channel (which maxes out at 255), so most pixels show NDVI well above 0.3. The correlation between channel pairs is preserved exactly under z-score normalization because z-scoring is a linear transformation that doesn't change the Pearson correlation coefficient -- it only shifts and scales each variable independently. This confirms that normalization preserves the inter-channel relationships that downstream models rely on for tasks like vegetation mapping and land-use classification.
On to today's episode
Here we go! Episode ninety -- a nice round number, and we've arrived at something I consider one of the most important developments in modern computer vision. We've been building up to this for a LONG time.
Think about everything we've done in the computer vision arc so far. From basic image processing (#77) through object detection (#78-79), segmentation (#80), pose estimation (#81), OCR (#82), video (#83), diffusion models (#84-85), image editing (#86), 3D vision (#87), face analysis (#88), and medical/scientific imaging (#89). That's fourteen episodes of increasingly sophisticated vision techniques. And all of them shared one common assumption: you need labeled data to train your models.
ImageNet has 14 million images with human-assigned category labels. COCO has 330,000 images with hand-drawn bounding boxes and segmentation masks. Medical datasets (as we saw last episode) can cost $150 per hour per specialist annotator. The label bottleneck is REAL, and it has been the single biggest constraint on how far supervised learning can scale.
But what if you could learn powerful visual representations without any labels at all?
That's exactly what self-supervised learning does. And it's the same paradigm shift that happened in NLP -- remember how BERT (episode #59) pre-trains on massive amounts of unlabeled text using masked language modeling, then fine-tunes on small labeled datasets? Self-supervised learning brings that revolution to vision. The internet has billions of unlabeled images sitting there, doing nothing useful for model training under the supervised paradigm. Self-supervised methods turn all of that unlabeled data into a training signal ;-)
The label bottleneck
Before we get into solutions, let's be concrete about the problem. Consider these numbers:
import numpy as np
class LabelBottleneckAnalyzer:
"""Quantify the gap between available
images and available labels."""
def __init__(self):
self.datasets = {
"ImageNet": {
"images": 14_000_000,
"labeled": 14_000_000,
"cost_per_label": 0.01,
"years_to_label": 3},
"COCO": {
"images": 330_000,
"labeled": 330_000,
"cost_per_label": 0.50,
"years_to_label": 2},
"Medical (chest)": {
"images": 500,
"labeled": 500,
"cost_per_label": 37.50,
"years_to_label": 0.5},
"Internet images": {
"images": 5_000_000_000,
"labeled": 0,
"cost_per_label": 0.01,
"years_to_label": None},
}
def analyze(self):
print(f"{'Dataset':<18} {'Images':>14} "
f"{'Labeled':>14} {'Cost':>14}")
print("-" * 64)
for name, d in self.datasets.items():
cost = (d['labeled']
* d['cost_per_label'])
print(f"{name:<18} "
f"{d['images']:>14,} "
f"{d['labeled']:>14,} "
f"${cost:>13,.0f}")
# What would it cost to label
# 1% of internet images?
one_pct = 5_000_000_000 * 0.01
cost_cheap = one_pct * 0.01
print(f"\nLabeling 1% of internet "
f"images at $0.01 each:")
print(f" {one_pct:,.0f} images = "
f"${cost_cheap:,.0f}")
print(f" That's 50 million dollars "
f"for 1 percent.")
analyzer = LabelBottleneckAnalyzer()
analyzer.analyze()
Five billion images on the internet, and labeling even one percent of them at the cheapest possible rate costs fifty million dollars. This is the fundamental motivation for self-supervised learning: extract training signal from the structure of the data itself, without any human annotation. The labels come "for free" from the data.
Contrastive learning: same image, different views
The core idea behind contrastive learning is beautifully simple. Take an image, create two different augmented "views" of it (random crops, color jitter, blurring, flipping), and train the network to recognize that these two views came from the same source. Views of different images should produce different representations.
SimCLR (Chen et al., 2020) is the cleanest implementation of this idea. It's the one I'd recommend reading first if you want to understand the paradigm:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
# Two random augmentations of the same image
# produce a "positive pair"
simclr_augment = T.Compose([
T.RandomResizedCrop(
224, scale=(0.2, 1.0)),
T.RandomHorizontalFlip(),
T.RandomApply([T.ColorJitter(
0.4, 0.4, 0.4, 0.1)], p=0.8),
T.RandomGrayscale(p=0.2),
T.GaussianBlur(
kernel_size=23, sigma=(0.1, 2.0)),
T.ToTensor(),
T.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
class SimCLR(nn.Module):
"""SimCLR: contrastive learning with
augmented image views."""
def __init__(self, backbone, proj_dim=128):
super().__init__()
self.backbone = backbone
self.feature_dim = (
backbone.fc.in_features)
backbone.fc = nn.Identity()
self.projector = nn.Sequential(
nn.Linear(self.feature_dim,
self.feature_dim),
nn.ReLU(),
nn.Linear(self.feature_dim,
proj_dim),
)
def forward(self, x):
features = self.backbone(x)
projections = self.projector(features)
return F.normalize(projections, dim=1)
The training procedure works like this: for each batch of N images, you apply simclr_augment twice to create 2N views. Each image's two views form a positive pair (they should have similar representations). All other 2(N-1) views are negatives (they should have different representations). The loss function pushes positives together and negatives apart:
def nt_xent_loss(z1, z2, temperature=0.5):
"""NT-Xent loss: Normalized Temperature-
scaled Cross Entropy. The contrastive
loss used by SimCLR."""
batch_size = z1.shape[0]
z = torch.cat([z1, z2], dim=0)
# Cosine similarity matrix: (2N, 2N)
sim = torch.mm(z, z.t()) / temperature
# Mask out self-similarity (diagonal)
mask = torch.eye(
2 * batch_size, device=z.device).bool()
sim.masked_fill_(mask, -1e9)
# Positive pairs: (i, i+N) and (i+N, i)
pos_indices = torch.arange(
batch_size, device=z.device)
labels = torch.cat([
pos_indices + batch_size,
pos_indices])
return F.cross_entropy(sim, labels)
# Quick test
z1 = F.normalize(torch.randn(32, 128), dim=1)
z2 = F.normalize(torch.randn(32, 128), dim=1)
loss = nt_xent_loss(z1, z2)
print(f"NT-Xent loss: {loss.item():.4f}")
print(f"Batch size: 32, total views: 64")
print(f"Negatives per view: 62")
The temperature parameter controls how sharply the loss differentiates between similar and dissimilar pairs. Lower temperature makes the loss more sensitive to small differences in similarity (sharper contrastive signal, but harder to optimize). The original SimCLR paper uses 0.5, and small changes in temperature can meaningfully affect final performance.
The catch: SimCLR requires large batch sizes (4096+) to work well because it needs many negatives in each batch. With a batch size of 32, you only have 62 negatives -- not enough diversity for the model to learn useful representations. This is computationally demanding. Not every lab has access to 128 TPU cores.
MoCo (Momentum Contrast, He et al., 2019) solves this elegantly with a queue of recent embeddings and a momentum encoder:
class MoCo(nn.Module):
"""MoCo: Momentum Contrast. Uses a queue
of negatives and a slowly-updated key
encoder to decouple batch size from
negative count."""
def __init__(self, backbone, dim=128,
K=65536, m=0.999, T=0.07):
super().__init__()
self.K = K
self.m = m
self.T = T
self.encoder_q = backbone
self.encoder_k = self._copy(backbone)
for p in self.encoder_k.parameters():
p.requires_grad = False
self.register_buffer(
"queue",
F.normalize(
torch.randn(dim, K), dim=0))
self.register_buffer(
"queue_ptr",
torch.zeros(1, dtype=torch.long))
def _copy(self, model):
import copy
return copy.deepcopy(model)
@torch.no_grad()
def momentum_update(self):
"""Slowly update key encoder toward
query encoder. k = m*k + (1-m)*q"""
for p_q, p_k in zip(
self.encoder_q.parameters(),
self.encoder_k.parameters()):
p_k.data = (self.m * p_k.data
+ (1 - self.m)
* p_q.data)
@torch.no_grad()
def enqueue(self, keys):
batch_size = keys.shape[0]
ptr = int(self.queue_ptr)
self.queue[:, ptr:ptr + batch_size] = (
keys.T)
ptr = (ptr + batch_size) % self.K
self.queue_ptr[0] = ptr
The momentum coefficient m=0.999 means the key encoder changes very slowly -- only 0.1% of its weights move toward the query encoder at each step. This provides consistent embeddings for the queue: if the key encoder changed rapidly, old entries in the queue would be stale and useless. With the queue storing 65,536 recent key embeddings, even a batch size of 256 gives you 65,536 negatives -- 256x more than SimCLR would get with the same batch size. MoCo was the first method to match supervised ImageNet pretraining with self-supervised features.
Beyond negatives: BYOL and DINO
A genuinely suprising discovery followed: you don't actually need negative pairs at all.
BYOL (Bootstrap Your Own Latent, Grill et al., 2020) uses two networks: an online network and a target network (momentum-updated, like MoCo's key encoder). The online network predicts the target network's output for a different view of the same image. No negatives, no contrastive loss -- just match your own target:
class BYOL(nn.Module):
"""BYOL: learns representations without
negative pairs. The online network
predicts the target network's output
through an asymmetric predictor."""
def __init__(self, backbone,
proj_dim=256, pred_dim=4096,
feat_dim=2048):
super().__init__()
import copy
self.online = backbone
self.target = copy.deepcopy(backbone)
for p in self.target.parameters():
p.requires_grad = False
self.online_proj = nn.Sequential(
nn.Linear(feat_dim, pred_dim),
nn.BatchNorm1d(pred_dim),
nn.ReLU(),
nn.Linear(pred_dim, proj_dim))
# Predictor: ONLY the online network
# has this -- the asymmetry is what
# prevents collapse
self.predictor = nn.Sequential(
nn.Linear(proj_dim, pred_dim),
nn.BatchNorm1d(pred_dim),
nn.ReLU(),
nn.Linear(pred_dim, proj_dim))
self.target_proj = nn.Sequential(
nn.Linear(feat_dim, pred_dim),
nn.BatchNorm1d(pred_dim),
nn.ReLU(),
nn.Linear(pred_dim, proj_dim))
def forward(self, view1, view2):
online_z = self.online_proj(
self.online(view1))
online_p = self.predictor(online_z)
with torch.no_grad():
target_z = self.target_proj(
self.target(view2))
loss = 2 - 2 * F.cosine_similarity(
online_p,
target_z.detach(),
dim=-1).mean()
return loss
The obvious question: why doesn't this collapse? If the model maps everything to the same constant vector, the cosine similarity is always 1.0 and the loss is always 0 -- seems like a trivial optimum. Two factors prevent collapse: the momentum update provides a slowly-moving target that the online network can't trivially match by producing constants (because the target keeps changing), and the asymmetric predictor creates an information bottleneck that forces meaningful representations. Having said that, the theoretical understanding of why BYOL works is still incomplete -- it's one of those results where the empirical success came before the full theoretical explanation.
DINO (Self-Distillation with No Labels, Caron et al., 2021) builds on this by adding a centering operation to explicitly prevent collapse and extending the approach to Vision Transformers. The most remarkable finding: ViTs trained with DINO spontaneously learn to segment objects in their self-attention maps. The attention heads in the final layer cleanly separate foreground from background, without EVER seeing a segmentation label. This emergent segmentation ability does NOT appear in supervised training -- it's a property unique to the self-supervised objective ;-)
Masked image modeling: BERT for vision
If you remember BERT from episode #59, the idea was simple: mask out some words in a sentence, and train the model to predict the missing words from context. MAE (Masked Autoencoders, He et al., 2022) applies the exact same idea to images. Mask out patches of the image, and train a ViT to reconstruct them.
The big difference from BERT: the masking ratio. BERT masks 15% of tokens. MAE masks 75% of image patches. Seventy-five percent! This works because images have far more spatial redundancy than text. A missing word in a sentence can completely change the meaning. A missing patch in an image can often be inferred from its neighbors through texture continuation, symmetry, and object priors:
class MAE(nn.Module):
"""Simplified Masked Autoencoder.
Masks 75% of image patches and trains
a ViT to reconstruct them."""
def __init__(self, encoder, decoder,
mask_ratio=0.75):
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.mask_ratio = mask_ratio
def random_masking(self, x, mask_ratio):
"""Remove random patches."""
N, L, D = x.shape
keep = int(L * (1 - mask_ratio))
noise = torch.rand(N, L, device=x.device)
ids_shuffle = torch.argsort(
noise, dim=1)
ids_keep = ids_shuffle[:, :keep]
x_masked = torch.gather(
x, 1,
ids_keep.unsqueeze(-1).expand(
-1, -1, D))
return x_masked, ids_keep, ids_shuffle
def forward(self, images):
patches = self.encoder.patch_embed(
images)
# Mask 75% -- remove entirely
visible, ids_keep, ids_shuffle = (
self.random_masking(
patches, self.mask_ratio))
# Encode only visible patches
# (25% of tokens -- FAST)
encoded = self.encoder(visible)
# Decode: insert mask tokens,
# reconstruct all patches
decoded = self.decoder(
encoded, ids_shuffle)
return decoded
The efficiency win is massive: during pre-training, the encoder only processes 25% of tokens. That's a 3-4x speedup compared to processing the full image. After pre-training, you throw away the decoder entirely -- it was just a training tool. The encoder becomes your feature extractor for downstream tasks.
The high masking ratio also forces the model to learn genuine semantic understanding rather than cheap local interpolation. If only a few scattered patches remain visible, the model can't reconstruct a masked region just by copying textures from immediate neighbors. It has to understand what the object IS -- "this is a dog, so the masked region should contain fur with this coloring" -- which requires high-level semantic features.
Evaluating self-supervised representations
How do you know if your self-supervised features are any good? You can't just look at training loss (which measures reconstruction or contrastive accuracy, not downstream usefulness). The standard evaluation protocol is linear probing: freeze the pre-trained backbone, attach a single linear layer on top, and train ONLY that linear layer on a labeled dataset:
import torch
import torch.nn as nn
import torch.nn.functional as F
class LinearProbeEvaluator:
"""Evaluate pre-trained representations
with a frozen backbone + linear head.
The gold standard for measuring feature
quality from self-supervised models."""
def __init__(self, backbone, feature_dim,
num_classes, lr=1e-3):
self.backbone = backbone
for param in backbone.parameters():
param.requires_grad = False
self.backbone.eval()
self.head = nn.Linear(
feature_dim, num_classes)
self.optimizer = torch.optim.Adam(
self.head.parameters(), lr=lr)
def train_epoch(self, loader):
self.head.train()
total_loss = 0
correct = 0
total = 0
for images, labels in loader:
with torch.no_grad():
features = self.backbone(images)
logits = self.head(features)
loss = F.cross_entropy(
logits, labels)
loss.backward()
self.optimizer.step()
self.optimizer.zero_grad()
total_loss += loss.item()
correct += (
logits.argmax(1) == labels
).sum().item()
total += len(labels)
return total_loss / len(loader), (
correct / total)
def evaluate(self, loader):
self.head.eval()
correct = 0
total = 0
with torch.no_grad():
for images, labels in loader:
features = self.backbone(images)
logits = self.head(features)
correct += (
logits.argmax(1) == labels
).sum().item()
total += len(labels)
return correct / total
# Usage pattern:
# backbone = load_pretrained_dinov2()
# evaluator = LinearProbeEvaluator(
# backbone, feature_dim=384,
# num_classes=1000)
# for epoch in range(100):
# loss, acc = evaluator.train_epoch(
# imagenet_train)
# test_acc = evaluator.evaluate(
# imagenet_val)
print("Linear probe: freeze backbone, "
"train only a linear layer")
print("Top-1 accuracy on ImageNet:")
print(" Supervised ResNet50: ~76.5%")
print(" SimCLR: ~69.3%")
print(" MoCo v2: ~71.1%")
print(" BYOL: ~74.3%")
print(" DINO (ViT-S): ~77.0%")
print(" MAE (ViT-L): ~75.8%")
print(" DINOv2 (ViT-g): ~83.5%")
The progression is remarkable. Early contrastive methods (SimCLR, 2020) lagged behind supervised training by ~7 points. BYOL closed the gap to ~2 points. DINO actually surpassed supervised training. And DINOv2, at scale, demolishes the supervised baseline by 7 points. Self-supervised pre-training has gone from "interesting research direction" to "strictly better than supervised pre-training" in just four years.
Vision foundation models: DINOv2
Self-supervised pre-training at scale produces foundation models -- general-purpose feature extractors that work across tasks without task-specific training. This is the vision equivalent of what GPT and BERT are for language.
DINOv2 (Meta, 2023) scales DINO training to 142 million curated images with a ViT-Giant backbone (1.1 billion parameters). The resulting features work out of the box for classification, segmentation, depth estimation, and retrieval -- often without ANY fine-tuning:
import torch
import torch.nn as nn
# DINOv2: pretrained vision foundation model
dinov2 = torch.hub.load(
'facebookresearch/dinov2',
'dinov2_vits14')
dinov2.eval()
# Extract features from any image
with torch.no_grad():
# Global feature: (1, 384)
features = dinov2(
torch.randn(1, 3, 224, 224))
print(f"Global feature: {features.shape}")
# These features work directly for:
# - k-NN classification (no training!)
# - Linear probing (one linear layer)
# - Dense prediction (reshape patches)
# - Image retrieval (cosine similarity)
# The practical recipe
print("\nPractical workflow:")
print("1. Load DINOv2 (or CLIP, or MAE)")
print("2. Extract features from your data")
print("3. Try k-NN first (zero training)")
print("4. Try linear probe (minutes)")
print("5. Fine-tune only if needed (hours)")
The "few labels" scenario is where self-supervised pre-training really shines. With 100% of ImageNet labels, supervised and self-supervised are comparable. With 10% of labels, self-supervised pulls ahead significantly. With 1% of labels (just 13 images per class!), self-supervised methods massively outperform training from scratch. This has direct practical consequences for domains like medical imaging (episode #89), where labeled data is scarce and expensive -- exactly the scenario where the pre-train-then-fine-tune paradigm delivers the biggest wins.
The practical recipe
For most practitioners working on real problems, the workflow is straightforward:
import torch
import torch.nn as nn
import torch.nn.functional as F
def practical_ssl_pipeline(
backbone_name='dinov2_vits14',
num_classes=10, num_epochs=20):
"""The recipe most practitioners should
follow for using self-supervised models."""
# Step 1: Load foundation model
backbone = torch.hub.load(
'facebookresearch/dinov2',
backbone_name)
# Step 2: Freeze backbone
for param in backbone.parameters():
param.requires_grad = False
backbone.eval()
# Step 3: Linear probe
feature_dim = backbone.embed_dim
linear_head = nn.Linear(
feature_dim, num_classes)
optimizer = torch.optim.Adam(
linear_head.parameters(), lr=1e-3)
print(f"Backbone: {backbone_name}")
print(f"Feature dim: {feature_dim}")
print(f"Backbone params: "
f"{sum(p.numel() for p in backbone.parameters()):,}"
f" (frozen)")
print(f"Head params: "
f"{sum(p.numel() for p in linear_head.parameters()):,}"
f" (trainable)")
print(f"Ratio: "
f"{sum(p.numel() for p in linear_head.parameters()) / sum(p.numel() for p in backbone.parameters()) * 100:.3f}%"
f" trainable")
# Step 4: If linear probe isn't enough,
# unfreeze backbone with small LR
# optimizer = torch.optim.Adam([
# {'params': backbone.parameters(),
# 'lr': 1e-5},
# {'params': linear_head.parameters(),
# 'lr': 1e-3}])
return backbone, linear_head
backbone, head = practical_ssl_pipeline()
Start with linear probing -- it's fast, doesn't require much data, and tells you how good the features already are. If linear probing gives you 90%+ accuracy, you probably don't need fine-tuning. If it gives you 70%, fine-tuning the backbone with a very small learning rate (100x lower than the head) will usually close the gap. The key insight: the fewer labeled examples you have, the more the pre-trained features matter and the less you should fine-tune (because fine-tuning on tiny datasets leads to overfitting).
Samengevat
- Self-supervised learning trains visual representations from unlabeled data by designing pretext tasks that generate labels from the data itself, eliminating the annotation bottleneck that limits supervised learning;
- contrastive learning (SimCLR, MoCo) trains encoders to produce similar representations for augmented views of the same image and different representations for different images; SimCLR needs large batches (4096+) while MoCo decouples batch size from negative count using a momentum encoder and a queue of 65K embeddings;
- BYOL and DINO show that negative pairs aren't necessary: momentum-updated target networks with asymmetric predictors prevent collapse; DINO-trained ViTs spontaneously learn object segmentation in their attention maps without ever seeing segmentation labels;
- masked image modeling (MAE) masks 75% of image patches and trains a ViT to reconstruct them, mirroring BERT's approach for text; the high masking ratio forces semantic understanding rather than local interpolation, and encoding only 25% of patches makes pre-training 3-4x faster;
- linear probing is the standard evaluation protocol: freeze the backbone, train only a linear classifier; DINOv2's features achieve 83.5% top-1 on ImageNet with just a linear probe, surpassing supervised baselines by 7 points;
- vision foundation models (DINOv2, CLIP, MAE) pre-train at massive scale on curated unlabeled data, producing features that generalize across classification, segmentation, depth estimation, and retrieval; the fewer labels you have, the more pre-training helps;
- the practical recipe: start with a foundation model, try linear probing first, fine-tune only if needed with a small learning rate on the backbone; this approach consistently outperforms training from scratch, especially in low-data regimes like medical imaging.
We've reached the end of the computer vision arc with this episode. Ninety episodes in, and we've gone from "what is machine learning" all the way through the complete vision stack -- pixels, convolutions, detection, segmentation, generation, 3D, faces, medical imaging, and now self-supervised representations that tie it all together. The vision foundation model paradigm we covered today is where the field is heading: pre-train once on massive unlabeled data, then adapt to any task with minimal supervision. It's the same trajectory that language models followed, and the results speak for themselves.
Exercises
Exercise 1: Build a contrastive learning augmentation analyzer. Create a class AugmentationAnalyzer that: (a) generates a synthetic 64x64 RGB image as a numpy array with distinct features -- a red circle in the top-left quadrant, a blue rectangle in the bottom-right, and a green diagonal stripe across the middle, all against a grey background (value 128), (b) implements 5 augmentations: (1) random crop to 48x48 then resize back to 64x64, (2) horizontal flip, (3) color jitter (random brightness +/-20%), (4) Gaussian blur (sigma=2.0), (5) grayscale conversion (average RGB channels), (c) for each pair of augmentations (25 combinations including self-pairs), computes the overlap coefficient: the Pearson correlation between the flattened augmented images. This measures how much information is shared between augmented views, (d) prints a 5x5 matrix showing the average overlap coefficient for each augmentation pair, (e) identifies which augmentation pair produces the lowest overlap (hardest positive pairs) and which produces the highest (easiest positive pairs). Verify that self-pairs (same augmentation applied twice with different random seeds) have high but not perfect correlation, and that combining crop with color jitter produces lower overlap than combining blur with brightness.
Exercise 2: Build a momentum encoder dynamics simulator. Create a class MomentumSimulator that: (a) simulates a 1D "encoder" as a vector of 100 parameters (initialized to zeros for the online encoder, ones for the momentum encoder), (b) at each training step, the online encoder's parameters are updated by adding Gaussian noise with std=0.1 (simulating gradient updates), (c) the momentum encoder is updated via the EMA rule: target = m * target + (1 - m) * online, (d) runs 500 steps for momentum values [0.9, 0.99, 0.999, 0.9999], (e) for each momentum value, tracks: the L2 distance between online and momentum encoders at each step, the "staleness" (how many steps behind the momentum encoder effectively is, estimated as the step at which the online encoder was closest to the current momentum encoder), and the parameter variance of the momentum encoder over time (lower = more stable), (f) prints a comparison table and identifies the momentum value that best balances being close to the online encoder (low lag) while maintaining stability (low variance). Verify that higher momentum means more stable but more stale representations, and that m=0.999 (MoCo's default) is a reasonable tradeoff.
Exercise 3: Build a masking strategy comparator for MAE. Create a class MaskingComparator that: (a) creates a synthetic 8x8 "image" represented as a 64-element vector where each element is a "patch" with a unique value from 1 to 64, (b) implements three masking strategies: (1) random masking: randomly select patches to mask, (2) block masking: mask a contiguous rectangular region, (3) grid masking: mask every Nth patch in a regular pattern, (c) for each strategy and masking ratios [0.25, 0.50, 0.75, 0.90], generates 100 random masks and computes: the average number of visible patches, the average distance between visible patches (using their 2D grid positions), the "coverage uniformity" (divide the 8x8 grid into 4 quadrants and measure how evenly visible patches are distributed -- standard deviation of per-quadrant counts, lower = more uniform), and the "reconstruction difficulty" (estimated as the average distance from each masked patch to its nearest visible neighbor), (d) prints a comparison table for all strategies and ratios, (e) identifies which strategy at 75% masking (MAE's default) produces the highest reconstruction difficulty and most uniform coverage. Verify that random masking provides the most uniform coverage, grid masking has zero coverage variance but low reconstruction difficulty, and block masking creates the highest local reconstruction difficulty but poor coverage uniformity.