If Model-Free methods from previous chapters are "learn by doing"—
directly optimizing policies or value functions through extensive
trial-and-error, then Model-Based methods are "think before doing"—
learning environment dynamics modelsto plan futures in
imagination, dramatically improving sample efficiency. Human and animal
intelligence heavily relies on internal world models: chess grandmasters
simulate dozens of future moves mentally, infants predict object
trajectories through physical intuition. DeepMind's AlphaGo plans
through Monte Carlo Tree Search in simulated games, OpenAI's Dota 2
agents use environment simulators to "rehearse" strategies during
training. The core advantage of Model-Based RL is sample efficiency — in
scenarios where real environment interaction is expensive (like robot
control, autonomous driving), generating virtual experiences through
learned models can achieve Model-Free performance with 1/10 or even
1/100 of the samples. From classic Dyna architecture to World Models
combining deep learning, from MuZero's implicit planning to Dreamer
series learning in latent dream spaces, Model-Based RL has demonstrated
enormous potential in sample efficiency, generalization, and
interpretability. This chapter systematically traces this evolution,
deeply analyzing the design motivation, mathematical principles, and
implementation details of each algorithm.
Model-Based vs
Model-Free: Two Paths Compared
Core Differences
Model-Free RL: Directly learns policyor value functionfrom experience, without explicitly
modeling environment dynamics.
Advantages: - Simple algorithms, easy to implement - No assumptions
about environment needed - Performs well in tasks with complex or
unknown models
Disadvantages: - Low sample efficiency — requires millions of
interactions to learn well - Hard to generalize to new tasks — policy
only valid for training environment - Lacks interpretability — can't
understand why agent behaves certain way
Model-Based RL: Explicitly learns environment
modeland reward
function, uses model
for planning or generating virtual data.
Advantages: - High sample efficiency — can learn in imagination,
reducing real interactions - Strong generalization — dynamics knowledge
in model can transfer to new tasks - Good interpretability — can
visualize agent's understanding of world
Disadvantages: - Model error accumulation — prediction errors grow
exponentially with planning steps - Computationally expensive — need to
train additional model networks, planning also requires computation -
Hard to model complex environments — dynamics models for
high-dimensional observations (like pixels) are difficult to learn
Sample Efficiency Comparison
In classic control tasks, Model-Based methods show clear sample
advantages:
Algorithm Type
Environment
Samples to Expert Level
DQN (Model-Free)
Atari Pong
10M frames
PPO (Model-Free)
MuJoCo HalfCheetah
1M steps
Dreamer (Model-Based)
DMControl Walker
100K steps
MBPO (Model-Based)
MuJoCo HalfCheetah
100K steps
Model-Based methods typically need only 1/10 the samples of
Model-Free methods!
When to Use Model-Based
Methods
Suitable scenarios: - High real interaction cost
(robotics, autonomous driving) - Relatively simple learnable environment
dynamics (physics simulators, board games) - Need quick adaptation to
new tasks (transfer learning, meta-learning) - Need interpretability and
safety guarantees
Unsuitable scenarios: - Highly stochastic
environments (like stock markets) - High-dimensional observations with
complex dynamics (like complex 3D games) - Abundant free simulators
available (like Atari, which is itself a simulator)
Dyna
Architecture: Integrating Learning and Planning
Historical Background
Dyna architecture was proposed by Richard Sutton in
1990 as the earliest systematic Model-Based RL framework. Core idea:
alternate between real experience learning and simulated
experience learning.
Dyna contains three core components:
Direct Learning: Collectfrom real environment,
update policy/value function (Model-Free)
Model Learning: Learn environment modelfrom
real experience
Planning: Use learned model to generate simulated
experience, further update policy/value function
Dyna-Q Algorithm
Dyna-Q is Dyna architecture instantiated with Q-Learning:
Algorithm flow:
Initialize Q-tableand
model
Each timestep:
(a) Real interaction: Execute actionin environment,
observe
(b) Direct learning: Update Q-value (standard
Q-Learning):
(c) Model learning: Update model:
(tabular model directly records)
(d) Planning: Repeattimes:
Randomly samplefrom
visited state-action pairs
Generate simulated experience with model:
Update Q-value with simulated experience (virtual Q-Learning):
Advantages: - Simple and intuitive, easy to
implement - Significantly improved sample efficiency — each real sample
can generatevirtual samples
through planning - Excellent in simple environments (like GridWorld)
Problem 1: Model Error Accumulation
When environment modelis
inaccurate, updating Q-values with simulated experience introduces bias.
As planning steps increase, errors accumulate exponentially, leading to
policy collapse.
Example: In FrozenLake, if model incorrectly
believes transition probability to trap is 0, agent will overestimate
state value during planning, leading to frequent failures in real
environment.
Problem 2: Computational Efficiency
After each real interaction, need to performplanning steps, computational overhead
istimes Model-Free methods. In
large state spaces, enumerating allpairs is impractical.
Problem 3: Distribution Shift
States sampled during planning may differ from states actually
visited by policy, wasting computation on unimportant regions.
MBPO:
Mitigating Model Error Through Short-Horizon Imagination
Motivation: The Curse of
Model Error
In deep RL, environment models are usually neural networks. The problem:
model prediction errors grow exponentially with planning
length.
Assuming single-step prediction mean squared error is, after planningsteps, cumulative error is
approximately(linear)
to(worst case).
Whenis large (like 50 steps), even
if single-step error is small, cumulative error makes planning
completely fail.
Model-Based Policy Optimization (MBPO) (Janner et
al., NeurIPS 2019) proposed a simple yet effective solution:
only plan in short-horizon imagination.
Core Idea
MBPO's key insight: - Long-term planning (H=50) easily affected by
model errors - Short-term planning (H=1-5) has controllable model errors
- Combining short-horizon model imagination with Model-Free optimization
achieves both sample efficiency and robustness
Algorithm flow:
Collect real data: Execute policyin environment, store experience in
real buffer
Sample initial statefrom - Rolloutsteps in model (usually 1-5):
Store virtual experiencein virtual buffer
Update policy: Train with SAC (or other Model-Free
algorithm) on mixed bufferKey detail:
Rollout lengthis determined
through theoretical analysis and experiments, typically. Janner et al. proved
existence of optimalthat
balances model bias and real data variance.
Mathematical
Analysis: Why Short-Horizon Works
Define difference between policy's true returnand model-predicted
return:whereis average model error,is total variation
distance.
This bound tells us: smaller discount factor(more short-sighted), smaller
impact of model errors. MBPO limits rollout length, equivalent to using smaller, thereby
reducing negative impact of model errors.
Hopper: Reaches 3000 points in 100K steps, 10x
sample efficiency of SAC
Humanoid: First to learn standing and walking
within 100K steps
Key findings: 1. Rollout lengthoptimal for most tasks (shortest
horizon best!) 2. Ensemble model (5 networks averaged) significantly
improves robustness 3. Virtual-to-real data ratio of 4:1 is optimal
MBPO performs excellently in low-dimensional state spaces (like
MuJoCo joint angles), but directly modelingin high-dimensional
observations (like Atari'spixels) is nearly impossible — pixel-level prediction
requires learning RGB values for each pixel, enormous computational and
storage overhead, and models easily overfit.
World Models (Ha & Schmidhuber, NeurIPS 2018)
proposed an elegant solution: don't model in raw pixel space,
but learn dynamics in compressed latent space.
Predict next state in latent space:whereis RNN hidden
state,is parameterized by Mixture
Density Network (MDN-RNN), outputting Gaussian mixture distribution
for.
3. Controller Module (C): Policy Network
Select action from latent state:Controller is a simple linear policy or small
MLP.
Training Procedure
Stage 1: Data Collection
Collect trajectoriesin environment with random policy.
Stage 2: Train V Module
Train VAE with all observations, learn compressed representation.
Stage 3: Train M Module
Train MDN-RNN with encoded trajectories, learn latent
dynamics:
Stage 4: Train C Module
Train policy in "dream"— completely rollout in learned model, no longer interact with real
environment:where trajectoryis
generated by model. Controller is
optimized using evolution strategies (CMA-ES) or reinforcement learning
algorithms.
# ========== Training Procedure ========== deftrain_world_models(env, episodes=1000): """Train World Models""" # 1. Collect data print("Collecting data...") trajectories = collect_random_trajectories(env, episodes) # 2. Train VAE print("Training VAE...") vae = VAE() optimizer_vae = torch.optim.Adam(vae.parameters(), lr=1e-4) for epoch inrange(50): total_loss = 0 for obs_batch in get_obs_batches(trajectories): recon, mu, logvar = vae(obs_batch) loss = vae.loss_function(recon, obs_batch, mu, logvar) optimizer_vae.zero_grad() loss.backward() optimizer_vae.step() total_loss += loss.item() print(f"VAE Epoch {epoch}, Loss: {total_loss:.2f}") # 3. Encode all observations to latent space print("Encoding observations...") latent_trajectories = encode_trajectories(vae, trajectories) # 4. Train MDN-RNN print("Training MDN-RNN...") mdn_rnn = MDNRNN(latent_dim=32, action_dim=env.action_space.shape[0]) optimizer_rnn = torch.optim.Adam(mdn_rnn.parameters(), lr=1e-3) for epoch inrange(50): total_loss = 0 for z_seq, a_seq in get_latent_batches(latent_trajectories): mu, sigma, pi, _ = mdn_rnn(z_seq[:, :-1], a_seq[:, :-1]) # Negative log-likelihood loss loss = mdn_nll_loss(z_seq[:, 1:], mu, sigma, pi) optimizer_rnn.zero_grad() loss.backward() optimizer_rnn.step() total_loss += loss.item() print(f"RNN Epoch {epoch}, Loss: {total_loss:.2f}") # 5. Train Controller in dream print("Training Controller in dream...") controller = Controller(latent_dim=32, hidden_dim=256, action_dim=env.action_space.shape[0]) # Optimize controller with CMA-ES or PPO (omitted) return vae, mdn_rnn, controller
Experimental Results
Ha & Schmidhuber demonstrated World Models' powerful capabilities
on CarRacing game:
Controller has only 867 parameters (vs DQN's 1.7M)
Trained for 10000 virtual episodes in model, completely without real
environment interaction
Reaches 900+ points (near human level)
Visualization shows model learned to predict tracks, background and
vehicle behavior
More remarkably: even in environments with visual glitches,
agents can complete tasks relying on internal models—
demonstrating world model robustness.
One limitation of World Models is staged training:
first train VAE, then RNN, finally Controller, each stage optimized
independently. This causes: - VAE-learned representation may not be
suitable for planning - RNN prediction errors can't backpropagate to VAE
to improve representation - Need manual tuning of hyperparameters for
each stage
Dream to Control (Dreamer) (Hafner et al., ICLR
2020) proposed end-to-end training of world models: all
components (representation, dynamics, policy) jointly optimized,
learning policy directly in latent space "dreams".
1. Representation Learning LossThis loss
simultaneously trains: - Observation reconstruction (first term) -
Reward prediction (second term) - Transition model and representation
consistency (KL term)
2. Behavior Learning Loss (optimize policy in
dream)
Rolloutsteps in latent space,
compute return in dream:whereis latent state,predicted by reward model. Policymaximizes value:
3. Value Learning Loss
Train value functionto fit
dream returns:
Complete Algorithm Flow
1 2 3 4 5 6 7 8 9 10 11 12
1. Initialize all module parameters 2. Collect initial data to replay buffer 3. Loop: a. Sample sequence data from buffer b. Encode latent states (h, z) through representation model c. Update world model (representation learning loss) d. Rollout H steps in dream: - Predict future latent states with transition model - Predict rewards with reward model - Compute returns e. Update policy (actor loss) and value function (critic loss) f. Execute policy in real environment, collect new data
Subsequent versions DreamerV2 (2021) and
DreamerV3 (2023) further improved, DreamerV3 achieved
breakthroughs in open-world games like Minecraft, demonstrating world
models' enormous potential.
Previous methods (World Models, Dreamer) all tried to learn models
that reconstruct observations — i.e., predict. But for
decision-making, do we really need perfect reconstruction of every
pixel?
DeepMind proposed: only need to predict decision-relevant
information (like value, policy, reward), not complete
observations.
MuZero (Schrittwieser et al., Nature 2020)
implemented this idea: learn an implicit model that doesn't predict
observations, only predicts value, policy and reward — all information
needed for MCTS planning.
Core Architecture
MuZero contains three networks:
1. Representation Function
Encode observation to hidden state:Note:is not physical state, but abstract
"planning state".
2. Dynamics Function
Predict next hidden state and reward:Similarly,is
abstract, doesn't correspond to any physical meaning.
3. Prediction Function
Predict policy and value:whereis action
probability,is state value.
Training Objective
MuZero plans in hidden state space through MCTS, then trains with
self-play data. Loss function contains three terms:where:
-is value target from MCTS
search -is true reward -is MCTS-improved policy
Key insight: Model doesn't need to predict
observations, only needs to predicttriplet needed for MCTS. This greatly reduces model learning
difficulty.
Comparison with AlphaZero
Feature
AlphaZero
MuZero
Environment Model
Needs perfect simulator (like Go rules)
Learns implicit model from data
State Representation
Explicit physical state
Abstract hidden state
Applicability
Only rule-known games
Any environment (including Atari)
Sample Efficiency
Unlimited simulation
Needs real interaction
Experimental Results
MuZero achieved breakthroughs across multiple domains:
Go: Comparable to AlphaZero (ELO ~5000)
Chess: Slightly below Stockfish but far exceeds
humans
Atari 57 games: Average score exceeds R2D2, reaches
human level on half the games
DMControl: Surpasses Dreamer in 100K steps
Most remarkably: same algorithm, same hyperparameters,
achieves SOTA on completely different task types (board games, Atari,
continuous control)— demonstrating generality.
A: Assuming single-step prediction error is, after predictingsteps, error becomesto(depending on whether
errors are correlated). This is because each step's prediction depends
on previous state, errors propagate layer by layer. Imagine on a map
each step deviates 1 degree, after 10 steps you might go completely
wrong direction.
Q2: How to evaluate model quality?
A: Can't just look at reconstruction error (MSE)! More important:
do model-predicted trajectories lead to good decisions.
Can use: - Policy evaluation difference: - Planning quality: Performance of
model-planned policy in real environment - Dynamics
consistency: Whether long-term trajectory distributions
match
Q3: Can Model-Based be used in partially observable
environments?
A: Yes! Dreamer is an example — it uses RNN hidden stateto encode history, plans in latent
space. In POMDPs, Model-Based actually has advantages because model can
learn to "remember" important historical information.
Q4: How to mitigate model errors?
A: Several methods: 1. Short-horizon planning
(MBPO): Only rollout 1-5 steps 2. Ensemble models:
Train multiple models, average or randomly sample 3. Uncertainty
estimation: Use Bayesian neural networks or Dropout to estimate
prediction uncertainty, avoid planning in uncertain regions 4.
Model regularization: Penalize overconfident
predictions
Q5: What's the difference between Dyna, MBPO,
Dreamer?
A: - Dyna: Tabular, model perfectly records
transitions, suitable for small environments - MBPO:
Learns neural network model of state space, short-horizon rollout,
combines SAC - Dreamer: Learns latent space model,
long-term rollout, end-to-end training
Common point: all follow "use model to generate data, learn policy
with Model-Free algorithm" framework.
Q6: Why doesn't MuZero reconstruct observations?
A: Reconstructing high-dimensional observations (pixels) is hard, and
many details are irrelevant to decisions (like background color). MuZero
only predicts decision-needed information, greatly reducing learning
difficulty. Analogy: playing Go, you don't need to predict opponent's
clothing color, only where they'll move.
Q7: Can Model-Based transfer to new tasks?
A: Theoretically yes, but practically very hard. If new task has
different dynamics (like transferring from walking to running), model
needs retraining or fine-tuning. Meta-learning and conditional models
are active research directions.
Q8: On which tasks is Model-Free better?
A: - Environment has cheap simulator (like Atari, Go) - Dynamics
extremely complex (like social interaction, stock market) - Need
absolute optimal performance (Model-Based model errors may limit
ceiling)
Q9: What is Dreamer's "dream"?
A: Dream refers to imagined trajectories rolled out in learned latent
space model. Agent doesn't interact with real environment, instead: 1.
Start from current latent state$s_ta_ts_{t+1}, r_t$4. Repeat 15 steps, collect
virtual experience 5. Use this experience to update policy
Like human daydreaming: simulating possible futures mentally,
learning from them without actually experiencing.
Q10: Future research directions for Model-Based
RL?
A: - Composable world models: Learn reusable modules
(like "objects", "physics laws"), compose to new scenarios -
Causal world models: Not just predict "what will
happen", but understand "why it happens" - Lifelong
learning: Continuously update model, adapt to dynamically
changing environments - Safe planning: Verify policy
safety in model before execution
Recommended Papers
Dyna Architecture
Integrated Architectures for Learning, Planning, and Reacting
Based on Approximating Dynamic Programming (Sutton, 1990)
Classic textbook: Reinforcement Learning: An Introduction
Chapter 8
MBPO
When to Trust Your Model: Model-Based Policy Optimization
(Janner et al., NeurIPS 2019)
Model-Based reinforcement learning achieves "think before doing"
intelligent decision-making through learning environment models,
dramatically improving sample efficiency. From classic Dyna architecture
to deep learning era's MBPO, World Models, Dreamer and MuZero, each
improvement stems from deep understanding of model error, representation
learning and planning efficiency. Model-Based methods demonstrate
enormous potential in sample-limited real applications (like robotics,
autonomous driving), but model error accumulation, computational
overhead and generalization remain challenges.
Future research may integrate symbolic reasoning, causal modeling and
neural networks to build more interpretable, composable, transferable
world models. Just as humans understand physical laws, social rules and
others' intentions through internal models, next-generation AI systems
may also need similar "world understanding" capabilities — this is not
only a core problem in reinforcement learning, but also a key path
toward artificial general intelligence.
From Model-Based planning perspective, the next chapter will enter
AlphaGo and Monte Carlo Tree Search — through combining deep learning
with classic search algorithms, revealing how to achieve superhuman
level in perfect information games, and exploring deep connections
between search and learning.
Post title:Reinforcement Learning (5): Model-Based RL and World Models
Post author:Chen Kai
Create time:2024-08-30 09:15:00
Post link:https://www.chenk.top/reinforcement-learning-5-model-based-rl-and-world-models/
Copyright Notice:All articles in this blog are licensed under BY-NC-SA unless stating additionally.