Reinforcement Learning (5): Model-Based RL and World Models
Chen Kai BOSS

If Model-Free methods from previous chapters are "learn by doing"— directly optimizing policies or value functions through extensive trial-and-error, then Model-Based methods are "think before doing"— learning environment dynamics modelsto plan futures in imagination, dramatically improving sample efficiency. Human and animal intelligence heavily relies on internal world models: chess grandmasters simulate dozens of future moves mentally, infants predict object trajectories through physical intuition. DeepMind's AlphaGo plans through Monte Carlo Tree Search in simulated games, OpenAI's Dota 2 agents use environment simulators to "rehearse" strategies during training. The core advantage of Model-Based RL is sample efficiency — in scenarios where real environment interaction is expensive (like robot control, autonomous driving), generating virtual experiences through learned models can achieve Model-Free performance with 1/10 or even 1/100 of the samples. From classic Dyna architecture to World Models combining deep learning, from MuZero's implicit planning to Dreamer series learning in latent dream spaces, Model-Based RL has demonstrated enormous potential in sample efficiency, generalization, and interpretability. This chapter systematically traces this evolution, deeply analyzing the design motivation, mathematical principles, and implementation details of each algorithm.

Model-Based vs Model-Free: Two Paths Compared

Core Differences

Model-Free RL: Directly learns policyor value functionfrom experience, without explicitly modeling environment dynamics.

Advantages: - Simple algorithms, easy to implement - No assumptions about environment needed - Performs well in tasks with complex or unknown models

Disadvantages: - Low sample efficiency — requires millions of interactions to learn well - Hard to generalize to new tasks — policy only valid for training environment - Lacks interpretability — can't understand why agent behaves certain way

Model-Based RL: Explicitly learns environment modeland reward function, uses model for planning or generating virtual data.

Advantages: - High sample efficiency — can learn in imagination, reducing real interactions - Strong generalization — dynamics knowledge in model can transfer to new tasks - Good interpretability — can visualize agent's understanding of world

Disadvantages: - Model error accumulation — prediction errors grow exponentially with planning steps - Computationally expensive — need to train additional model networks, planning also requires computation - Hard to model complex environments — dynamics models for high-dimensional observations (like pixels) are difficult to learn

Sample Efficiency Comparison

In classic control tasks, Model-Based methods show clear sample advantages:

Algorithm Type Environment Samples to Expert Level
DQN (Model-Free) Atari Pong 10M frames
PPO (Model-Free) MuJoCo HalfCheetah 1M steps
Dreamer (Model-Based) DMControl Walker 100K steps
MBPO (Model-Based) MuJoCo HalfCheetah 100K steps

Model-Based methods typically need only 1/10 the samples of Model-Free methods!

When to Use Model-Based Methods

Suitable scenarios: - High real interaction cost (robotics, autonomous driving) - Relatively simple learnable environment dynamics (physics simulators, board games) - Need quick adaptation to new tasks (transfer learning, meta-learning) - Need interpretability and safety guarantees

Unsuitable scenarios: - Highly stochastic environments (like stock markets) - High-dimensional observations with complex dynamics (like complex 3D games) - Abundant free simulators available (like Atari, which is itself a simulator)

Dyna Architecture: Integrating Learning and Planning

Historical Background

Dyna architecture was proposed by Richard Sutton in 1990 as the earliest systematic Model-Based RL framework. Core idea: alternate between real experience learning and simulated experience learning.

Dyna contains three core components:

  1. Direct Learning: Collectfrom real environment, update policy/value function (Model-Free)
  2. Model Learning: Learn environment modelfrom real experience
  3. Planning: Use learned model to generate simulated experience, further update policy/value function

Dyna-Q Algorithm

Dyna-Q is Dyna architecture instantiated with Q-Learning:

Algorithm flow:

  1. Initialize Q-tableand model
  2. Each timestep:
    • (a) Real interaction: Execute actionin environment, observe
    • (b) Direct learning: Update Q-value (standard Q-Learning):
    • (c) Model learning: Update model: (tabular model directly records)
    • (d) Planning: Repeattimes:
      • Randomly samplefrom visited state-action pairs
      • Generate simulated experience with model:
      • Update Q-value with simulated experience (virtual Q-Learning):

Code Implementation (Tabular Dyna-Q)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import numpy as np
import gym

class DynaQ:
"""Dyna-Q algorithm (tabular)"""
def __init__(self, n_states, n_actions, alpha=0.1, gamma=0.95, epsilon=0.1, planning_steps=5):
self.n_states = n_states
self.n_actions = n_actions
self.alpha = alpha
self.gamma = gamma
self.epsilon = epsilon
self.planning_steps = planning_steps

# Q-table
self.Q = np.zeros((n_states, n_actions))

# Model: stores (r, s') for each (s,a)
self.model = {}

# Visited (s,a) pairs for planning sampling
self.visited = set()

def select_action(self, state):
"""Epsilon-greedy action selection"""
if np.random.rand() < self.epsilon:
return np.random.randint(self.n_actions)
return np.argmax(self.Q[state])

def update_q(self, s, a, r, s_next):
"""Q-Learning update"""
target = r + self.gamma * np.max(self.Q[s_next])
self.Q[s, a] += self.alpha * (target - self.Q[s, a])

def learn(self, state, action, reward, next_state):
"""Dyna-Q learning step"""
# (a) Direct learning
self.update_q(state, action, reward, next_state)

# (b) Model learning
self.model[(state, action)] = (reward, next_state)
self.visited.add((state, action))

# (c) Planning (generate simulated experience with model)
for _ in range(self.planning_steps):
if not self.visited:
break
# Randomly sample visited (s,a)
s, a = list(self.visited)[np.random.randint(len(self.visited))]
r, s_next = self.model[(s, a)]
# Update Q-value with simulated experience
self.update_q(s, a, r, s_next)

def train_dyna_q(env_name='FrozenLake-v1', episodes=500):
"""Train Dyna-Q"""
env = gym.make(env_name, is_slippery=False)
n_states = env.observation_space.n
n_actions = env.action_space.n

agent = DynaQ(n_states, n_actions, planning_steps=10)

rewards_history = []
for episode in range(episodes):
state = env.reset()
total_reward = 0
done = False

while not done:
action = agent.select_action(state)
next_state, reward, done, _ = env.step(action)
agent.learn(state, action, reward, next_state)

state = next_state
total_reward += reward

rewards_history.append(total_reward)

if (episode + 1) % 50 == 0:
avg_reward = np.mean(rewards_history[-50:])
print(f"Episode {episode+1}, Avg Reward: {avg_reward:.2f}")

return agent

# Run example
agent = train_dyna_q()

Advantages and Problems

Advantages: - Simple and intuitive, easy to implement - Significantly improved sample efficiency — each real sample can generatevirtual samples through planning - Excellent in simple environments (like GridWorld)

Problem 1: Model Error Accumulation

When environment modelis inaccurate, updating Q-values with simulated experience introduces bias. As planning steps increase, errors accumulate exponentially, leading to policy collapse.

Example: In FrozenLake, if model incorrectly believes transition probability to trap is 0, agent will overestimate state value during planning, leading to frequent failures in real environment.

Problem 2: Computational Efficiency

After each real interaction, need to performplanning steps, computational overhead istimes Model-Free methods. In large state spaces, enumerating allpairs is impractical.

Problem 3: Distribution Shift

States sampled during planning may differ from states actually visited by policy, wasting computation on unimportant regions.

MBPO: Mitigating Model Error Through Short-Horizon Imagination

Motivation: The Curse of Model Error

In deep RL, environment models are usually neural networks. The problem: model prediction errors grow exponentially with planning length.

Assuming single-step prediction mean squared error is, after planningsteps, cumulative error is approximately(linear) to(worst case). Whenis large (like 50 steps), even if single-step error is small, cumulative error makes planning completely fail.

Model-Based Policy Optimization (MBPO) (Janner et al., NeurIPS 2019) proposed a simple yet effective solution: only plan in short-horizon imagination.

Core Idea

MBPO's key insight: - Long-term planning (H=50) easily affected by model errors - Short-term planning (H=1-5) has controllable model errors - Combining short-horizon model imagination with Model-Free optimization achieves both sample efficiency and robustness

Algorithm flow:

  1. Collect real data: Execute policyin environment, store experience in real buffer

  2. Train dynamics model: Train modeland reward functionusing

  3. Generate virtual data:

    • Sample initial statefrom - Rolloutsteps in model (usually 1-5):
  • Store virtual experiencein virtual buffer
  1. Update policy: Train with SAC (or other Model-Free algorithm) on mixed buffer Key detail: Rollout lengthis determined through theoretical analysis and experiments, typically. Janner et al. proved existence of optimalthat balances model bias and real data variance.

Mathematical Analysis: Why Short-Horizon Works

Define difference between policy's true returnand model-predicted return:whereis average model error,is total variation distance.

This bound tells us: smaller discount factor(more short-sighted), smaller impact of model errors. MBPO limits rollout length, equivalent to using smaller, thereby reducing negative impact of model errors.

Complete Code Implementation

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from collections import deque

# ========== Dynamics Model ==========
class EnsembleDynamicsModel(nn.Module):
"""Ensemble dynamics model (predicts next state and reward)"""
def __init__(self, state_dim, action_dim, hidden_dim=256, n_models=5):
super().__init__()
self.n_models = n_models
self.models = nn.ModuleList([
nn.Sequential(
nn.Linear(state_dim + action_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, state_dim + 1) # Predict delta_state and reward
) for _ in range(n_models)
])

def forward(self, state, action):
"""
Return predictions from each model
state: (B, state_dim)
action: (B, action_dim)
Returns: (n_models, B, state_dim+1)
"""
x = torch.cat([state, action], dim=-1)
preds = torch.stack([model(x) for model in self.models])
return preds

def predict(self, state, action):
"""Randomly select one model for prediction (for rollout)"""
with torch.no_grad():
preds = self.forward(state, action)
# Randomly select one model's prediction
idx = np.random.randint(self.n_models)
pred = preds[idx]

delta_state = pred[:, :-1]
reward = pred[:, -1:]

next_state = state + delta_state # Predict state change
return next_state, reward

def train_model(self, replay_buffer, batch_size=256, epochs=5):
"""Train dynamics model"""
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)

for epoch in range(epochs):
batch = replay_buffer.sample(batch_size)
states, actions, rewards, next_states, _ = batch

# Compute true delta_state
delta_states = next_states - states
targets = torch.cat([delta_states, rewards.unsqueeze(1)], dim=-1)

# Predictions from all models
preds = self.forward(states, actions)

# Maximum likelihood loss (each model independent)
losses = []
for i in range(self.n_models):
loss = F.mse_loss(preds[i], targets)
losses.append(loss)

total_loss = sum(losses)

optimizer.zero_grad()
total_loss.backward()
optimizer.step()

return total_loss.item()

# ========== MBPO Algorithm ==========
class MBPO:
"""Model-Based Policy Optimization"""
def __init__(self, env, state_dim, action_dim, rollout_length=5):
self.env = env
self.state_dim = state_dim
self.action_dim = action_dim
self.rollout_length = rollout_length

# Real environment buffer
self.env_buffer = ReplayBuffer(state_dim, action_dim, max_size=1_000_000)

# Virtual environment buffer
self.model_buffer = ReplayBuffer(state_dim, action_dim, max_size=1_000_000)

# Dynamics model
self.dynamics_model = EnsembleDynamicsModel(state_dim, action_dim)

# Policy (using SAC here, can use other Model-Free algorithms)
self.policy = SAC(state_dim, action_dim)

def collect_real_data(self, n_steps=1000):
"""Collect real environment data"""
state = self.env.reset()
for _ in range(n_steps):
action = self.policy.select_action(state)
next_state, reward, done, _ = self.env.step(action)
self.env_buffer.add(state, action, reward, next_state, done)

state = next_state if not done else self.env.reset()

def generate_virtual_data(self, n_rollouts=1000):
"""Generate virtual data using model"""
self.model_buffer.clear()

for _ in range(n_rollouts):
# Sample initial state from real buffer
states, _, _, _, _ = self.env_buffer.sample(1)
state = states[0]

# Rollout k steps in model
for _ in range(self.rollout_length):
action = self.policy.select_action(state.cpu().numpy())
action_tensor = torch.FloatTensor(action).unsqueeze(0)

# Predict with model
next_state, reward = self.dynamics_model.predict(state.unsqueeze(0), action_tensor)

# Store in virtual buffer
self.model_buffer.add(
state.cpu().numpy(),
action,
reward.item(),
next_state.squeeze(0).cpu().numpy(),
False # Virtual data doesn't set done
)

state = next_state.squeeze(0)

def train(self, total_steps=100_000):
"""MBPO training loop"""
# Initial random data collection
self.collect_real_data(n_steps=5000)

for step in range(total_steps):
# 1. Collect small amount of real data
self.collect_real_data(n_steps=100)

# 2. Train dynamics model
if step % 250 == 0:
model_loss = self.dynamics_model.train_model(self.env_buffer)
print(f"Step {step}, Model Loss: {model_loss:.4f}")

# 3. Generate virtual data
self.generate_virtual_data(n_rollouts=1000)

# 4. Train policy with mixture (real + virtual)
for _ in range(20):
# Sample from both buffers (can adjust ratio)
real_batch = self.env_buffer.sample(128)
model_batch = self.model_buffer.sample(128)

# Mixed batch (simplified here as separate training)
self.policy.train_step(real_batch)
self.policy.train_step(model_batch)

# 5. Evaluation
if step % 1000 == 0:
eval_reward = self.evaluate()
print(f"Step {step}, Eval Reward: {eval_reward:.2f}")

def evaluate(self, n_episodes=10):
"""Evaluate policy"""
total_rewards = []
for _ in range(n_episodes):
state = self.env.reset()
episode_reward = 0
done = False
while not done:
action = self.policy.select_action(state, deterministic=True)
state, reward, done, _ = self.env.step(action)
episode_reward += reward
total_rewards.append(episode_reward)
return np.mean(total_rewards)

Experimental Results

Janner et al.'s paper shows MBPO achieved breakthroughs on MuJoCo continuous control tasks:

  • HalfCheetah: Reaches 10000 points in 100K steps (SAC needs 1M steps)
  • Hopper: Reaches 3000 points in 100K steps, 10x sample efficiency of SAC
  • Humanoid: First to learn standing and walking within 100K steps

Key findings: 1. Rollout lengthoptimal for most tasks (shortest horizon best!) 2. Ensemble model (5 networks averaged) significantly improves robustness 3. Virtual-to-real data ratio of 4:1 is optimal

Paper link: arXiv:1906.08253

World Models: Dreaming in Compressed Latent Space

Motivation: Pixel Space Model Learning Dilemma

MBPO performs excellently in low-dimensional state spaces (like MuJoCo joint angles), but directly modelingin high-dimensional observations (like Atari'spixels) is nearly impossible — pixel-level prediction requires learning RGB values for each pixel, enormous computational and storage overhead, and models easily overfit.

World Models (Ha & Schmidhuber, NeurIPS 2018) proposed an elegant solution: don't model in raw pixel space, but learn dynamics in compressed latent space.

Core Architecture

World Models contains three core components:

1. Vision Module (V): Variational Autoencoder (VAE)

Compress high-dimensional observation(pixels) to low-dimensional latent encoding:VAE's objective is to maximize ELBO:

2. Memory Module (M): Recurrent Neural Network (RNN/LSTM)

Predict next state in latent space:whereis RNN hidden state,is parameterized by Mixture Density Network (MDN-RNN), outputting Gaussian mixture distribution for.

3. Controller Module (C): Policy Network

Select action from latent state:Controller is a simple linear policy or small MLP.

Training Procedure

Stage 1: Data Collection

Collect trajectoriesin environment with random policy.

Stage 2: Train V Module

Train VAE with all observations, learn compressed representation.

Stage 3: Train M Module

Train MDN-RNN with encoded trajectoriesExtra close brace or missing open brace\{(z_t, a_t, z_{t+1})}, learn latent dynamics:

Stage 4: Train C Module

Train policy in "dream"— completely rollout in learned model, no longer interact with real environment:where trajectoryis generated by model. Controller is optimized using evolution strategies (CMA-ES) or reinforcement learning algorithms.

Code Implementation (Simplified)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import torch
import torch.nn as nn
import torch.nn.functional as F

# ========== VAE Encoder ==========
class VAE(nn.Module):
"""Variational Autoencoder"""
def __init__(self, obs_channels=3, latent_dim=32):
super().__init__()
# Encoder
self.encoder = nn.Sequential(
nn.Conv2d(obs_channels, 32, 4, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(32, 64, 4, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(64, 128, 4, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(128, 256, 4, stride=2, padding=1),
nn.ReLU(),
nn.Flatten()
)

# Latent distribution parameters
self.fc_mu = nn.Linear(256 * 4 * 4, latent_dim)
self.fc_logvar = nn.Linear(256 * 4 * 4, latent_dim)

# Decoder
self.fc_decode = nn.Linear(latent_dim, 256 * 4 * 4)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1),
nn.ReLU(),
nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
nn.ReLU(),
nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),
nn.ReLU(),
nn.ConvTranspose2d(32, obs_channels, 4, stride=2, padding=1),
nn.Sigmoid()
)

def encode(self, x):
"""Encode to latent distribution"""
h = self.encoder(x)
mu = self.fc_mu(h)
logvar = self.fc_logvar(h)
return mu, logvar

def reparameterize(self, mu, logvar):
"""Reparameterization trick"""
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std

def decode(self, z):
"""Decode to image"""
h = self.fc_decode(z)
h = h.view(-1, 256, 4, 4)
return self.decoder(h)

def forward(self, x):
mu, logvar = self.encode(x)
z = self.reparameterize(mu, logvar)
recon = self.decode(z)
return recon, mu, logvar

def loss_function(self, recon, x, mu, logvar):
"""VAE loss (reconstruction + KL divergence)"""
recon_loss = F.mse_loss(recon, x, reduction='sum')
kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return recon_loss + kl_loss

# ========== MDN-RNN ==========
class MDNRNN(nn.Module):
"""Mixture Density Network RNN"""
def __init__(self, latent_dim, action_dim, hidden_dim=256, n_gaussians=5):
super().__init__()
self.hidden_dim = hidden_dim
self.n_gaussians = n_gaussians

# LSTM
self.lstm = nn.LSTM(latent_dim + action_dim, hidden_dim, batch_first=True)

# MDN parameter outputs
self.mdn_mu = nn.Linear(hidden_dim, latent_dim * n_gaussians)
self.mdn_sigma = nn.Linear(hidden_dim, latent_dim * n_gaussians)
self.mdn_pi = nn.Linear(hidden_dim, n_gaussians)

def forward(self, z, a, h=None):
"""
Predict distribution of next latent state
z: (B, T, latent_dim)
a: (B, T, action_dim)
h: hidden state (optional)
"""
x = torch.cat([z, a], dim=-1)
lstm_out, h_next = self.lstm(x, h)

# MDN parameters
mu = self.mdn_mu(lstm_out)
sigma = torch.exp(self.mdn_sigma(lstm_out))
pi = F.softmax(self.mdn_pi(lstm_out), dim=-1)

return mu, sigma, pi, h_next

def sample(self, mu, sigma, pi):
"""Sample from Gaussian mixture"""
# Select Gaussian component
k = torch.multinomial(pi, 1).squeeze(-1)

# Sample from selected Gaussian
mu_k = mu.gather(-1, k.unsqueeze(-1).expand(-1, -1, mu.size(-1) // self.n_gaussians))
sigma_k = sigma.gather(-1, k.unsqueeze(-1).expand(-1, -1, sigma.size(-1) // self.n_gaussians))

z_next = mu_k + sigma_k * torch.randn_like(mu_k)
return z_next

# ========== Controller ==========
class Controller(nn.Module):
"""Simple linear policy"""
def __init__(self, latent_dim, hidden_dim, action_dim):
super().__init__()
self.fc = nn.Linear(latent_dim + hidden_dim, action_dim)

def forward(self, z, h):
"""Output action"""
x = torch.cat([z, h], dim=-1)
return torch.tanh(self.fc(x))

# ========== Training Procedure ==========
def train_world_models(env, episodes=1000):
"""Train World Models"""
# 1. Collect data
print("Collecting data...")
trajectories = collect_random_trajectories(env, episodes)

# 2. Train VAE
print("Training VAE...")
vae = VAE()
optimizer_vae = torch.optim.Adam(vae.parameters(), lr=1e-4)

for epoch in range(50):
total_loss = 0
for obs_batch in get_obs_batches(trajectories):
recon, mu, logvar = vae(obs_batch)
loss = vae.loss_function(recon, obs_batch, mu, logvar)

optimizer_vae.zero_grad()
loss.backward()
optimizer_vae.step()

total_loss += loss.item()
print(f"VAE Epoch {epoch}, Loss: {total_loss:.2f}")

# 3. Encode all observations to latent space
print("Encoding observations...")
latent_trajectories = encode_trajectories(vae, trajectories)

# 4. Train MDN-RNN
print("Training MDN-RNN...")
mdn_rnn = MDNRNN(latent_dim=32, action_dim=env.action_space.shape[0])
optimizer_rnn = torch.optim.Adam(mdn_rnn.parameters(), lr=1e-3)

for epoch in range(50):
total_loss = 0
for z_seq, a_seq in get_latent_batches(latent_trajectories):
mu, sigma, pi, _ = mdn_rnn(z_seq[:, :-1], a_seq[:, :-1])
# Negative log-likelihood loss
loss = mdn_nll_loss(z_seq[:, 1:], mu, sigma, pi)

optimizer_rnn.zero_grad()
loss.backward()
optimizer_rnn.step()

total_loss += loss.item()
print(f"RNN Epoch {epoch}, Loss: {total_loss:.2f}")

# 5. Train Controller in dream
print("Training Controller in dream...")
controller = Controller(latent_dim=32, hidden_dim=256, action_dim=env.action_space.shape[0])
# Optimize controller with CMA-ES or PPO (omitted)

return vae, mdn_rnn, controller

Experimental Results

Ha & Schmidhuber demonstrated World Models' powerful capabilities on CarRacing game:

  • Controller has only 867 parameters (vs DQN's 1.7M)
  • Trained for 10000 virtual episodes in model, completely without real environment interaction
  • Reaches 900+ points (near human level)
  • Visualization shows model learned to predict tracks, background and vehicle behavior

More remarkably: even in environments with visual glitches, agents can complete tasks relying on internal models— demonstrating world model robustness.

Paper link: arXiv:1803.10122

Dreamer: End-to-End Learning in Latent Space

From World Models to Dreamer

One limitation of World Models is staged training: first train VAE, then RNN, finally Controller, each stage optimized independently. This causes: - VAE-learned representation may not be suitable for planning - RNN prediction errors can't backpropagate to VAE to improve representation - Need manual tuning of hyperparameters for each stage

Dream to Control (Dreamer) (Hafner et al., ICLR 2020) proposed end-to-end training of world models: all components (representation, dynamics, policy) jointly optimized, learning policy directly in latent space "dreams".

Dreamer Architecture

Dreamer contains four core components:

1. Representation Model

Infer latent state from observation and history:$

$ whereis deterministic recurrent state (RNN hidden state),is stochastic latent variable (like VAE).

2. Transition Model

Predict future in latent space:$

3. Observation Model

Decode latent state to observation:

4. Reward Model

Predict reward:

Training Objectives

Dreamer jointly optimizes three objectives:

1. Representation Learning LossThis loss simultaneously trains: - Observation reconstruction (first term) - Reward prediction (second term) - Transition model and representation consistency (KL term)

2. Behavior Learning Loss (optimize policy in dream)

Rolloutsteps in latent space, compute return in dream:whereis latent state,predicted by reward model. Policymaximizes value:

3. Value Learning Loss

Train value functionto fit dream returns:

Complete Algorithm Flow

1
2
3
4
5
6
7
8
9
10
11
12
1. Initialize all module parameters
2. Collect initial data to replay buffer
3. Loop:
a. Sample sequence data from buffer
b. Encode latent states (h, z) through representation model
c. Update world model (representation learning loss)
d. Rollout H steps in dream:
- Predict future latent states with transition model
- Predict rewards with reward model
- Compute returns
e. Update policy (actor loss) and value function (critic loss)
f. Execute policy in real environment, collect new data

Code Framework (Simplified)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import torch
import torch.nn as nn

class Dreamer:
"""Dreamer algorithm"""
def __init__(self, obs_shape, action_dim, hidden_dim=200, latent_dim=30):
# Representation model
self.encoder = Encoder(obs_shape, latent_dim)
self.rnn = nn.GRUCell(latent_dim + action_dim, hidden_dim)
self.posterior = nn.Linear(hidden_dim, latent_dim * 2) # Output mu, std

# Transition model
self.prior = nn.Linear(hidden_dim, latent_dim * 2)

# Observation and reward models
self.decoder = Decoder(latent_dim + hidden_dim, obs_shape)
self.reward_model = nn.Linear(latent_dim + hidden_dim, 1)

# Policy and value
self.actor = nn.Linear(latent_dim + hidden_dim, action_dim)
self.critic = nn.Linear(latent_dim + hidden_dim, 1)

def encode_obs(self, obs, h_prev, z_prev, action):
"""Representation learning: encode observation to latent state"""
# RNN update
h = self.rnn(torch.cat([z_prev, action], dim=-1), h_prev)

# Posterior distribution q(z|h,o)
obs_feat = self.encoder(obs)
post_params = self.posterior(torch.cat([h, obs_feat], dim=-1))
post_mu, post_std = torch.chunk(post_params, 2, dim=-1)
post_std = F.softplus(post_std)
z = post_mu + post_std * torch.randn_like(post_mu)

# Prior distribution p(z|h)
prior_params = self.prior(h)
prior_mu, prior_std = torch.chunk(prior_params, 2, dim=-1)
prior_std = F.softplus(prior_std)

return h, z, (post_mu, post_std, prior_mu, prior_std)

def imagine_trajectory(self, h, z, horizon=15):
"""Rollout trajectory in dream"""
states = []
actions = []
rewards = []

for _ in range(horizon):
# Policy selects action
state = torch.cat([h, z], dim=-1)
action = torch.tanh(self.actor(state))

# Transition to next state
h = self.rnn(torch.cat([z, action], dim=-1), h)
prior_params = self.prior(h)
prior_mu, prior_std = torch.chunk(prior_params, 2, dim=-1)
prior_std = F.softplus(prior_std)
z = prior_mu + prior_std * torch.randn_like(prior_mu)

# Predict reward
reward = self.reward_model(torch.cat([h, z], dim=-1))

states.append(state)
actions.append(action)
rewards.append(reward)

return states, actions, rewards

def update(self, batch):
"""Single update step"""
obs_seq, action_seq, reward_seq = batch
T = obs_seq.size(1)

# === Representation Learning ===
h = torch.zeros(obs_seq.size(0), 200)
z = torch.zeros(obs_seq.size(0), 30)

recon_loss = 0
reward_loss = 0
kl_loss = 0

for t in range(T):
h, z, (post_mu, post_std, prior_mu, prior_std) = self.encode_obs(
obs_seq[:, t], h, z, action_seq[:, t]
)

# Reconstruct observation
recon = self.decoder(torch.cat([h, z], dim=-1))
recon_loss += F.mse_loss(recon, obs_seq[:, t])

# Predict reward
pred_reward = self.reward_model(torch.cat([h, z], dim=-1))
reward_loss += F.mse_loss(pred_reward, reward_seq[:, t])

# KL divergence
kl = self.kl_divergence(post_mu, post_std, prior_mu, prior_std)
kl_loss += kl.mean()

world_model_loss = recon_loss + reward_loss + 0.1 * kl_loss

# === Behavior Learning (in dream) ===
with torch.no_grad():
h_start = h.detach()
z_start = z.detach()

states, actions, rewards = self.imagine_trajectory(h_start, z_start, horizon=15)

# Compute returns (lambda-return)
returns = self.compute_lambda_return(states, rewards, gamma=0.99, lam=0.95)

# Actor loss
values = torch.stack([self.critic(s) for s in states])
actor_loss = -(returns.detach() - values.detach()).mean()

# Critic loss
critic_loss = F.mse_loss(values.squeeze(), returns.detach())

# Total loss
total_loss = world_model_loss + actor_loss + critic_loss
return total_loss

def kl_divergence(self, mu1, std1, mu2, std2):
"""KL divergence between two Gaussians"""
return torch.log(std2 / std1) + (std1**2 + (mu1 - mu2)**2) / (2 * std2**2) - 0.5

def compute_lambda_return(self, states, rewards, gamma=0.99, lam=0.95):
"""Compute lambda-return (TD(lambda))"""
values = torch.stack([self.critic(s) for s in states]).squeeze()
rewards = torch.stack(rewards).squeeze()

returns = []
g = 0
for t in reversed(range(len(rewards))):
delta = rewards[t] + gamma * (values[t+1] if t+1 < len(values) else 0) - values[t]
g = delta + gamma * lam * g
returns.insert(0, g + values[t])

return torch.stack(returns)

Experimental Results

Dreamer achieved SOTA results on DMControl and Atari tasks:

  • DMControl Walker: Reaches 900 points in 100K steps (SAC needs 1M steps)
  • Atari Breakout: Reaches 400 points in 200K steps, 5x sample efficiency of Rainbow
  • Average across 20 tasks: Surpasses all baselines (including D4PG, SAC, PlaNet)

Key findings: 1. End-to-end training representation quality significantly better than staged training 2. 15-step dream rollout optimal (balances model error and long-term planning) 3. KL regularization crucial — prevents overfitting and mode collapse

Paper link: arXiv:1912.01603

Subsequent versions DreamerV2 (2021) and DreamerV3 (2023) further improved, DreamerV3 achieved breakthroughs in open-world games like Minecraft, demonstrating world models' enormous potential.

MuZero: Planning Without Explicit Models

Motivation: Why Perfect Reconstruction Unnecessary

Previous methods (World Models, Dreamer) all tried to learn models that reconstruct observations — i.e., predict. But for decision-making, do we really need perfect reconstruction of every pixel?

DeepMind proposed: only need to predict decision-relevant information (like value, policy, reward), not complete observations.

MuZero (Schrittwieser et al., Nature 2020) implemented this idea: learn an implicit model that doesn't predict observations, only predicts value, policy and reward — all information needed for MCTS planning.

Core Architecture

MuZero contains three networks:

1. Representation Function

Encode observation to hidden state:Note:is not physical state, but abstract "planning state".

2. Dynamics Function

Predict next hidden state and reward:Similarly,is abstract, doesn't correspond to any physical meaning.

3. Prediction Function

Predict policy and value:whereis action probability,is state value.

Training Objective

MuZero plans in hidden state space through MCTS, then trains with self-play data. Loss function contains three terms:where: -is value target from MCTS search -is true reward -is MCTS-improved policy

Key insight: Model doesn't need to predict observations, only needs to predicttriplet needed for MCTS. This greatly reduces model learning difficulty.

Comparison with AlphaZero

Feature AlphaZero MuZero
Environment Model Needs perfect simulator (like Go rules) Learns implicit model from data
State Representation Explicit physical state Abstract hidden state
Applicability Only rule-known games Any environment (including Atari)
Sample Efficiency Unlimited simulation Needs real interaction

Experimental Results

MuZero achieved breakthroughs across multiple domains:

  • Go: Comparable to AlphaZero (ELO ~5000)
  • Chess: Slightly below Stockfish but far exceeds humans
  • Atari 57 games: Average score exceeds R2D2, reaches human level on half the games
  • DMControl: Surpasses Dreamer in 100K steps

Most remarkably: same algorithm, same hyperparameters, achieves SOTA on completely different task types (board games, Atari, continuous control)— demonstrating generality.

Paper link: arXiv:1911.08265

Theoretical Analysis and Open Problems

Q&A: Common Model-Based RL Questions

Q1: Why do model errors accumulate?

A: Assuming single-step prediction error is, after predictingsteps, error becomesto(depending on whether errors are correlated). This is because each step's prediction depends on previous state, errors propagate layer by layer. Imagine on a map each step deviates 1 degree, after 10 steps you might go completely wrong direction.

Q2: How to evaluate model quality?

A: Can't just look at reconstruction error (MSE)! More important: do model-predicted trajectories lead to good decisions. Can use: - Policy evaluation difference: - Planning quality: Performance of model-planned policy in real environment - Dynamics consistency: Whether long-term trajectory distributions match

Q3: Can Model-Based be used in partially observable environments?

A: Yes! Dreamer is an example — it uses RNN hidden stateto encode history, plans in latent space. In POMDPs, Model-Based actually has advantages because model can learn to "remember" important historical information.

Q4: How to mitigate model errors?

A: Several methods: 1. Short-horizon planning (MBPO): Only rollout 1-5 steps 2. Ensemble models: Train multiple models, average or randomly sample 3. Uncertainty estimation: Use Bayesian neural networks or Dropout to estimate prediction uncertainty, avoid planning in uncertain regions 4. Model regularization: Penalize overconfident predictions

Q5: What's the difference between Dyna, MBPO, Dreamer?

A: - Dyna: Tabular, model perfectly records transitions, suitable for small environments - MBPO: Learns neural network model of state space, short-horizon rollout, combines SAC - Dreamer: Learns latent space model, long-term rollout, end-to-end training

Common point: all follow "use model to generate data, learn policy with Model-Free algorithm" framework.

Q6: Why doesn't MuZero reconstruct observations?

A: Reconstructing high-dimensional observations (pixels) is hard, and many details are irrelevant to decisions (like background color). MuZero only predicts decision-needed information, greatly reducing learning difficulty. Analogy: playing Go, you don't need to predict opponent's clothing color, only where they'll move.

Q7: Can Model-Based transfer to new tasks?

A: Theoretically yes, but practically very hard. If new task has different dynamics (like transferring from walking to running), model needs retraining or fine-tuning. Meta-learning and conditional models are active research directions.

Q8: On which tasks is Model-Free better?

A: - Environment has cheap simulator (like Atari, Go) - Dynamics extremely complex (like social interaction, stock market) - Need absolute optimal performance (Model-Based model errors may limit ceiling)

Q9: What is Dreamer's "dream"?

A: Dream refers to imagined trajectories rolled out in learned latent space model. Agent doesn't interact with real environment, instead: 1. Start from current latent state$s_ta_ts_{t+1}, r_t$4. Repeat 15 steps, collect virtual experience 5. Use this experience to update policy

Like human daydreaming: simulating possible futures mentally, learning from them without actually experiencing.

Q10: Future research directions for Model-Based RL?

A: - Composable world models: Learn reusable modules (like "objects", "physics laws"), compose to new scenarios - Causal world models: Not just predict "what will happen", but understand "why it happens" - Lifelong learning: Continuously update model, adapt to dynamically changing environments - Safe planning: Verify policy safety in model before execution

  1. Dyna Architecture
    • Integrated Architectures for Learning, Planning, and Reacting Based on Approximating Dynamic Programming (Sutton, 1990)
    • Classic textbook: Reinforcement Learning: An Introduction Chapter 8
  2. MBPO
    • When to Trust Your Model: Model-Based Policy Optimization (Janner et al., NeurIPS 2019)
    • Link: arXiv:1906.08253
  3. World Models
  4. PlaNet
    • Learning Latent Dynamics for Planning from Pixels (Hafner et al., ICML 2019)
    • Link: arXiv:1811.04551
  5. Dreamer
    • Dream to Control: Learning Behaviors by Latent Imagination (Hafner et al., ICLR 2020)
    • Link: arXiv:1912.01603
  6. DreamerV2
    • Mastering Atari with Discrete World Models (Hafner et al., ICLR 2021)
    • Link: arXiv:2010.02193
  7. DreamerV3
    • Mastering Diverse Domains through World Models (Hafner et al., 2023)
    • Link: arXiv:2301.04104
  8. MuZero
    • Mastering Atari, Go, Chess and Shogi by Planning with a Learned Model (Schrittwieser et al., Nature 2020)
    • Link: arXiv:1911.08265
  9. STEVE-1 (Minecraft with World Models)
    • See, Think, Explore, Search, and Interact: Open-World AGI in Minecraft (2023)
    • Link: arXiv:2311.08845
  10. Model-Based RL Survey
    • Model-Based Reinforcement Learning: A Survey (Moerland et al., 2023)
    • Link: arXiv:2006.16712

Core Formula Summary

Dyna-Q Update

MBPO Rollout

Dreamer Representation Learning

Dreamer Behavior Learning

MuZero Prediction

Summary and Outlook

Model-Based reinforcement learning achieves "think before doing" intelligent decision-making through learning environment models, dramatically improving sample efficiency. From classic Dyna architecture to deep learning era's MBPO, World Models, Dreamer and MuZero, each improvement stems from deep understanding of model error, representation learning and planning efficiency. Model-Based methods demonstrate enormous potential in sample-limited real applications (like robotics, autonomous driving), but model error accumulation, computational overhead and generalization remain challenges.

Future research may integrate symbolic reasoning, causal modeling and neural networks to build more interpretable, composable, transferable world models. Just as humans understand physical laws, social rules and others' intentions through internal models, next-generation AI systems may also need similar "world understanding" capabilities — this is not only a core problem in reinforcement learning, but also a key path toward artificial general intelligence.

From Model-Based planning perspective, the next chapter will enter AlphaGo and Monte Carlo Tree Search — through combining deep learning with classic search algorithms, revealing how to achieve superhuman level in perfect information games, and exploring deep connections between search and learning.

  • Post title:Reinforcement Learning (5): Model-Based RL and World Models
  • Post author:Chen Kai
  • Create time:2024-08-30 09:15:00
  • Post link:https://www.chenk.top/reinforcement-learning-5-model-based-rl-and-world-models/
  • Copyright Notice:All articles in this blog are licensed under BY-NC-SA unless stating additionally.
 Comments