Learn AI Series (#110) - Model-Based Reinforcement Learning

Learn AI Series (#110) - Model-Based Reinforcement Learning

variant-a-12-green.png

What will I learn

  • You will learn the real difference between model-free and model-based RL -- and why the second family exists at all;
  • how to learn a model of the environment: a network that predicts the next state and the reward;
  • how to plan with a learned model by training on imagined trajectories in stead of only on real ones;
  • Dyna-Q, the simplest model-based architecture there is, and its neural cousin;
  • World Models -- networks that learn to dream -- and MuZero, which masters games without ever being told the rules;
  • and the one trade-off that governs this entire chapter: sample efficiency bought against model accuracy.

Requirements

  • A working modern computer running macOS, Windows or Ubuntu;
  • An installed Python 3(.11+) distribution with NumPy and PyTorch;
  • You've followed along through episodes #106 (TD learning), #107 (DQN) and #109 (PPO) -- this one leans on all three.

Difficulty

  • Beginner

Curriculum (of the Learn AI Series):

Learn AI Series (#110) - Model-Based Reinforcement Learning

Solutions to Episode #109 Exercises

Before we teach an agent to dream, let's settle last episode's three exercises. All of them build on the PPONetwork, RolloutBuffer and PPOAgent classes from episode #109, so I'm assuming those are imported and sitting in scope. As before I'm leaning on gymnasium throughout (pip install gymnasium if you skipped it).

Exercise 1: Get the PPOAgent training on CartPole-v1, plot the per-rollout average reward, then run the clip ablation -- set clip_eps enormous so the clamp never fires -- and compare.

import gymnasium as gym
import numpy as np
import torch
# Assumes PPOAgent and train_ppo from episode #109.


def run_ppo(clip_eps, seed, total_steps=80_000, rollout_len=2048):
    env = gym.make("CartPole-v1")
    torch.manual_seed(seed)
    np.random.seed(seed)
    agent = PPOAgent(env.observation_space.shape[0],
                     env.action_space.n, clip_eps=clip_eps)
    history = train_ppo(env, agent, total_steps, rollout_len)
    return history


clipped = run_ppo(0.2, seed=0)      # the real PPO
unclipped = run_ppo(100.0, seed=0)  # clamp never triggers
for name, h in [("clip=0.2", clipped), ("clip=100", unclipped)]:
    print(f"{name:>9}: final avg (last 20) = {np.mean(h[-20:]):6.1f} "
          f"| peak = {max(h):6.1f}")

Plot the two reward curves side by side and the story tells itself. With clip_eps = 0.2 the curve climbs steadily and stays up at CartPole's ceiling of 500. With clip_eps = 100.0 the clamp in torch.clamp(ratio, 1 - eps, 1 + eps) never actually bites -- surr1 and surr2 become the same number, the min is a no-op, and you've quietly turned PPO back into a multi-epoch REINFORCE-with-baseline. And because we still loop over the same rollout for several epochs, those repeated unconstrained steps march the policy miles away from the data that produced it -- exactly the "no seatbelt" death spiral that opened episode #109. You'll see the unclipped curve spike and then crater, sometimes more than once per run. Same code, one disabled clamp, night-and-day stability. Tells you precisely how much work that one little clamp is doing ;-)

Exercise 2: Add an approximate-KL early stop. After each epoch estimate the mean KL with the cheap mean(old_log_probs - new_log_probs), and if it exceeds a threshold, break out of the epoch loop.

import torch
import torch.nn.functional as F
from torch.distributions import Categorical
# A drop-in replacement for PPOAgent.update() from episode #109.


def update_with_kl_stop(self, last_value, target_kl=0.015):
    returns, advantages = self.buffer.compute_gae(
        last_value, self.gamma, self.lam)
    advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

    stopped_at = self.epochs                       # for logging
    for epoch in range(self.epochs):
        approx_kls = []
        for (states, actions, old_log_probs,
             b_returns, b_adv) in self.buffer.batches(
                 returns, advantages, self.batch_size):

            logits, values = self.net(states)
            dist = Categorical(logits=logits)
            new_log_probs = dist.log_prob(actions)
            entropy = dist.entropy().mean()

            ratio = torch.exp(new_log_probs - old_log_probs)
            surr1 = ratio * b_adv
            surr2 = torch.clamp(ratio, 1 - self.clip_eps,
                                1 + self.clip_eps) * b_adv
            actor_loss = -torch.min(surr1, surr2).mean()
            critic_loss = F.mse_loss(values.squeeze(), b_returns)
            loss = (actor_loss + self.value_coef * critic_loss
                    - self.entropy_coef * entropy)

            self.opt.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.net.parameters(), 0.5)
            self.opt.step()
            approx_kls.append((old_log_probs - new_log_probs).mean().item())

        if np.mean(approx_kls) > target_kl:        # the early stop
            stopped_at = epoch + 1
            break
    self.buffer.clear()
    return stopped_at                              # log this across training

The estimator mean(old_log_probs - new_log_probs) is a first-order stand-in for the true KL divergence between the old and new policy -- cheap, no extra forward passes, and good enough to act on. Log stopped_at across a whole training run and you'll notice the early stop fires often in the early, fast-moving phase (when each rollout teaches the policy a lot, so it wants to move far) and almost never late in training (when the policy is nearly converged and barely budges). What you've built here is a little hybrid: PPO's implicit clip keeps any single step honest, and this explicit KL check keeps the accumulation of several epochs honest. That explicit KL bound is exactly the leash TRPO used (episode #109) -- we've smuggled a piece of TRPO back in on top of PPO, for about eight lines of code. Having said that, vanilla PPO mostly does fine without it -- this is a belt-and-suspenders touch for the runs that misbehave.

Exercise 3: Adapt PPO to a continuous action space and run it on Pendulum-v1, swapping the Categorical for a Normal.

import torch
import torch.nn as nn
from torch.distributions import Normal


class ContinuousPPONetwork(nn.Module):
    """Actor outputs a Gaussian mean; log_std is a free parameter."""
    def __init__(self, state_dim, action_dim, hidden=64):
        super().__init__()
        self.shared = nn.Sequential(
            nn.Linear(state_dim, hidden), nn.Tanh(),
            nn.Linear(hidden, hidden), nn.Tanh(),
        )
        self.mean = nn.Linear(hidden, action_dim)        # action mean
        self.log_std = nn.Parameter(torch.zeros(action_dim))  # state-independent
        self.critic = nn.Linear(hidden, 1)

    def forward(self, state):
        features = self.shared(state)
        return self.mean(features), self.critic(features)

    def get_action(self, state):
        state_t = torch.FloatTensor(state).unsqueeze(0)
        mean, value = self.forward(state_t)
        std = self.log_std.exp()
        dist = Normal(mean, std)
        action = dist.sample()
        # sum over action dims: independent Gaussians -> joint log-prob
        log_prob = dist.log_prob(action).sum(-1)
        return action.squeeze(0).numpy(), log_prob, value.squeeze()

Run this on Pendulum-v1 (a continuous torque-control task) and the same PPOAgent skeleton trains it -- you only have to thread Normal through the update's log-prob calls and .sum(-1) the per-dimension log-probs. Now notice carefully what changed: the distribution (Normal not Categorical), the action shape (a real-valued vector, not an index), and the actor's output layer (a mean head plus a learned log_std). And then notice everything that stayed exactly the same: the clipped surrogate, the GAE computation, the multi-epoch loop, the value loss, the entropy bonus. That invariance is the whole reason PPO travels so well -- discrete or continuous, the optimisation core does not care. It's also the reason I made you build PPO carefully last time: the investment pays off across an entire zoo of problems.

On to today's episode

Right -- episode 110, and we are about to cross a genuine line in this RL chapter.

Cast your eye back over everything we've done since episode #102. Q-Learning, SARSA, DQN, REINFORCE, Actor-Critic, PPO -- every single one of them is model-free. The agent reaches into the environment, pulls out rewards, and adjusts a policy or a value function accordingly. It never once asks how the environment works. It doesn't predict what the next state will be. It just reacts to whatever it observes, like a creature with very fast reflexes and absolutely no imagination.

That works, and it works famously well. But it has one nagging, expensive flaw, and today we finally confront it.

The crippling cost of model-free RL

Model-free methods are sample-hungry. Embarrassingly so. DQN learning to play a single Atari game can chew through tens of millions of frames -- the equivalent of weeks of nonstop play. PPO solving a robotics task in simulation might need hundreds of millions of timesteps. In a simulator, where you can run a thousand environments in parallel and time costs nothing, that's tolerable. On a real robot, where every interaction takes real seconds and a bad action can snap a real servo, it's a non-starter. You cannot crash a real drone a million times to teach it to fly.

Now think about how you learn instead. You did not need to total a thousand cars to learn to drive. You hold a mental model of how a car behaves -- turn the wheel right, the car goes right; brake hard, you lurch forward -- and you run little simulations in your head before you ever touch the road. "If I pull out now, that lorry arrives about there..." You imagine consequences, and you learn from the imagining.

Model-based RL is the attempt to give an agent that same gift. The plan is simple to state:

  1. From real experience, learn a model of the environment -- something that predicts the next state and the reward.
  2. Use that model to imagine experience -- trajectories the agent never actually lived.
  3. Train the policy on the imagined experience as if it were real.

Real interactions are precious and slow. Model queries are cheap and fast (a forward pass, or a table lookup). So you spend your scarce real data improving the model, then mint as much synthetic data as you like from it. That is the entire pitch, and when it works, the sample-efficiency gains are enormous -- ten to a hundred times fewer real interactions is not unusual.

The environment model

Everything hinges on that model. So what is it? Concretely, a function that takes the current state s and an action a and predicts two things:

  1. the next state s' -- where do we end up?
  2. the reward r -- what do we get for going there?

In a neural network, that's a small regression problem:

import torch
import torch.nn as nn
import numpy as np


class EnvironmentModel(nn.Module):
    """Learned dynamics: predicts next state and reward from (state, action)."""
    def __init__(self, state_dim, action_dim, hidden=256):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim + action_dim, hidden), nn.ReLU(),
            nn.Linear(hidden, hidden), nn.ReLU(),
        )
        self.state_head = nn.Linear(hidden, state_dim)   # predicts delta-state
        self.reward_head = nn.Linear(hidden, 1)          # predicts reward

    def forward(self, state, action):
        # action is one-hot encoded for discrete action spaces
        x = torch.cat([state, action], dim=-1)
        features = self.net(x)
        delta_state = self.state_head(features)
        reward = self.reward_head(features)
        next_state = state + delta_state                 # predict the change
        return next_state, reward.squeeze(-1)

One detail there is doing quite some heavy lifting: the model predicts the state delta (state + delta_state) rather than the absolute next state. This is standard practice and it matters. Between one frame and the next, most of the state barely moves -- a cart's position nudges, a pole's angle ticks. Asking the network to output the full new state means re-learning all the bits that didn't change; asking it for only the change is a far easier target, with smaller, better-behaved numbers. If that reminds you of the residual connections from episode #46, good -- it's the same trick (learn the residual, not the whole function), wearing different clothes.

Training the model

The model is trained by plain old supervised learning on real transitions we've collected. No RL magic here at all -- it's the regression we've known since episode #10, just with a state-and-reward target:

class ModelTrainer:
    """Train the environment model from collected (s, a, r, s') experience."""
    def __init__(self, model, lr=1e-3):
        self.model = model
        self.optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        self.buffer = []                                 # (s, a, r, s') tuples

    def add_transition(self, state, action, reward, next_state):
        self.buffer.append((state, action, reward, next_state))

    def train_step(self, batch_size=64):
        if len(self.buffer) < batch_size:
            return None

        idx = np.random.choice(len(self.buffer), batch_size)
        batch = [self.buffer[i] for i in idx]
        states, actions, rewards, next_states = zip(*batch)

        states_t = torch.FloatTensor(np.array(states))
        actions_t = torch.FloatTensor(np.array(actions))
        rewards_t = torch.FloatTensor(rewards)
        next_states_t = torch.FloatTensor(np.array(next_states))

        pred_next, pred_reward = self.model(states_t, actions_t)

        state_loss = nn.functional.mse_loss(pred_next, next_states_t)
        reward_loss = nn.functional.mse_loss(pred_reward, rewards_t)
        loss = state_loss + reward_loss                  # joint objective

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        return loss.item()

Notice the loss is just the sum of two MSE terms -- one for predicting the next state, one for the reward. The agent's whole "understanding" of physics boils down to getting that joint loss low. Bam, jonguh -- that's a world model in fourteen lines.

Dyna-Q: the simplest model-based architecture there is

The cleanest way to see model-based RL in action is Dyna-Q (Sutton, 1991), and it is almost cheekily simple. After every real step, you do your normal Q-Learning update (the model-free part, exactly as in episode #106) -- and then you do a handful of extra Q-Learning updates on transitions conjured up by the model. Real learning and imagined learning, interleaved, sharing the same Q-table:

from collections import defaultdict
import numpy as np


class DynaQ:
    """Dyna-Q: tabular Q-Learning interleaved with model-based planning."""
    def __init__(self, n_states, n_actions, alpha=0.1, gamma=0.99,
                 epsilon=0.1, planning_steps=5):
        self.n_actions = n_actions
        self.alpha, self.gamma, self.epsilon = alpha, gamma, epsilon
        self.planning_steps = planning_steps

        self.Q = defaultdict(lambda: np.zeros(n_actions))
        self.model = {}                  # (state, action) -> (reward, next_state, done)
        self.visited_sa = []             # which (s, a) pairs we've actually seen

    def choose_action(self, state):
        if np.random.random() < self.epsilon:
            return np.random.randint(self.n_actions)
        return int(np.argmax(self.Q[state]))

    def update(self, state, action, reward, next_state, done):
        # 1. Real Q-Learning update (model-free)
        target = reward + self.gamma * np.max(self.Q[next_state]) * (1 - done)
        self.Q[state][action] += self.alpha * (target - self.Q[state][action])

        # 2. Update the model with what really happened
        self.model[(state, action)] = (reward, next_state, done)
        if (state, action) not in self.visited_sa:
            self.visited_sa.append((state, action))

        # 3. Planning: replay imagined transitions from the model
        for _ in range(self.planning_steps):
            s, a = self.visited_sa[np.random.randint(len(self.visited_sa))]
            r, ns, d = self.model[(s, a)]
            target = r + self.gamma * np.max(self.Q[ns]) * (1 - d)
            self.Q[s][a] += self.alpha * (target - self.Q[s][a])

Look at what that buys you. With planning_steps = 5, every single real interaction triggers six Q-updates -- one from reality plus five replayed from the model. The Q-values propagate backwards through the state space five times faster in wall-clock terms, because the agent isn't sitting idle waiting for the environment to hand it the next experience -- it's mining the experience it already has. On a small gridworld where the tabular model is essentially perfect, Dyna-Q converges dramatically quicker than plain Q-Learning, and the gap widens the more planning steps you allow.

There's a deep idea hiding in that humble loop, and it's worth saying plainly: real data improves the model, the model improves the policy. Two learning processes feeding each other, and only one of them costs you precious real-world interactions.

NB: if Dyna-Q's replay reminds you of DQN's experience replay buffer from episode #107, you've spotted something real -- both reuse past experience instead of throwing it away. The difference is that replay only ever serves up transitions that genuinely happened, whereas a learned model can generate transitions you've never seen at all. That extra reach is the model-based superpower, and (as we'll see) also its Achilles' heel.

Neural Dyna: scaling past the table

A tabular model only works when states are few and discrete. The moment the state is a vector of real numbers -- a robot's joint angles, a game screen's features -- the table explodes. The fix is the obvious one: replace the dictionary with the EnvironmentModel network from earlier, and let it generalise across states it has only sort-of seen:

class NeuralDynaAgent:
    """Dyna with a neural environment model for continuous-state problems."""
    def __init__(self, state_dim, n_actions, hidden=128,
                 planning_steps=10, planning_horizon=5):
        self.state_dim, self.n_actions = state_dim, n_actions
        self.planning_steps = planning_steps
        self.planning_horizon = planning_horizon

        self.q_net = nn.Sequential(                      # model-free component
            nn.Linear(state_dim, hidden), nn.ReLU(),
            nn.Linear(hidden, hidden), nn.ReLU(),
            nn.Linear(hidden, n_actions),
        )
        self.q_optimizer = torch.optim.Adam(self.q_net.parameters(), lr=1e-3)

        self.env_model = EnvironmentModel(state_dim, n_actions, hidden)
        self.model_optimizer = torch.optim.Adam(
            self.env_model.parameters(), lr=1e-3)
        self.replay_buffer = []

    def imagine_rollout(self, start_state, horizon):
        """Walk forward through the LEARNED model, no real environment touched."""
        state = torch.FloatTensor(start_state).unsqueeze(0)
        imagined = []
        for _ in range(horizon):
            with torch.no_grad():
                action_idx = self.q_net(state).argmax(dim=1).item()
            action = torch.zeros(1, self.n_actions)
            action[0, action_idx] = 1.0
            with torch.no_grad():
                next_state, reward = self.env_model(state, action)
            imagined.append((state.squeeze().numpy(), action_idx,
                             reward.item(), next_state.squeeze().numpy()))
            state = next_state                           # feed prediction back in
        return imagined

    def plan(self):
        """Generate imagined experience and train the Q-network on it."""
        if len(self.replay_buffer) < 100:
            return
        for _ in range(self.planning_steps):
            start = self.replay_buffer[
                np.random.randint(len(self.replay_buffer))][0]
            for s, a, r, ns in self.imagine_rollout(start, self.planning_horizon):
                self._q_update(s, a, r, ns)

    def _q_update(self, state, action, reward, next_state):
        state_t = torch.FloatTensor(state).unsqueeze(0)
        next_state_t = torch.FloatTensor(next_state).unsqueeze(0)
        q_pred = self.q_net(state_t)[0, action]
        with torch.no_grad():
            q_target = reward + 0.99 * self.q_net(next_state_t).max().item()
        loss = (q_pred - q_target) ** 2
        self.q_optimizer.zero_grad()
        loss.backward()
        self.q_optimizer.step()

The crucial line is in imagine_rollout: state = next_state. We feed the model's own prediction back into the model as the next input, and step forward again. The agent is walking through a world entirely of its own making -- a daydream stitched together from a learned dynamics function. As long as the dream stays faithful to reality, training on it is almost free.

That phrase -- "as long as the dream stays faithful" -- is the whole ballgame, and we'll come back to it with a vengeance shortly. First, two landmark systems that pushed the dreaming idea to its limit.

World Models: an agent that trains inside its own dream

World Models (Ha & Schmidhuber, 2018) is one of those papers that's a genuine pleasure to read, because the central image is so vivid: an agent that learns a compact mental model of its world and then learns to act entirely inside that model, never touching the real environment during policy training. It splits the job into three parts:

  1. a VAE (the "Vision") that squashes each raw observation down into a small latent code z -- the gist of the scene, stripped of pixel clutter;
  2. an MDN-RNN (the "Memory") that, given the current code and action, predicts the next latent code -- a dynamics model living in latent space, not pixel space;
  3. a Controller -- a tiny policy mapping the latent code to an action.

The remarkable result: they trained the controller purely inside the learned model -- inside the "dream" -- and then dropped it into the real environment, where it still performed. The policy never saw a single real observation during its training. It learned to drive a racing car, in effect, by practising in a hallucination of the track. That works only because the model was faithful enough that habits learned in the dream carried over to reality. When you hear people say a policy was "trained in imagination", this is the lineage they're pointing at.

MuZero: planning without being told the rules

If World Models is the elegant idea, MuZero (Schrittwieser et al., 2020) is the heavyweight result. To appreciate it, line up DeepMind's family tree. AlphaGo was handed the rules of Go. AlphaZero was handed the rules of chess, Go and shogi. MuZero is handed... nothing. It learns to play at superhuman level without knowing the rules of the game at all -- it has to figure out the consequences of its own moves from experience.

It pulls this off with three learned functions:

class MuZeroComponents(nn.Module):
    """Simplified MuZero: representation, dynamics, prediction."""
    def __init__(self, obs_dim, hidden_dim, n_actions):
        super().__init__()
        # Representation: real observation -> hidden state
        self.representation = nn.Sequential(
            nn.Linear(obs_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
        )
        # Dynamics: hidden state + action -> next hidden state (+ reward)
        self.dynamics = nn.Sequential(
            nn.Linear(hidden_dim + n_actions, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
        )
        self.reward_pred = nn.Linear(hidden_dim, 1)
        # Prediction: hidden state -> policy + value
        self.policy_head = nn.Linear(hidden_dim, n_actions)
        self.value_head = nn.Linear(hidden_dim, 1)

    def initial_inference(self, observation):
        """From a real observation: encode, then read off policy and value."""
        hidden = self.representation(observation)
        return hidden, self.policy_head(hidden), self.value_head(hidden)

    def recurrent_inference(self, hidden_state, action_onehot):
        """From a hidden state + action: imagine the next hidden state."""
        x = torch.cat([hidden_state, action_onehot], dim=-1)
        next_hidden = self.dynamics(x)
        reward = self.reward_pred(next_hidden)
        policy = self.policy_head(next_hidden)
        value = self.value_head(next_hidden)
        return next_hidden, reward, policy, value

The masterstroke is what MuZero refuses to predict. It does not try to reconstruct the next screen of pixels (hard, wasteful, and mostly irrelevant to playing well). Instead it learns an abstract hidden state -- an internal representation tuned for one job only: making the dynamics easy to predict and the planning effective. The model lives in this learned latent space, never in observation space. It doesn't need to imagine what the board looks like, only what matters about it for winning.

And then it plans with Monte Carlo Tree Search (MCTS) -- the same family of tree search we touched on in the bandits and dynamic-programming episodes (#103, #104), now run inside the learned model. From the current hidden state, MuZero simulates many candidate action sequences using its dynamics function, scores each imagined line with the value head, and commits to the move whose imagined future looks best. Search, but over a dream of the game rather than a known rulebook.

The payoff: MuZero matched AlphaZero on Go, chess and shogi -- without the rules -- and simultaneously set a new state of the art on Atari, beating the model-free champions while using far less data. That last bit is the whole thesis of this episode, proven at the highest level: a good learned model buys you sample efficiency that pure reaction simply cannot.

Model error: the elephant in the room

Now the catch, and it is a big one. Learned models are wrong. Not catastrophically, usually -- but wrong in small ways, every step. And in a multi-step rollout those small errors compound.

Do the arithmetic and it's sobering. Say your model is a very respectable 95% accurate on a single step. Chain twenty steps together in your imagination and the accuracy of the final state is roughly 0.95 ** 20, which is about 36%. Twenty steps into the dream and you're more wrong than right. The agent, of course, doesn't know this -- it cheerfully trains on the garbage as if it were gospel.

Worse still is what happens when the model has an exploitable flaw. Suppose there's some weird corner of state space where the model wrongly predicts a huge reward. A model-free agent could never be fooled by this -- it only ever sees real rewards. But a model-based agent will find that phantom jackpot and exploit it relentlessly, optimising hard for a payoff that exists only in its own buggy imagination. It's the RL equivalent of a student who learns to game the practice test instead of the subject.

The field has a toolbox for keeping the dream honest:

  • Short planning horizons -- imagine only a few steps ahead, before errors snowball;
  • Ensemble models -- train several models and treat their disagreement as a measure of uncertainty (where they argue, don't trust the dream);
  • Model-predictive control (MPC) -- replan after every real step, so the model never drifts far before reality corrects it;
  • Bounded planning -- lean on the model for short-term imagination but fall back to model-free value estimates for the long-term picture.

The honest summary: model-based RL is a constant negotiation between the sample efficiency you gain and the model error you risk. Push the horizon too far and the second eats the first.

Model-based vs model-free: choosing wisely

So when do you reach for which? Here's the trade-off laid bare:

Model-free (DQN, PPO)Model-based (Dyna, MuZero)
Sample efficiencyLow -- millions of interactionsHigh -- often 10-100x fewer
Compute per stepLowHigh (model training + planning)
Best-case performanceCan reach optimalCapped by model accuracy
RobustnessHigh (no model to be wrong)Sensitive to model error
Moving partsFewMany -- more to tune and debug

The deciding question is almost always: how expensive is a real interaction? When interactions are cheap -- games, fast simulators where you can spin up a thousand parallel worlds -- model-free wins on sheer simplicity and robustness; who cares about sample efficiency when samples are free? But when interactions are expensive or dangerous -- a physical robot arm, a drone, a treatment policy in healthcare, an industrial controller -- model-based methods stop being a nicety and become the only sane option. You can't crash a real robot ten million times. You can let it crash ten million times in a dream.

That, in one sentence, is why this family of methods exists.

So, what do you know now?

  • Model-free RL (everything since episode #102) is powerful but sample-hungry -- it reacts to the world without ever modelling it, and pays for that with millions of interactions;
  • model-based RL learns a model that predicts next states and rewards, then trains the policy on imagined experience -- trading cheap computation for scarce real-world samples;
  • Dyna-Q is the minimal example: ordinary Q-Learning plus a few extra updates per step replayed from the model -- a handful of planning steps can multiply your effective experience several times over;
  • predicting state deltas (not absolute states) makes the model far easier to learn -- the same residual-learning idea from episode #46;
  • World Models trains a policy entirely inside a learned dream, and MuZero learns the dynamics in an abstract latent space and plans with MCTS -- reaching AlphaZero-level play without being given the rules;
  • the permanent danger is compounding model error: small per-step mistakes snowball over a rollout, and an agent will gleefully exploit phantom rewards in a flawed model -- short horizons, ensembles and replanning are the defence;
  • pick model-based when real interactions are expensive or dangerous (robotics, healthcare); stick with model-free when they're cheap (games, simulators).

Exercises

Exercise 1: Implement tabular DynaQ (the class above) on a small gridworld -- FrozenLake-v1 with is_slippery=False from gymnasium does nicely. Train it with planning_steps = 0 (which is just plain Q-Learning), then 5, then 50, all under the same seed, and plot episodes-to-solve for each. Confirm with your own eyes that more planning means faster convergence -- and then explain in a sentence why the benefit eventually plateaus (hint: think about how much genuinely new information one real transition can carry).

Exercise 2: Take the EnvironmentModel and ModelTrainer, collect a few thousand random-policy transitions from CartPole-v1, and train the model on them. Then measure compounding error directly: from a real start state, roll the model forward k steps and compare its predicted state against the true environment's state for k = 1, 5, 10, 20. Plot prediction error against k. You should see roughly the snowball we discussed -- and you'll have measured the 0.95 ** k problem rather than just read about it.

Exercise 3: Build a tiny ensemble -- train five EnvironmentModels on the same data but with different random seeds (so different initialisations and minibatch orders). For a batch of states, have all five predict the next state, and compute the variance of their predictions per state. Now connect the dots: argue how you'd use that variance as an uncertainty signal to decide when to stop trusting an imagined rollout. This is the seed of how serious model-based agents keep their dreams from running off the rails.

That ensemble-disagreement idea is your stepping stone into the messier, more crowded worlds we're heading for next -- environments where the agent is no longer alone, and where "predicting what happens next" means predicting what other learning agents will do. We've now got the model-based engine built and, more importantly, we know exactly where it breaks. That honest understanding of the failure mode is worth more than the algorithm itself ;-)

Bedankt en tot de volgende keer!

@scipio



0
0
0.000
0 comments