Learn AI Series (#89) - Medical and Scientific Imaging

avatar

Learn AI Series (#89) - Medical and Scientific Imaging

variant-c-12-green.png

What will I learn

  • You will learn the unique challenges of medical imaging: small datasets, class imbalance, and regulatory requirements;
  • transfer learning from natural images to medical domains and why it works despite massive visual differences;
  • data augmentation strategies specific to medical images (and which standard augmentations will break your model);
  • handling class imbalance when rare conditions are what matter most;
  • explainability requirements in healthcare AI and how Grad-CAM reveals what your model actually learned;
  • regulatory considerations for deploying AI in clinical settings;
  • scientific imaging beyond healthcare: satellite data, microscopy, and astronomical image analysis.

Requirements

  • A working modern computer running macOS, Windows or Ubuntu;
  • An installed Python 3(.11+) distribution;
  • The ambition to learn AI and machine learning.

Difficulty

  • Beginner

Curriculum (of the Learn AI Series):

Learn AI Series (#89) - Medical and Scientific Imaging

Solutions to Episode #88 Exercises

Exercise 1: Face embedding similarity analyzer.

import numpy as np


class EmbeddingSimilarityAnalyzer:
    """Analyze face embedding similarity
    distributions for identity verification."""

    def __init__(self, num_ids=5, photos_per=10,
                 embed_dim=512, noise_sigma=0.15,
                 seed=42):
        rng = np.random.RandomState(seed)
        self.num_ids = num_ids
        self.photos_per = photos_per
        self.labels = []
        self.embeddings = []

        for i in range(num_ids):
            centroid = rng.randn(embed_dim)
            centroid /= np.linalg.norm(centroid)
            for _ in range(photos_per):
                variant = centroid + rng.randn(
                    embed_dim) * noise_sigma
                variant /= np.linalg.norm(variant)
                self.embeddings.append(variant)
                self.labels.append(i)

        self.embeddings = np.array(self.embeddings)
        self.labels = np.array(self.labels)

    def cosine_matrix(self):
        norms = np.linalg.norm(
            self.embeddings, axis=1, keepdims=True)
        normed = self.embeddings / norms
        return normed @ normed.T

    def analyze(self):
        sim = self.cosine_matrix()
        n = len(self.labels)
        same_sims = []
        diff_sims = []
        for i in range(n):
            for j in range(i + 1, n):
                if self.labels[i] == self.labels[j]:
                    same_sims.append(sim[i, j])
                else:
                    diff_sims.append(sim[i, j])
        same_sims = np.array(same_sims)
        diff_sims = np.array(diff_sims)

        print(f"Same-identity mean: "
              f"{same_sims.mean():.4f}")
        print(f"Diff-identity mean: "
              f"{diff_sims.mean():.4f}")
        print(f"Hardest positive:   "
              f"{same_sims.min():.4f}")
        print(f"Hardest negative:   "
              f"{diff_sims.max():.4f}")
        gap = same_sims.min() - diff_sims.max()
        print(f"Gap: {gap:.4f}")

        print(f"\n{'Thresh':>7} {'TPR':>6} "
              f"{'FPR':>6} {'Acc':>6}")
        print("-" * 28)
        best_acc, best_t = 0, 0
        for t in [0.3, 0.4, 0.5, 0.6, 0.7, 0.8]:
            tp = (same_sims >= t).mean()
            fp = (diff_sims >= t).mean()
            acc = ((same_sims >= t).sum()
                   + (diff_sims < t).sum()) / (
                       len(same_sims) + len(diff_sims))
            if acc > best_acc:
                best_acc = acc
                best_t = t
            print(f"{t:>7.1f} {tp:>6.3f} "
                  f"{fp:>6.3f} {acc:>6.3f}")
        print(f"\nBest: thresh={best_t}, "
              f"acc={best_acc:.3f}")
        return same_sims, diff_sims


analyzer = EmbeddingSimilarityAnalyzer()
analyzer.analyze()

With noise_sigma=0.15, the same-identity pairs cluster around cosine similarity 0.85-0.95 (high, because the noise is modest relative to the 512-dimensional centroid), while different-identity pairs settle near 0.0 (random unit vectors in 512D are nearly orthogonal). The gap between the hardest positive and hardest negative is comfortably positive, meaning a single threshold cleanly separates all identities. If you increase noise_sigma to 0.5 or higher, the same-identity distribution spreads out and starts overlapping with the different-identity distribution -- the gap shrinks toward zero and no threshold achieves perfect accuracy. This directly demonstrates why face recognition systems demand well-lit, frontal, aligned crops: reducing noise in the embedding space widens the gap.

Exercise 2: Landmark-based face alignment tool.

import numpy as np


class FaceAligner:
    """Align faces using 5-point landmarks
    via affine transformation."""

    def __init__(self, target_eye_dist=70,
                 target_center=(112, 112)):
        self.target_dist = target_eye_dist
        self.target_cx = target_center[0]
        self.target_cy = target_center[1]

    def compute_alignment(self, left_eye,
                          right_eye):
        dx = right_eye[0] - left_eye[0]
        dy = right_eye[1] - left_eye[1]
        angle = np.degrees(np.arctan2(dy, dx))
        dist = np.sqrt(dx ** 2 + dy ** 2)
        scale = self.target_dist / max(
            dist, 1e-8)

        cos_a = np.cos(np.radians(-angle))
        sin_a = np.sin(np.radians(-angle))
        R = np.array([[cos_a, -sin_a],
                      [sin_a, cos_a]]) * scale

        mid_x = (left_eye[0] + right_eye[0]) / 2
        mid_y = (left_eye[1] + right_eye[1]) / 2
        tx = self.target_cx - (
            R[0, 0] * mid_x + R[0, 1] * mid_y)
        ty = self.target_cy - (
            R[1, 0] * mid_x + R[1, 1] * mid_y)

        M = np.zeros((2, 3))
        M[:2, :2] = R
        M[0, 2] = tx
        M[1, 2] = ty
        return M, angle, dist

    def apply_affine(self, points, M):
        pts = np.array(points)
        ones = np.ones((len(pts), 1))
        aug = np.hstack([pts, ones])
        return (M @ aug.T).T

    def run(self):
        rng = np.random.RandomState(42)
        print(f"{'Case':>5} {'InAngle':>8} "
              f"{'OutAngle':>9} {'InDist':>7} "
              f"{'OutDist':>8} {'EyeCenter':>14}")
        print("-" * 56)

        for i in range(5):
            angle = rng.uniform(-30, 30)
            dist = rng.uniform(40, 120)
            cx = rng.uniform(80, 160)
            cy = rng.uniform(80, 160)
            rad = np.radians(angle)
            le = (cx - dist / 2 * np.cos(rad),
                  cy - dist / 2 * np.sin(rad))
            re = (cx + dist / 2 * np.cos(rad),
                  cy + dist / 2 * np.sin(rad))

            M, in_angle, in_dist = (
                self.compute_alignment(le, re))
            landmarks = [le, re,
                         (cx, cy + 20),
                         (cx - 15, cy + 40),
                         (cx + 15, cy + 40)]
            aligned = self.apply_affine(
                landmarks, M)

            out_dx = aligned[1, 0] - aligned[0, 0]
            out_dy = aligned[1, 1] - aligned[0, 1]
            out_angle = np.degrees(
                np.arctan2(out_dy, out_dx))
            out_dist = np.sqrt(
                out_dx ** 2 + out_dy ** 2)
            eye_cx = (aligned[0, 0]
                      + aligned[1, 0]) / 2
            eye_cy = (aligned[0, 1]
                      + aligned[1, 1]) / 2

            print(f"{i + 1:>5} {in_angle:>8.1f} "
                  f"{out_angle:>9.4f} "
                  f"{in_dist:>7.1f} "
                  f"{out_dist:>8.1f} "
                  f"({eye_cx:.0f}, {eye_cy:.0f})")


aligner = FaceAligner()
aligner.run()

Every test case produces an aligned eye angle of essentially 0 degrees (within floating point precision), an inter-eye distance of exactly 70 pixels, and eye centers at (112, 112) -- regardless of the input rotation, scale, or position. This consistency is exactly why alignment matters for face recognition: the downstream embedding network sees faces in a canonical pose every time, so it can focus all of its capacity on identity-discriminative features rather than wasting parameters on learning to handle rotations and scale variations.

Exercise 3: Expression confusion matrix analyzer.

import numpy as np


class ExpressionAnalyzer:
    """Analyze expression classification
    confusion patterns."""

    EXPRESSIONS = [
        "angry", "disgusted", "fearful",
        "happy", "neutral", "sad", "surprised"
    ]

    def __init__(self, samples_per=100, seed=42):
        self.n = samples_per
        self.seed = seed

    def generate_confusion(self, correct_probs=
            None, boost=None):
        rng = np.random.RandomState(self.seed)
        if correct_probs is None:
            correct_probs = {
                "angry": 0.65, "disgusted": 0.60,
                "fearful": 0.60, "happy": 0.80,
                "neutral": 0.70, "sad": 0.68,
                "surprised": 0.72}
        if boost:
            for cls, amt in boost.items():
                correct_probs[cls] = min(
                    correct_probs[cls] + amt, 0.95)

        confusion_pairs = {
            ("angry", "disgusted"): 0.12,
            ("disgusted", "angry"): 0.12,
            ("fearful", "surprised"): 0.15,
            ("surprised", "fearful"): 0.10,
            ("sad", "neutral"): 0.08,
            ("neutral", "sad"): 0.06}

        nc = len(self.EXPRESSIONS)
        cm = np.zeros((nc, nc), dtype=int)

        for i, expr in enumerate(self.EXPRESSIONS):
            p_correct = correct_probs[expr]
            remaining = 1.0 - p_correct
            probs = np.zeros(nc)
            probs[i] = p_correct
            used = 0.0
            for (a, b), p in confusion_pairs.items():
                if a == expr:
                    j = self.EXPRESSIONS.index(b)
                    probs[j] = min(p, remaining)
                    used += probs[j]
            leftover = remaining - used
            for j in range(nc):
                if j != i and probs[j] == 0:
                    probs[j] = leftover / max(
                        nc - 1 - sum(
                            1 for k in range(nc)
                            if k != i
                            and probs[k] > 0), 1)
            probs /= probs.sum()
            cm[i] = rng.multinomial(self.n, probs)
        return cm

    def metrics(self, cm):
        nc = cm.shape[0]
        results = {}
        for i, expr in enumerate(self.EXPRESSIONS):
            tp = cm[i, i]
            fp = cm[:, i].sum() - tp
            fn = cm[i, :].sum() - tp
            prec = tp / max(tp + fp, 1)
            rec = tp / max(tp + fn, 1)
            f1 = (2 * prec * rec
                  / max(prec + rec, 1e-8))
            results[expr] = {
                "precision": prec,
                "recall": rec, "f1": f1}
        return results

    def run(self):
        cm = self.generate_confusion()

        # Print confusion matrix
        header = "     " + " ".join(
            f"{e[:4]:>5}" for e in
            self.EXPRESSIONS)
        print(header)
        for i, expr in enumerate(
                self.EXPRESSIONS):
            row = f"{expr[:4]:>5} " + " ".join(
                f"{cm[i, j]:>5}"
                for j in range(len(
                    self.EXPRESSIONS)))
            print(row)

        m = self.metrics(cm)
        print(f"\n{'Class':<12} {'Prec':>6} "
              f"{'Rec':>6} {'F1':>6}")
        print("-" * 32)
        for expr in self.EXPRESSIONS:
            d = m[expr]
            print(f"{expr:<12} "
                  f"{d['precision']:>6.3f} "
                  f"{d['recall']:>6.3f} "
                  f"{d['f1']:>6.3f}")

        # Most confused pairs
        nc = len(self.EXPRESSIONS)
        pairs = []
        for i in range(nc):
            for j in range(nc):
                if i != j:
                    pairs.append(
                        (cm[i, j], self.EXPRESSIONS[i],
                         self.EXPRESSIONS[j]))
        pairs.sort(reverse=True)
        print(f"\nTop confused pairs:")
        for cnt, a, b in pairs[:4]:
            print(f"  {a} -> {b}: {cnt}")

        ranked = sorted(m.items(),
                        key=lambda x: x[1]["f1"])
        print(f"\nDifficulty ranking (worst F1 "
              f"first):")
        for expr, d in ranked:
            print(f"  {expr}: F1={d['f1']:.3f}")

        # Boost worst two
        worst = [ranked[0][0], ranked[1][0]]
        print(f"\nBoosting {worst} by +0.05:")
        cm2 = self.generate_confusion(
            boost={w: 0.05 for w in worst})
        m2 = self.metrics(cm2)
        for w in worst:
            print(f"  {w}: F1 {m[w]['f1']:.3f}"
                  f" -> {m2[w]['f1']:.3f}")


analyzer = ExpressionAnalyzer()
analyzer.run()

The angry/disgusted and fearful/surprised pairs dominate the off-diagonal entries -- exactly as configured, and consistent with real expression recognition research. These pairs are genuinely hard because the underlying facial muscle movements overlap: anger and disgust both involve brow lowering and nose wrinkling, while fear and surprise both involve wide eyes and raised eyebrows. Happy achieves the highest F1 because smiling is the most visually distinctive expression (unique cheek raise + lip corner pull), while disgusted and fearful rank lowest due to their high mutual confusion rates. Adding 0.05 to the correct-class probability for the two worst performers improves their F1 scores modestly, demonstrating the diminishing returns of more data when the fundamental visual similarity between confusable classes remains.

On to today's episode

Here we go! We've spent the past thirteen episodes building a thorough understanding of computer vision: from basic image processing (#77) through object detection (#78-79), segmentation (#80), pose estimation (#81), OCR (#82), video (#83), diffusion models (#84-85), image editing (#86), 3D reconstruction (#87), and face analysis (#88). That entire arc dealt with natural images -- photographs of everyday scenes, objects, people, and places.

But some of the highest-impact applications of computer vision don't involve natural photos at all. They involve medical images: chest X-rays, retinal scans, pathology slides, CT volumes. And beyond medicine, there's a whole universe of scientific imaging: satellite photos, microscopy, spectrograms, astronomical observations. These domains share a common property -- the images look nothing like ImageNet, the datasets are small and expensive to label, and getting the answer wrong can have serious consequences.

This episode covers how to adapt everything we've learned to domains where the stakes are literally life-and-death, and where the standard "download a big dataset, train a deep model, deploy" pipeline falls apart completely ;-)

Why medical imaging is fundamentally different

If you've been following along since episode #13 (evaluation) and #14 (data preparation), you already know that data quality matters more than model architecture. Medical imaging takes every data challenge from those episodes and turns the dial to 11.

Small datasets: a hospital might have 500 labeled chest X-rays showing a rare condition. Compare that to ImageNet's 14 million images. You can't just train a deeper model and hope it generalizes -- you'll overfit before the first epoch finishes.

Annotation cost: labeling a single medical image often requires a board-certified specialist spending 10-30 minutes. Pixel-level segmentation masks (delineating tumor boundaries on a pathology slide) can take hours per image. Some images need consensus from multiple experts because even specialists disagree on borderline cases. You can't crowdsource this on Mechanical Turk like you can with "is there a cat in this picture?"

Class imbalance: in screening applications, the positive rate might be 1-5%. A model that always predicts "healthy" achieves 95%+ accuracy while being completely useless. Remember the precision-recall tradeoff from episode #13? This is where it becomes a matter of life and death.

High stakes: a false negative (missed cancer) can kill. A false positive (unnecessary biopsy) causes harm, stress, and expense. The error profile matters far more than the aggregate accuracy number. Telling a doctor "my model is 97% accurate" is meaningless without specifying sensitivity and specificity separately.

Domain shift: a model trained on X-rays from Hospital A's GE scanner may fail on Hospital B's Siemens scanner. Different imaging protocols, sensor characteristics, patient demographics, and even the brand of contrast dye can shift the data distribution enough to break a model that worked perfectly in development. This is the same distribution shift problem from episode #14, but with much higher consequences.

import numpy as np


class MedicalDatasetProfiler:
    """Profile a medical imaging dataset for
    common challenges: imbalance, size, and
    annotation cost estimation."""

    def __init__(self, n_total, n_positive,
                 annotation_minutes=15):
        self.n_total = n_total
        self.n_pos = n_positive
        self.n_neg = n_total - n_positive
        self.ann_min = annotation_minutes

    def profile(self):
        ratio = self.n_pos / self.n_total
        imbalance = self.n_neg / max(
            self.n_pos, 1)

        # Estimate annotation budget
        total_hours = (
            self.n_total * self.ann_min / 60)
        cost_per_hour = 150  # specialist rate
        total_cost = total_hours * cost_per_hour

        # Effective training size accounting
        # for imbalance
        effective = min(
            self.n_pos, self.n_neg) * 2

        print(f"Total samples:     {self.n_total:,}")
        print(f"Positive:          {self.n_pos:,} "
              f"({ratio * 100:.1f}%)")
        print(f"Negative:          {self.n_neg:,} "
              f"({(1 - ratio) * 100:.1f}%)")
        print(f"Imbalance ratio:   "
              f"{imbalance:.1f}:1")
        print(f"Effective samples: "
              f"{effective:,}")
        print(f"Annotation time:   "
              f"{total_hours:,.0f} hours")
        print(f"Annotation cost:   "
              f"${total_cost:,.0f}")

        if ratio < 0.05:
            print("WARNING: Severe imbalance -- "
                  "focal loss recommended")
        if self.n_total < 1000:
            print("WARNING: Small dataset -- "
                  "transfer learning essential")
        if effective < 200:
            print("WARNING: Very few effective "
                  "samples -- consider few-shot")


# Typical medical datasets
print("=== Rare disease screening ===")
MedicalDatasetProfiler(500, 15).profile()
print()
print("=== Chest X-ray (pneumonia) ===")
MedicalDatasetProfiler(5000, 250).profile()
print()
print("=== Retinal scan (diabetic) ===")
MedicalDatasetProfiler(10000, 1500).profile()

The numbers are sobering. A rare disease screening dataset with 500 images and 15 positives gives you an effective training set of just 30 samples after accounting for imbalance. And annotating those 500 images cost someone over $18,000 in specialist time. This is the reality of medical ML -- you don't get to complain about "only" having 50,000 training images like you might in a Kaggle competition.

Transfer learning: standing on ImageNet's shoulders

Despite the visual differences between ImageNet photos (dogs, cars, landscapes) and chest X-rays (greyscale, abstract anatomical structures), transfer learning works surprisingly well for medical images. The early layers of an ImageNet-pretrained CNN detect edges, textures, gradients, and simple shapes -- these are universal visual features that appear in every type of image. Only the later layers specialize for the target domain:

import torch
import torch.nn as nn
import torchvision.models as models


class MedicalClassifier(nn.Module):
    """Medical image classifier built on top
    of an ImageNet-pretrained backbone."""

    def __init__(self, num_classes=2,
                 pretrained=True):
        super().__init__()
        self.backbone = models.resnet50(
            weights='DEFAULT' if pretrained
            else None)
        num_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(num_features, num_classes))

    def forward(self, x):
        return self.backbone(x)


def setup_gradual_unfreezing(model,
                              lr_backbone=1e-5,
                              lr_head=1e-3):
    """Different learning rates for pretrained
    backbone vs new classification head.
    Slow updates preserve learned features,
    fast updates adapt the new head."""
    backbone_params = []
    head_params = []
    for name, param in model.named_parameters():
        if 'fc' in name:
            head_params.append(param)
        else:
            backbone_params.append(param)

    return torch.optim.Adam([
        {'params': backbone_params,
         'lr': lr_backbone},
        {'params': head_params,
         'lr': lr_head}])


model = MedicalClassifier(num_classes=3)
optimizer = setup_gradual_unfreezing(model)
print(f"Backbone params: "
      f"{sum(p.numel() for p in model.backbone.parameters()):,}")
print(f"Head params: "
      f"{sum(p.numel() for p in model.backbone.fc.parameters()):,}")

The strategy is: freeze the backbone initially, train only the classification head for a few epochs until it converges, then gradually unfreeze backbone layers from top to bottom with a much smaller learning rate. This prevents the pretrained features from being destroyed by agressive updates on a small medical dataset. The differential learning rate (100x difference between head and backbone) is critical -- the head needs to learn fast because it starts from random weights, while the backbone needs to adapt slowly because it already has useful features.

Medical foundation models like MedCLIP and BiomedCLIP are now available -- pretrained on large collections of medical images and clinical reports. They serve as better starting points than ImageNet for medical tasks, similar to how domain-specific language models outperform general ones (as we discussed in episode #69). Having said that, ImageNet pretraining remains a surprisingly strong baseline even for medical domains, and many published results showing "medical foundation model beats ImageNet transfer" disappear when you control carefully for training procedure and hyperparameters.

Data augmentation: what works and what breaks things

Standard augmentation needs careful adaptation for medical images. You can't just copy-paste augmentation pipelines from a dog-vs-cat classifier and expect sensible results:

import torchvision.transforms as T
import numpy as np
from PIL import Image
from scipy.ndimage import gaussian_filter
from scipy.ndimage import map_coordinates


class ElasticDeformation:
    """Elastic deformation: simulates tissue
    variation in medical images. The single
    most effective augmentation for soft-tissue
    segmentation tasks."""

    def __init__(self, alpha=50, sigma=5):
        self.alpha = alpha
        self.sigma = sigma

    def __call__(self, image):
        img = np.array(image)
        shape = img.shape[:2]
        dx = gaussian_filter(
            np.random.randn(*shape),
            self.sigma) * self.alpha
        dy = gaussian_filter(
            np.random.randn(*shape),
            self.sigma) * self.alpha
        y, x = np.meshgrid(
            np.arange(shape[0]),
            np.arange(shape[1]),
            indexing='ij')
        indices = [
            np.clip(y + dy, 0, shape[0] - 1),
            np.clip(x + dx, 0, shape[1] - 1)]
        if img.ndim == 3:
            result = np.stack([
                map_coordinates(img[:, :, c],
                                indices, order=1)
                for c in range(img.shape[2])],
                axis=-1)
        else:
            result = map_coordinates(
                img, indices, order=1)
        return Image.fromarray(
            result.astype(np.uint8))


# Safe augmentations for medical images
medical_train_transforms = T.Compose([
    T.RandomRotation(degrees=15),
    T.RandomAffine(
        degrees=0,
        translate=(0.05, 0.05),
        scale=(0.95, 1.05)),
    T.RandomResizedCrop(
        224, scale=(0.85, 1.0)),
    T.RandomAdjustSharpness(
        sharpness_factor=2, p=0.3),
    T.GaussianBlur(
        kernel_size=3, sigma=(0.1, 1.0)),
    T.ColorJitter(
        brightness=0.2, contrast=0.2),
    T.ToTensor(),
    T.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]),
])

# Things you MUST NOT do to medical images:
print("SAFE augmentations:")
print("  Small rotations (< 15 degrees)")
print("  Small translations")
print("  Elastic deformation")
print("  Brightness/contrast jitter")
print("  Gaussian blur/noise")
print()
print("DANGEROUS augmentations:")
print("  Horizontal flip (heart is on the LEFT)")
print("  Vertical flip (anatomy has a top)")
print("  Heavy color jitter (diagnostic info)")
print("  Random erasing (may mask pathology)")

The horizontal flip warning is not theoretical. The human heart is on the left side of the chest. A horizontally flipped chest X-ray represents situs inversus -- a rare anatomical condition where all organs are mirrored. If you augment with horizontal flips, your model sees "normal" anatomy in a mirrored configuration and learns that the heart can be on either side. This is a subtle but real failure mode that has tripped up published research papers.

Elastic deformation, on the other hand, is a goldmine for medical images. It simulates the natural biological variation in soft tissues -- muscles stretch and compress differently across patients, organs shift slightly with breathing, and pathological changes deform surrounding tissue. It's the single most effective domain-specific augmentation for segmentation tasks like tumor boundary delineation.

Handling class imbalance

When only 2% of your training images show the condition you're trying to detect, standard training produces a model that never predicts positive. Three complementary strategies:

import torch
import torch.nn as nn
from torch.utils.data import WeightedRandomSampler
from torch.utils.data import DataLoader


# Strategy 1: Weighted cross-entropy loss
class_counts = [9800, 200]  # healthy, diseased
weights = torch.tensor(
    [1.0 / c for c in class_counts])
weights = weights / weights.sum()
criterion = nn.CrossEntropyLoss(weight=weights)
print(f"Class weights: {weights.tolist()}")


# Strategy 2: Focal loss (Lin et al., 2017)
class FocalLoss(nn.Module):
    """Down-weight easy examples, focus on
    hard ones. Elegant solution to class
    imbalance without manual weight tuning."""

    def __init__(self, alpha=0.25, gamma=2.0):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, logits, targets):
        ce = nn.functional.cross_entropy(
            logits, targets, reduction='none')
        p = torch.exp(-ce)
        focal_weight = (
            self.alpha * (1 - p) ** self.gamma)
        return (focal_weight * ce).mean()


# Strategy 3: Oversampling with
# WeightedRandomSampler
def make_balanced_loader(dataset, labels,
                         batch_size=32):
    """Each batch gets roughly equal
    representation of both classes."""
    class_counts_per = np.bincount(labels)
    sample_weights = [
        1.0 / class_counts_per[l]
        for l in labels]
    sampler = WeightedRandomSampler(
        sample_weights,
        num_samples=len(dataset))
    return DataLoader(
        dataset, batch_size=batch_size,
        sampler=sampler)


# Compare focal loss behavior
focal = FocalLoss(alpha=0.25, gamma=2.0)
logits = torch.tensor([[2.0, -2.0],
                        [0.5, 0.5],
                        [-1.0, 1.0]])
targets = torch.tensor([0, 0, 1])
loss = focal(logits, targets)
print(f"Focal loss: {loss.item():.4f}")

Focal loss is particularly elegant. When the model is already confident about an easy normal case (high p), the (1-p)^gamma term makes the gradient nearly zero -- the model doesn't waste time getting more confident about cases it already handles. When it's uncertain about a difficult positive case (low p), the gradient is large, so the model focuses its learning capacity where it matters. No manual tuning of class weights required -- the gamma parameter controls how aggressively easy examples are down-weighted, and gamma=2.0 works well across a wide range of imbalance ratios.

Evaluation: the metrics that actually matter

Standard accuracy is misleading for imbalanced medical data. The metrics clinicians care about:

from sklearn.metrics import roc_auc_score
from sklearn.metrics import confusion_matrix
import numpy as np


def medical_evaluation(y_true, y_prob,
                       threshold=0.5):
    """Comprehensive evaluation for medical
    binary classification. Prints the metrics
    that actually matter in clinical settings."""
    y_pred = (y_prob >= threshold).astype(int)
    cm = confusion_matrix(y_true, y_pred)
    tn, fp, fn, tp = cm.ravel()

    sensitivity = tp / (tp + fn + 1e-10)
    specificity = tn / (tn + fp + 1e-10)
    ppv = tp / (tp + fp + 1e-10)
    npv = tn / (tn + fn + 1e-10)
    auc = roc_auc_score(y_true, y_prob)

    print(f"AUC: {auc:.4f}")
    print(f"Sensitivity: {sensitivity:.4f} "
          f"(missed {fn} of {tp + fn} cases)")
    print(f"Specificity: {specificity:.4f} "
          f"(false alarms: {fp})")
    print(f"PPV: {ppv:.4f}, NPV: {npv:.4f}")
    print(f"Confusion matrix:")
    print(f"  TN={tn}, FP={fp}")
    print(f"  FN={fn}, TP={tp}")

    # Find threshold for 95% sensitivity
    thresholds = np.linspace(0, 1, 1000)
    for t in thresholds:
        preds = (y_prob >= t).astype(int)
        sens = preds[y_true == 1].sum() / max(
            (y_true == 1).sum(), 1)
        if sens >= 0.95:
            spec = (
                1 - preds[y_true == 0]).sum() / max(
                (y_true == 0).sum(), 1)
            print(f"\nAt 95% sensitivity: "
                  f"thresh={t:.3f}, "
                  f"specificity={spec:.4f}")
            break


# Simulated screening results
rng = np.random.RandomState(42)
y_true = np.array([0] * 950 + [1] * 50)
y_prob = np.where(
    y_true == 1,
    rng.beta(8, 3, len(y_true)),
    rng.beta(2, 8, len(y_true)))
medical_evaluation(y_true, y_prob)

The threshold choice is a clinical decision, not a technical one. For screening (catching every possible case -- think mammography), you want high sensitivity even at the cost of more false positives. For confirmatory diagnosis (being sure about a positive result before scheduling surgery), you want high specificity. The model doesn't change -- only the operating point on the ROC curve changes. This is why AUC is the standard metric for comparing models: it measures performance across all possible thresholds, independent of the clinical use case.

Explainability: showing the model's reasoning

In healthcare, a black-box prediction is often unacceptable. Clinicians need to understand why the model flagged an image. Grad-CAM (Gradient-weighted Class Activation Mapping) produces heatmaps showing which regions of the image contributed most to the prediction:

import torch
import torch.nn.functional as F


class GradCAM:
    """Grad-CAM: visual explanations for CNN
    predictions. Highlights which image regions
    drove the classification decision."""

    def __init__(self, model, target_layer):
        self.model = model
        self.gradients = None
        self.activations = None
        target_layer.register_forward_hook(
            self._save_activation)
        target_layer.register_full_backward_hook(
            self._save_gradient)

    def _save_activation(self, module,
                         input, output):
        self.activations = output.detach()

    def _save_gradient(self, module,
                       grad_in, grad_out):
        self.gradients = grad_out[0].detach()

    def generate(self, input_image,
                 target_class):
        self.model.eval()
        output = self.model(input_image)
        self.model.zero_grad()
        output[0, target_class].backward()

        # Weight each channel by its
        # average gradient
        weights = self.gradients.mean(
            dim=[2, 3], keepdim=True)
        cam = (weights * self.activations).sum(
            dim=1, keepdim=True)
        cam = F.relu(cam)
        cam = F.interpolate(
            cam, size=input_image.shape[2:],
            mode='bilinear',
            align_corners=False)
        cam = cam / (cam.max() + 1e-8)
        return cam.squeeze()


# Demo with a ResNet50
backbone = models.resnet50(weights='DEFAULT')
backbone.eval()
target = backbone.layer4[-1]  # last conv layer
gradcam = GradCAM(backbone, target)

fake_input = torch.randn(1, 3, 224, 224)
heatmap = gradcam.generate(fake_input,
                            target_class=0)
print(f"Heatmap shape: {heatmap.shape}")
print(f"Range: [{heatmap.min():.3f}, "
      f"{heatmap.max():.3f}]")

The resulting heatmap highlights the region the model focused on when making its prediction. If the model correctly predicts pneumonia and the heatmap highlights the lower right lung where the consolidation is visible, that builds clinical trust. If the heatmap highlights the patient's name label in the corner of the X-ray, that reveals a data leakage problem -- the model learned to associate certain patients (who appear in both training and test sets) with certain diagnoses, rather than learning actual radiological features.

This data leakage scenario is not hypothetical. A 2018 study found that a pneumonia detection model achieved suspiciously high accuracy because it learned to recognize which hospital the X-ray came from (via scanner-specific markings and formatting), and different hospitals had different patient populations with different disease prevalence. The model was predicting hospital identity, not pathology. Grad-CAM caught it ;-)

Beyond medicine: scientific imaging at large

Medical imaging gets the most attention (and funding), but the same techniques apply across scientific disciplines. The common thread: specialized imaging modalities, expensive expert annotations, small datasets, and high-stakes decisions.

Satellite and remote sensing: images have more channels than RGB (infrared, radar, multispectral bands with 10-200+ channels), much larger spatial extents (a single Sentinel-2 tile is 10,980 x 10,980 pixels at 10m resolution), and the objects of interest are often tiny relative to the image (a single building, a few hectares of deforestation, individual vehicles). The preprocessing pipeline alone -- atmospheric correction, cloud masking, radiometric calibration, georeferencing -- is a field of its own.

Microscopy: fluorescence microscopy produces multichannel images where each channel corresponds to a different stain or fluorescent marker. The images are often very noisy (photon shot noise at low light levels), may be 3D (confocal z-stacks), and the relevant structures (cell nuclei, mitochondria, protein aggregates) range from a few pixels to thousands of pixels in size. Cell segmentation -- finding individual cell boundaries in densely packed tissue -- is one of the enduring computer vision challenges that gets harder as cells overlap and deform.

Astronomy: images from telescopes deal with extreme dynamic range (a star might be millions of times brighter than the background), noise that follows Poisson statistics rather than Gaussian, and objects so faint they're barely distinguishable from noise artifacts. Galaxy morphology classification (spiral, elliptical, irregular) is one of the classic ML astronomy tasks, and the Galaxy Zoo citizen science project produced one of the first large-scale crowd-annotated astronomical datasets.

import numpy as np


class MultiChannelNormalizer:
    """Normalize scientific images with
    arbitrary channel counts (satellite,
    microscopy, spectral). Each channel
    gets independent normalization."""

    def __init__(self):
        self.means = None
        self.stds = None

    def fit(self, images):
        """Compute per-channel statistics
        from a batch of images.
        images: (N, C, H, W)"""
        self.means = images.mean(
            axis=(0, 2, 3))
        self.stds = images.std(
            axis=(0, 2, 3))
        self.stds[self.stds < 1e-8] = 1.0
        return self

    def transform(self, image):
        """Normalize a single image.
        image: (C, H, W)"""
        normalized = np.zeros_like(
            image, dtype=np.float32)
        for c in range(image.shape[0]):
            normalized[c] = (
                (image[c] - self.means[c])
                / self.stds[c])
        return normalized

    def percentile_clip(self, image,
                        low=1, high=99):
        """Clip each channel to percentile
        range -- handles extreme values in
        satellite and astronomical data."""
        clipped = np.zeros_like(
            image, dtype=np.float32)
        for c in range(image.shape[0]):
            lo = np.percentile(image[c], low)
            hi = np.percentile(image[c], high)
            clipped[c] = np.clip(
                image[c], lo, hi)
            rng = hi - lo
            if rng > 1e-8:
                clipped[c] = (
                    (clipped[c] - lo) / rng)
        return clipped


# Simulate a 13-band satellite image
rng = np.random.RandomState(42)
# Different channels have very different
# value ranges (realistic for Sentinel-2)
sat_image = np.zeros((13, 256, 256))
for c in range(13):
    base = rng.uniform(100, 5000)
    scale = rng.uniform(50, 2000)
    sat_image[c] = rng.normal(
        base, scale, (256, 256))

norm = MultiChannelNormalizer()
batch = sat_image[np.newaxis]
norm.fit(batch)
result = norm.transform(sat_image)
clipped = norm.percentile_clip(sat_image)

print(f"Channels: {sat_image.shape[0]}")
print(f"\n{'Ch':>3} {'RawMean':>10} "
      f"{'RawStd':>10} {'NormMean':>10} "
      f"{'ClipRange':>12}")
print("-" * 48)
for c in range(sat_image.shape[0]):
    print(f"{c:>3} "
          f"{sat_image[c].mean():>10.1f} "
          f"{sat_image[c].std():>10.1f} "
          f"{result[c].mean():>10.4f} "
          f"[{clipped[c].min():.3f}, "
          f"{clipped[c].max():.3f}]")

The key insight for all scientific imaging: you can't use ImageNet normalization statistics (mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) when your images have 13 bands with value ranges spanning three orders of magnitude. Each domain needs its own preprocessing pipeline, its own normalization strategy, and often its own evaluation metrics. The model architectures we've been building throughout this series (CNNs, transformers, U-Nets) transfer well across domains -- it's the data pipeline that needs the most domain-specific engineering.

The regulatory landscape

Medical AI devices require regulatory approval before clinical deployment. This isn't optional -- it's the law, and it fundamentally shapes how medical AI is developed:

FDA (US): Software as a Medical Device (SaMD) classification. AI diagnostics typically fall under Class II (510(k) pathway, showing "substantial equivalence" to an existing cleared device) or Class III (premarket approval, for higher-risk applications). The FDA has cleared hundreds of AI medical devices as of 2025, mostly in radiology and cardiology.

CE marking (EU): Under the Medical Device Regulation (MDR), AI systems must demonstrate safety and performance through clinical evaluation. The EU's AI Act adds additional requirements for "high-risk" AI systems, which includes most medical diagnostic applications.

Both require: documented training data provenance (where did every image come from, with what consent?), validated performance on representative populations (does it work for all demographics?), post-market surveillance plans (how will you catch failures after deployment?), and clear intended use statements (this model is for X, not for Y). You can't just train a model and put it in a hospital.

The regulatory process also means that medical AI development follows a fundamentally different timesline than typical ML projects. From initial model development to regulatory clearance takes 1-3 years and costs hundreds of thousands to millions of dollars. This is why most medical AI research stays in academic papers rather than reaching patients -- the gap between "works on our test set" and "approved for clinical use" is enormous.

Samengevat

  • Medical imaging has unique constraints that fundamentally alter the ML pipeline: small datasets (hundreds, not millions), expensive expert annotations ($150+/hour specialists), severe class imbalance (1-5% positive rates), and high-stakes errors where false negatives can be fatal;
  • transfer learning from ImageNet works for medical images despite massive visual differences; the early CNN layers learn universal edge/texture features that transfer across domains; medical foundation models (MedCLIP, BiomedCLIP) are emerging as better starting points but ImageNet remains a strong baseline;
  • augmentation must be domain-aware: elastic deformation is highly effective for simulating biological tissue variation, but horizontal flipping breaks anatomical assumptions (the heart is on the left); every augmentation choice must be validated against domain knowledge;
  • focal loss automatically focuses training on hard examples without manual class weight tuning; the (1-p)^gamma modulation makes the gradient near-zero for easy cases and large for difficult ones;
  • evaluation requires sensitivity, specificity, and AUC rather than accuracy; threshold selection is a clinical decision that trades off missed cases against false alarms, and different clinical use cases (screening vs confirmation) demand different operating points;
  • Grad-CAM provides visual explanations essential for clinical trust and debugging data leakage -- verifying that models focus on pathology rather than scanner artifacts or patient labels is non-negotiable;
  • scientific imaging beyond medicine (satellite, microscopy, astronomy) shares the same challenges of specialized modalities, small datasets, and domain-specific preprocessing but with additional complications like multi-channel inputs, extreme dynamic ranges, and non-Gaussian noise models;
  • regulatory approval (FDA, CE marking) is mandatory before clinical deployment and shapes the entire development process -- from data provenance documentation to post-market surveillance plans.

We've now covered how computer vision applies to specialized domains where getting it right really matters. The vision arc of this series has been extensive -- from raw pixels all the way to generating images, reconstructing 3D scenes, analyzing faces, and now domain-specific scientific applications. There's one more foundational concept in vision that we haven't explored yet: how machines can learn powerful visual representations without any human labels at all, using only the structure of the data itself.

Exercises

Exercise 1: Build a medical dataset augmentation validator. Create a class AugmentationValidator that: (a) generates a synthetic 64x64 "chest X-ray" as a numpy array with a circle in the left half (simulating the heart) and a smaller circle in the right half (simulating a nodule), both with different intensities against a noisy background, (b) implements check_anatomical_consistency(original, augmented) that verifies the heart-like circle is still in the left half after augmentation -- compute the column-wise intensity sum for each half and check that the brighter half hasn't switched sides, (c) applies 5 different augmentations to the synthetic image: (1) small rotation 10 degrees, (2) horizontal flip, (3) elastic deformation with alpha=30 sigma=4, (4) brightness shift +20%, (5) Gaussian noise sigma=0.05, (d) for each augmentation, prints whether anatomical consistency is preserved, (e) computes a "signal preservation score" for each augmentation: the correlation coefficient between the original and augmented images (values near 1.0 mean the content is well preserved, values near 0.0 mean significant distortion). Verify that horizontal flip fails the anatomical consistency check while all other augmentations pass, and that elastic deformation preserves anatomy while having lower correlation than simple brightness changes.

Exercise 2: Build a class imbalance strategy comparator. Create a class ImbalanceComparator that: (a) generates a synthetic binary classification dataset with 1000 samples where only 30 are positive, with features drawn from overlapping Gaussians (positive: mean=1.0, std=1.5; negative: mean=0.0, std=1.0) in 10 dimensions, (b) trains 3 logistic regression classifiers (using sklearn): one with default settings, one with class_weight='balanced', and one using SMOTE oversampling on the training set (use imblearn.over_sampling.SMOTE or implement a simple random oversampler if imblearn is not available), (c) evaluates each on a held-out test set (20% split) using sensitivity, specificity, AUC, and F1, (d) prints a comparison table showing all 4 metrics for all 3 strategies, (e) for each strategy, finds the threshold that achieves >= 90% sensitivity and reports the corresponding specificity. Verify that the default (unweighted) classifier has high specificity but poor sensitivity, while balanced weighting and oversampling improve sensitivity at the cost of some specificity.

Exercise 3: Build a multi-channel scientific image normalizer and analyzer. Create a class ScientificImageAnalyzer that: (a) generates a synthetic 8-channel "satellite image" of size 128x128 where: channels 0-2 are visible RGB (values 0-255), channels 3-4 are near-infrared (values 500-3000), channel 5 is thermal (values 250-320, representing Kelvin), channels 6-7 are radar backscatter (values in dB, range -25 to 5), (b) implements per_channel_normalize(image) that z-score normalizes each channel independently, (c) implements percentile_clip(image, low_pct=2, high_pct=98) that clips each channel to its 2nd and 98th percentiles then rescales to [0, 1], (d) implements compute_ndvi(image) that computes the Normalized Difference Vegetation Index from the NIR and Red channels: NDVI = (NIR - Red) / (NIR + Red + 1e-8), (e) prints per-channel statistics (mean, std, min, max) before and after normalization, (f) computes and prints NDVI statistics (mean, std, fraction of pixels with NDVI > 0.3 indicating vegetation), (g) compares the correlation between channel pairs before and after normalization to verify that normalization preserves inter-channel relationships. Verify that per-channel normalization brings all channels to mean~0, std~1 regardless of their original value ranges, and that NDVI values are invariant to the normalization method used.

Thanks for reading!

@scipio



0
0
0.000
0 comments