Learn AI Series (#108) - Policy Gradient Methods
Learn AI Series (#108) - Policy Gradient Methods

What will I learn
- You will learn why directly optimizing the policy is sometimes a far better idea than learning a value function first;
- the policy gradient theorem -- the one equation that makes the whole family of methods tick;
- REINFORCE -- the simplest possible policy gradient algorithm, in about thirty lines;
- why REINFORCE is so painfully noisy, and how a baseline tames that noise without lying to you;
- Actor-Critic methods, where a value function and a policy learn side by side;
- and A2C, the Advantage Actor-Critic that ties it all together.
Requirements
- A working modern computer running macOS, Windows or Ubuntu;
- An installed Python 3(.11+) distribution with NumPy and PyTorch;
- The ambition to learn AI and machine learning.
Difficulty
- Beginner
Curriculum (of the Learn AI Series):
- Learn AI Series (#1) - What Machine Learning Actually Is
- Learn AI Series (#2) - Setting Up Your AI Workbench - Python and NumPy
- Learn AI Series (#3) - Your Data Is Just Numbers - How Machines See the World
- Learn AI Series (#4) - Your First Prediction - No Math, Just Intuition
- Learn AI Series (#5) - Patterns in Data - What "Learning" Actually Looks Like
- Learn AI Series (#6) - From Intuition to Math - Why We Need Formulas
- Learn AI Series (#7) - The Training Loop - See It Work Step by Step
- Learn AI Series (#8) - The Math You Actually Need (Part 1) - Linear Algebra
- Learn AI Series (#9) - The Math You Actually Need (Part 2) - Calculus and Probability
- Learn AI Series (#10) - Your First ML Model - Linear Regression From Scratch
- Learn AI Series (#11) - Making Linear Regression Real
- Learn AI Series (#12) - Classification - Logistic Regression From Scratch
- Learn AI Series (#13) - Evaluation - How to Know If Your Model Actually Works
- Learn AI Series (#14) - Data Preparation - The 80% Nobody Talks About
- Learn AI Series (#15) - Feature Engineering and Selection
- Learn AI Series (#16) - Scikit-Learn - The Standard Library of ML
- Learn AI Series (#17) - Decision Trees - How Machines Make Decisions
- Learn AI Series (#18) - Random Forests - Wisdom of Crowds
- Learn AI Series (#19) - Gradient Boosting - The Kaggle Champion
- Learn AI Series (#20) - Support Vector Machines - Drawing the Perfect Boundary
- Learn AI Series (#21) - Mini Project - Predicting Crypto Market Regimes
- Learn AI Series (#22) - K-Means Clustering - Finding Groups
- Learn AI Series (#23) - Advanced Clustering - Beyond K-Means
- Learn AI Series (#24) - Dimensionality Reduction - PCA
- Learn AI Series (#25) - Advanced Dimensionality Reduction - t-SNE and UMAP
- Learn AI Series (#26) - Anomaly Detection - Finding What Doesn't Belong
- Learn AI Series (#27) - Recommendation Systems - "Users Like You Also Liked..."
- Learn AI Series (#28) - Time Series Fundamentals - When Order Matters
- Learn AI Series (#29) - Time Series Forecasting - Predicting What Comes Next
- Learn AI Series (#30) - Natural Language Processing - Text as Data
- Learn AI Series (#31) - Word Embeddings - Meaning in Numbers
- Learn AI Series (#32) - Bayesian Methods - Thinking in Probabilities
- Learn AI Series (#33) - Ensemble Methods Deep Dive - Stacking and Blending
- Learn AI Series (#34) - ML Engineering - From Notebook to Production
- Learn AI Series (#35) - Data Ethics and Bias in ML
- Learn AI Series (#36) - Mini Project - Complete ML Pipeline
- Learn AI Series (#37) - The Perceptron - Where It All Started
- Learn AI Series (#38) - Neural Networks From Scratch - Forward Pass
- Learn AI Series (#39) - Neural Networks From Scratch - Backpropagation
- Learn AI Series (#40) - Training Neural Networks - Practical Challenges
- Learn AI Series (#41) - Optimization Algorithms - SGD, Momentum, Adam
- Learn AI Series (#42) - PyTorch Fundamentals - Tensors and Autograd
- Learn AI Series (#43) - PyTorch Data and Training
- Learn AI Series (#44) - PyTorch nn.Module - Building Real Networks
- Learn AI Series (#45) - Convolutional Neural Networks - Theory
- Learn AI Series (#46) - CNNs in Practice - Classic to Modern Architectures
- Learn AI Series (#47) - CNN Applications - Detection, Segmentation, Style Transfer
- Learn AI Series (#48) - Recurrent Neural Networks - Sequences
- Learn AI Series (#49) - LSTM and GRU - Solving the Memory Problem
- Learn AI Series (#50) - Sequence-to-Sequence Models
- Learn AI Series (#51) - Attention Mechanisms
- Learn AI Series (#52) - The Transformer Architecture (Part 1)
- Learn AI Series (#53) - The Transformer Architecture (Part 2)
- Learn AI Series (#54) - Vision Transformers
- Learn AI Series (#55) - Generative Adversarial Networks
- Learn AI Series (#56) - Mini Project - Building a Transformer From Scratch
- Learn AI Series (#57) - Language Modeling - Predicting the Next Word
- Learn AI Series (#58) - GPT Architecture - Decoder-Only Transformers
- Learn AI Series (#59) - BERT and Encoder Models
- Learn AI Series (#60) - Training Large Language Models
- Learn AI Series (#61) - Instruction Tuning and Alignment
- Learn AI Series (#62) - Prompt Engineering - Getting the Most from LLMs
- Learn AI Series (#63) - Embeddings and Vector Search
- Learn AI Series (#64) - Retrieval-Augmented Generation (RAG) - Basics
- Learn AI Series (#65) - RAG - Advanced Techniques
- Learn AI Series (#66) - Working with LLM APIs
- Learn AI Series (#67) - Building AI Agents (Part 1) - Foundations
- Learn AI Series (#68) - Building AI Agents (Part 2) - Advanced Patterns
- Learn AI Series (#69) - Fine-Tuning Language Models
- Learn AI Series (#70) - Running Local Models
- Learn AI Series (#71) - Text Generation Techniques
- Learn AI Series (#72) - Tokenization Deep Dive
- Learn AI Series (#73) - LLM Evaluation
- Learn AI Series (#74) - The Hugging Face Ecosystem
- Learn AI Series (#75) - Multimodal Models - Text Meets Vision
- Learn AI Series (#76) - Mini Project - Your Own AI Assistant
- Learn AI Series (#77) - Image Processing Fundamentals
- Learn AI Series (#78) - Object Detection (Part 1) - Foundations
- Learn AI Series (#79) - Object Detection (Part 2) - Modern Approaches
- Learn AI Series (#80) - Image Segmentation
- Learn AI Series (#81) - Pose Estimation and Tracking
- Learn AI Series (#82) - Optical Character Recognition
- Learn AI Series (#83) - Video Understanding
- Learn AI Series (#84) - Generative Images - Diffusion Models (Part 1)
- Learn AI Series (#85) - Generative Images - Diffusion Models (Part 2)
- Learn AI Series (#86) - Image-to-Image and Editing
- Learn AI Series (#87) - 3D Vision
- Learn AI Series (#88) - Face Analysis
- Learn AI Series (#89) - Medical and Scientific Imaging
- Learn AI Series (#90) - Self-Supervised Learning for Vision
- Learn AI Series (#91) - Mini Project - Building a Visual AI System
- Learn AI Series (#92) - Audio Fundamentals for AI
- Learn AI Series (#93) - Speech Recognition
- Learn AI Series (#94) - Text-to-Speech (TTS)
- Learn AI Series (#95) - Audio Classification
- Learn AI Series (#96) - Music Generation
- Learn AI Series (#97) - Speaker Recognition and Diarization
- Learn AI Series (#98) - Natural Language Understanding for Voice
- Learn AI Series (#99) - Audio Enhancement
- Learn AI Series (#100) - Multimodal Audio-Visual Models
- Learn AI Series (#101) - Mini Project: Voice-Controlled AI Assistant
- Learn AI Series (#102) - What Is Reinforcement Learning?
- Learn AI Series (#103) - Multi-Armed Bandits
- Learn AI Series (#104) - Dynamic Programming
- Learn AI Series (#105) - Monte Carlo Methods
- Learn AI Series (#106) - Temporal Difference Learning
- Learn AI Series (#107) - Deep Q-Networks (DQN)
- Learn AI Series (#108) - Policy Gradient Methods (this post)
Learn AI Series (#108) - Policy Gradient Methods
Solutions to Episode #107 Exercises
Before we tear the value function down and rebuild from the other direction, let's clear last episode's three exercises. All of them build on the DQNAgent from episode #107, so I'm assuming that class (with its replay buffer and target network) is imported and sitting in scope. I'm also leaning on gymnasium -- pip install gymnasium if you haven't already.
Exercise 1: Wire DQNAgent up to CartPole-v1 and train it. Plot the per-episode reward and the running 100-episode average. Then ablate the target network: train once with a refresh every 1,000 steps, and once where you copy the online weights into the target every step (effectively no target network), and explain the second curve.
import gymnasium as gym
import numpy as np
# Assumes DQNAgent and train_dqn from episode #107.
def run_cartpole(target_update, n_episodes=400):
env = gym.make("CartPole-v1")
state_dim = env.observation_space.shape[0] # 4 numbers
n_actions = env.action_space.n # left or right
agent = DQNAgent(state_dim, n_actions, target_update=target_update)
rewards = []
for ep in range(n_episodes):
state, _ = env.reset()
total, done, trunc = 0.0, False, False
while not (done or trunc):
action = agent.choose_action(state)
next_state, reward, done, trunc, _ = env.step(action)
agent.store_transition(state, action, reward, next_state,
float(done))
agent.learn()
state = next_state
total += reward
rewards.append(total)
return rewards
def moving_average(x, window=100):
return np.convolve(x, np.ones(window) / window, mode="valid")
stable = run_cartpole(target_update=1000) # proper frozen target
broken = run_cartpole(target_update=1) # target == online, every step
print(f"{'metric':>16}{'C=1000':>10}{'C=1':>10}")
print(f"{'final avg-100':>16}"
f"{np.mean(stable[-100:]):>10.1f}{np.mean(broken[-100:]):>10.1f}")
The C=1000 run climbs steadily toward CartPole's ceiling of 500 and parks there. The C=1 run is the instructive disaster: with the target network copied every step, the target r + gamma * max Q(s', a') is computed from the same weights you are nudging with that very gradient step. So the target shifts the instant you move toward it -- the exact "moving target" problem from this episode, only now with the brakes off. You'll see the reward curve spike encouragingly and then collapse, sometimes oscillating wildly, sometimes diverging outright. It is a near-perfect demonstration of why DeepMind bothered freezing a second network in the first place ;-)
Exercise 2: Implement soft target updates (Polyak averaging). In stead of a hard copy every C steps, blend the weights every step with theta_target = tau * theta_online + (1 - tau) * theta_target, using a small tau like 0.005.
import torch
def soft_update(target_net, online_net, tau=0.005):
"""Polyak averaging: nudge the target a tiny step toward the online net."""
for t_param, o_param in zip(target_net.parameters(),
online_net.parameters()):
t_param.data.copy_(tau * o_param.data + (1.0 - tau) * t_param.data)
# Drop-in: replace the periodic hard copy inside DQNAgent.learn() with:
# soft_update(self.target_network, self.q_network, tau=0.005)
# and delete the `if self.step_count % self.target_update == 0` block.
Both schemes solve the same problem -- keeping the target slow-moving -- but they feel different. The hard copy holds the target perfectly still for a thousand steps and then makes it lurch to a brand-new estimate all at once. The soft update lets the target drift continuously, always trailing the online network by a hair. With tau = 0.005 the target moves at roughly the pace of a hard copy every 200 steps, but smoothly, so you never get that jolt of a sudden target jump. In practice Polyak averaging tends to feel less twitchy, which is exactly why the continuous-control algorithms we'll meet down the line adopted it as standard. A gentle, constant drift is kinder to gradient descent than an occasional earthquake.
Exercise 3: Add Double DQN to your agent and measure the overestimation directly: periodically log the mean predicted Q of the start state alongside the actual return the agent goes on to collect from there.
import numpy as np
import torch
def measured_return(env, agent, gamma=0.99):
"""Run one greedy episode, return predicted start-Q and the real return."""
state, _ = env.reset()
with torch.no_grad():
start_q = agent.q_network(
torch.FloatTensor(state).unsqueeze(0)).max().item()
rewards, done, trunc = [], False, False
while not (done or trunc):
with torch.no_grad():
action = agent.q_network(
torch.FloatTensor(state).unsqueeze(0)).argmax().item()
state, reward, done, trunc, _ = env.step(action)
rewards.append(reward)
# discounted actual return from the start state
G = 0.0
for r in reversed(rewards):
G = r + gamma * G
return start_q, G
# Every 20 training episodes, call measured_return(env, agent) and stash
# (predicted_q, actual_G) for both a plain-DQN agent (learn) and a
# Double-DQN agent (learn_double_dqn). Plot predicted vs actual for each.
Plot the two curves and the story tells itself. Plain DQN's predicted start-state Q floats consistently above the return the agent actually collects -- that is the max-of-noisy-numbers optimism we dissected last time, made visible. The Double DQN agent, which decouples action selection from evaluation, hugs the real returns far more honestly. The gap between the orange line and the blue line is the overestimation bias, and you just measured it on your own machine. Satisfying, that.
On to today's episode
Right -- episode 108, and this is the one where we stop sneaking up on the policy through a value function and just optimise the darned thing directly.
Step back and notice something about every single reinforcement learning algorithm in this arc so far. Dynamic programming (episode #104), Monte Carlo (#105), TD learning (#106), DQN (#107) -- all of them are value-based. They learn Q(s, a), and the policy is a side effect: pick the action with the biggest Q. The policy never has parameters of its own. It is squeezed out of the value function by an argmax.
Policy gradient methods throw that arrangement out. They give the policy its own parameters, its own neural network, and optimise it directly by gradient ascent on expected reward. No Q-table, no argmax, no value function required (though, as we'll see, one sneaks back in to help). Having said that, why would you ever want this? Three solid reasons.
First, continuous actions. DQN outputs one Q-value per action and takes the max, which assumes a small, countable set of actions -- "left, right, fire". But a robot joint can rotate to any angle, a throttle can open to any fraction. There is no finite list to argmax over. A policy network laughs at this: it just outputs the parameters of a distribution (say a mean and a spread) over a continuous action and samples from it.
Second, stochastic optimal policies. Value-based methods always hand you a deterministic policy (argmax is a single action). But sometimes the best policy is genuinely random -- rock-paper-scissors is the textbook case, where any predictable strategy gets exploited and uniform randomness is provably optimal. A policy network can represent "70% rock, 30% paper" effortlessly; an argmax cannot.
Third, smoothness. A tiny tweak to the network weights produces a tiny change in the policy. Compare that to DQN, where a microscopic change in two Q-values can flip the argmax and lurch the behaviour from one action to its opposite. Smooth changes are friendlier to optimise -- something we've leaned on since the very first gradient descent back in episode #7.
The policy network
So what does a policy network look like? For a discrete action space it is almost embarrassingly familiar -- a classifier, basically, that outputs a probability distribution over actions:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical
import numpy as np
class PolicyNetwork(nn.Module):
"""Stochastic policy: maps a state to action probabilities."""
def __init__(self, state_dim, n_actions, hidden=128):
super().__init__()
self.net = nn.Sequential(
nn.Linear(state_dim, hidden), nn.ReLU(),
nn.Linear(hidden, hidden), nn.ReLU(),
nn.Linear(hidden, n_actions),
)
def forward(self, state):
logits = self.net(state)
return F.softmax(logits, dim=-1) # a valid probability distribution
def get_action(self, state):
"""Sample an action, and remember its log-probability for the update."""
state_t = torch.FloatTensor(state).unsqueeze(0)
probs = self.forward(state_t)
dist = Categorical(probs)
action = dist.sample()
return action.item(), dist.log_prob(action)
The softmax at the end is the whole trick (we first met it back in the classification episodes). The output is a genuine probability distribution, and the agent samples from it in stead of always taking the most likely action. That sampling is what makes the policy stochastic: in a state where two actions look about equally good, the network spreads probability across both and the agent tries each one sometimes. Exploration, in other words, is baked right into the policy -- no separate epsilon-greedy bolted on the side like we needed for DQN.
The policy gradient theorem
Here is the goal, written plainly. We want to maximise the expected discounted return of the policy:
J(theta) = E[ sum_t gamma^t * r_t ] # expected return, following policy pi_theta
We want to climb this J by adjusting theta. To do gradient ascent we need its gradient, grad J(theta). And this is where it gets hairy, because the thing we are differentiating is an expectation over trajectories that the parameters themselves shape -- change theta and you change which states you even visit. It looks hopeless to differentiate. The policy gradient theorem is the small miracle that says it isn't:
grad J(theta) = E[ sum_t grad log pi_theta(a_t | s_t) * G_t ]
Read that in words, because the words are the whole intuition. The gradient of expected reward equals the expected value of (the gradient of the log-probability of the action you took) times (the return you got from that point onward). Actions that were followed by big returns get their log-probability pushed up; actions followed by poor returns get pushed down. The agent quite literally makes good moves more likely and bad moves less likely, weighted by how good or bad things turned out.
That grad log pi_theta(a | s) term has a name -- the score function -- and the beautiful part is that PyTorch's autograd (episode #42) computes it for free. We never write the gradient by hand. We just construct the right loss and call .backward().
REINFORCE
The most direct implementation of the theorem is REINFORCE (Williams, 1992 -- this algorithm is older than quite some of you reading this). Play a whole episode, compute the return from each step, and nudge every action's log-probability in proportion to its return:
class REINFORCE:
"""Monte Carlo policy gradient."""
def __init__(self, state_dim, n_actions, lr=1e-3, gamma=0.99):
self.gamma = gamma
self.policy = PolicyNetwork(state_dim, n_actions)
self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=lr)
self.log_probs = [] # filled during the episode
self.rewards = []
def choose_action(self, state):
action, log_prob = self.policy.get_action(state)
self.log_probs.append(log_prob)
return action
def store_reward(self, reward):
self.rewards.append(reward)
def learn(self):
"""Update once, after a complete episode."""
# discounted return from each timestep onward
returns, G = [], 0.0
for r in reversed(self.rewards):
G = r + self.gamma * G
returns.insert(0, G)
returns = torch.FloatTensor(returns)
# normalise returns -- a cheap, effective variance reducer
returns = (returns - returns.mean()) / (returns.std() + 1e-8)
# loss = -sum( log_prob * return ); negative because we MAXIMISE return
loss = torch.stack(
[-lp * G for lp, G in zip(self.log_probs, returns)]
).sum()
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
self.log_probs, self.rewards = [], [] # clear for next episode
return loss.item()
The one bit that trips people up is the minus sign in -lp * G. PyTorch optimisers minimise a loss, but we want to maximise return. Minimising the negative of a thing is the same as maximising the thing -- so we negate, and gradient descent on the loss becomes gradient ascent on J. The whole algorithm is "remember Monte Carlo returns from episode #105, but instead of averaging them into a value table, push them through the log-probabilities of the actions you took". Notice this is on-policy and Monte Carlo -- it waits for the full episode before learning anything, just like the MC methods we built earlier.
The variance problem
REINFORCE is correct. It is also, out of the box, a bit of a nightmare to actually train. The trouble is variance. The return G_t comes from a single rollout, and that rollout is soaked in randomness: every action was sampled, every state transition may be stochastic. Two episodes that begin in the exact same state can hand you wildly different returns purely by luck of the draw.
So the gradient estimate jitters violently from episode to episode. The policy lurches one way, then the other, and learning crawls -- when it isn't actively going backward. The returns.mean() / returns.std() normalisation you saw above takes the edge off, but it's a sticking plaster. We need a real fix, and it's a clever one.
Baselines: less noise, no lies
Here's the key insight. Suppose we subtract some quantity b(s) from the return before scaling the gradient:
grad J(theta) = E[ grad log pi_theta(a_t | s_t) * (G_t - b(s_t)) ]
If b(s) depends only on the state and not the action, this leaves the gradient unbiased -- the expected gradient is exactly the same as before -- while potentially slashing the variance. (The proof is one line: the expected score function, summed over actions, is zero, so multiplying it by any function of s alone contributes nothing in expectation. Take my word for it, or grind through it -- it's genuinely a single step.)
What's the best baseline? The state value function V(s) -- the expected return from state s. Subtract it and the quantity G_t - V(s_t) gets a name we'll be saying a lot from here on: the advantage. It measures how much better this particular trajectory did compared to what was expected from that state. Positive advantage ("better than average") reinforces the action; negative advantage ("worse than average") suppresses it. That is so much more sensible than raw returns: in a state where every action yields a return of +500, the raw return screams "everything you did was brilliant!" -- but the advantage correctly whispers "meh, that was just an average day here, learn nothing".
So we add a second network -- a critic -- to learn V(s), while the policy plays the role of actor:
class REINFORCEWithBaseline:
"""REINFORCE plus a learned value baseline (an actor and a critic)."""
def __init__(self, state_dim, n_actions, lr_policy=1e-3,
lr_value=1e-3, gamma=0.99):
self.gamma = gamma
self.policy = PolicyNetwork(state_dim, n_actions) # actor
self.policy_opt = torch.optim.Adam(
self.policy.parameters(), lr=lr_policy)
self.value_net = nn.Sequential( # critic
nn.Linear(state_dim, 128), nn.ReLU(),
nn.Linear(128, 128), nn.ReLU(),
nn.Linear(128, 1),
)
self.value_opt = torch.optim.Adam(
self.value_net.parameters(), lr=lr_value)
self.log_probs, self.rewards, self.states = [], [], []
def choose_action(self, state):
self.states.append(state)
action, log_prob = self.policy.get_action(state)
self.log_probs.append(log_prob)
return action
def store_reward(self, reward):
self.rewards.append(reward)
def learn(self):
returns, G = [], 0.0
for r in reversed(self.rewards):
G = r + self.gamma * G
returns.insert(0, G)
returns = torch.FloatTensor(returns)
states_t = torch.FloatTensor(np.array(self.states))
values = self.value_net(states_t).squeeze()
# advantage = return - baseline; detach so the actor doesn't
# backprop through the critic
advantages = returns - values.detach()
policy_loss = torch.stack(
[-lp * adv for lp, adv in zip(self.log_probs, advantages)]
).sum()
value_loss = F.mse_loss(values, returns) # critic regresses toward returns
self.policy_opt.zero_grad(); policy_loss.backward(); self.policy_opt.step()
self.value_opt.zero_grad(); value_loss.backward(); self.value_opt.step()
self.log_probs, self.rewards, self.states = [], [], []
return policy_loss.item(), value_loss.item()
Two losses, two optimisers. The critic learns to predict returns (plain mean-squared error, just like the regression all the way back in episode #10), and the actor uses the critic's prediction as its baseline. The .detach() on values is load-bearing: it stops the actor's gradient from leaking into the critic, keeping the two jobs cleanly separated. Nota bene: this actor-critic split -- one network judging states, another choosing actions -- is one of the most durable ideas in all of RL. Very nearly every modern algorithm is some flavour of it.
Actor-Critic: stop waiting for the episode to end
REINFORCE-with-baseline is better, but it still has the Monte Carlo handicap: it sits on its hands until the episode is over before it learns anything. For a game that lasts thousands of steps -- or never naturally ends at all -- that's painfully slow. We already solved this exact impatience once, back in episode #106: bootstrapping. Don't wait for the true return, estimate it from the next state's value. The one-step TD error becomes our advantage:
advantage ~= delta = r + gamma * V(s') - V(s)
That delta is precisely the TD error from temporal difference learning -- how much better the transition turned out than the critic predicted. It is a one-step, lower-variance stand-in for the full Monte Carlo advantage, and it lets us update at every step:
class ActorCritic:
"""One-step Actor-Critic (the A2C pattern)."""
def __init__(self, state_dim, n_actions, lr=3e-4, gamma=0.99):
self.gamma = gamma
# actor and critic share a feature trunk
self.features = nn.Sequential(nn.Linear(state_dim, 128), nn.ReLU())
self.actor = nn.Sequential(
nn.Linear(128, 128), nn.ReLU(), nn.Linear(128, n_actions))
self.critic = nn.Sequential(
nn.Linear(128, 128), nn.ReLU(), nn.Linear(128, 1))
params = (list(self.features.parameters())
+ list(self.actor.parameters())
+ list(self.critic.parameters()))
self.optimizer = torch.optim.Adam(params, lr=lr)
self._all_params = params
def get_action_and_value(self, state):
state_t = torch.FloatTensor(state).unsqueeze(0)
feats = self.features(state_t)
dist = Categorical(F.softmax(self.actor(feats), dim=-1))
action = dist.sample()
return action.item(), dist.log_prob(action), self.critic(feats).squeeze()
def update(self, log_prob, value, reward, next_value, done):
target = reward + self.gamma * next_value.detach() * (1 - done)
advantage = target - value
actor_loss = -log_prob * advantage.detach() # policy gradient with TD advantage
critic_loss = advantage.pow(2) # squared TD error
loss = actor_loss + 0.5 * critic_loss # 0.5 balances the two gradients
self.optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(self._all_params, max_norm=0.5) # stability
self.optimizer.step()
return loss.item()
Notice the actor and critic now share a feature trunk -- the first linear layer feeds both heads. That's common and sensible: both jobs need to understand the state, so why learn that twice? The critic loss is scaled by 0.5 to keep its gradient from bullying the actor's. And the gradient clipping (which we first reached for with the RNNs in episode #48) keeps the whole thing from blowing up on an unlucky batch.
Training an Actor-Critic agent
The training loop is the same act-evaluate-update rhythm we've used all arc long, just with the value head riding along:
def train_actor_critic(env, agent, n_episodes=2000):
rewards_history = []
for episode in range(n_episodes):
state, _ = env.reset()
total, done, trunc = 0.0, False, False
while not (done or trunc):
action, log_prob, value = agent.get_action_and_value(state)
next_state, reward, done, trunc, _ = env.step(action)
_, _, next_value = agent.get_action_and_value(next_state)
agent.update(log_prob, value, reward, next_value,
float(done or trunc))
state = next_state
total += reward
rewards_history.append(total)
if episode % 100 == 0:
avg = np.mean(rewards_history[-100:])
print(f"Episode {episode} | avg reward {avg:6.1f}")
return rewards_history
Point this at CartPole-v1 and it learns to balance the pole, same as our DQN did last episode -- but along an entirely different conceptual route. No replay buffer, no target network, no epsilon schedule. The exploration comes free from sampling the policy, and the learning comes from the advantage. Different machinery, same victory.
Value-based vs policy-based: which when?
Let me lay the two families side by side, because the contrast is the real lesson of today:
| Aspect | Value-based (DQN) | Policy-based (REINFORCE / A2C) |
|---|---|---|
| Action space | Discrete only | Discrete or continuous |
| Policy type | Deterministic (argmax) | Stochastic (sampling) |
| Stability | Can oscillate (Q flips argmax) | Smoother (small theta -> small pi change) |
| Sample efficiency | Better (off-policy replay) | Worse (on-policy, can't reuse data freely) |
| Convergence | Guaranteed for tabular Q | Guaranteed to a local optimum |
Neither column is the winner. DQN's experience replay makes it stingy with data, which matters when each environment step is expensive; policy methods burn through experience faster but go where value methods simply cannot follow -- continuous control, stochastic optima, smooth behaviour. And in modern practice the line has blurred almost to nothing: the algorithms that power real robotics and game-playing agents are hybrids, borrowing the actor-critic skeleton you just built and bolting on cleverer ways of taking the policy-gradient step without it exploding. Understanding these pure forms is exactly what lets the hybrids make sense -- which is precisely where we're headed next ;-)
So, what do you know now?
- Policy gradient methods parameterise the policy directly and climb expected return by gradient ascent -- no value function required to act;
- the policy gradient theorem says the gradient is the expected score function
grad log pi(a|s)weighted by the return, so good actions get made more likely and bad ones less likely; - REINFORCE is the bare-bones Monte Carlo version -- correct, but cursed with high variance because each return is a single noisy rollout;
- subtracting a state-only baseline (best of all, the value function
V(s)) gives the advantageG_t - V(s_t), cutting variance with zero added bias; - Actor-Critic replaces the Monte Carlo return with the one-step TD error as its advantage, so it learns every step in stead of only at episode's end;
- the actor (policy) and critic (value) usually share a feature trunk, with the critic's loss down-weighted so it doesn't drown out the actor;
- policy methods natively handle continuous actions, stochastic policies, and smooth optimisation -- three places value-based methods struggle or fail outright.
Exercises
Exercise 1: Implement plain REINFORCE (no baseline) on CartPole-v1 and train it for 1,000 episodes, logging the 100-episode moving average of reward. Then run it three times with different random seeds and plot all three curves on one chart. The point is to see the variance problem with your own eyes: describe how much the three runs differ from one another, and contrast that with how tightly clustered three DQN runs from episode #107 would be.
Exercise 2: Take your REINFORCE agent and add the learned value baseline (turn it into REINFORCEWithBaseline). Train both versions on CartPole-v1 under the same seeds and plot their reward curves together. Quantify the improvement: roughly how many episodes does each take to first reach an average reward of 195, and how does the wobble of the two curves compare? Tie your answer back to why subtracting V(s) reduces variance without biasing the gradient.
Exercise 3: Add an entropy bonus to the ActorCritic agent. Compute the entropy of the action distribution at each step (dist.entropy()) and add -beta * entropy to the loss (so the optimiser is rewarded for keeping the policy uncertain), with a small beta like 0.01. Train CartPole with beta = 0 and beta = 0.01 and compare. Explain why nudging the policy toward higher entropy discourages it from collapsing prematurely onto one action -- and connect this to the exploration-exploitation tension we first met with the bandits in episode #103.