Learn AI Series (#96) - Music Generation

Learn AI Series (#96) - Music Generation

variant-c-03-red.png

What will I learn

  • You will learn how AI generates music: from symbolic (MIDI) to raw audio generation;
  • representing music as sequences: piano roll encoding, token-based encoding, and audio-domain approaches;
  • Music Transformer: adapting self-attention to musical structure with relative positional encoding;
  • Jukebox (OpenAI): generating raw audio with VQ-VAE compression and autoregressive modeling at massive scale;
  • MusicGen (Meta): controllable music generation conditioned on text descriptions;
  • building a practical music generator from scratch using Markov chains on MIDI sequences;
  • evaluating generated music: the Frechet Audio Distance, human evaluation, and why music quality is fundamentally subjective.

Requirements

  • A working modern computer running macOS, Windows or Ubuntu;
  • An installed Python 3(.11+) distribution;
  • The ambition to learn AI and machine learning.

Difficulty

  • Beginner

Curriculum (of the Learn AI Series):

Learn AI Series (#96) - Music Generation

Solutions to Episode #95 Exercises

Exercise 1: Spectrogram-based audio feature comparator.

import numpy as np

class AudioFeatureComparator:
    def __init__(self, sr=16000, dur=2.0):
        self.sr, self.n = sr, int(sr * dur)
        self.t = np.arange(self.n) / sr
        self.n_mels, self.n_fft, self.hop = 64, 1024, 256

    def generate_signals(self):
        rng = np.random.RandomState(42)
        pure = np.sin(2*np.pi*440*self.t)
        noise = rng.randn(self.n)
        sweep = np.sin(2*np.pi*np.cumsum(
            np.linspace(200,4000,self.n))/self.sr)
        speech = sum(np.sin(2*np.pi*f*self.t)
            for f in [150,300,900,2500]) * (
            0.5+0.5*np.sin(2*np.pi*4*self.t))
        music = sum(np.sin(2*np.pi*f*self.t)
            for f in [440,554,659]) * (
            1+0.3*np.sin(2*np.pi*5*self.t))
        return {'pure_440':pure,'white_noise':noise,
            'sweep':sweep,'speech_like':speech,'music_like':music}

    def spectral_centroid(self, sig):
        fv = np.abs(np.fft.rfft(sig))
        fr = np.fft.rfftfreq(len(sig), 1.0/self.sr)
        return float((fr*fv).sum()/max(fv.sum(),1e-12))

    def spectral_flatness(self, sig):
        fv = np.abs(np.fft.rfft(sig))+1e-12
        return float(np.exp(np.mean(np.log(fv)))/np.mean(fv))

    def temporal_variation(self, sig):
        w = np.hanning(self.n_fft)
        nf = (len(sig)-self.n_fft)//self.hop+1
        nb = self.n_fft//2+1
        spec = np.zeros((nb, nf))
        for i in range(nf):
            s = i*self.hop
            spec[:,i] = np.abs(np.fft.rfft(sig[s:s+self.n_fft]*w))
        return float(np.mean(np.abs(np.diff(spec,axis=1))))

    def run(self):
        sigs = self.generate_signals()
        print(f"{'Signal':<14} {'Centroid':>10} {'Flatness':>10} {'TempVar':>10}")
        print("-"*48)
        for name, sig in sigs.items():
            print(f"{name:<14} {self.spectral_centroid(sig):>9.1f}Hz"
                  f" {self.spectral_flatness(sig):>10.4f}"
                  f" {self.temporal_variation(sig):>10.2f}")

comp = AudioFeatureComparator()
comp.run()

White noise has the highest spectral flatness (close to 1.0) because energy is distributed evenly across all frequencies. The pure tone has the lowest. The speech-like signal shows high temporal variation from the 4 Hz amplitude modulation.

Exercise 2: Multi-label evaluation toolkit.

import numpy as np


class MultiLabelEvaluator:
    def compute_ap(self, y_true, y_scores):
        pairs = sorted(
            zip(y_scores, y_true),
            key=lambda x: -x[0])
        tp_cum = 0
        total_pos = sum(y_true)
        if total_pos == 0:
            return 0.0
        precisions = []
        recalls = []
        for i, (score, label) in enumerate(
                pairs):
            if label:
                tp_cum += 1
            prec = tp_cum / (i + 1)
            rec = tp_cum / total_pos
            precisions.append(prec)
            recalls.append(rec)
        ap = 0.0
        prev_rec = 0.0
        for prec, rec in zip(
                precisions, recalls):
            ap += prec * (rec - prev_rec)
            prev_rec = rec
        return float(ap)

    def compute_map(self, Y_true, Y_scores):
        n_classes = len(Y_true)
        aps = []
        for c in range(n_classes):
            ap = self.compute_ap(
                Y_true[c], Y_scores[c])
            aps.append(ap)
        return aps, float(np.mean(aps))

    def run(self):
        rng = np.random.RandomState(42)
        n = 20
        Y_true = []
        Y_scores = []
        # Class 0: perfect
        t0 = rng.choice(
            [0, 1], n, p=[0.5, 0.5])
        s0 = t0.astype(float) + rng.randn(
            n) * 0.01
        Y_true.append(t0.tolist())
        Y_scores.append(s0.tolist())
        # Class 1: near-perfect
        t1 = rng.choice(
            [0, 1], n, p=[0.5, 0.5])
        s1 = t1.astype(float) + rng.randn(
            n) * 0.2
        Y_true.append(t1.tolist())
        Y_scores.append(s1.tolist())
        # Class 2: moderate
        t2 = rng.choice(
            [0, 1], n, p=[0.5, 0.5])
        s2 = t2.astype(float) * 0.5 + (
            rng.rand(n) * 0.5)
        Y_true.append(t2.tolist())
        Y_scores.append(s2.tolist())
        # Class 3: weak
        t3 = rng.choice(
            [0, 1], n, p=[0.5, 0.5])
        s3 = rng.rand(n)
        Y_true.append(t3.tolist())
        Y_scores.append(s3.tolist())
        # Class 4: inverted
        t4 = rng.choice(
            [0, 1], n, p=[0.5, 0.5])
        s4 = 1.0 - t4.astype(float) + (
            rng.randn(n) * 0.05)
        Y_true.append(t4.tolist())
        Y_scores.append(s4.tolist())

        aps, mAP = self.compute_map(
            Y_true, Y_scores)
        print(f"{'Class':>6} {'AP':>8}")
        print("-" * 16)
        for i, ap in enumerate(aps):
            print(f"{i:>6} {ap:>8.4f}")
        print(f"\nmAP: {mAP:.4f}")
        print(f"Order check: "
              f"{aps[0]:.3f} > {aps[1]:.3f} >"
              f" {aps[2]:.3f} > {aps[3]:.3f} >"
              f" {aps[4]:.3f}")


evaluator = MultiLabelEvaluator()
evaluator.run()

The AP scores follow the expected ordering: perfect predictions yield AP near 1.0, random predictions hover around the class prior, and inverted predictions produce very low AP. The mAP is just the arithmetic mean across all classes -- exactly what AudioSet uses as its primary metric.

Exercise 3: Sound event post-processor.

import numpy as np

class SEDPostProcessor:
    def __init__(self, sr=16000, hop=512, duration=10.0):
        self.sr, self.hop = sr, hop
        self.n_frames = int(duration*sr/hop)
        self.classes = ['Speech','Music','Dog']

    def generate_predictions(self):
        rng = np.random.RandomState(42)
        preds = rng.uniform(0,0.15,(3,self.n_frames))
        preds[0,20:120] += 0.8   # Speech
        preds[1,50:250] += 0.7   # Music
        preds[2,40:55] += 0.75   # Dog burst 1
        preds[2,100:110] += 0.75 # Dog burst 2
        preds[2,200:215] += 0.75 # Dog burst 3
        return np.clip(preds,0,1)

    def threshold(self, preds, t=0.5):
        return (preds > t).astype(int)

    def fill_gaps(self, binary, max_gap=5):
        r = binary.copy()
        for c in range(r.shape[0]):
            row = r[c]; in_gap = False; gs = 0
            for f in range(len(row)):
                if row[f]==0 and not in_gap: gs=f; in_gap=True
                elif row[f]==1 and in_gap:
                    if f-gs<=max_gap: row[gs:f]=1
                    in_gap=False
        return r

    def remove_short(self, binary, min_dur=8):
        r = binary.copy()
        for c in range(r.shape[0]):
            row = r[c]; in_ev = False; s = 0
            for f in range(len(row)):
                if row[f]==1 and not in_ev: s=f; in_ev=True
                elif row[f]==0 and in_ev:
                    if f-s<min_dur: row[s:f]=0
                    in_ev=False
            if in_ev and len(row)-s<min_dur: row[s:]=0
        return r

    def extract_events(self, binary):
        events = []
        for c in range(binary.shape[0]):
            row = binary[c]; in_ev = False; s = 0
            for f in range(len(row)):
                if row[f]==1 and not in_ev: s=f; in_ev=True
                elif row[f]==0 and in_ev:
                    events.append((self.classes[c],s,f)); in_ev=False
            if in_ev: events.append((self.classes[c],s,len(row)))
        return sorted(events, key=lambda e: e[1])

    def run(self):
        preds = self.generate_predictions()
        binary = self.threshold(preds)
        stages = [('After threshold', binary),
            ('After gap-fill', self.fill_gaps(binary)),]
        stages.append(('After min-dur', self.remove_short(stages[-1][1])))
        for label, b in stages:
            evs = self.extract_events(b)
            print(f"\n{label}:")
            for cls,s,e in evs:
                print(f"  {cls:>8} [{s:>3}-{e:>3}] "
                      f"({s*self.hop/self.sr:.2f}-{e*self.hop/self.sr:.2f}s)")

proc = SEDPostProcessor()
proc.run()

After thresholding, Dog has three separate events. Gap-filling merges events separated by 5 or fewer frames. Minimum duration filtering removes short spurious detections caused by noise.

On to today's episode

Here we go! We've been building up the audio AI toolkit over the past four episodes: fundamentals in episode #92, speech-to-text in #93, text-to-speech in #94, and classifying sounds in #95. All of those deal with understanding or reproducing existing audio. Today we're going somewhere fundamentally different: creating music from scratch. Teaching a machine not just what music sounds like, but what makes it music -- rhythm, melody, harmony, and that elusive quality of musical structure that makes a piece feel coherent over minutes rather than seconds.

This is personally one of my favorite topics in the entire series. Music generation sits at the intersection of nearly everything we've built: sequence modeling (episodes #48-51), transformers (#52-53), generative models (GANs in #55, diffusion in #84-85), VQ-VAE (which we touched on in #85 as well), and the audio representations from episode #92. If you've been following along, you've already got all the building blocks. Today we put them together in a genuinly creative way ;-)

How do you represent music?

Before we can generate music, we need to decide how to represent it. And this choice has massive consequences for what kinds of models work and what kinds of music they can produce.

There are three major representaton paradigms:

1. Symbolic (MIDI/piano roll): Represent music as a sequence of discrete events -- note on, note off, pitch, velocity, timing. This is what MIDI files contain. It's compact, structured, and explicitly encodes musical concepts like chords and rhythm. But it doesn't capture timbre, expression, or the thousand subtle variations that make a real piano sound different from a synthesized one.

2. Spectrogram-based: Represent music as a Mel spectrogram (episode #92) and generate spectrograms with image-like models (CNNs, diffusion). The spectrogram captures everything about the sound -- timbre, dynamics, room acoustics -- but loses the discrete musical structure. You can't easily say "play a C major chord" to a spectrogram generator.

3. Raw audio: Generate the actual waveform samples directly. Maximum fidelity, zero information loss. But at 44,100 samples per second for CD-quality audio, a 3-minute song is nearly 8 million samples. Generating that autoregressively (one sample at a time, like we saw with WaveNet in episode #94) is computationally brutal.

import numpy as np


class MusicRepresentations:
    """Compare three ways to represent music
    for AI generation."""

    def __init__(self, sr=16000, bpm=120):
        self.sr = sr
        self.bpm = bpm
        self.beat_dur = 60.0 / bpm

    def midi_to_piano_roll(self, notes,
                            n_steps=32,
                            step_dur=None):
        """Convert MIDI note events to
        piano roll matrix."""
        if step_dur is None:
            step_dur = self.beat_dur / 4
        roll = np.zeros((128, n_steps))
        for pitch, start, dur, vel in notes:
            s = int(start / step_dur)
            e = min(int((start + dur)
                        / step_dur),
                    n_steps)
            for t in range(s, e):
                roll[pitch, t] = vel / 127.0
        return roll

    def piano_roll_to_tokens(self, roll):
        """Convert piano roll to token
        sequence (event-based encoding)."""
        tokens = []
        n_pitches, n_steps = roll.shape
        for t in range(n_steps):
            active = np.where(
                roll[:, t] > 0)[0]
            if len(active) > 0:
                tokens.append(
                    f"TIME_{t}")
                for p in active:
                    tokens.append(
                        f"NOTE_{p}_"
                        f"{roll[p,t]:.1f}")
            else:
                tokens.append(
                    f"TIME_{t}")
                tokens.append("REST")
        return tokens

    def run(self):
        # C major chord: C4, E4, G4
        notes = [
            (60, 0.0, 0.5, 100),  # C4
            (64, 0.0, 0.5, 90),   # E4
            (67, 0.0, 0.5, 80),   # G4
            (60, 0.5, 0.5, 100),  # C4
            (65, 0.5, 0.5, 90),   # F4
            (69, 0.5, 0.5, 80),   # A4
        ]
        roll = self.midi_to_piano_roll(
            notes, n_steps=16)
        tokens = self.piano_roll_to_tokens(
            roll)

        active_cells = (roll > 0).sum()
        total_cells = roll.size
        print(f"Piano roll: {roll.shape}")
        print(f"  Active cells: {active_cells}"
              f" / {total_cells} "
              f"({active_cells/total_cells:.1%})")
        print(f"  -> Extremely sparse!")
        print(f"\nToken sequence "
              f"({len(tokens)} tokens):")
        for t in tokens[:12]:
            print(f"  {t}")
        print(f"  ...")

        samples = int(
            1.0 * self.sr)
        print(f"\nRepresentation sizes "
              f"(1 second of music):")
        print(f"  MIDI events: "
              f"~{len(notes)} events")
        print(f"  Piano roll (16th notes): "
              f"128 x {int(self.bpm*4/60)}"
              f" = {128*int(self.bpm*4/60)}"
              f" values")
        print(f"  Mel spectrogram: "
              f"80 x {samples//256} = "
              f"{80*(samples//256)} values")
        print(f"  Raw audio: "
              f"{samples} samples")


rep = MusicRepresentations()
rep.run()

The sparsity of the piano roll is the key insight here. At any given time step, maybe 3-5 out of 128 possible pitches are active. That's 97%+ zeros. This sparsity is why event-based token sequences (listing only the active notes) are much more efficient than the dense piano roll matrix. Most modern music transformers use token-based representations for exactly this reason.

Markov chain music: the simplest generator

Before jumping into neural networks, let me show you the simplest possible generative music model: a Markov chain. We saw Markov models briefly when discussing language modeling in episode #57 (predicting the next word based on the previous N words). Same idea here, but with notes instead of words:

import numpy as np
from collections import defaultdict


class MarkovMusicGenerator:
    """Generate melodies using a first-order
    Markov chain over note sequences."""

    def __init__(self, order=2):
        self.order = order
        self.transitions = defaultdict(
            lambda: defaultdict(int))

    def train(self, sequences):
        """Learn transition probabilities
        from example melodies."""
        for seq in sequences:
            for i in range(
                    len(seq) - self.order):
                context = tuple(
                    seq[i:i + self.order])
                next_note = seq[
                    i + self.order]
                self.transitions[context][
                    next_note] += 1

    def generate(self, seed, length=32):
        """Generate a melody from a seed."""
        rng = np.random.RandomState(42)
        melody = list(seed[:self.order])
        for _ in range(length):
            context = tuple(
                melody[-self.order:])
            if context not in self.transitions:
                melody.append(
                    rng.choice([60, 62, 64,
                                65, 67]))
                continue
            nexts = self.transitions[context]
            notes = list(nexts.keys())
            counts = np.array(
                list(nexts.values()),
                dtype=float)
            probs = counts / counts.sum()
            choice = rng.choice(
                notes, p=probs)
            melody.append(choice)
        return melody

    def note_name(self, midi):
        names = ['C', 'C#', 'D', 'D#', 'E',
                 'F', 'F#', 'G', 'G#', 'A',
                 'A#', 'B']
        return (f"{names[midi % 12]}"
                f"{midi // 12 - 1}")

    def synthesize(self, melody, sr=16000,
                    note_dur=0.25):
        """Convert MIDI melody to audio."""
        samples = []
        for note in melody:
            freq = 440.0 * (
                2.0 ** ((note - 69) / 12.0))
            t = np.arange(
                int(sr * note_dur)) / sr
            wave = 0.3 * np.sin(
                2 * np.pi * freq * t)
            # Simple envelope
            env = np.ones_like(wave)
            attack = int(0.01 * sr)
            release = int(0.05 * sr)
            env[:attack] = np.linspace(
                0, 1, attack)
            env[-release:] = np.linspace(
                1, 0, release)
            samples.append(wave * env)
        return np.concatenate(samples)

    def run(self):
        # Training data: simple melodies
        # in C major scale
        training = [
            [60, 62, 64, 65, 67, 65,
             64, 62, 60],
            [64, 62, 60, 62, 64, 64,
             64, 62, 62, 62],
            [67, 65, 64, 62, 60, 60,
             62, 64, 67],
            [60, 64, 67, 64, 60, 65,
             69, 67, 65, 64, 62, 60],
            [60, 60, 67, 67, 69, 69,
             67, 65, 65, 64, 64, 62,
             62, 60],
        ]
        self.train(training)
        seed = [60, 64]
        melody = self.generate(seed, 16)
        names = [self.note_name(n)
                 for n in melody]
        print(f"Seed: {names[:2]}")
        print(f"Generated: "
              f"{' '.join(names)}")

        audio = self.synthesize(melody)
        print(f"\nAudio: {len(audio)} samples"
              f" ({len(audio)/16000:.1f}s)")
        print(f"Transitions learned: "
              f"{len(self.transitions)}")


gen = MarkovMusicGenerator()
gen.run()

The Markov chain captures local patterns: after C-E, G is likely (because C-E-G is a C major arpeggio that appears in the training data). But it has zero understanding of larger-scale structure. It doesn't know that songs have verses and choruses, that melodic phrases should resolve, or that a piece needs tension and release. The melody wanders aimlessly because each next-note decision only looks at the last 2 notes. This is exactly the same limitation we identified for N-gram language models back in episode #57 -- and the same limitation that motivated the move to neural networks.

Music Transformer

The Music Transformer (Huang et al., 2018) applies the transformer architecture (episodes #52-53) directly to symbolic music sequences. The key innovation is relative positional encoding -- instead of absolute position embeddings ("this token is at position 47"), it encodes the distance between tokens ("these two tokens are 3 steps apart"). Music has strong relative patterns: a melody transposed up by a fifth is still recognizable as the same melody. Absolute positions don't capture that; relative positions do:

import torch
import torch.nn as nn
import torch.nn.functional as F


class RelativeAttention(nn.Module):
    """Self-attention with relative position
    encoding for music sequences."""

    def __init__(self, d_model=256,
                 n_heads=4, max_len=512):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        self.max_len = max_len
        self.qkv = nn.Linear(
            d_model, 3 * d_model)
        self.rel_embed = nn.Embedding(
            2 * max_len + 1, self.d_k)
        self.out = nn.Linear(
            d_model, d_model)

    def forward(self, x):
        b, seq_len, _ = x.shape
        qkv = self.qkv(x).reshape(
            b, seq_len, 3,
            self.n_heads, self.d_k)
        q, k, v = qkv.unbind(dim=2)
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        # Content-based attention
        content = torch.matmul(
            q, k.transpose(-2, -1))

        # Relative position attention
        positions = torch.arange(
            seq_len, device=x.device)
        rel_pos = (
            positions.unsqueeze(0)
            - positions.unsqueeze(1)
            + self.max_len)
        rel_pos = rel_pos.clamp(
            0, 2 * self.max_len)
        rel_emb = self.rel_embed(rel_pos)
        # q: (b, h, seq, d_k)
        # rel: (seq, seq, d_k)
        rel_attn = torch.einsum(
            'bhqd,qkd->bhqk',
            q, rel_emb)

        scores = (content + rel_attn) / (
            self.d_k ** 0.5)
        mask = torch.triu(
            torch.ones(seq_len, seq_len,
                       device=x.device),
            diagonal=1).bool()
        scores.masked_fill_(
            mask.unsqueeze(0).unsqueeze(0),
            float('-inf'))
        weights = F.softmax(scores, dim=-1)
        out = torch.matmul(weights, v)
        out = out.transpose(1, 2).reshape(
            b, seq_len, self.d_model)
        return self.out(out)


# Demo
attn = RelativeAttention()
x = torch.randn(2, 64, 256)
out = attn(x)
print(f"Input: {x.shape}")
print(f"Output: {out.shape}")
print(f"Relative embeddings: "
      f"{attn.rel_embed.weight.shape}")

The causal mask (upper triangular) ensures the model can only attend to past tokens during generation -- same as GPT (episode #58). The relative position embeddings give the model a sense of musical distance without fixing absolute positions. A pattern like "up a third, then down a second" gets the same relative encoding regardless of where it appears in the piece.

The Music Transformer was trained on the MAESTRO dataset (about 200 hours of virtuoso piano performances with precise MIDI annotations). The results were genuinely impressive for 2018: multi-minute piano pieces with coherent structure, proper phrasing, and recognizable musical style. But it's limited to symbolic MIDI -- the output still needs a synthesizer to become actual audio.

Jukebox: raw audio at scale

Jukebox (Dhariwal et al., 2020, from OpenAI) took the brute-force approach: generate raw audio waveforms directly. The key enabler is VQ-VAE (Vector Quantized Variational Autoencoder), which we touched on in episode #85 when discussing image generation. VQ-VAE compresses the raw audio into a much shorter sequence of discrete codes, and then a transformer generates those codes autoregressively:

import numpy as np


class VQVAEMusicDemo:
    """Demonstrate VQ-VAE compression for
    music generation (simplified)."""

    def __init__(self, sr=16000,
                 codebook_size=512,
                 n_codes_per_sec=50):
        self.sr = sr
        self.codebook_size = codebook_size
        self.codes_per_sec = n_codes_per_sec

    def compression_analysis(self):
        """Show why VQ-VAE is essential
        for raw audio generation."""
        durations = [1, 10, 30, 180]
        print(f"{'Duration':>10} {'Raw':>12} "
              f"{'VQ codes':>10} "
              f"{'Ratio':>8}")
        print("-" * 44)
        for dur in durations:
            raw = dur * self.sr
            codes = dur * self.codes_per_sec
            ratio = raw / codes
            print(f"{dur:>8}s "
                  f"{raw:>11,} "
                  f"{codes:>10,} "
                  f"{ratio:>7.0f}x")

    def hierarchical_vqvae(self):
        """Jukebox uses 3 levels of VQ-VAE
        with different compression rates."""
        levels = {
            'Top (most compressed)': {
                'downsample': 128,
                'codebook': 2048,
                'codes_per_sec': 344},
            'Middle': {
                'downsample': 32,
                'codebook': 2048,
                'codes_per_sec': 1378},
            'Bottom (least compressed)': {
                'downsample': 8,
                'codebook': 2048,
                'codes_per_sec': 5512},
        }
        dur = 60  # 1 minute at 44.1 kHz
        raw_samples = dur * 44100
        print(f"\nJukebox hierarchical VQ-VAE")
        print(f"1 min of 44.1 kHz audio = "
              f"{raw_samples:,} samples\n")
        print(f"{'Level':>30} {'Codes':>10} "
              f"{'Compression':>12}")
        print("-" * 56)
        for name, lvl in levels.items():
            codes = dur * lvl['codes_per_sec']
            comp = raw_samples / codes
            print(f"{name:>30} "
                  f"{codes:>10,} "
                  f"{comp:>11.0f}x")

    def run(self):
        self.compression_analysis()
        self.hierarchical_vqvae()
        print(f"\nThe top level generates the"
              f" musical structure.")
        print(f"Middle and bottom levels "
              f"add detail and fidelity.")
        print(f"Each level is its OWN "
              f"transformer -- three models"
              f" total.")


demo = VQVAEMusicDemo()
demo.run()

The compression ratios here are the key. At the top level, Jukebox compresses 44,100 samples per second down to about 344 codes per second -- a 128x reduction. A 3-minute song goes from ~8 million samples to ~62,000 codes. That's a sequence length a transformer can handle (barely -- Jukebox's transformer still needed clever memory tricks to process sequences this long).

The generation process goes top-down: first generate the highest-level codes (capturing broad musical structure -- verse, chorus, bridge), then condition the middle-level transformer on those to add more detail, and finally the bottom level fills in the fine audio texture. Each level is a separate transformer model.

Jukebox was trained on 1.2 million songs (600K hours of audio) with artist and genre labels as conditioning. The results are... impressive and eerie. It generates vocals that sound like garbled singing -- you can hear that it's trying to sing words, but the lyrics are mostly nonsensical. The instrumentals are more convincing. And the fact that it works directly on raw audio means it captures timbre, dynamics, and production style in ways that symbolic systems never could. But it takes about 9 hours to generate one minute of audio on a V100 GPU. Not exactly real-time ;-)

MusicGen: text-conditioned generation

MusicGen (Copet et al., 2023, from Meta) represents the current state of the art for practical music generation. It generates music from text descriptions ("epic orchestral trailer music with pounding drums") using an approach that combines the ideas from language modeling (episode #57) with the audio compression from Jukebox:

import torch
import torch.nn as nn


class SimpleMusicGen(nn.Module):
    """Simplified MusicGen-style model:
    text conditioning + audio token
    generation."""

    def __init__(self, vocab_size=2048,
                 text_dim=768, d_model=1024,
                 n_layers=24, n_heads=16,
                 max_len=1500):
        super().__init__()
        self.tok_embed = nn.Embedding(
            vocab_size, d_model)
        self.pos_embed = nn.Embedding(
            max_len, d_model)
        self.text_proj = nn.Linear(
            text_dim, d_model)

        layer = nn.TransformerDecoderLayer(
            d_model, n_heads,
            dim_feedforward=4 * d_model,
            batch_first=True)
        self.decoder = nn.TransformerDecoder(
            layer, n_layers)
        self.head = nn.Linear(
            d_model, vocab_size)

    def forward(self, audio_tokens,
                text_features):
        """Generate next audio token
        conditioned on text."""
        b, seq_len = audio_tokens.shape
        pos = torch.arange(
            seq_len,
            device=audio_tokens.device)
        x = (self.tok_embed(audio_tokens)
             + self.pos_embed(pos))

        # Text conditioning as memory
        # for cross-attention
        mem = self.text_proj(
            text_features)

        mask = nn.Transformer\
            .generate_square_subsequent_mask(
                seq_len,
                device=audio_tokens.device)
        out = self.decoder(
            x, mem, tgt_mask=mask)
        return self.head(out)


# Architecture summary
model = SimpleMusicGen()
total = sum(p.numel()
            for p in model.parameters())
print(f"Parameters: {total:,}")
print(f"  (~{total/1e6:.0f}M)")

tokens = torch.randint(0, 2048, (1, 100))
text = torch.randn(1, 10, 768)
logits = model(tokens, text)
print(f"\nInput tokens: {tokens.shape}")
print(f"Text features: {text.shape}")
print(f"Output logits: {logits.shape}")

MusicGen's clever contribution is the codebook interleaving pattern. Remember how Jukebox uses three separate VQ-VAE levels with three separate transformers? MusicGen uses a single transformer but interleaves the codebook entries from multiple quantizers. Meta's EnCodec audio codec produces 4 codebook streams in parallel. Instead of generating them separately, MusicGen generates them in a specific interleaved pattern that reduces the effective sequence length while maintaining audio quality.

The text conditioning comes from a frozen T5 encoder (the same text encoder family we discussed in episode #59 on BERT-style models). The text embedding is fed into the transformer via cross-attention, so every generated audio token can attend to the text description. This is conceptually identical to how Tacotron 2 conditions audio generation on text in TTS (episode #94) -- the conditioning mechanism is the same, just the conditioning signal is a free-form text description rather than a phoneme sequence.

Comparing music generation approaches

Let me put all the approaches in perspective:

class MusicGenComparison:
    """Compare music generation systems."""

    def __init__(self):
        self.systems = {
            'MuseNet (OpenAI)': dict(
                year=2019, domain='symbolic',
                output='MIDI',
                max_dur='4 min',
                params='~1.2B',
                conditioning='genre+instrument',
                realtime=False),
            'Music Transformer': dict(
                year=2018, domain='symbolic',
                output='MIDI',
                max_dur='~2 min',
                params='~41M',
                conditioning='none (unconditional)',
                realtime=True),
            'Jukebox': dict(
                year=2020, domain='raw audio',
                output='waveform',
                max_dur='~4 min',
                params='~5B',
                conditioning='artist+genre+lyrics',
                realtime=False),
            'MusicGen': dict(
                year=2023, domain='compressed',
                output='waveform',
                max_dur='30s',
                params='300M-3.3B',
                conditioning='text description',
                realtime=True),
            'MusicLM (Google)': dict(
                year=2023, domain='compressed',
                output='waveform',
                max_dur='~5 min',
                params='~600M',
                conditioning='text description',
                realtime=True),
        }

    def run(self):
        print(f"{'System':>20} {'Year':>5} "
              f"{'Domain':>12} "
              f"{'Output':>10} "
              f"{'Params':>10}")
        print("-" * 62)
        for name, s in self.systems.items():
            print(f"{name:>20} "
                  f"{s['year']:>5} "
                  f"{s['domain']:>12} "
                  f"{s['output']:>10} "
                  f"{s['params']:>10}")
        print(f"\nKey tradeoffs:")
        print(f"  Symbolic: compact, "
              f"controllable, no timbre")
        print(f"  Raw audio: full fidelity, "
              f"massive compute")
        print(f"  Compressed: best of both "
              f"- VQ-VAE reduces sequence "
              f"length")


comp = MusicGenComparison()
comp.run()

The progression is clear: symbolic models (2018-2019) gave structure but no sound. Raw audio models (2020) gave sound but were impractically slow. Compressed audio models (2023) hit the sweet spot -- generate discrete codes from a learned codebook, then decode to audio in a single pass. The VQ-VAE compression is what made practical music generation possible, just as it enabled practical image generation in the DALL-E lineage (episode #85).

Evaluating generated music

This is where things get philosophically interesting. How do you measure whether generated music is "good"? What even is good music? There's no equivalent of Word Error Rate (ASR) or FID score (images) that captures musical quality objectively. Two people can listen to the same generated piece and have completely oposite reactions:

import numpy as np


class MusicEvaluator:
    """Metrics for evaluating generated
    music quality."""

    def __init__(self):
        self.rng = np.random.RandomState(42)

    def frechet_audio_distance(self,
                                real_feats,
                                gen_feats):
        """FAD: lower is better. Measures
        distribution similarity between real
        and generated audio features."""
        mu_r = np.mean(real_feats, axis=0)
        mu_g = np.mean(gen_feats, axis=0)
        sigma_r = np.cov(
            real_feats, rowvar=False)
        sigma_g = np.cov(
            gen_feats, rowvar=False)

        diff = mu_r - mu_g
        mean_term = np.dot(diff, diff)

        # Trace term (simplified -- full
        # FAD uses matrix square root)
        trace_term = (
            np.trace(sigma_r)
            + np.trace(sigma_g)
            - 2 * np.sqrt(
                np.abs(np.trace(
                    sigma_r @ sigma_g))))
        return float(
            mean_term + trace_term)

    def pitch_class_histogram(self, pitches):
        """Distribution over 12 pitch
        classes (C through B)."""
        hist = np.zeros(12)
        for p in pitches:
            hist[p % 12] += 1
        total = hist.sum()
        if total > 0:
            hist /= total
        return hist

    def rhythmic_regularity(self, onsets,
                             sr=16000):
        """Measure how regular the rhythm
        is. High = mechanical, low = human-
        like variation."""
        if len(onsets) < 3:
            return 0.0
        intervals = np.diff(onsets)
        mean_ioi = np.mean(intervals)
        if mean_ioi == 0:
            return 0.0
        cv = np.std(intervals) / mean_ioi
        return float(1.0 - min(cv, 1.0))

    def harmonic_consistency(self, pitches,
                              window=8):
        """Check if generated pitches stay
        within consistent key areas."""
        if len(pitches) < window:
            return 0.0
        scores = []
        # Major scale intervals
        major = {0, 2, 4, 5, 7, 9, 11}
        for i in range(
                len(pitches) - window):
            chunk = pitches[i:i + window]
            best = 0
            for root in range(12):
                scale = {
                    (root + n) % 12
                    for n in major}
                in_key = sum(
                    1 for p in chunk
                    if p % 12 in scale)
                best = max(best, in_key)
            scores.append(best / window)
        return float(np.mean(scores))

    def run(self):
        # Simulate "real" vs "generated"
        # audio features
        real = self.rng.randn(100, 64)
        good_gen = real + self.rng.randn(
            100, 64) * 0.3
        bad_gen = self.rng.randn(
            100, 64) * 2.0

        fad_good = (
            self.frechet_audio_distance(
                real, good_gen))
        fad_bad = (
            self.frechet_audio_distance(
                real, bad_gen))
        print(f"FAD (good gen): "
              f"{fad_good:.1f}")
        print(f"FAD (bad gen):  "
              f"{fad_bad:.1f}")
        print(f"(lower = more similar "
              f"to real music)\n")

        # Musical quality metrics
        good_melody = [60, 62, 64, 65, 67,
                       65, 64, 62, 60, 64,
                       67, 72, 71, 67, 64]
        random_melody = self.rng.randint(
            48, 84, 15).tolist()

        print(f"Pitch class distribution:")
        for name, mel in [
                ('Tonal', good_melody),
                ('Random', random_melody)]:
            hist = self.pitch_class_histogram(
                mel)
            top3 = np.argsort(hist)[::-1][:3]
            names = ['C','C#','D','D#','E',
                     'F','F#','G','G#','A',
                     'A#','B']
            top = ', '.join(
                f"{names[i]}:{hist[i]:.0%}"
                for i in top3)
            print(f"  {name}: {top}")

        hc_good = self.harmonic_consistency(
            good_melody)
        hc_rand = self.harmonic_consistency(
            random_melody)
        print(f"\nHarmonic consistency:")
        print(f"  Tonal melody: {hc_good:.2f}")
        print(f"  Random notes: {hc_rand:.2f}")


evaluator = MusicEvaluator()
evaluator.run()

Frechet Audio Distance (FAD) is the audio equivalent of FID for images (episode #55). It compares the distribution of audio features (typically VGGish embeddings) between real and generated music. Lower FAD means the generated music's feature distribution is closer to real music. But a low FAD doesn't mean the music is good -- it means it's statistically similar to the training data. You could have a low FAD score on a model that produces pleasant but utterly boring elevator music.

Pitch class histogram analysis checks whether the generated music uses pitch distributions that resemble real music. Tonal music concentrates on 7 out of 12 pitch classes (the notes of the key). Random note generation spreads evenly across all 12. This is a crude but useful sanity check -- if your model's pitch distribution is flat, it's not generating tonal music.

Human evaluation remains the gold standard, just like MOS for TTS (episode #94). Studies typically ask listeners to rate generated music on dimensions like: overall quality, musicality, coherence, enjoyability, and whether it sounds like it could have been composed by a human. The challenge is that musical taste is deeply subjective -- one listener's "repetitive garbage" is another's "hypnotic minimalism."

What music generation still gets wrong

Even the best current systems have fundamental limitations:

Long-range structure: generating 30 seconds of convincing music is mostly solved. Generating 5 minutes with proper song structure (intro, verse, chorus, bridge, outro) where the chorus sounds like a development of the verse material? Much harder. Current models tend to produce music that wanders or becomes repetitive over longer durations.

Musical intentionality: human composers make deliberate choices. A dissonant chord resolves to a consonant one because the composer wanted tension-and-release. Current models learn statistical patterns of tension-and-release but don't have the intentional, goal-directed composition that humans bring. The music sounds "right" but doesn't feel purposeful.

Originality vs. memorization: large training datasets mean the model may reproduce existing songs or combine fragments in ways that are technically plagiarism. How much of "creativity" is actually recombination of training data? This is a genuinely unsolved philosophical question, and it applies equally to human composers (who also learned from existing music) and AI models.

Cultural context: music doesn't exist in a vacuum. A minor key is "sad" in Western music but might not carry the same emotional weight in other musical traditions. Current models trained on Western music datasets reproduce Western musical assumptions -- they don't understand the cultural context that gives music meaning.

The field is moving fast though. New architectures for long-range coherence, better evaluation metrics, and larger multi-cultural training datasets are all active research areas. And the combination of symbolic and audio approaches (generate structure symbolically, render audio neurally) might be the path to truly compelling AI composition.

Samengevat

  • Music can be represented symbolically (MIDI events/piano rolls), as spectrograms, or as raw audio; the choice determines what models work and what quality is achievable;
  • Markov chains capture local note transitions but have zero understanding of large-scale musical structure -- they illustrate why neural sequence models are necessary;
  • the Music Transformer uses relative positional encoding (distance between tokens, not absolute position) to capture musical patterns that work regardless of transposition -- trained on the MAESTRO piano dataset;
  • Jukebox generates raw audio using hierarchical VQ-VAE compression (128x at the top level) plus autoregressive transformers at each level -- impressive quality but 9 hours per minute of audio;
  • MusicGen combines EnCodec audio compression with a single transformer and text conditioning via cross-attention -- practical, real-time generation from text descriptions like "ambient electronic with soft pads";
  • FAD (Frechet Audio Distance) measures distributional similarity to real music; pitch class histograms and harmonic consistency check tonal quality; but human evaluation remains the gold standard because musical quality is fundamentally subjective;
  • current limitations: long-range structural coherence, musical intentionality (purpose behind compositional choices), originality vs. memorization, and cultural context -- the same deep challenges that make music composition hard for humans too.

The audio AI arc is winding down. We've covered how machines hear (episode #92), understand speech (#93), produce speech (#94), classify sounds (#95), and now create music. There's still more ground to cover in the audio domain though -- identifying who's speaking, understanding spoken intent, cleaning up noisy audio, and combining audio with vision.

Exercises

Exercise 1: Build a melody complexity analyzer. Create a class MelodyAnalyzer that: (a) takes a list of MIDI note numbers representing a melody, (b) computes interval histogram: count the frequency of each interval (in semitones) between consecutive notes, classify intervals as: unison (0), step (1-2), skip (3-4), leap (5-7), large leap (8+), and report the percentage of each category, (c) computes pitch range: the difference between highest and lowest note in semitones, (d) computes contour complexity: count the number of direction changes in the melody (going up then down or vice versa), divide by total number of intervals to get a ratio between 0 (monotonically ascending/descending) and ~1.0 (highly zigzagging), (e) tests on three melodies: (1) a C major scale ascending (60,62,64,65,67,69,71,72) -- should have 100% steps, zero contour changes, (2) "Twinkle Twinkle" opening (60,60,67,67,69,69,67,65,65,64,64,62,62,60) -- mixed intervals with direction changes, (3) a random walk of 20 notes starting at 60 with random intervals of -5 to +5 (seed 42) -- should have high contour complexity. Print a comparison table showing all metrics for all three melodies.

Exercise 2: Build a chord progression generator and analyzer. Create a class ChordProgressionGen that: (a) defines the 7 diatonic triads in C major as lists of MIDI notes (I = [60,64,67], ii = [62,65,69], iii = [64,67,71], IV = [65,69,72], V = [67,71,74], vi = [69,72,76], vii_dim = [71,74,77]), (b) implements a transition probability matrix for common progressions: I can go to any chord, IV often goes to V or I, V strongly resolves to I or vi, vi goes to IV or ii, ii goes to V, iii goes to vi or IV, vii goes to I, (c) generates an 8-chord progression starting from I using the transition matrix (seed 42), (d) for each chord in the generated progression, synthesizes the 3 notes simultaneously as sine waves (0.5 seconds each, sr=16000), (e) computes and prints: the total audio length, which chord numerals appeared, the number of "strong resolutions" (V->I or vii->I), and the consonance of each chord (measured as the ratio of frequency ratios that are close to simple fractions like 3/2, 4/3, 5/4). Verify that the V->I resolution is the most common cadence in the generated progression.

Exercise 3: Build a rhythmic pattern evaluator. Create a class RhythmEvaluator that: (a) represents rhythmic patterns as lists of onset times in seconds, (b) creates 4 test patterns: (1) "metronome" -- perfectly even 8th notes at 120 BPM (onsets every 0.25s for 2 seconds), (2) "swing" -- alternating long-short pattern (0.33s, 0.17s, 0.33s, 0.17s...) for 2 seconds, (3) "human drummer" -- the metronome pattern with Gaussian timing noise (std=0.015s, seed 42) added to each onset, (4) "random" -- 16 uniformly random onset times in [0, 2] sorted ascending (seed 42), (c) for each pattern computes: number of onsets, mean inter-onset interval (IOI), IOI coefficient of variation (std/mean -- 0 for perfectly regular), swing ratio (mean ratio of odd IOIs to even IOIs -- 1.0 for straight, ~2.0 for heavy swing), tempo stability (1 - normalized variance of IOI computed in sliding windows of 4 onsets), (d) prints a comparison table with all metrics. Verify that: the metronome has CV=0 and swing ratio=1.0, the swing pattern has a swing ratio near 2.0, the human drummer has small but nonzero CV, and the random pattern has the highest CV and lowest tempo stability.

De groeten!

@scipio



0
0
0.000
0 comments