Learn AI Series (#93) - Speech Recognition

avatar

Learn AI Series (#93) - Speech Recognition

variant-c-06-magenta.png

What will I learn

  • You will learn the ASR pipeline: converting raw audio signals into text using deep learning;
  • CTC (Connectionist Temporal Classification): aligning audio frames to characters without explicit alignment labels;
  • attention-based ASR: the listen, attend, and spell paradigm for sequence-to-sequence speech recognition;
  • Whisper: OpenAI's robust multilingual speech recognition model and why training data scale matters more than architecture;
  • building practical speech-to-text systems and evaluating them with Word Error Rate;
  • fine-tuning pre-trained ASR models for domain-specific audio.

Requirements

  • A working modern computer running macOS, Windows or Ubuntu;
  • An installed Python 3(.11+) distribution;
  • The ambition to learn AI and machine learning.

Difficulty

  • Beginner

Curriculum (of the Learn AI Series):

Learn AI Series (#93) - Speech Recognition

Solutions to Episode #92 Exercises

Exercise 1: Frequency content analyzer.

import numpy as np


class FrequencyAnalyzer:
    def __init__(self, sr=16000, duration=2.0):
        self.sr = sr
        self.duration = duration
        self.n_samples = int(sr * duration)
        self.t = np.arange(self.n_samples) / sr

    def generate_signals(self):
        pure = np.sin(2 * np.pi * 440 * self.t)
        chord = (
            np.sin(2 * np.pi * 440 * self.t)
            + np.sin(2 * np.pi * 554 * self.t)
            + np.sin(2 * np.pi * 659 * self.t)
        ) / 3.0
        rng = np.random.RandomState(42)
        noise = rng.randn(self.n_samples)
        return {
            "pure_440": pure,
            "chord": chord,
            "white_noise": noise,
        }

    def compute_spectrum(self, signal):
        fft_vals = np.fft.rfft(signal)
        magnitudes = np.abs(fft_vals)
        freqs = np.fft.rfftfreq(
            len(signal), 1.0 / self.sr)
        return freqs, magnitudes

    def spectral_centroid(self, freqs, mags):
        total = mags.sum()
        if total < 1e-12:
            return 0.0
        return float((freqs * mags).sum() / total)

    def spectral_bandwidth(self, freqs, mags,
                            centroid):
        total = mags.sum()
        if total < 1e-12:
            return 0.0
        variance = (
            mags * (freqs - centroid) ** 2
        ).sum() / total
        return float(np.sqrt(variance))

    def spectral_rolloff(self, freqs, mags,
                          percentile=85):
        total_energy = (mags ** 2).sum()
        target = total_energy * percentile / 100
        cumulative = np.cumsum(mags ** 2)
        idx = np.searchsorted(cumulative, target)
        idx = min(idx, len(freqs) - 1)
        return float(freqs[idx])

    def run(self):
        signals = self.generate_signals()
        print(f"{'Signal':<14} {'Centroid':>10} "
              f"{'Bandwidth':>10} {'Rolloff':>10}")
        print("-" * 48)
        for name, sig in signals.items():
            freqs, mags = self.compute_spectrum(sig)
            c = self.spectral_centroid(freqs, mags)
            bw = self.spectral_bandwidth(
                freqs, mags, c)
            ro = self.spectral_rolloff(freqs, mags)
            print(f"{name:<14} {c:>9.1f}Hz "
                  f"{bw:>9.1f}Hz {ro:>9.1f}Hz")


analyzer = FrequencyAnalyzer()
analyzer.run()

The pure 440 Hz tone has its centroid right at 440 Hz with near-zero bandwidth -- all energy is concentrated at a single frequency. White noise has the highest centroid (near 4000 Hz) and the largest bandwidth because its energy is spread uniformly across all frequencies.

Exercise 2: Mel filterbank visualizer and verifier.

import numpy as np


class MelFilterbankAnalyzer:
    def __init__(self, sr=16000, n_fft=1024,
                 n_mels=40, fmin=0, fmax=8000):
        self.sr = sr
        self.n_fft = n_fft
        self.n_mels = n_mels
        self.fmin = fmin
        self.fmax = fmax

    def hz_to_mel(self, hz):
        return 2595.0 * np.log10(1.0 + hz / 700.0)

    def mel_to_hz(self, mel):
        return 700.0 * (10.0 ** (mel / 2595.0) - 1.0)

    def build_filterbank(self):
        n_bins = self.n_fft // 2 + 1
        mel_lo = self.hz_to_mel(self.fmin)
        mel_hi = self.hz_to_mel(self.fmax)
        mel_pts = np.linspace(
            mel_lo, mel_hi, self.n_mels + 2)
        hz_pts = self.mel_to_hz(mel_pts)
        bin_pts = np.round(
            hz_pts * self.n_fft / self.sr
        ).astype(int)
        bin_pts = np.clip(bin_pts, 0, n_bins - 1)

        fb = np.zeros((self.n_mels, n_bins))
        centers, bandwidths, peaks = [], [], []
        for m in range(self.n_mels):
            lo, mid, hi = (bin_pts[m],
                bin_pts[m + 1], bin_pts[m + 2])
            centers.append(hz_pts[m + 1])
            bandwidths.append(
                hz_pts[m + 2] - hz_pts[m])
            if mid > lo:
                for k in range(lo, mid):
                    fb[m, k] = (k - lo) / (mid - lo)
            fb[m, mid] = 1.0
            if hi > mid:
                for k in range(mid, hi + 1):
                    if k < n_bins:
                        fb[m, k] = (
                            (hi - k) / (hi - mid))
            peaks.append(fb[m].max())
        return fb, centers, bandwidths, peaks

    def run(self):
        fb, centers, bws, peaks = (
            self.build_filterbank())
        n_bins = self.n_fft // 2 + 1
        fmin_bin = int(
            self.fmin * self.n_fft / self.sr)
        fmax_bin = min(
            int(self.fmax * self.n_fft / self.sr),
            n_bins - 1)
        covered = sum(
            1 for k in range(fmin_bin, fmax_bin + 1)
            if fb[:, k].sum() > 0)
        total_bins = fmax_bin - fmin_bin + 1
        print(f"Coverage: "
              f"{covered/total_bins*100:.1f}%")
        lo_bw = np.mean(bws[:10])
        hi_bw = np.mean(bws[-10:])
        print(f"Avg BW low 10:  {lo_bw:.1f} Hz")
        print(f"Avg BW high 10: {hi_bw:.1f} Hz")
        print(f"Ratio: {hi_bw/lo_bw:.1f}x")
        show = [0, 9, 19, 29, 39]
        print(f"\n{'Idx':>4} {'Center':>8} "
              f"{'BW':>8} {'Peak':>6}")
        print("-" * 30)
        for i in show:
            if i < self.n_mels:
                print(f"{i:>4} {centers[i]:>7.1f}Hz"
                      f" {bws[i]:>7.1f}Hz"
                      f" {peaks[i]:>6.3f}")


analyzer = MelFilterbankAnalyzer()
analyzer.run()

Low-frequency filters have narrow bandwidth (around 100-150 Hz) because the Mel scale compresses these frequencies. The high-frequency filters are much wider (1000+ Hz). The bandwidth ratio between the highest and lowest filters confirms the roughly 10:1 factor.

Exercise 3: Time-frequency resolution tradeoff analyzer.

import numpy as np


class ResolutionAnalyzer:
    def __init__(self, sr=16000, duration=1.0):
        self.sr = sr
        self.n = int(sr * duration)
        t = np.arange(self.n) / sr
        self.signal = np.zeros(self.n)
        mid = self.n // 2
        self.signal[:mid] = np.sin(
            2 * np.pi * 500 * t[:mid])
        self.signal[mid:] = np.sin(
            2 * np.pi * 600 * t[mid:]
            - 2 * np.pi * 600 * t[mid])

    def stft(self, n_fft):
        hop = n_fft // 4
        window = np.hanning(n_fft)
        n_frames = (self.n - n_fft) // hop + 1
        n_bins = n_fft // 2 + 1
        spec = np.zeros((n_bins, n_frames))
        for i in range(n_frames):
            start = i * hop
            frame = (self.signal[start:start + n_fft]
                     * window)
            spec[:, i] = np.abs(np.fft.rfft(frame))
        return spec, hop, n_frames

    def run(self):
        print(f"{'WinSize':>8} {'FreqRes':>8} "
              f"{'TimeRes':>10} {'Resolved':>9} "
              f"{'TransFrames':>12}")
        print("-" * 52)
        for n_fft in [256, 512, 1024, 2048]:
            _, hop, nf = self.stft(n_fft)
            freq_res = self.sr / n_fft
            time_res = hop / self.sr
            resolved = freq_res < 100.0
            trans = int(np.ceil(
                1.0 / (hop / self.sr)))
            trans = min(trans, nf)
            print(f"{n_fft:>8} {freq_res:>7.1f}Hz "
                  f"{time_res*1000:>8.1f}ms "
                  f"{'yes' if resolved else 'no':>9}"
                  f" {trans:>12}")


analyzer = ResolutionAnalyzer()
analyzer.run()

Small windows (256 samples) give 4 ms time resolution but 62.5 Hz frequency resolution -- wider than the 100 Hz gap between our tones, so they blur together. Large windows (2048 samples) resolve the tones easily at 7.8 Hz but smear the time boundary across many frames. This is the Heisenberg uncertainty of signal processing.

On to today's episode

Here we go! Last episode we learned how to turn sound into numbers -- waveforms, spectra, Mel spectrograms, MFCCs. All the raw representation machinery that converts vibrating air molecules into numpy arrays. But having a beautiful spectrogram sitting in memory doesn't tell you what was said. That's the problem we're tackling today: automatic speech recognition, or ASR. The task of making a computer listen to someone talking and produce the corresponding text.

This is one of the oldest problems in AI. Bell Labs built the first speech recognizer in 1952: a machine called "Audrey" that could recognize the digits 0-9 spoken by a single specific person. That was it -- ten words, one speaker. Seventy-something years later, Whisper transcribes 96 languages with near-human accuracy from noisy podcasts, phone calls, and YouTube videos. The journey between those two points involved several paradigm shifts, and we're going to focus on the deep learning era that made modern ASR actually usable ;-)

The alignment problem

Speech recognition has a challenge that most of the classification tasks we've studied don't have: the input and output sequences are different lengths, and the mapping between them is unknown. Think about it -- a 5-second audio clip at 16 kHz produces 80,000 samples, which gets compressed to maybe 500 spectrogram frames. But the transcription might only be 25 characters long. Which frames correspond to which characters? Nobody knows upfront, and manually aligning audio to text for every training sample is insanely expensive.

This is the alignment problem, and it's fundamentally different from image classification (one image -> one label) or even sequence classification (one sequence -> one label). Here we need to go from one variable-length sequence to another variable-length sequence, with no explicit correspondence between the elements.

Two approaches emerged to solve this:

  1. CTC (Connectionist Temporal Classification): output a character probability at every audio frame and let the math marginalize over all possible alignments
  2. Attention-based encoder-decoder: let the decoder learn to attend to the right audio frames for each output token (same architecture we used for machine translation in episode #50)

CTC: Connectionist Temporal Classification

CTC (Graves et al., 2006) is the same technique we saw applied to OCR in episode #82 -- and this is actually its original domain! The idea is elegant: the model outputs a probability distribution over characters at every single time frame. Most frames will output a special blank token meaning "no character here." The few frames that correspond to actual characters will output the right letter. CTC then marginalizes over all possible alignments between frames and characters to compute the loss.

Here's a CTC-based ASR model:

import torch
import torch.nn as nn


class CTCASRModel(nn.Module):
    """CTC-based speech recognition.
    Audio spectrogram in, character logits
    out at every time frame."""

    def __init__(self, n_mels=80,
                 hidden_dim=512,
                 vocab_size=29):
        # vocab: 26 letters + space +
        # apostrophe + CTC blank
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 32,
                      kernel_size=(3, 3),
                      padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 32,
                      kernel_size=(3, 3),
                      stride=(2, 1),
                      padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
        )
        cnn_out = 32 * (n_mels // 2)
        self.rnn = nn.LSTM(
            cnn_out, hidden_dim,
            num_layers=3,
            batch_first=True,
            bidirectional=True,
            dropout=0.1)
        self.fc = nn.Linear(
            hidden_dim * 2, vocab_size)

    def forward(self, mel_spec):
        # mel_spec: (batch, 1, n_mels, time)
        x = self.cnn(mel_spec)
        b, c, f, t = x.shape
        x = x.permute(0, 3, 1, 2).reshape(
            b, t, c * f)
        x, _ = self.rnn(x)
        logits = self.fc(x)
        return logits


model = CTCASRModel()
criterion = nn.CTCLoss(
    blank=0, zero_infinity=True)

mel_input = torch.randn(2, 1, 80, 400)
logits = model(mel_input)
log_probs = logits.log_softmax(
    dim=-1).permute(1, 0, 2)

targets = torch.tensor(
    [8, 5, 12, 12, 15, 23, 15, 18, 12, 4])
target_lengths = torch.tensor([5, 5])
input_lengths = torch.tensor([400, 400])

loss = criterion(
    log_probs, targets,
    input_lengths, target_lengths)
print(f"CTC loss: {loss.item():.4f}")

The architecture is straighforward: CNNs extract local patterns from the Mel spectrogram (similar to how we used CNNs on images in episodes #45-47), then a bidirectional LSTM (episode #49) captures temporal context in both directions, and finally a linear layer outputs character probabilities at each frame. The CNN's stride of 2 along the frequency axis halves the frequency dimension, reducing computation while keeping the time dimension intact -- because CTC needs one output per time frame.

The CTCLoss function handles the hard part: computing the probability of the target sequence summed over ALL possible alignments. For "hello", valid alignments include --hh-ee-ll-ll-oo--, or h---e---l---l---o, or any other arrangement that produces "hello" after collapsing duplicates and removing blanks. CTC sums over all of them using dynamic programming -- you don't need to tell it which frames correspond to which characters.

CTC decoding

Converting frame-level CTC outputs into text is called decoding. The simplest approach is greedy decoding:

def ctc_greedy_decode(logits, blank=0):
    """Greedy CTC decoding: argmax at each
    frame, collapse duplicates, remove
    blanks."""
    predictions = logits.argmax(dim=-1)
    idx_to_char = {
        0: '', 1: 'a', 2: 'b', 3: 'c',
        4: 'd', 5: 'e', 6: 'f', 7: 'g',
        8: 'h', 9: 'i', 10: 'j', 11: 'k',
        12: 'l', 13: 'm', 14: 'n', 15: 'o',
        16: 'p', 17: 'q', 18: 'r', 19: 's',
        20: 't', 21: 'u', 22: 'v', 23: 'w',
        24: 'x', 25: 'y', 26: 'z', 27: ' ',
        28: "'"}
    texts = []
    for pred in predictions:
        chars = []
        prev = blank
        for p in pred:
            p = p.item()
            if p != blank and p != prev:
                chars.append(
                    idx_to_char.get(p, '?'))
            prev = p
        texts.append(''.join(chars))
    return texts


import torch
dummy_logits = torch.randn(1, 20, 29)
decoded = ctc_greedy_decode(dummy_logits)
print(f"Decoded: '{decoded[0]}'")
print("(random logits = garbage text)")

Greedy decoding takes the argmax at each frame, then collapses consecutive duplicates and removes blank tokens. The sequence [blank, blank, h, h, h, blank, e, e, l, l, l, l, o] becomes [h, e, l, o] -> "helo". Notice the problem: CTC greedy decoding has trouble with repeated characters! To produce "hello" with two l's, there must be a blank token BETWEEN them: [h, e, l, blank, l, o]. Without that blank, the two consecutive l's collapse into one.

For better results, you'd use beam search with a language model -- the language model scores partial hypotheses and helps the decoder pick "hello" over "helo" because "hello" is an actual English word. This is where the language modeling concepts from episode #57 become directly relevant.

CTC has one fundamental limitation though: it assumes frame-level outputs are conditionally independent given the hidden states. It can't natively model dependencies like "after 'q' usually comes 'u'". The RNN hidden states capture some context, but the CTC loss itself doesn't enforce output-level dependencies. That's where attention-based models come in.

Attention-based ASR: Listen, Attend, and Spell

The encoder-decoder architecture with attention (episodes #50-51) was the next paradigm shift. Instead of outputting a character at every frame (like CTC), the decoder generates one token at a time, attending to relevant parts of the encoder output. The attention mechanism implicitly learns the alignment -- which audio frames correspond to each output token:

class AttentionASR(nn.Module):
    """Simplified attention-based ASR.
    Listen (encoder), Attend (attention),
    Spell (decoder)."""

    def __init__(self, n_mels=80,
                 enc_dim=256, dec_dim=256,
                 vocab_size=5000):
        super().__init__()
        self.encoder = nn.LSTM(
            n_mels, enc_dim, num_layers=3,
            batch_first=True,
            bidirectional=True)
        self.enc_proj = nn.Linear(
            enc_dim * 2, dec_dim)
        self.embed = nn.Embedding(
            vocab_size, dec_dim)
        self.decoder = nn.LSTM(
            dec_dim * 2, dec_dim,
            num_layers=1,
            batch_first=True)
        self.attention = (
            nn.MultiheadAttention(
                dec_dim, num_heads=4,
                batch_first=True))
        self.output = nn.Linear(
            dec_dim, vocab_size)

    def forward(self, mel_features, targets):
        enc_out, _ = self.encoder(
            mel_features)
        enc_out = self.enc_proj(enc_out)
        tgt_embed = self.embed(
            targets[:, :-1])
        dec_out = torch.zeros_like(tgt_embed)
        h = None
        for t in range(tgt_embed.shape[1]):
            ctx, _ = self.attention(
                tgt_embed[:, t:t + 1],
                enc_out, enc_out)
            lstm_in = torch.cat(
                [tgt_embed[:, t:t + 1], ctx],
                dim=-1)
            out, h = self.decoder(lstm_in, h)
            dec_out[:, t:t + 1] = out
        return self.output(dec_out)


model = AttentionASR()
mel = torch.randn(2, 200, 80)
targets = torch.randint(0, 5000, (2, 30))
logits = model(mel, targets)
print(f"Output shape: {logits.shape}")
# (batch=2, seq_len=29, vocab=5000)

The "Listen" part is the encoder (bidirectional LSTM processing the audio features). The "Attend" part is the multi-head attention computing which encoder frames are relevant for the current decoding step. The "Spell" part is the decoder LSTM generating the text token by token.

This approach has a significant advantage over CTC: the decoder is autoregressive -- each token is generated conditioned on all previous tokens. So it naturally models output dependencies. After generating "q", it knows "u" is likely next. CTC can't do this without an external language model.

The tradeoff? Attention-based models are slower at inference because they generate tokens sequentially (can't parallelize the decoding steps). CTC can decode all frames in parallel. In practice, modern systems often combine both: use CTC as an auxiliary loss during training to help the encoder learn good alignments, and use attention for the final decoding.

Whisper: why data scale beats architecture

Whisper (Radford et al., 2022) is the model that made speech recognition "just work" for most practical purposes. Its architecture is nothing special -- a standard encoder-decoder transformer, the same thing we built in episodes #52-53. What makes Whisper exceptional is the training data: 680,000 hours of multilingual audio paired with (imperfect) transcripts scraped from the internet.

import numpy as np


class WhisperArchitecture:
    """Summarize Whisper model variants."""

    def __init__(self):
        self.models = {
            "tiny": {
                "params": "39M",
                "enc_layers": 4,
                "dec_layers": 4,
                "d_model": 384,
                "english_wer": 7.6,
                "multi_wer": 12.1},
            "base": {
                "params": "74M",
                "enc_layers": 6,
                "dec_layers": 6,
                "d_model": 512,
                "english_wer": 5.0,
                "multi_wer": 9.1},
            "small": {
                "params": "244M",
                "enc_layers": 12,
                "dec_layers": 12,
                "d_model": 768,
                "english_wer": 4.2,
                "multi_wer": 7.6},
            "medium": {
                "params": "769M",
                "enc_layers": 24,
                "dec_layers": 24,
                "d_model": 1024,
                "english_wer": 3.8,
                "multi_wer": 6.7},
            "large": {
                "params": "1550M",
                "enc_layers": 32,
                "dec_layers": 32,
                "d_model": 1280,
                "english_wer": 3.0,
                "multi_wer": 5.1},
        }

    def compare(self):
        print(f"{'Model':>8} {'Params':>8} "
              f"{'Enc':>4} {'Dec':>4} "
              f"{'d_model':>8} "
              f"{'EN WER':>7} {'Multi':>6}")
        print("-" * 50)
        for name, m in self.models.items():
            print(
                f"{name:>8} {m['params']:>8} "
                f"{m['enc_layers']:>4} "
                f"{m['dec_layers']:>4} "
                f"{m['d_model']:>8} "
                f"{m['english_wer']:>6.1f}%"
                f" {m['multi_wer']:>5.1f}%")
        print(f"\nHuman WER on LibriSpeech: "
              f"~5.8%")
        print(f"Whisper-medium beats humans "
              f"on clean English ;-)")


arch = WhisperArchitecture()
arch.compare()

The key insight from Whisper is something we've seen before in this series: scale of training data matters more than architectural cleverness. The same transformer architecture trained on 680K hours of diverse, noisy, multilingual audio produces a model that generalizes to accents, background noise, and domains it was never explicitly trained on. Compare this to earlier ASR systems that were trained on maybe 1,000 hours of clean read speech (like LibriSpeech) and fell apart the moment someone spoke with a non-standard accent or in a noisy room.

Using Whisper in practice is pretty simple:

# Using OpenAI's whisper package
# pip install openai-whisper
import whisper

model = whisper.load_model("base")
result = model.transcribe(
    "meeting_recording.wav")
print(f"Text: {result['text']}")
print(f"Language: {result['language']}")

# With word-level timestamps
for segment in result['segments']:
    print(f"[{segment['start']:.1f}s - "
          f"{segment['end']:.1f}s] "
          f"{segment['text']}")

# Or use Hugging Face transformers
from transformers import pipeline

asr = pipeline(
    "automatic-speech-recognition",
    model="openai/whisper-base")
result = asr(
    "meeting_recording.wav",
    return_timestamps=True)
print(result['text'])

Whisper's decoder uses special tokens to control its behavior. <|transcribe|> tells it to transcribe (keep the original language). <|translate|> tells it to translate to English. <|en|> specifies the language. This multi-task conditioning through special tokens is the same idea as T5's text-to-text framing that we touched on when discussing language models in episode #57 -- a single model handles multiple tasks by changing the prompt.

Evaluating ASR: Word Error Rate

You can't improve what you can't measure. ASR quality is measured by Word Error Rate (WER): the edit distance between the hypothesis (what the model produced) and the reference (the correct transcription), divided by the reference length:

import numpy as np


def word_error_rate(reference, hypothesis):
    """Compute WER using Levenshtein distance
    on word sequences."""
    ref = reference.lower().split()
    hyp = hypothesis.lower().split()
    n, m = len(ref), len(hyp)

    d = np.zeros((n + 1, m + 1), dtype=int)
    for i in range(n + 1):
        d[i, 0] = i
    for j in range(m + 1):
        d[0, j] = j

    for i in range(1, n + 1):
        for j in range(1, m + 1):
            if ref[i - 1] == hyp[j - 1]:
                d[i, j] = d[i - 1, j - 1]
            else:
                d[i, j] = 1 + min(
                    d[i - 1, j],
                    d[i, j - 1],
                    d[i - 1, j - 1])

    edits = d[n, m]
    # Backtrace for error breakdown
    subs, ins, dels = 0, 0, 0
    i, j = n, m
    while i > 0 or j > 0:
        if (i > 0 and j > 0
                and ref[i-1] == hyp[j-1]):
            i -= 1
            j -= 1
        elif (i > 0 and j > 0
              and d[i,j] == d[i-1,j-1] + 1):
            subs += 1
            i -= 1
            j -= 1
        elif (j > 0
              and d[i,j] == d[i,j-1] + 1):
            ins += 1
            j -= 1
        else:
            dels += 1
            i -= 1

    return {
        "wer": edits / max(n, 1),
        "substitutions": subs,
        "insertions": ins,
        "deletions": dels,
    }


cases = [
    ("the quick brown fox jumps over "
     "the lazy dog",
     "the quick brown box jumps over "
     "a lazy dog"),
    ("hello world", "hello world"),
    ("i went to the store yesterday",
     "i went to store yesterday"),
]
for ref, hyp in cases:
    r = word_error_rate(ref, hyp)
    print(f"WER: {r['wer']:.1%} "
          f"(S={r['substitutions']}, "
          f"I={r['insertions']}, "
          f"D={r['deletions']})")
    print(f"  Ref: {ref}")
    print(f"  Hyp: {hyp}\n")

WER counts three types of errors: substitutions (wrong word), insertions (extra word), and deletions (missing word). In our first example, "fox" -> "box" is a substitution, and "the" -> "a" is another substitution, giving WER = 2/9 = 22.2%. A WER of 0% means perfect transcription. The third case has one deletion ("the" is missing), giving WER = 1/6 = 16.7%.

WER has quirks worth knowing about. A 5% WER sounds great, but consider: if every 20th word is wrong, that's potentially one error per sentence, which can completely change meaning ("the patient has NO heart disease" vs "the patient has heart disease" -- one deletion, catastrophic consequence). For safety-critical applications like medical transcription, even 1% WER may not be good enough.

Fine-tuning Whisper for specific domains

Out of the box, Whisper handles general speech well. But for specialized domains -- medical dictation, legal proceedings, heavy accents, technical jargon -- fine-tuning on domain-specific data makes a significant difference:

from transformers import (
    WhisperForConditionalGeneration,
    WhisperProcessor,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
)


def setup_whisper_finetuning(
        model_name="openai/whisper-small"):
    """Set up Whisper fine-tuning."""
    processor = (WhisperProcessor
        .from_pretrained(model_name))
    model = (WhisperForConditionalGeneration
        .from_pretrained(model_name))

    # Freeze encoder for efficiency
    model.freeze_encoder()

    total = sum(
        p.numel()
        for p in model.parameters())
    trainable = sum(
        p.numel()
        for p in model.parameters()
        if p.requires_grad)
    print(f"Model: {model_name}")
    print(f"Total params: {total:,}")
    print(f"Trainable: {trainable:,} "
          f"({trainable/total*100:.1f}%)")

    args = Seq2SeqTrainingArguments(
        output_dir="./whisper-finetuned",
        per_device_train_batch_size=8,
        learning_rate=1e-5,
        num_train_epochs=3,
        fp16=True,
        predict_with_generate=True,
        evaluation_strategy="steps",
        eval_steps=500,
        save_steps=500,
        warmup_steps=100,
    )
    return model, processor, args


model, processor, args = (
    setup_whisper_finetuning())

The trick of freezing the encoder and only fine-tuning the decoder is worth remembering. The encoder has already learned excellent audio representations from 680K hours of data -- your small domain-specific dataset (maybe 100 hours of medical dictation) isn't going to improve those representations much. But the decoder needs to learn new vocabulary and language patterns specific to your domain. Freezing the encoder cuts trainable parameters by roughly half and significanly reduces overfitting risk.

This is the same transfer learning principle we saw with vision models in episode #90 on self-supervised pre-training and episode #69 on fine-tuning: keep the general-purpose representations, adapt only the task-specific layers.

Building a practical transcription pipeline

For a real application, you need more than just a model. You need audio loading, preprocessing, chunking (for long recordings), and post-processing:

import numpy as np


class TranscriptionPipeline:
    """Complete speech-to-text pipeline
    with chunking for long audio."""

    def __init__(self, chunk_seconds=30,
                 overlap_seconds=5,
                 sr=16000):
        self.chunk_size = chunk_seconds * sr
        self.overlap = overlap_seconds * sr
        self.sr = sr

    def chunk_audio(self, audio):
        """Split long audio into overlapping
        chunks."""
        chunks = []
        start = 0
        while start < len(audio):
            end = min(
                start + self.chunk_size,
                len(audio))
            chunks.append({
                "audio": audio[start:end],
                "start_sample": start,
                "end_sample": end,
                "start_time": start / self.sr,
                "end_time": end / self.sr,
            })
            if end >= len(audio):
                break
            start += (
                self.chunk_size - self.overlap)
        return chunks

    def merge_transcripts(self, chunks,
                           transcripts):
        """Merge overlapping transcripts."""
        if len(transcripts) <= 1:
            return transcripts[0] if transcripts else ""
        merged = [transcripts[0]]
        for i in range(1, len(transcripts)):
            words = transcripts[i].split()
            skip = max(1, len(words) // 6)
            merged.append(
                ' '.join(words[skip:]))
        return ' '.join(merged)

    def post_process(self, text):
        """Basic text cleanup."""
        text = ' '.join(text.split())
        if text:
            text = text[0].upper() + text[1:]
        if text and text[-1] not in '.!?':
            text += '.'
        return text

    def transcribe(self, audio):
        """Full pipeline: chunk -> transcribe
        -> merge -> post-process."""
        chunks = self.chunk_audio(audio)
        print(f"Audio: {len(audio)/self.sr:.1f}s"
              f" -> {len(chunks)} chunks")

        # In production you'd call whisper here
        # on each chunk. We simulate:
        fake_transcripts = []
        for i, chunk in enumerate(chunks):
            dur = (chunk['end_time']
                   - chunk['start_time'])
            n_words = int(dur * 3)
            fake_transcripts.append(
                f"chunk {i} with "
                f"{n_words} words of speech")

        merged = self.merge_transcripts(
            chunks, fake_transcripts)
        return self.post_process(merged)


pipe = TranscriptionPipeline()
for duration in [10, 45, 120]:
    audio = np.random.randn(duration * 16000)
    result = pipe.transcribe(audio)
    print(f"  Result: {result[:60]}...\n")

The chunking strategy matters for practical reasons. Whisper processes 30-second windows. For a 2-hour podcast, you need to split the audio into chunks, transcribe each one, and merge the results. The overlap ensures no words are lost at chunk boundaries -- but you need to handle the duplicated text in the overlap region. The simplest approach is to skip the first few words of each subsequent chunk, since those overlap with the end of the previous chunk.

Production systems like AssemblyAI and Deepgram add more sophistication: speaker diarization (who said what), punctuation restoration, entity recognition, and topic segmentation. But the core loop is always the same: load audio -> preprocess -> chunk -> transcribe -> merge -> post-process.

Model size vs accuracy vs speed

Choosing the right Whisper variant depends on your constraints:

import numpy as np


class ModelSelector:
    """Help choose the right ASR model."""

    def __init__(self):
        self.models = {
            "whisper-tiny": {
                "params_m": 39,
                "wer_clean": 7.6,
                "wer_noisy": 13.2,
                "rtf_cpu": 0.4,
                "rtf_gpu": 0.02,
                "ram_gb": 0.5},
            "whisper-base": {
                "params_m": 74,
                "wer_clean": 5.0,
                "wer_noisy": 9.8,
                "rtf_cpu": 0.7,
                "rtf_gpu": 0.03,
                "ram_gb": 0.8},
            "whisper-small": {
                "params_m": 244,
                "wer_clean": 4.2,
                "wer_noisy": 7.9,
                "rtf_cpu": 2.1,
                "rtf_gpu": 0.06,
                "ram_gb": 1.5},
            "whisper-medium": {
                "params_m": 769,
                "wer_clean": 3.8,
                "wer_noisy": 6.2,
                "rtf_cpu": 6.5,
                "rtf_gpu": 0.12,
                "ram_gb": 4.5},
            "whisper-large": {
                "params_m": 1550,
                "wer_clean": 3.0,
                "wer_noisy": 5.1,
                "rtf_cpu": 14.0,
                "rtf_gpu": 0.25,
                "ram_gb": 9.0},
        }

    def compare(self):
        print(f"{'Model':>16} {'Params':>7} "
              f"{'Clean':>6} {'Noisy':>6} "
              f"{'CPU':>5} {'GPU':>5} "
              f"{'RAM':>5}")
        print("-" * 54)
        for name, m in self.models.items():
            print(
                f"{name:>16} "
                f"{m['params_m']:>5}M "
                f"{m['wer_clean']:>5.1f}% "
                f"{m['wer_noisy']:>5.1f}% "
                f"{m['rtf_cpu']:>4.1f}x "
                f"{m['rtf_gpu']:>4.2f}x "
                f"{m['ram_gb']:>4.1f}G")
        print(f"\nRTF = Real-Time Factor")
        print(f"RTF < 1.0 = faster than "
              f"real-time")
        print(f"RTF > 1.0 = slower, batch "
              f"only")


selector = ModelSelector()
selector.compare()

The Real-Time Factor (RTF) tells you how long it takes to process one second of audio. RTF < 1.0 means the model is faster than real-time -- it can keep up with live speech. RTF > 1.0 means it's slower than real-time, only suitable for batch processing. On a GPU, even whisper-large runs at RTF ~0.25 (4x faster than real-time), so GPU availability is really the deciding factor for most applications. On CPU only, you're limited to tiny or base for real-time use cases.

What ASR still gets wrong

It's worth being honest about where current ASR systems struggle, even Whisper:

  • Proper nouns and rare words: "Kubernetes" might become "Cooper net tees" if the model hasn't seen it enough in training data. Technical jargon, brand names, and unusual names are the most common error source.
  • Homophones without context: "their/there/they're", "to/too/two" -- the model makes the right choice most of the time using language model knowledge, but not always.
  • Heavy accents: Whisper handles accents far better than older models (thanks to multilingual training data), but thick regional accents with unusual phonological patterns still cause elevated WER.
  • Crosstalk and overlapping speakers: when multiple people talk at the same time, all current ASR systems degrade significantly. This is a hard open problem.
  • Very long silences and non-speech audio: Whisper sometimes hallucinates text during long pauses or music segments. The "hallucination" problem is shared with LLMs (episode #73 on LLM evaluation) -- the model is trained to always produce output, so silence can produce phantom text.

These limitations matter for choosing where to deploy ASR. Transcribing a clear podcast recording? Whisper is nearly flawless. Transcribing a noisy meeting with six people talking over each other in heavy accents? You'll need human review.

Samengevat

  • Automatic speech recognition converts variable-length audio into variable-length text; the core challenge is the alignment problem -- which audio frames correspond to which characters;
  • CTC (Connectionist Temporal Classification) outputs a character probability at every audio frame and marginalizes over all possible alignments using dynamic programming; greedy decoding collapses duplicates and removes blank tokens, but struggles with repeated characters;
  • attention-based encoder-decoder models (Listen, Attend, and Spell) generate text autoregressively, learning alignment implicitly through attention weights; they model output dependencies naturally but are slower at inference;
  • Whisper is an encoder-decoder transformer trained on 680,000 hours of multilingual audio; its quality comes from data scale, not architectural innovation; it handles transcription, translation, and language identification through special conditioning tokens;
  • Word Error Rate (WER) measures ASR quality as the edit distance between hypothesis and reference, counting substitutions, insertions, and deletions; Whisper-medium achieves ~3.8% WER on clean English, beating the human baseline of ~5.8%;
  • fine-tuning Whisper for domain-specific audio (freezing the encoder, training only the decoder) adapts it to specialized vocabulary with minimal overfitting risk;
  • practical systems need chunking (for long audio), overlap handling (to avoid losing words at boundaries), and post-processing (capitalization, punctuation) -- the model is just one component of the full pipeline.

We've gone from understanding raw audio (last episode) to turning speech into text. Next up we'll go the other direction: generating speech from text. That brings its own set of challenges around naturalness, prosody, and the uncanny valley of synthetic voices.

Exercises

Exercise 1: Build a CTC alignment visualizer. Create a class CTCAlignmentVisualizer that: (a) creates a simulated CTC output for the word "hello" over 50 time frames: a (50, 29) numpy array where each row is a probability distribution over 29 tokens (26 letters + space + apostrophe + blank), (b) designs a realistic alignment where: frames 0-8 are blank, frames 9-14 emit 'h' with high probability, frames 15-18 are blank, frames 19-24 emit 'e', frames 25-28 are blank, frames 29-34 emit 'l', frame 35 is blank (critical for the double-l!), frames 36-40 emit 'l' again, frames 41-44 are blank, frames 45-49 emit 'o', (c) adds noise to the probabilities (don't make them perfectly one-hot -- add small random probability mass to other tokens to simulate a real model's uncertainty), (d) implements greedy decoding on this simulated output and verifies it produces "hello" (with both l's), (e) implements a remove_blank_and_collapse function separately and demonstrates the difference between collapsing WITH the blank separator between the two l's (correct: "hello") and WITHOUT it (incorrect: "helo"), (f) prints frame-by-frame the top-3 most probable tokens and their probabilities.

Exercise 2: Build a WER error analyzer. Create a class WERAnalyzer that: (a) implements compute_wer(reference, hypothesis) returning WER, substitution count, insertion count, deletion count, AND the alignment (which words matched, which were substituted, inserted, or deleted), (b) implements error_type_distribution(pairs) that takes a list of (reference, hypothesis) pairs and computes: overall WER, the percentage of errors that are substitutions vs insertions vs deletions, (c) tests on at least 5 realistic ASR error pairs including: (1) a clean transcription with no errors, (2) a substitution-heavy example ("recognize speech" vs "wreck a nice beach"), (3) a deletion example (missing filler words), (4) an insertion example (hallucinated words), (5) a mixed-error example, (d) prints a detailed alignment for each pair showing matched/substituted/inserted/deleted words, (e) prints the aggregate error type distribution and identifies which error type is most common. Verify that substitutions typically dominate in real ASR errors.

Exercise 3: Build a Whisper model size advisor. Create a class WhisperAdvisor that: (a) stores specifications for all 5 Whisper variants (tiny through large): parameter count, approximate WER on LibriSpeech clean/other, real-time factor on CPU and GPU, VRAM requirement, and supported languages, (b) implements compute_throughput(model, hardware, hours_of_audio) that calculates how long it would take to transcribe a given amount of audio on CPU vs GPU, (c) implements recommend(constraints) that takes a dict of constraints (max_wer, max_rtf, max_ram_gb, need_multilingual) and returns the best model that fits ALL constraints, (d) tests the advisor with 4 scenarios: (1) edge device with 1GB RAM, English only, real-time required, (2) laptop with 8GB RAM, multilingual, no real-time requirement, (3) GPU server with 16GB VRAM, lowest possible WER, batch processing, (4) mobile phone with 2GB RAM, English only, must be 2x faster than real-time, (e) prints the recommendation and reasoning for each scenario, including the estimated transcription throughput. Verify that tiny is recommended for edge/mobile, large for server batch processing, and small for the balanced laptop scenario.

De groeten!

@scipio



0
0
0.000
0 comments