Learn AI Series (#107) - Deep Q-Networks (DQN)
Learn AI Series (#107) - Deep Q-Networks (DQN)

What will I learn
- You will learn why tabular Q-Learning falls apart the moment the state space gets big, and why a lookup table was never going to scale to pixels;
- function approximation: handing the job of estimating Q-values over to a neural network, so unseen states still get sensible numbers;
- experience replay: the memory buffer that breaks the temporal correlations which otherwise wreck training;
- target networks: freezing a second copy of the network to stop the targets running away from you;
- Double DQN: a one-line change that kills the overestimation bias baked into the plain
max; - and the extras that turned DQN from a clever trick into a workhorse -- prioritized replay and the dueling architecture.
Requirements
- A working modern computer running macOS, Windows or Ubuntu;
- An installed Python 3(.11+) distribution with NumPy and PyTorch;
- The ambition to learn AI and machine learning.
Difficulty
- Beginner
Curriculum (of the Learn AI Series):
- Learn AI Series (#1) - What Machine Learning Actually Is
- Learn AI Series (#2) - Setting Up Your AI Workbench - Python and NumPy
- Learn AI Series (#3) - Your Data Is Just Numbers - How Machines See the World
- Learn AI Series (#4) - Your First Prediction - No Math, Just Intuition
- Learn AI Series (#5) - Patterns in Data - What "Learning" Actually Looks Like
- Learn AI Series (#6) - From Intuition to Math - Why We Need Formulas
- Learn AI Series (#7) - The Training Loop - See It Work Step by Step
- Learn AI Series (#8) - The Math You Actually Need (Part 1) - Linear Algebra
- Learn AI Series (#9) - The Math You Actually Need (Part 2) - Calculus and Probability
- Learn AI Series (#10) - Your First ML Model - Linear Regression From Scratch
- Learn AI Series (#11) - Making Linear Regression Real
- Learn AI Series (#12) - Classification - Logistic Regression From Scratch
- Learn AI Series (#13) - Evaluation - How to Know If Your Model Actually Works
- Learn AI Series (#14) - Data Preparation - The 80% Nobody Talks About
- Learn AI Series (#15) - Feature Engineering and Selection
- Learn AI Series (#16) - Scikit-Learn - The Standard Library of ML
- Learn AI Series (#17) - Decision Trees - How Machines Make Decisions
- Learn AI Series (#18) - Random Forests - Wisdom of Crowds
- Learn AI Series (#19) - Gradient Boosting - The Kaggle Champion
- Learn AI Series (#20) - Support Vector Machines - Drawing the Perfect Boundary
- Learn AI Series (#21) - Mini Project - Predicting Crypto Market Regimes
- Learn AI Series (#22) - K-Means Clustering - Finding Groups
- Learn AI Series (#23) - Advanced Clustering - Beyond K-Means
- Learn AI Series (#24) - Dimensionality Reduction - PCA
- Learn AI Series (#25) - Advanced Dimensionality Reduction - t-SNE and UMAP
- Learn AI Series (#26) - Anomaly Detection - Finding What Doesn't Belong
- Learn AI Series (#27) - Recommendation Systems - "Users Like You Also Liked..."
- Learn AI Series (#28) - Time Series Fundamentals - When Order Matters
- Learn AI Series (#29) - Time Series Forecasting - Predicting What Comes Next
- Learn AI Series (#30) - Natural Language Processing - Text as Data
- Learn AI Series (#31) - Word Embeddings - Meaning in Numbers
- Learn AI Series (#32) - Bayesian Methods - Thinking in Probabilities
- Learn AI Series (#33) - Ensemble Methods Deep Dive - Stacking and Blending
- Learn AI Series (#34) - ML Engineering - From Notebook to Production
- Learn AI Series (#35) - Data Ethics and Bias in ML
- Learn AI Series (#36) - Mini Project - Complete ML Pipeline
- Learn AI Series (#37) - The Perceptron - Where It All Started
- Learn AI Series (#38) - Neural Networks From Scratch - Forward Pass
- Learn AI Series (#39) - Neural Networks From Scratch - Backpropagation
- Learn AI Series (#40) - Training Neural Networks - Practical Challenges
- Learn AI Series (#41) - Optimization Algorithms - SGD, Momentum, Adam
- Learn AI Series (#42) - PyTorch Fundamentals - Tensors and Autograd
- Learn AI Series (#43) - PyTorch Data and Training
- Learn AI Series (#44) - PyTorch nn.Module - Building Real Networks
- Learn AI Series (#45) - Convolutional Neural Networks - Theory
- Learn AI Series (#46) - CNNs in Practice - Classic to Modern Architectures
- Learn AI Series (#47) - CNN Applications - Detection, Segmentation, Style Transfer
- Learn AI Series (#48) - Recurrent Neural Networks - Sequences
- Learn AI Series (#49) - LSTM and GRU - Solving the Memory Problem
- Learn AI Series (#50) - Sequence-to-Sequence Models
- Learn AI Series (#51) - Attention Mechanisms
- Learn AI Series (#52) - The Transformer Architecture (Part 1)
- Learn AI Series (#53) - The Transformer Architecture (Part 2)
- Learn AI Series (#54) - Vision Transformers
- Learn AI Series (#55) - Generative Adversarial Networks
- Learn AI Series (#56) - Mini Project - Building a Transformer From Scratch
- Learn AI Series (#57) - Language Modeling - Predicting the Next Word
- Learn AI Series (#58) - GPT Architecture - Decoder-Only Transformers
- Learn AI Series (#59) - BERT and Encoder Models
- Learn AI Series (#60) - Training Large Language Models
- Learn AI Series (#61) - Instruction Tuning and Alignment
- Learn AI Series (#62) - Prompt Engineering - Getting the Most from LLMs
- Learn AI Series (#63) - Embeddings and Vector Search
- Learn AI Series (#64) - Retrieval-Augmented Generation (RAG) - Basics
- Learn AI Series (#65) - RAG - Advanced Techniques
- Learn AI Series (#66) - Working with LLM APIs
- Learn AI Series (#67) - Building AI Agents (Part 1) - Foundations
- Learn AI Series (#68) - Building AI Agents (Part 2) - Advanced Patterns
- Learn AI Series (#69) - Fine-Tuning Language Models
- Learn AI Series (#70) - Running Local Models
- Learn AI Series (#71) - Text Generation Techniques
- Learn AI Series (#72) - Tokenization Deep Dive
- Learn AI Series (#73) - LLM Evaluation
- Learn AI Series (#74) - The Hugging Face Ecosystem
- Learn AI Series (#75) - Multimodal Models - Text Meets Vision
- Learn AI Series (#76) - Mini Project - Your Own AI Assistant
- Learn AI Series (#77) - Image Processing Fundamentals
- Learn AI Series (#78) - Object Detection (Part 1) - Foundations
- Learn AI Series (#79) - Object Detection (Part 2) - Modern Approaches
- Learn AI Series (#80) - Image Segmentation
- Learn AI Series (#81) - Pose Estimation and Tracking
- Learn AI Series (#82) - Optical Character Recognition
- Learn AI Series (#83) - Video Understanding
- Learn AI Series (#84) - Generative Images - Diffusion Models (Part 1)
- Learn AI Series (#85) - Generative Images - Diffusion Models (Part 2)
- Learn AI Series (#86) - Image-to-Image and Editing
- Learn AI Series (#87) - 3D Vision
- Learn AI Series (#88) - Face Analysis
- Learn AI Series (#89) - Medical and Scientific Imaging
- Learn AI Series (#90) - Self-Supervised Learning for Vision
- Learn AI Series (#91) - Mini Project - Building a Visual AI System
- Learn AI Series (#92) - Audio Fundamentals for AI
- Learn AI Series (#93) - Speech Recognition
- Learn AI Series (#94) - Text-to-Speech (TTS)
- Learn AI Series (#95) - Audio Classification
- Learn AI Series (#96) - Music Generation
- Learn AI Series (#97) - Speaker Recognition and Diarization
- Learn AI Series (#98) - Natural Language Understanding for Voice
- Learn AI Series (#99) - Audio Enhancement
- Learn AI Series (#100) - Multimodal Audio-Visual Models
- Learn AI Series (#101) - Mini Project: Voice-Controlled AI Assistant
- Learn AI Series (#102) - What Is Reinforcement Learning?
- Learn AI Series (#103) - Multi-Armed Bandits
- Learn AI Series (#104) - Dynamic Programming
- Learn AI Series (#105) - Monte Carlo Methods
- Learn AI Series (#106) - Temporal Difference Learning
- Learn AI Series (#107) - Deep Q-Networks (DQN) (this post)
Learn AI Series (#107) - Deep Q-Networks (DQN)
Solutions to Episode #106 Exercises
Before we wheave the neural networks into the picture, let's settle last episode's three exercises. All of them lean on the CliffWalking environment and the n_step_td evaluator from episode #106, so I'm assuming those are imported and sitting in scope.
Exercise 1: Build a SARSA vs Q-Learning reward tracker on CliffWalking. Return the total reward per episode for both, train for 500 episodes, print a smoothed moving average, and explain why the "worse" learner ends up with the better optimal policy.
import numpy as np
from collections import defaultdict
# Assumes CliffWalking from episode #106.
def sarsa_tracked(env, n_episodes=500, alpha=0.5, gamma=1.0, epsilon=0.1):
Q = defaultdict(lambda: np.zeros(env.n_actions))
rewards_per_episode = []
def epsilon_greedy(state):
if np.random.random() < epsilon:
return np.random.randint(env.n_actions)
return int(np.argmax(Q[state]))
for _ in range(n_episodes):
state = env.reset()
action = epsilon_greedy(state)
total, done = 0.0, False
while not done:
next_state, reward, done = env.step(action)
next_action = epsilon_greedy(next_state)
target = reward + gamma * Q[next_state][next_action] * (1 - done)
Q[state][action] += alpha * (target - Q[state][action])
state, action = next_state, next_action
total += reward
rewards_per_episode.append(total)
return Q, rewards_per_episode
def q_learning_tracked(env, n_episodes=500, alpha=0.5, gamma=1.0, epsilon=0.1):
Q = defaultdict(lambda: np.zeros(env.n_actions))
rewards_per_episode = []
def epsilon_greedy(state):
if np.random.random() < epsilon:
return np.random.randint(env.n_actions)
return int(np.argmax(Q[state]))
for _ in range(n_episodes):
state = env.reset()
total, done = 0.0, False
while not done:
action = epsilon_greedy(state)
next_state, reward, done = env.step(action)
target = reward + gamma * np.max(Q[next_state]) * (1 - done)
Q[state][action] += alpha * (target - Q[state][action])
state = next_state
total += reward
rewards_per_episode.append(total)
return Q, rewards_per_episode
def moving_average(x, window=10):
return np.convolve(x, np.ones(window) / window, mode="valid")
env = CliffWalking()
_, r_sarsa = sarsa_tracked(env)
_, r_qlearn = q_learning_tracked(env)
ma_s, ma_q = moving_average(r_sarsa), moving_average(r_qlearn)
print(f"{'episode':>8}{'SARSA':>10}{'Q-Learning':>12}")
for ep in range(0, len(ma_s), 50):
print(f"{ep:>8}{ma_s[ep]:>10.1f}{ma_q[ep]:>12.1f}")
Read off the smoothed curves and Q-Learning sits lower for the whole run -- its reward-per-episode is consistently worse than SARSA's. That looks like a damning verdict until you remember what each one is actually optimising. Q-Learning's greedy target evaluates the optimal path right along the cliff edge, so during training epsilon keeps shoving it over the edge for -100 a pop, dragging the curve down. SARSA bakes that exploration risk into its values and retreats to the safe ledge one row up, so it collects more reward while learning. The punchline: Q-Learning's greedy policy is genuinely shorter and optimal -- it just pays for that knowledge with a rougher training ride. The on-line score and the quality of the final greedy policy are two different questions, and here they point in opposite directions.
Exercise 2: Implement Expected SARSA -- bootstrap from the expectation over next actions under the epsilon-greedy policy, in stead of the sampled next action (SARSA) or the max (Q-Learning).
import numpy as np
from collections import defaultdict
# Assumes CliffWalking from episode #106.
def expected_sarsa(env, n_episodes=500, alpha=0.5, gamma=1.0, epsilon=0.1):
Q = defaultdict(lambda: np.zeros(env.n_actions))
def epsilon_greedy(state):
if np.random.random() < epsilon:
return np.random.randint(env.n_actions)
return int(np.argmax(Q[state]))
for _ in range(n_episodes):
state, done = env.reset(), False
while not done:
action = epsilon_greedy(state)
next_state, reward, done = env.step(action)
# expected Q under epsilon-greedy:
# (1 - eps) on the greedy action, eps spread over all actions
q_next = Q[next_state]
expected_q = (1 - epsilon) * np.max(q_next) + epsilon * np.mean(q_next)
target = reward + gamma * expected_q * (1 - done)
Q[state][action] += alpha * (target - Q[state][action])
state = next_state
return dict(Q)
The trick is that (1 - epsilon) * max + epsilon * mean is exactly the expected value of Q[next_state] when the next action is drawn from epsilon-greedy (the greedy action carries the lion's share 1 - epsilon, and the leftover epsilon is smeared uniformly across all actions). Because we average over next actions in stead of sampling one, Expected SARSA removes the variance contributed by that random choice. On cliff walking it learns the same cautious, away-from-the-edge path as SARSA -- it is still on-policy -- but the updates are noticeably smoother, which in practice lets you crank alpha higher without the whole thing wobbling. It is the quiet middle child between SARSA and Q-Learning, and frankly it is underused ;-)
Exercise 3: Take the n_step_td evaluator and run it for n in [1, 2, 4, 8, 16, 32] on a random walk, measuring the error of the learned values against a reference. You should see a U-shape -- some intermediate n beats both TD(0) and full Monte Carlo.
import numpy as np
from collections import defaultdict
import random
# Assumes n_step_td from episode #106.
class RandomWalk:
"""19-state random walk. States 1..19, terminals 0 and 20.
Equal-probability left/right step. +1 at the right end, -1 at the left."""
def __init__(self, n=19):
self.n = n
self.start = (n + 1) // 2
def reset(self):
self.state = self.start
return self.state
def step(self, action): # action ignored -- the walk is random
self.state += 1 if random.random() < 0.5 else -1
if self.state == 0:
return self.state, -1.0, True
if self.state == self.n + 1:
return self.state, 1.0, True
return self.state, 0.0, False
def random_policy(state):
return 0
# True values are linear from -1 to +1 across the chain (analytic reference).
true_V = {s: (s - 10) / 10.0 for s in range(1, 20)}
env = RandomWalk(19)
print(f"{'n':>4}{'RMS error':>12}")
for n in [1, 2, 4, 8, 16, 32]:
errors = []
for _ in range(20): # average over seeds to smooth the curve
V = n_step_td(env, random_policy, n_steps=n, n_episodes=10, alpha=0.1)
sq = [(V[s] - true_V[s]) ** 2 for s in range(1, 20)]
errors.append(np.sqrt(np.mean(sq)))
print(f"{n:>4}{np.mean(errors):>12.4f}")
Run it and the U-shape jumps right out of the column of numbers. At n = 1 you have plain TD(0): low variance, but it bootstraps so aggressively that with only ten episodes the signal barely propagates back along the chain. At n = 32 you have something close to Monte Carlo: unbiased, but every estimate is now hostage to one long noisy return. Somewhere around n = 4 to n = 8 the two failure modes cancel and the error bottoms out. This is the exact same picture as the famous random-walk figure in Sutton and Barto, reproduced on your own machine in twenty lines -- the dial between TD and MC is real, and the sweet spot lives in the middle.
On to today's episode
Right, episode 107 -- and this is the one where reinforcement learning stops being a toy and starts being the thing you read about in the news.
At the very end of episode #106 I left a deliberate bait dangling: Q-Learning is magnificent, but it leans on a humble Q dictionary, and the moment you swap that dictionary for a neural network "a whole box of new problems" springs open. Today we open the box. The payoff is enormous -- this is the algorithm that taught a computer to play Atari games straight from the pixels, the result that kicked off the entire deep reinforcement learning era. Having said that, let's first understand why the dictionary had to go.
Where the Q-table breaks
In episode #106 Q-Learning stored a value for every (state, action) pair in a table. That is perfectly fine when the world is small -- a 4x4 grid has 16 states, our Blackjack from episode #105 has a few hundred. You can visit each one thousands of times and pin down its value by averaging.
Now consider an Atari frame: 210x160 pixels, 128 possible colours each. The number of distinct screens is not "large", it is astronomical -- more states than there are atoms in quite some galaxies. Two problems hit you at once, and both are fatal:
- You could never store a table that big.
- You could never visit each state often enough to estimate its value, since you'll essentially never see the exact same screen twice.
The fix is the same idea that has powered most of this series since episode #38: replace the lookup table with a function approximator. In stead of looking up Q(s, a) in a dictionary, we feed the state s into a neural network and let it output Q-values for every action. Crucially, the network generalises -- a screen it has never seen before still gets a sensible answer, because the network has learned features (paddles, balls, enemies) that transfer across similar screens. This is the Deep Q-Network, the DQN, published by DeepMind in their 2015 Nature paper.
From table to network
A small fully-connected network is plenty for low-dimensional state vectors (think the four numbers describing a balancing pole):
import torch
import torch.nn as nn
import numpy as np
from collections import deque
import random
class QNetwork(nn.Module):
"""Neural network that approximates Q(s, a) for all actions at once."""
def __init__(self, state_dim, n_actions, hidden=128):
super().__init__()
self.net = nn.Sequential(
nn.Linear(state_dim, hidden),
nn.ReLU(),
nn.Linear(hidden, hidden),
nn.ReLU(),
nn.Linear(hidden, n_actions), # one output per action
)
def forward(self, state):
return self.net(state) # returns Q-values for all actions
Notice the shape of the output: one Q-value per action, computed in a single forward pass. That is a deliberate design choice -- you get all the action values for a state in one shot, in stead of running the network once per action. For pixels, DeepMind used the convolutional stack we built up over episodes #45 and #46:
class AtariQNetwork(nn.Module):
"""CNN Q-network for raw pixel observations."""
def __init__(self, n_actions):
super().__init__()
# Input: 4 stacked grayscale frames, each 84x84.
self.conv = nn.Sequential(
nn.Conv2d(4, 32, 8, stride=4), nn.ReLU(),
nn.Conv2d(32, 64, 4, stride=2), nn.ReLU(),
nn.Conv2d(64, 64, 3, stride=1), nn.ReLU(),
)
self.fc = nn.Sequential(
nn.Linear(64 * 7 * 7, 512),
nn.ReLU(),
nn.Linear(512, n_actions),
)
def forward(self, x):
# x: (batch, 4, 84, 84), pixel values scaled to 0-1
x = self.conv(x)
x = x.view(x.size(0), -1)
return self.fc(x)
Why four stacked frames in stead of one? Because a single snapshot has no notion of motion -- from one still frame you cannot tell whether the ball in Pong is heading left or right. Stack four consecutive frames and the network can read velocity straight off the input. (A neat trick, and far cheaper than wiring in a recurrent network like the LSTMs from episode #49.)
Why the naive version blows up
Here is the trap, and it catches everyone the first time. Take Q-Learning, swap the dictionary for QNetwork, run the same update -- and training diverges. Not "learns slowly". Diverges, spectacularly, into nonsense. Two distinct problems gang up on you:
Problem 1 -- correlated samples. In on-line Q-Learning the agent feeds the network a stream of consecutive transitions (s, a, r, s'), and consecutive transitions are heavily correlated (the agent is walking through connected states). Neural networks hate correlated training data -- it is exactly like training an image classifier where every cat photo comes first and every dog photo comes after. The network overfits to the recent stretch and "forgets" what it learned ten seconds ago.
Problem 2 -- moving targets. The Q-Learning target is r + gamma * max Q(s', a'). But Q is the very network you are updating. Every gradient step shifts the weights, which shifts all the targets, which means you are chasing a target that jumps every time you shoot at it. The network ends up endlessly pursuing its own tail.
DeepMind's two famous innovations each kill one of these problems dead: experience replay for the correlation, and a target network for the moving target. Let's take them in turn.
Experience replay
In stead of learning from each transition once and throwing it away, we stash every transition in a fixed-size replay buffer and train on random minibatches drawn from it:
class ReplayBuffer:
"""Fixed-size buffer that stores transitions and samples minibatches."""
def __init__(self, capacity=100000):
self.buffer = deque(maxlen=capacity)
def push(self, state, action, reward, next_state, done):
self.buffer.append((state, action, reward, next_state, done))
def sample(self, batch_size=32):
batch = random.sample(self.buffer, batch_size)
states, actions, rewards, next_states, dones = zip(*batch)
return (
torch.FloatTensor(np.array(states)),
torch.LongTensor(actions),
torch.FloatTensor(rewards),
torch.FloatTensor(np.array(next_states)),
torch.FloatTensor(dones),
)
def __len__(self):
return len(self.buffer)
Random sampling shatters the temporal correlation -- one minibatch might mix a transition from episode 5, another from episode 200, another from episode 47. The network now sees a diverse, decorrelated dataset, which is precisely the i.i.d.-ish input that gradient descent was designed for. Nota bene: there is a lovely bonus here too -- each stored transition can be replayed for many gradient updates, so DQN squeezes far more learning out of every interaction with the environment (a big deal when each interaction is an expensive game step).
The target network
To pin down the moving target, DQN keeps two copies of the Q-network:
- the online network -- the one updated every single step;
- the target network -- a frozen snapshot, refreshed only occasionally.
The TD target is computed with the target network, not the online one:
# target = r + gamma * max_a' Q_target(s', a') <- frozen target network
# loss = ( target - Q_online(s, a) )^2
Because the target network's weights stand still between refreshes, the targets stay put long enough for the online network to actually converge toward them. Every C steps (typically somewhere in the 1,000 to 10,000 range) you copy the online weights into the target network, the target lurches forward to the new estimate, and the chase resumes -- but now in stable, deliberate hops in stead of a continuous wild scramble.
The complete DQN agent
Bolt the pieces together -- online network, target network, replay buffer, epsilon-greedy exploration (straight from the bandit toolkit in episode #103) -- and you get the full agent:
class DQNAgent:
"""Complete DQN agent: replay buffer + target network + epsilon-greedy."""
def __init__(self, state_dim, n_actions, lr=1e-3, gamma=0.99,
epsilon_start=1.0, epsilon_end=0.01, epsilon_decay=0.995,
buffer_size=100000, batch_size=64, target_update=1000):
self.n_actions = n_actions
self.gamma = gamma
self.batch_size = batch_size
self.target_update = target_update
self.epsilon = epsilon_start
self.epsilon_end = epsilon_end
self.epsilon_decay = epsilon_decay
self.step_count = 0
# online and target networks start out identical
self.q_network = QNetwork(state_dim, n_actions)
self.target_network = QNetwork(state_dim, n_actions)
self.target_network.load_state_dict(self.q_network.state_dict())
self.optimizer = torch.optim.Adam(self.q_network.parameters(), lr=lr)
self.buffer = ReplayBuffer(buffer_size)
def choose_action(self, state):
"""Epsilon-greedy over the online network's Q-values."""
if random.random() < self.epsilon:
return random.randint(0, self.n_actions - 1)
with torch.no_grad():
state_t = torch.FloatTensor(state).unsqueeze(0)
return self.q_network(state_t).argmax(dim=1).item()
def store_transition(self, state, action, reward, next_state, done):
self.buffer.push(state, action, reward, next_state, done)
def learn(self):
"""One gradient step on a sampled minibatch."""
if len(self.buffer) < self.batch_size:
return None
states, actions, rewards, next_states, dones = self.buffer.sample(self.batch_size)
# Q-values the online network assigns to the actions actually taken
q_values = self.q_network(states)
q_selected = q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
# TD target, using the FROZEN target network
with torch.no_grad():
max_next_q = self.target_network(next_states).max(dim=1)[0]
targets = rewards + self.gamma * max_next_q * (1 - dones)
loss = nn.functional.mse_loss(q_selected, targets)
self.optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(self.q_network.parameters(), max_norm=1.0) # stability
self.optimizer.step()
# periodically refresh the target network
self.step_count += 1
if self.step_count % self.target_update == 0:
self.target_network.load_state_dict(self.q_network.state_dict())
# decay exploration over time
self.epsilon = max(self.epsilon_end, self.epsilon * self.epsilon_decay)
return loss.item()
The gather call is the one bit worth squinting at: the network spits out a Q-value for every action, but the loss should only touch the action we actually took, so gather plucks out that single column per row. Everything else is the Q-Learning update from episode #106 wearing a PyTorch coat.
The training loop itself is almost boringly familiar -- it is the same act-store-learn rhythm, just with a replay buffer humming in the background:
def train_dqn(env, agent, n_episodes=1000):
"""Train a DQN agent on a Gym-style environment."""
rewards_history = []
for episode in range(n_episodes):
state = env.reset()
total_reward, done = 0.0, False
while not done:
action = agent.choose_action(state)
next_state, reward, done, _ = env.step(action)
agent.store_transition(state, action, reward, next_state, done)
agent.learn()
total_reward += reward
state = next_state
rewards_history.append(total_reward)
if episode % 100 == 0:
avg = np.mean(rewards_history[-100:])
print(f"Episode {episode} | avg reward {avg:6.1f} | "
f"epsilon {agent.epsilon:.3f} | buffer {len(agent.buffer)}")
return rewards_history
Point this at CartPole-v1 and within a few hundred episodes it learns to balance the pole indefinitely. That is genuinely the same code -- minus the CNN -- that conquered Atari. Wowzers.
Double DQN: fixing the overestimation
Plain DQN has a sneaky flaw: it systematically overestimates Q-values. The culprit is the max in the target r + gamma * max Q(s', a'). Your Q-estimates always carry some noise (they are network outputs, of course they do), and max does not pick the genuinely best action -- it picks the action whose noise happens to point highest. Take the max of noisy numbers and you get a number biased upward, every single time. Over thousands of updates that optimism compounds into a real problem.
Double DQN (van Hasselt et al., 2016) fixes it with a delightfully cheap idea: decouple choosing the action from evaluating it.
def learn_double_dqn(self):
"""Double DQN: online network SELECTS the action, target network EVALUATES it."""
if len(self.buffer) < self.batch_size:
return None
states, actions, rewards, next_states, dones = self.buffer.sample(self.batch_size)
q_selected = self.q_network(states).gather(1, actions.unsqueeze(1)).squeeze(1)
with torch.no_grad():
# online network picks which action it thinks is best
best_actions = self.q_network(next_states).argmax(dim=1, keepdim=True)
# target network reports the value of THAT action
max_next_q = self.target_network(next_states).gather(1, best_actions).squeeze(1)
targets = rewards + self.gamma * max_next_q * (1 - dones)
loss = nn.functional.mse_loss(q_selected, targets)
self.optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(self.q_network.parameters(), max_norm=1.0)
self.optimizer.step()
self.step_count += 1
if self.step_count % self.target_update == 0:
self.target_network.load_state_dict(self.q_network.state_dict())
self.epsilon = max(self.epsilon_end, self.epsilon * self.epsilon_decay)
return loss.item()
The whole change lives in two lines: the online network chooses the action (argmax), the target network grades it (gather). If the online network's noise oversells some action, the target network -- a different set of weights, with different noise -- usually does not share the same delusion, so the inflated estimate gets cut back down. One small decoupling, and the upward bias largely evaporates. It improves results across very nearly every environment, which is why "DQN" in practice almost always means "Double DQN".
Prioritized experience replay
Uniform sampling from the replay buffer treats every transition as equally worth learning from. But they are not equally informative -- a transition that surprised the agent (a big TD error) carries far more signal than one the network already predicts perfectly. Prioritized experience replay samples transitions in proportion to their TD error:
class PrioritizedReplayBuffer:
"""Replay buffer that samples transitions in proportion to their TD error."""
def __init__(self, capacity=100000, alpha=0.6):
self.capacity = capacity
self.alpha = alpha # 0 = uniform, 1 = full prioritization
self.buffer = []
self.priorities = np.zeros(capacity)
self.position = 0
def push(self, state, action, reward, next_state, done):
# new transitions enter at max priority so they get seen at least once
max_priority = self.priorities[:len(self.buffer)].max() if self.buffer else 1.0
if len(self.buffer) < self.capacity:
self.buffer.append((state, action, reward, next_state, done))
else:
self.buffer[self.position] = (state, action, reward, next_state, done)
self.priorities[self.position] = max_priority
self.position = (self.position + 1) % self.capacity
def sample(self, batch_size, beta=0.4):
n = len(self.buffer)
probs = self.priorities[:n] ** self.alpha
probs /= probs.sum()
indices = np.random.choice(n, batch_size, p=probs)
# importance-sampling weights correct for the non-uniform sampling
weights = (n * probs[indices]) ** (-beta)
weights /= weights.max()
batch = [self.buffer[i] for i in indices]
states, actions, rewards, next_states, dones = zip(*batch)
return (
torch.FloatTensor(np.array(states)),
torch.LongTensor(actions),
torch.FloatTensor(rewards),
torch.FloatTensor(np.array(next_states)),
torch.FloatTensor(dones),
indices,
torch.FloatTensor(weights),
)
def update_priorities(self, indices, td_errors):
for idx, error in zip(indices, td_errors):
self.priorities[idx] = abs(error) + 1e-6 # tiny constant so nothing hits zero
There is a subtle catch the code handles with those weights. Once you sample non-uniformly you have biased the data -- the surprising transitions are now over-represented. The importance-sampling weights (the same correction idea we met for off-policy Monte Carlo back in episode #105) scale each update back down to undo that bias. Skip them and your value estimates drift; include them and prioritized replay buys you a real speed-up, often roughly doubling the learning rate on Atari compared to uniform replay.
Dueling DQN
One last refinement, and this one touches the architecture in stead of the algorithm. Dueling DQN (Wang et al., 2016) splits the network into two streams: one estimating the state value V(s) (how good is it to be here at all?) and one estimating the advantage A(s, a) (how much better is each action than the average?).
class DuelingQNetwork(nn.Module):
"""Dueling architecture: separate value and advantage streams."""
def __init__(self, state_dim, n_actions, hidden=128):
super().__init__()
self.feature = nn.Sequential(
nn.Linear(state_dim, hidden),
nn.ReLU(),
)
self.value_stream = nn.Sequential( # how good is this state?
nn.Linear(hidden, hidden),
nn.ReLU(),
nn.Linear(hidden, 1),
)
self.advantage_stream = nn.Sequential( # how much better is each action?
nn.Linear(hidden, hidden),
nn.ReLU(),
nn.Linear(hidden, n_actions),
)
def forward(self, state):
features = self.feature(state)
value = self.value_stream(features) # (batch, 1)
advantage = self.advantage_stream(features) # (batch, n_actions)
# recombine: subtract the mean advantage for identifiability
return value + advantage - advantage.mean(dim=1, keepdim=True)
The intuition is that in quite some states it barely matters which action you pick -- the state is already wonderful or already doomed, and any action lands you in much the same place. The value stream can learn "this state is worth +8" once, cleanly, while the advantage stream only frets over the small differences between actions. Splitting the labour like this learns faster, especially when the action space is large and most actions are near-equivalent. That subtracted mean, by the way, is not cosmetic -- without it the split between V and A is ambiguous (you could add a constant to one and subtract it from the other), and pinning the advantages to zero-mean nails the decomposition down.
What DQN actually achieved
Let me put the 2015 result in perspective, because it is easy to be blase about it now. DeepMind trained one architecture, with one set of hyperparameters, on 49 different Atari games. The only input was raw pixels and the score counter -- no hand-crafted features, no game-specific rules, nobody whispering "the ball is the important bit". The same network that learned Breakout also learned Space Invaders, Pong, and Boxing. On 29 of those 49 games it reached or beat a professional human games tester.
That was the moment deep RL became undeniable: a single agent learning wildly different skills from raw sensory input alone. It lit the fuse for everything that followed -- AlphaGo in 2016, AlphaZero in 2017, and the human-feedback alignment machinery (episode #61) that now sits underneath every chat assistant you've used.
DQN does have a hard boundary, mind you. That max over actions, and the single Q-value-per-action output, both assume a small, discrete set of actions -- "left, right, fire". The instant the actions become continuous -- a steering angle, a joint torque, anything you cannot enumerate -- the argmax has nothing finite to range over and the whole approach stalls. Climbing over that wall means learning the policy directly in stead of squeezing it out of a value function, and that shift opens up the next big family of methods. But that, as ever, is a story for another episode ;-)
So, what do you know now?
- Tabular Q-Learning collapses in large state spaces -- you can neither store nor sufficiently visit the states -- so DQN replaces the table with a neural network that generalises across similar states;
- naive network-plus-Q-Learning diverges, sunk by correlated samples and a target that moves every gradient step;
- experience replay stores transitions in a buffer and samples random minibatches, breaking those correlations (and recycling each transition for many updates);
- a target network is a frozen copy of the Q-network used to compute TD targets, holding the target still long enough to converge against;
- Double DQN decouples action selection (online network) from evaluation (target network), erasing the
max-induced overestimation -- one of the highest value-for-effort tricks in all of RL; - prioritized replay samples surprising (high-TD-error) transitions more often, with importance-sampling weights to undo the resulting bias, often doubling learning speed;
- the dueling architecture splits the network into a state-value stream and an advantage stream, learning faster when many actions are near-equivalent;
- DQN launched the deep RL era by playing 49 Atari games from raw pixels with one shared network and one set of hyperparameters -- superhuman on 29 of them.
Exercises
Exercise 1: Wire DQNAgent up to CartPole-v1 (via gymnasium) and actually train it. Plot the per-episode reward and the running 100-episode average. Then run a small ablation: train once with the target network refreshing every 1,000 steps, and once where you copy the online weights into the target network every step (i.e. effectively no target network at all). Describe what happens to the reward curve in the second case, and connect it back to the "moving target" problem from this episode.
Exercise 2: Implement soft target updates (Polyak averaging) as an alternative to the periodic hard copy. In stead of target.load_state_dict(online.state_dict()) every C steps, blend the weights every step with theta_target = tau * theta_online + (1 - tau) * theta_target, using a small tau like 0.005. Train CartPole with both schemes and compare stability and final performance. Which one feels less twitchy, and why might a gradually drifting target be gentler than an occasional lurch?
Exercise 3: Add Double DQN to your working agent (swap learn for the learn_double_dqn logic) and measure the overestimation directly. During training, periodically log the mean predicted Q-value of the start state alongside the actual return the agent goes on to collect from there. Plot both curves for plain DQN and for Double DQN. You should see plain DQN's predicted Q sitting noticeably above the real returns, with Double DQN hugging them far more honestly.