Learn AI Series (#107) - Deep Q-Networks (DQN)

avatar

Learn AI Series (#107) - Deep Q-Networks (DQN)

variant-a-04-coral.png

What will I learn

  • You will learn why tabular Q-Learning falls apart the moment the state space gets big, and why a lookup table was never going to scale to pixels;
  • function approximation: handing the job of estimating Q-values over to a neural network, so unseen states still get sensible numbers;
  • experience replay: the memory buffer that breaks the temporal correlations which otherwise wreck training;
  • target networks: freezing a second copy of the network to stop the targets running away from you;
  • Double DQN: a one-line change that kills the overestimation bias baked into the plain max;
  • and the extras that turned DQN from a clever trick into a workhorse -- prioritized replay and the dueling architecture.

Requirements

  • A working modern computer running macOS, Windows or Ubuntu;
  • An installed Python 3(.11+) distribution with NumPy and PyTorch;
  • The ambition to learn AI and machine learning.

Difficulty

  • Beginner

Curriculum (of the Learn AI Series):

Learn AI Series (#107) - Deep Q-Networks (DQN)

Solutions to Episode #106 Exercises

Before we wheave the neural networks into the picture, let's settle last episode's three exercises. All of them lean on the CliffWalking environment and the n_step_td evaluator from episode #106, so I'm assuming those are imported and sitting in scope.

Exercise 1: Build a SARSA vs Q-Learning reward tracker on CliffWalking. Return the total reward per episode for both, train for 500 episodes, print a smoothed moving average, and explain why the "worse" learner ends up with the better optimal policy.

import numpy as np
from collections import defaultdict
# Assumes CliffWalking from episode #106.


def sarsa_tracked(env, n_episodes=500, alpha=0.5, gamma=1.0, epsilon=0.1):
    Q = defaultdict(lambda: np.zeros(env.n_actions))
    rewards_per_episode = []

    def epsilon_greedy(state):
        if np.random.random() < epsilon:
            return np.random.randint(env.n_actions)
        return int(np.argmax(Q[state]))

    for _ in range(n_episodes):
        state = env.reset()
        action = epsilon_greedy(state)
        total, done = 0.0, False
        while not done:
            next_state, reward, done = env.step(action)
            next_action = epsilon_greedy(next_state)
            target = reward + gamma * Q[next_state][next_action] * (1 - done)
            Q[state][action] += alpha * (target - Q[state][action])
            state, action = next_state, next_action
            total += reward
        rewards_per_episode.append(total)
    return Q, rewards_per_episode


def q_learning_tracked(env, n_episodes=500, alpha=0.5, gamma=1.0, epsilon=0.1):
    Q = defaultdict(lambda: np.zeros(env.n_actions))
    rewards_per_episode = []

    def epsilon_greedy(state):
        if np.random.random() < epsilon:
            return np.random.randint(env.n_actions)
        return int(np.argmax(Q[state]))

    for _ in range(n_episodes):
        state = env.reset()
        total, done = 0.0, False
        while not done:
            action = epsilon_greedy(state)
            next_state, reward, done = env.step(action)
            target = reward + gamma * np.max(Q[next_state]) * (1 - done)
            Q[state][action] += alpha * (target - Q[state][action])
            state = next_state
            total += reward
        rewards_per_episode.append(total)
    return Q, rewards_per_episode


def moving_average(x, window=10):
    return np.convolve(x, np.ones(window) / window, mode="valid")


env = CliffWalking()
_, r_sarsa = sarsa_tracked(env)
_, r_qlearn = q_learning_tracked(env)
ma_s, ma_q = moving_average(r_sarsa), moving_average(r_qlearn)

print(f"{'episode':>8}{'SARSA':>10}{'Q-Learning':>12}")
for ep in range(0, len(ma_s), 50):
    print(f"{ep:>8}{ma_s[ep]:>10.1f}{ma_q[ep]:>12.1f}")

Read off the smoothed curves and Q-Learning sits lower for the whole run -- its reward-per-episode is consistently worse than SARSA's. That looks like a damning verdict until you remember what each one is actually optimising. Q-Learning's greedy target evaluates the optimal path right along the cliff edge, so during training epsilon keeps shoving it over the edge for -100 a pop, dragging the curve down. SARSA bakes that exploration risk into its values and retreats to the safe ledge one row up, so it collects more reward while learning. The punchline: Q-Learning's greedy policy is genuinely shorter and optimal -- it just pays for that knowledge with a rougher training ride. The on-line score and the quality of the final greedy policy are two different questions, and here they point in opposite directions.

Exercise 2: Implement Expected SARSA -- bootstrap from the expectation over next actions under the epsilon-greedy policy, in stead of the sampled next action (SARSA) or the max (Q-Learning).

import numpy as np
from collections import defaultdict
# Assumes CliffWalking from episode #106.


def expected_sarsa(env, n_episodes=500, alpha=0.5, gamma=1.0, epsilon=0.1):
    Q = defaultdict(lambda: np.zeros(env.n_actions))

    def epsilon_greedy(state):
        if np.random.random() < epsilon:
            return np.random.randint(env.n_actions)
        return int(np.argmax(Q[state]))

    for _ in range(n_episodes):
        state, done = env.reset(), False
        while not done:
            action = epsilon_greedy(state)
            next_state, reward, done = env.step(action)

            # expected Q under epsilon-greedy:
            # (1 - eps) on the greedy action, eps spread over all actions
            q_next = Q[next_state]
            expected_q = (1 - epsilon) * np.max(q_next) + epsilon * np.mean(q_next)

            target = reward + gamma * expected_q * (1 - done)
            Q[state][action] += alpha * (target - Q[state][action])
            state = next_state
    return dict(Q)

The trick is that (1 - epsilon) * max + epsilon * mean is exactly the expected value of Q[next_state] when the next action is drawn from epsilon-greedy (the greedy action carries the lion's share 1 - epsilon, and the leftover epsilon is smeared uniformly across all actions). Because we average over next actions in stead of sampling one, Expected SARSA removes the variance contributed by that random choice. On cliff walking it learns the same cautious, away-from-the-edge path as SARSA -- it is still on-policy -- but the updates are noticeably smoother, which in practice lets you crank alpha higher without the whole thing wobbling. It is the quiet middle child between SARSA and Q-Learning, and frankly it is underused ;-)

Exercise 3: Take the n_step_td evaluator and run it for n in [1, 2, 4, 8, 16, 32] on a random walk, measuring the error of the learned values against a reference. You should see a U-shape -- some intermediate n beats both TD(0) and full Monte Carlo.

import numpy as np
from collections import defaultdict
import random
# Assumes n_step_td from episode #106.


class RandomWalk:
    """19-state random walk. States 1..19, terminals 0 and 20.
    Equal-probability left/right step. +1 at the right end, -1 at the left."""

    def __init__(self, n=19):
        self.n = n
        self.start = (n + 1) // 2

    def reset(self):
        self.state = self.start
        return self.state

    def step(self, action):                 # action ignored -- the walk is random
        self.state += 1 if random.random() < 0.5 else -1
        if self.state == 0:
            return self.state, -1.0, True
        if self.state == self.n + 1:
            return self.state, 1.0, True
        return self.state, 0.0, False


def random_policy(state):
    return 0


# True values are linear from -1 to +1 across the chain (analytic reference).
true_V = {s: (s - 10) / 10.0 for s in range(1, 20)}

env = RandomWalk(19)
print(f"{'n':>4}{'RMS error':>12}")
for n in [1, 2, 4, 8, 16, 32]:
    errors = []
    for _ in range(20):                     # average over seeds to smooth the curve
        V = n_step_td(env, random_policy, n_steps=n, n_episodes=10, alpha=0.1)
        sq = [(V[s] - true_V[s]) ** 2 for s in range(1, 20)]
        errors.append(np.sqrt(np.mean(sq)))
    print(f"{n:>4}{np.mean(errors):>12.4f}")

Run it and the U-shape jumps right out of the column of numbers. At n = 1 you have plain TD(0): low variance, but it bootstraps so aggressively that with only ten episodes the signal barely propagates back along the chain. At n = 32 you have something close to Monte Carlo: unbiased, but every estimate is now hostage to one long noisy return. Somewhere around n = 4 to n = 8 the two failure modes cancel and the error bottoms out. This is the exact same picture as the famous random-walk figure in Sutton and Barto, reproduced on your own machine in twenty lines -- the dial between TD and MC is real, and the sweet spot lives in the middle.

On to today's episode

Right, episode 107 -- and this is the one where reinforcement learning stops being a toy and starts being the thing you read about in the news.

At the very end of episode #106 I left a deliberate bait dangling: Q-Learning is magnificent, but it leans on a humble Q dictionary, and the moment you swap that dictionary for a neural network "a whole box of new problems" springs open. Today we open the box. The payoff is enormous -- this is the algorithm that taught a computer to play Atari games straight from the pixels, the result that kicked off the entire deep reinforcement learning era. Having said that, let's first understand why the dictionary had to go.

Where the Q-table breaks

In episode #106 Q-Learning stored a value for every (state, action) pair in a table. That is perfectly fine when the world is small -- a 4x4 grid has 16 states, our Blackjack from episode #105 has a few hundred. You can visit each one thousands of times and pin down its value by averaging.

Now consider an Atari frame: 210x160 pixels, 128 possible colours each. The number of distinct screens is not "large", it is astronomical -- more states than there are atoms in quite some galaxies. Two problems hit you at once, and both are fatal:

  • You could never store a table that big.
  • You could never visit each state often enough to estimate its value, since you'll essentially never see the exact same screen twice.

The fix is the same idea that has powered most of this series since episode #38: replace the lookup table with a function approximator. In stead of looking up Q(s, a) in a dictionary, we feed the state s into a neural network and let it output Q-values for every action. Crucially, the network generalises -- a screen it has never seen before still gets a sensible answer, because the network has learned features (paddles, balls, enemies) that transfer across similar screens. This is the Deep Q-Network, the DQN, published by DeepMind in their 2015 Nature paper.

From table to network

A small fully-connected network is plenty for low-dimensional state vectors (think the four numbers describing a balancing pole):

import torch
import torch.nn as nn
import numpy as np
from collections import deque
import random


class QNetwork(nn.Module):
    """Neural network that approximates Q(s, a) for all actions at once."""
    def __init__(self, state_dim, n_actions, hidden=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, hidden),
            nn.ReLU(),
            nn.Linear(hidden, n_actions),   # one output per action
        )

    def forward(self, state):
        return self.net(state)              # returns Q-values for all actions

Notice the shape of the output: one Q-value per action, computed in a single forward pass. That is a deliberate design choice -- you get all the action values for a state in one shot, in stead of running the network once per action. For pixels, DeepMind used the convolutional stack we built up over episodes #45 and #46:

class AtariQNetwork(nn.Module):
    """CNN Q-network for raw pixel observations."""
    def __init__(self, n_actions):
        super().__init__()
        # Input: 4 stacked grayscale frames, each 84x84.
        self.conv = nn.Sequential(
            nn.Conv2d(4, 32, 8, stride=4),  nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2), nn.ReLU(),
            nn.Conv2d(64, 64, 3, stride=1), nn.ReLU(),
        )
        self.fc = nn.Sequential(
            nn.Linear(64 * 7 * 7, 512),
            nn.ReLU(),
            nn.Linear(512, n_actions),
        )

    def forward(self, x):
        # x: (batch, 4, 84, 84), pixel values scaled to 0-1
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)

Why four stacked frames in stead of one? Because a single snapshot has no notion of motion -- from one still frame you cannot tell whether the ball in Pong is heading left or right. Stack four consecutive frames and the network can read velocity straight off the input. (A neat trick, and far cheaper than wiring in a recurrent network like the LSTMs from episode #49.)

Why the naive version blows up

Here is the trap, and it catches everyone the first time. Take Q-Learning, swap the dictionary for QNetwork, run the same update -- and training diverges. Not "learns slowly". Diverges, spectacularly, into nonsense. Two distinct problems gang up on you:

Problem 1 -- correlated samples. In on-line Q-Learning the agent feeds the network a stream of consecutive transitions (s, a, r, s'), and consecutive transitions are heavily correlated (the agent is walking through connected states). Neural networks hate correlated training data -- it is exactly like training an image classifier where every cat photo comes first and every dog photo comes after. The network overfits to the recent stretch and "forgets" what it learned ten seconds ago.

Problem 2 -- moving targets. The Q-Learning target is r + gamma * max Q(s', a'). But Q is the very network you are updating. Every gradient step shifts the weights, which shifts all the targets, which means you are chasing a target that jumps every time you shoot at it. The network ends up endlessly pursuing its own tail.

DeepMind's two famous innovations each kill one of these problems dead: experience replay for the correlation, and a target network for the moving target. Let's take them in turn.

Experience replay

In stead of learning from each transition once and throwing it away, we stash every transition in a fixed-size replay buffer and train on random minibatches drawn from it:

class ReplayBuffer:
    """Fixed-size buffer that stores transitions and samples minibatches."""
    def __init__(self, capacity=100000):
        self.buffer = deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size=32):
        batch = random.sample(self.buffer, batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        return (
            torch.FloatTensor(np.array(states)),
            torch.LongTensor(actions),
            torch.FloatTensor(rewards),
            torch.FloatTensor(np.array(next_states)),
            torch.FloatTensor(dones),
        )

    def __len__(self):
        return len(self.buffer)

Random sampling shatters the temporal correlation -- one minibatch might mix a transition from episode 5, another from episode 200, another from episode 47. The network now sees a diverse, decorrelated dataset, which is precisely the i.i.d.-ish input that gradient descent was designed for. Nota bene: there is a lovely bonus here too -- each stored transition can be replayed for many gradient updates, so DQN squeezes far more learning out of every interaction with the environment (a big deal when each interaction is an expensive game step).

The target network

To pin down the moving target, DQN keeps two copies of the Q-network:

  • the online network -- the one updated every single step;
  • the target network -- a frozen snapshot, refreshed only occasionally.

The TD target is computed with the target network, not the online one:

# target = r + gamma * max_a' Q_target(s', a')   <- frozen target network
# loss   = ( target - Q_online(s, a) )^2

Because the target network's weights stand still between refreshes, the targets stay put long enough for the online network to actually converge toward them. Every C steps (typically somewhere in the 1,000 to 10,000 range) you copy the online weights into the target network, the target lurches forward to the new estimate, and the chase resumes -- but now in stable, deliberate hops in stead of a continuous wild scramble.

The complete DQN agent

Bolt the pieces together -- online network, target network, replay buffer, epsilon-greedy exploration (straight from the bandit toolkit in episode #103) -- and you get the full agent:

class DQNAgent:
    """Complete DQN agent: replay buffer + target network + epsilon-greedy."""
    def __init__(self, state_dim, n_actions, lr=1e-3, gamma=0.99,
                 epsilon_start=1.0, epsilon_end=0.01, epsilon_decay=0.995,
                 buffer_size=100000, batch_size=64, target_update=1000):
        self.n_actions = n_actions
        self.gamma = gamma
        self.batch_size = batch_size
        self.target_update = target_update
        self.epsilon = epsilon_start
        self.epsilon_end = epsilon_end
        self.epsilon_decay = epsilon_decay
        self.step_count = 0

        # online and target networks start out identical
        self.q_network = QNetwork(state_dim, n_actions)
        self.target_network = QNetwork(state_dim, n_actions)
        self.target_network.load_state_dict(self.q_network.state_dict())

        self.optimizer = torch.optim.Adam(self.q_network.parameters(), lr=lr)
        self.buffer = ReplayBuffer(buffer_size)

    def choose_action(self, state):
        """Epsilon-greedy over the online network's Q-values."""
        if random.random() < self.epsilon:
            return random.randint(0, self.n_actions - 1)
        with torch.no_grad():
            state_t = torch.FloatTensor(state).unsqueeze(0)
            return self.q_network(state_t).argmax(dim=1).item()

    def store_transition(self, state, action, reward, next_state, done):
        self.buffer.push(state, action, reward, next_state, done)

    def learn(self):
        """One gradient step on a sampled minibatch."""
        if len(self.buffer) < self.batch_size:
            return None

        states, actions, rewards, next_states, dones = self.buffer.sample(self.batch_size)

        # Q-values the online network assigns to the actions actually taken
        q_values = self.q_network(states)
        q_selected = q_values.gather(1, actions.unsqueeze(1)).squeeze(1)

        # TD target, using the FROZEN target network
        with torch.no_grad():
            max_next_q = self.target_network(next_states).max(dim=1)[0]
            targets = rewards + self.gamma * max_next_q * (1 - dones)

        loss = nn.functional.mse_loss(q_selected, targets)

        self.optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(self.q_network.parameters(), max_norm=1.0)  # stability
        self.optimizer.step()

        # periodically refresh the target network
        self.step_count += 1
        if self.step_count % self.target_update == 0:
            self.target_network.load_state_dict(self.q_network.state_dict())

        # decay exploration over time
        self.epsilon = max(self.epsilon_end, self.epsilon * self.epsilon_decay)
        return loss.item()

The gather call is the one bit worth squinting at: the network spits out a Q-value for every action, but the loss should only touch the action we actually took, so gather plucks out that single column per row. Everything else is the Q-Learning update from episode #106 wearing a PyTorch coat.

The training loop itself is almost boringly familiar -- it is the same act-store-learn rhythm, just with a replay buffer humming in the background:

def train_dqn(env, agent, n_episodes=1000):
    """Train a DQN agent on a Gym-style environment."""
    rewards_history = []

    for episode in range(n_episodes):
        state = env.reset()
        total_reward, done = 0.0, False

        while not done:
            action = agent.choose_action(state)
            next_state, reward, done, _ = env.step(action)

            agent.store_transition(state, action, reward, next_state, done)
            agent.learn()

            total_reward += reward
            state = next_state

        rewards_history.append(total_reward)
        if episode % 100 == 0:
            avg = np.mean(rewards_history[-100:])
            print(f"Episode {episode} | avg reward {avg:6.1f} | "
                  f"epsilon {agent.epsilon:.3f} | buffer {len(agent.buffer)}")

    return rewards_history

Point this at CartPole-v1 and within a few hundred episodes it learns to balance the pole indefinitely. That is genuinely the same code -- minus the CNN -- that conquered Atari. Wowzers.

Double DQN: fixing the overestimation

Plain DQN has a sneaky flaw: it systematically overestimates Q-values. The culprit is the max in the target r + gamma * max Q(s', a'). Your Q-estimates always carry some noise (they are network outputs, of course they do), and max does not pick the genuinely best action -- it picks the action whose noise happens to point highest. Take the max of noisy numbers and you get a number biased upward, every single time. Over thousands of updates that optimism compounds into a real problem.

Double DQN (van Hasselt et al., 2016) fixes it with a delightfully cheap idea: decouple choosing the action from evaluating it.

def learn_double_dqn(self):
    """Double DQN: online network SELECTS the action, target network EVALUATES it."""
    if len(self.buffer) < self.batch_size:
        return None

    states, actions, rewards, next_states, dones = self.buffer.sample(self.batch_size)

    q_selected = self.q_network(states).gather(1, actions.unsqueeze(1)).squeeze(1)

    with torch.no_grad():
        # online network picks which action it thinks is best
        best_actions = self.q_network(next_states).argmax(dim=1, keepdim=True)
        # target network reports the value of THAT action
        max_next_q = self.target_network(next_states).gather(1, best_actions).squeeze(1)
        targets = rewards + self.gamma * max_next_q * (1 - dones)

    loss = nn.functional.mse_loss(q_selected, targets)
    self.optimizer.zero_grad()
    loss.backward()
    nn.utils.clip_grad_norm_(self.q_network.parameters(), max_norm=1.0)
    self.optimizer.step()

    self.step_count += 1
    if self.step_count % self.target_update == 0:
        self.target_network.load_state_dict(self.q_network.state_dict())
    self.epsilon = max(self.epsilon_end, self.epsilon * self.epsilon_decay)
    return loss.item()

The whole change lives in two lines: the online network chooses the action (argmax), the target network grades it (gather). If the online network's noise oversells some action, the target network -- a different set of weights, with different noise -- usually does not share the same delusion, so the inflated estimate gets cut back down. One small decoupling, and the upward bias largely evaporates. It improves results across very nearly every environment, which is why "DQN" in practice almost always means "Double DQN".

Prioritized experience replay

Uniform sampling from the replay buffer treats every transition as equally worth learning from. But they are not equally informative -- a transition that surprised the agent (a big TD error) carries far more signal than one the network already predicts perfectly. Prioritized experience replay samples transitions in proportion to their TD error:

class PrioritizedReplayBuffer:
    """Replay buffer that samples transitions in proportion to their TD error."""
    def __init__(self, capacity=100000, alpha=0.6):
        self.capacity = capacity
        self.alpha = alpha                       # 0 = uniform, 1 = full prioritization
        self.buffer = []
        self.priorities = np.zeros(capacity)
        self.position = 0

    def push(self, state, action, reward, next_state, done):
        # new transitions enter at max priority so they get seen at least once
        max_priority = self.priorities[:len(self.buffer)].max() if self.buffer else 1.0
        if len(self.buffer) < self.capacity:
            self.buffer.append((state, action, reward, next_state, done))
        else:
            self.buffer[self.position] = (state, action, reward, next_state, done)
        self.priorities[self.position] = max_priority
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size, beta=0.4):
        n = len(self.buffer)
        probs = self.priorities[:n] ** self.alpha
        probs /= probs.sum()
        indices = np.random.choice(n, batch_size, p=probs)

        # importance-sampling weights correct for the non-uniform sampling
        weights = (n * probs[indices]) ** (-beta)
        weights /= weights.max()

        batch = [self.buffer[i] for i in indices]
        states, actions, rewards, next_states, dones = zip(*batch)
        return (
            torch.FloatTensor(np.array(states)),
            torch.LongTensor(actions),
            torch.FloatTensor(rewards),
            torch.FloatTensor(np.array(next_states)),
            torch.FloatTensor(dones),
            indices,
            torch.FloatTensor(weights),
        )

    def update_priorities(self, indices, td_errors):
        for idx, error in zip(indices, td_errors):
            self.priorities[idx] = abs(error) + 1e-6   # tiny constant so nothing hits zero

There is a subtle catch the code handles with those weights. Once you sample non-uniformly you have biased the data -- the surprising transitions are now over-represented. The importance-sampling weights (the same correction idea we met for off-policy Monte Carlo back in episode #105) scale each update back down to undo that bias. Skip them and your value estimates drift; include them and prioritized replay buys you a real speed-up, often roughly doubling the learning rate on Atari compared to uniform replay.

Dueling DQN

One last refinement, and this one touches the architecture in stead of the algorithm. Dueling DQN (Wang et al., 2016) splits the network into two streams: one estimating the state value V(s) (how good is it to be here at all?) and one estimating the advantage A(s, a) (how much better is each action than the average?).

class DuelingQNetwork(nn.Module):
    """Dueling architecture: separate value and advantage streams."""
    def __init__(self, state_dim, n_actions, hidden=128):
        super().__init__()
        self.feature = nn.Sequential(
            nn.Linear(state_dim, hidden),
            nn.ReLU(),
        )
        self.value_stream = nn.Sequential(       # how good is this state?
            nn.Linear(hidden, hidden),
            nn.ReLU(),
            nn.Linear(hidden, 1),
        )
        self.advantage_stream = nn.Sequential(   # how much better is each action?
            nn.Linear(hidden, hidden),
            nn.ReLU(),
            nn.Linear(hidden, n_actions),
        )

    def forward(self, state):
        features = self.feature(state)
        value = self.value_stream(features)            # (batch, 1)
        advantage = self.advantage_stream(features)    # (batch, n_actions)
        # recombine: subtract the mean advantage for identifiability
        return value + advantage - advantage.mean(dim=1, keepdim=True)

The intuition is that in quite some states it barely matters which action you pick -- the state is already wonderful or already doomed, and any action lands you in much the same place. The value stream can learn "this state is worth +8" once, cleanly, while the advantage stream only frets over the small differences between actions. Splitting the labour like this learns faster, especially when the action space is large and most actions are near-equivalent. That subtracted mean, by the way, is not cosmetic -- without it the split between V and A is ambiguous (you could add a constant to one and subtract it from the other), and pinning the advantages to zero-mean nails the decomposition down.

What DQN actually achieved

Let me put the 2015 result in perspective, because it is easy to be blase about it now. DeepMind trained one architecture, with one set of hyperparameters, on 49 different Atari games. The only input was raw pixels and the score counter -- no hand-crafted features, no game-specific rules, nobody whispering "the ball is the important bit". The same network that learned Breakout also learned Space Invaders, Pong, and Boxing. On 29 of those 49 games it reached or beat a professional human games tester.

That was the moment deep RL became undeniable: a single agent learning wildly different skills from raw sensory input alone. It lit the fuse for everything that followed -- AlphaGo in 2016, AlphaZero in 2017, and the human-feedback alignment machinery (episode #61) that now sits underneath every chat assistant you've used.

DQN does have a hard boundary, mind you. That max over actions, and the single Q-value-per-action output, both assume a small, discrete set of actions -- "left, right, fire". The instant the actions become continuous -- a steering angle, a joint torque, anything you cannot enumerate -- the argmax has nothing finite to range over and the whole approach stalls. Climbing over that wall means learning the policy directly in stead of squeezing it out of a value function, and that shift opens up the next big family of methods. But that, as ever, is a story for another episode ;-)

So, what do you know now?

  • Tabular Q-Learning collapses in large state spaces -- you can neither store nor sufficiently visit the states -- so DQN replaces the table with a neural network that generalises across similar states;
  • naive network-plus-Q-Learning diverges, sunk by correlated samples and a target that moves every gradient step;
  • experience replay stores transitions in a buffer and samples random minibatches, breaking those correlations (and recycling each transition for many updates);
  • a target network is a frozen copy of the Q-network used to compute TD targets, holding the target still long enough to converge against;
  • Double DQN decouples action selection (online network) from evaluation (target network), erasing the max-induced overestimation -- one of the highest value-for-effort tricks in all of RL;
  • prioritized replay samples surprising (high-TD-error) transitions more often, with importance-sampling weights to undo the resulting bias, often doubling learning speed;
  • the dueling architecture splits the network into a state-value stream and an advantage stream, learning faster when many actions are near-equivalent;
  • DQN launched the deep RL era by playing 49 Atari games from raw pixels with one shared network and one set of hyperparameters -- superhuman on 29 of them.

Exercises

Exercise 1: Wire DQNAgent up to CartPole-v1 (via gymnasium) and actually train it. Plot the per-episode reward and the running 100-episode average. Then run a small ablation: train once with the target network refreshing every 1,000 steps, and once where you copy the online weights into the target network every step (i.e. effectively no target network at all). Describe what happens to the reward curve in the second case, and connect it back to the "moving target" problem from this episode.

Exercise 2: Implement soft target updates (Polyak averaging) as an alternative to the periodic hard copy. In stead of target.load_state_dict(online.state_dict()) every C steps, blend the weights every step with theta_target = tau * theta_online + (1 - tau) * theta_target, using a small tau like 0.005. Train CartPole with both schemes and compare stability and final performance. Which one feels less twitchy, and why might a gradually drifting target be gentler than an occasional lurch?

Exercise 3: Add Double DQN to your working agent (swap learn for the learn_double_dqn logic) and measure the overestimation directly. During training, periodically log the mean predicted Q-value of the start state alongside the actual return the agent goes on to collect from there. Plot both curves for plain DQN and for Double DQN. You should see plain DQN's predicted Q sitting noticeably above the real returns, with Double DQN hugging them far more honestly.

Happy experimenting -- let me know below how your CartPole agent fares!

@scipio



0
0
0.000
0 comments