World Models from Scratch
1. Why Learn a World Model?
Model-free reinforcement learning agents learn by trial and error. Every lesson requires a real interaction — crash the car, lose the game, drop the package. A model-free agent training on Atari might need 50 million frames of gameplay to learn what a human picks up in minutes. That's thousands of hours of simulated time for one game.
Humans don't work this way. You can imagine what happens if you push a glass off a table without actually doing it. You can rehearse a conversation in your head, plan a road trip by visualizing the route, or predict that a ball thrown upward will come back down. That mental simulation is a world model — a compressed internal representation of how the world works that lets you plan without acting.
"We explore building generative neural network models of popular reinforcement learning environments. Our world model can be trained quickly in an unsupervised manner to learn a compressed spatial and temporal representation of the environment." — Ha & Schmidhuber, 2018
In 2018, David Ha and Jürgen Schmidhuber proposed a beautifully simple architecture that brings this idea to artificial agents. They decomposed an agent's brain into three components:
The Vision model (V) compresses high-dimensional observations into a compact latent code. The Memory model (M) predicts how that latent code evolves over time given actions. And a tiny Controller (C) picks actions based on the current latent state. The revolutionary result: you can train V and M from random exploration data, then train the controller entirely inside the learned dream — the agent never touches the real environment during policy optimization.
If you've followed this series, you already have the building blocks: autoencoders-from-scratch built VAEs for compression, rnn-from-scratch built LSTMs for sequence prediction, and reinforcement-learning-from-scratch built policy optimization. Today we combine all three into something greater than the sum of its parts.
2. The Vision Model — Compressing Observations
Raw observations from an environment are high-dimensional. A 64×64 RGB image has 12,288 dimensions. An agent that reasons directly in pixel space wastes capacity on irrelevant details — exact textures, lighting variations, background noise. What matters is the structure: where am I, where's the obstacle, how fast am I moving?
The Vision model is a Variational Autoencoder (VAE) that compresses observations into a compact latent vector z of roughly 32 dimensions. The encoder maps each observation to a distribution in latent space (parameterized by μ and σ), and the decoder reconstructs observations from sampled latent codes. Training maximizes the Evidence Lower Bound (ELBO): good reconstruction plus a smooth, regularized latent space via KL divergence.
Why a VAE instead of a plain autoencoder? The KL term ensures the latent space is continuous and well-organized. Similar observations map to nearby points, and you can smoothly interpolate between states. This is essential for the Memory model — it needs to predict transitions between latent states, and that only works if nearby latent codes correspond to nearby observations.
import numpy as np
class VisionModel:
"""VAE: compresses observations into a smooth latent space."""
def __init__(self, obs_dim, latent_dim=32, hidden_dim=64):
s = 0.1 # weight initialization scale
self.W1 = np.random.randn(obs_dim, hidden_dim) * s
self.W_mu = np.random.randn(hidden_dim, latent_dim) * s
self.W_logvar = np.random.randn(hidden_dim, latent_dim) * s
self.W3 = np.random.randn(latent_dim, hidden_dim) * s
self.W4 = np.random.randn(hidden_dim, obs_dim) * s
def encode(self, x):
h = np.tanh(x @ self.W1)
return h @ self.W_mu, h @ self.W_logvar # mu, log-variance
def decode(self, z):
h = np.tanh(z @ self.W3)
return 1 / (1 + np.exp(-h @ self.W4)) # sigmoid output
def forward(self, x):
mu, logvar = self.encode(x)
std = np.exp(0.5 * logvar)
z = mu + std * np.random.randn(*mu.shape) # reparameterize
return self.decode(z), mu, logvar
def vae_loss(x, x_hat, mu, logvar):
recon = np.mean((x - x_hat) ** 2) # reconstruction
kl = -0.5 * np.mean(1 + logvar - mu**2 - np.exp(logvar)) # regularization
return recon + kl
# Example: compress 64-dim observations to 8-dim latents
V = VisionModel(obs_dim=64, latent_dim=8)
obs = np.random.rand(64)
recon, mu, logvar = V.forward(obs)
print(f"Original dim: 64 -> Latent dim: {mu.shape[0]}")
print(f"Latent mean: {mu[:4].round(3)}...")
print(f"ELBO loss: {vae_loss(obs, recon, mu, logvar):.4f}")
The reparameterization trick (z = mu + std * noise) is what makes this trainable with backpropagation — we sample from the latent distribution while keeping gradients flowing through μ and σ. After training on thousands of random environment frames, the VAE learns to map each observation to a meaningful latent code: position, velocity, and other task-relevant features emerge naturally.
3. The Memory Model — Predicting the Future
With observations compressed into latent vectors, we need a model that captures dynamics: given the current latent state zt and an action at, what latent state zt+1 comes next? This is the Memory model — an MDN-RNN (Mixture Density Network combined with a Recurrent Neural Network).
The RNN component is an LSTM that maintains a hidden state h accumulating the history of the trajectory. At each timestep it ingests the current latent code and action, then updates its hidden state. This hidden state captures temporal context — momentum, acceleration patterns, the recent history of what happened.
The MDN component is the output head. Instead of predicting a single next latent state (which would assume the future is deterministic), it outputs the parameters of a mixture of Gaussians: multiple possible next states, each with a probability weight π, mean μ, and standard deviation σ. This is critical because environments can be stochastic — a ball at the edge of a table might fall left or right, and the model should assign probability to both outcomes.
A temperature parameter τ controls how stochastic the predictions are. At low temperature the model commits to the most likely future; at high temperature it explores diverse possibilities. This turns out to be crucial for training robust controllers, as we'll see in Section 5.
def sigmoid(x):
return 1 / (1 + np.exp(-np.clip(x, -500, 500)))
def softmax(x):
e = np.exp(x - np.max(x))
return e / e.sum()
class MemoryModel:
"""MDN-RNN: predicts next latent state as a mixture of Gaussians."""
def __init__(self, z_dim=8, a_dim=2, h_dim=64, n_mix=3):
self.h_dim, self.n_mix, self.z_dim = h_dim, n_mix, z_dim
d_in = z_dim + a_dim
# LSTM gate weights (four gates: forget, input, output, cell candidate)
self.Wf = np.random.randn(d_in + h_dim, h_dim) * 0.05
self.Wi = np.random.randn(d_in + h_dim, h_dim) * 0.05
self.Wo = np.random.randn(d_in + h_dim, h_dim) * 0.05
self.Wc = np.random.randn(d_in + h_dim, h_dim) * 0.05
# MDN output: pi, mu, sigma for each Gaussian component
self.W_mdn = np.random.randn(h_dim, n_mix * (1 + 2 * z_dim)) * 0.05
def lstm_step(self, x, h, c):
xh = np.concatenate([x, h])
f = sigmoid(xh @ self.Wf) # forget gate
i = sigmoid(xh @ self.Wi) # input gate
o = sigmoid(xh @ self.Wo) # output gate
c_hat = np.tanh(xh @ self.Wc) # candidate cell
c = f * c + i * c_hat
h = o * np.tanh(c)
return h, c
def predict(self, z_t, a_t, h, c):
x = np.concatenate([z_t, a_t])
h, c = self.lstm_step(x, h, c)
raw = h @ self.W_mdn
K, D = self.n_mix, self.z_dim
pi = softmax(raw[:K])
mu = raw[K : K + K*D].reshape(K, D)
sigma = np.exp(raw[K + K*D :].reshape(K, D))
return pi, mu, sigma, h, c
def sample_next(self, pi, mu, sigma, temperature=1.0):
pi_t = softmax(np.log(pi + 1e-8) / temperature) # sharpen/flatten
k = np.random.choice(len(pi_t), p=pi_t) # pick component
return mu[k] + sigma[k] * np.random.randn(self.z_dim) * np.sqrt(temperature)
The key method is sample_next: it first selects which Gaussian component to sample from (weighted by π, adjusted by temperature), then draws a sample from that component. Low temperature concentrates probability on the most likely component; high temperature spreads it across all components, producing more diverse futures.
4. The Controller — Acting from Dreams
Here's perhaps the most surprising design choice in the World Models paper: the controller is just a single linear layer. It takes the concatenation of the current latent state z and the RNN hidden state h, and maps it to an action through a matrix multiply and tanh activation.
Why so simple? Because all the hard work — understanding what the agent sees, remembering what happened, predicting what comes next — is done by V and M. If those models have learned good representations, the controller only needs to map "I'm here, heading this way" to "turn left." A linear mapping suffices when the features are good enough.
For training, Ha & Schmidhuber used CMA-ES (Covariance Matrix Adaptation Evolution Strategy), a gradient-free optimization algorithm. CMA-ES maintains a population of candidate parameter vectors, evaluates each one by running a full episode, then shifts the distribution toward the best performers. No backpropagation through the world model needed — just "try many controllers, keep the best ones."
class Controller:
"""Tiny linear policy: action = tanh(W @ [z, h] + b)."""
def __init__(self, z_dim=8, h_dim=64, a_dim=2):
self.W = np.zeros((a_dim, z_dim + h_dim))
self.b = np.zeros(a_dim)
def act(self, z, h):
return np.tanh(self.W @ np.concatenate([z, h]) + self.b)
def get_params(self):
return np.concatenate([self.W.ravel(), self.b])
def set_params(self, p):
n = self.W.size
self.W = p[:n].reshape(self.W.shape)
self.b = p[n:]
def cma_es(evaluate_fn, n_params, generations=50, pop=32, sigma=0.5):
"""Simplified CMA-ES: evolve parameters using rank-based selection."""
mean = np.zeros(n_params)
for gen in range(generations):
noise = np.random.randn(pop, n_params)
candidates = mean + sigma * noise
rewards = np.array([evaluate_fn(c) for c in candidates])
# Select top 25% and shift distribution toward them
elite_idx = rewards.argsort()[::-1][:pop // 4]
mean = candidates[elite_idx].mean(axis=0)
sigma *= 0.995 # slowly reduce exploration
if gen % 10 == 0:
print(f"Gen {gen}: best={rewards.max():.1f}, avg={rewards.mean():.1f}")
return mean
The controller for the original World Models car-racing task had only 867 parameters (a 2×(32+256) weight matrix plus 2 biases). Compare that to the millions of parameters in V and M. The lesson: invest in your world model, not your policy.
5. Dreaming — Training Inside the Model
This is where everything comes together, and it's the most remarkable part of the architecture. Once V and M are trained from random exploration data, we can generate entirely imagined experience by chaining them together:
- Start from a real observation, encode it: z0 = V.encode(obs0)
- Controller picks an action: at = C.act(zt, ht)
- Memory predicts next state: zt+1 ~ M.predict(zt, at)
- Repeat for T timesteps — all inside the model
The controller is evaluated on the cumulative reward from these dreamed trajectories. CMA-ES generates many candidate controllers, each one "dreams" an episode, and the best dreamers get selected. The real environment is never consulted during this process.
Think of it as a flight simulator for RL: pilots train in simulators before flying real aircraft. The world model is the simulator, trained from data rather than hand-engineered. The controller learns to fly by practicing in the simulator.
The temperature parameter plays a crucial role here. Training in dreams that are too deterministic (low τ) produces controllers that exploit quirks of the model. Training in slightly stochastic dreams (higher τ) forces the controller to be robust — it must handle uncertainty in future states, which transfers better to the real environment.
def dream_reward(z):
"""Reward: negative distance from origin in latent space."""
return -np.sum(z ** 2)
def dream_episode(V, M, C, start_obs, horizon=50, temperature=1.0):
"""Roll out the controller inside the learned world model."""
z, _ = V.encode(start_obs)
h = np.zeros(M.h_dim)
c = np.zeros(M.h_dim)
total_reward = 0
for t in range(horizon):
action = C.act(z, h)
pi, mu, sigma, h, c = M.predict(z, action, h, c)
z = M.sample_next(pi, mu, sigma, temperature)
total_reward += dream_reward(z)
return total_reward
# Train controller entirely in dreams
def train_in_dream(V, M, controller, real_obs_buffer, n_gens=50):
def evaluate(params):
controller.set_params(params)
idx = np.random.randint(len(real_obs_buffer))
return dream_episode(V, M, controller, real_obs_buffer[idx],
horizon=50, temperature=1.2)
best = cma_es(evaluate, controller.get_params().size, generations=n_gens)
controller.set_params(best)
return controller # trained without touching the real environment!
The result of this pipeline is striking: on the CarRacing-v0 task, a controller trained entirely in dreams achieved a score of 906 out of 1000, competitive with model-free methods that required orders of magnitude more real environment interactions. The world model provided enough fidelity for the controller to learn useful driving behavior purely from imagination.
Try It: World Model Playground
Click Explore to let the agent wander and learn the world's dynamics. Then click Dream to see imagined trajectories on the right grid, or Plan to find the optimal path using only the learned model.
6. From Dreams to Reality — The Sim-to-Real Gap
There's a catch: the world model isn't perfect. Every prediction carries a small error, and over a long dreamed trajectory those errors compound. After 10 steps the dream might be a reasonable approximation; after 50 steps it might bear little resemblance to reality. This is the fundamental challenge of model-based RL.
Several strategies help bridge the gap between dreams and reality:
- Short dream horizons — limit how far ahead the model predicts, reducing error accumulation
- Model ensembles (PETS, Chua et al. 2018) — train multiple world models and average their predictions, reducing variance
- Temperature tuning — increase stochasticity to prevent the controller from exploiting model errors
- Dyna architecture (Sutton 1991) — mix real experience with model-generated experience, getting the best of both worlds
The Dyna approach is particularly elegant. For every real interaction with the environment, generate k additional imagined transitions using the world model. The policy learns from both real and dreamed data. Real experience keeps the learning grounded; dreamed experience amplifies sample efficiency.
def dyna_train(env, model, policy, buffer, episodes=200, k=5):
"""Dyna: one real step generates k imagined steps."""
for ep in range(episodes):
obs, done, ep_reward = env.reset(), False, 0
while not done:
action = policy.act(obs)
next_obs, reward, done = env.step(action)
buffer.add(obs, action, reward, next_obs)
# 1. Update world model on real transition
model.update(obs, action, reward, next_obs)
# 2. Generate k dream transitions per real one
for _ in range(k):
s = buffer.random_state()
a = policy.act(s)
s_next, r = model.imagine(s, a)
policy.learn(s, a, r, s_next)
# 3. Also learn from real experience
policy.learn(obs, action, reward, next_obs)
obs = next_obs
ep_reward += reward
if ep % 20 == 0:
print(f"Episode {ep}: reward={ep_reward:.1f}, "
f"model_loss={model.loss:.4f}, buffer={len(buffer)}")
With k=5, every real interaction generates 5 additional learning updates. The agent effectively learns 6× faster than a purely model-free approach, without sacrificing the grounding that real data provides. Richard Sutton's Dyna architecture from 1991 was decades ahead of its time — the idea of mixing real and simulated experience is now central to modern model-based RL.
Try It: Dream Quality Explorer
See how prediction errors compound over time. Increase the horizon to watch dreams diverge from reality. Raise the temperature to see how stochastic dreams explore diverse futures.
7. Modern World Models
The original World Models paper sparked an explosion of research. Each subsequent architecture improved on a different limitation — longer planning horizons, richer representations, more expressive dynamics models.
Dreamer (Hafner et al. 2020) replaced CMA-ES with actor-critic training directly in latent space. Instead of evolving controllers through black-box optimization, Dreamer backpropagates analytic gradients through the learned dynamics model to optimize a policy network. DreamerV3 (2023) scaled this to 150+ diverse tasks with a single set of hyperparameters using categorical latent representations and symlog predictions.
MuZero (Schrittwieser et al. 2020) took a different approach: it doesn't even try to reconstruct observations. Instead, it learns a dynamics model that predicts rewards, values, and policies directly in a learned representation space. Combined with Monte Carlo Tree Search for planning, MuZero mastered Go, chess, shogi, and 57 Atari games — all without knowing the rules.
IRIS (Micheli et al. 2023) replaced the RNN with a Transformer, treating world modeling as sequence prediction over discrete tokens. And DIAMOND (Alonso et al. 2024) used diffusion models to generate future frames, achieving state-of-the-art results on the Atari 100K benchmark by literally "dreaming" future frames pixel by pixel.
| Model | Year | Vision | Dynamics | Planning | Innovation |
|---|---|---|---|---|---|
| World Models | 2018 | VAE | MDN-RNN | CMA-ES in dream | Dream training |
| Dreamer | 2020 | CNN | RSSM | Actor-Critic (latent) | Gradient-based imagination |
| DreamerV3 | 2023 | Symlog CNN | Categorical RSSM | Actor-Critic | Universal hyperparameters |
| MuZero | 2020 | Learned repr. | Dynamics net | MCTS | No observation reconstruction |
| IRIS | 2023 | Discrete tokens | Transformer | Actor-Critic | Transformer world model |
| DIAMOND | 2024 | Diffusion | Diffusion | Actor-Critic | Diffusion-based simulation |
The trend is clear: world models keep getting more expressive, and the line between "model" and "simulator" keeps blurring. Modern world models can generate photorealistic future frames, plan over hundreds of timesteps, and generalize across radically different tasks.
8. Conclusion
World models represent a fundamental shift in how we think about intelligent agents. Instead of learning reactive mappings from observations to actions, the agent builds an internal simulator and thinks before it acts. The V-M-C decomposition — compress, predict, act — is both simple and powerful.
The implications extend beyond games and robotics. There's a compelling argument that large language models are themselves a kind of world model — they learn the "dynamics" of text by predicting next tokens, and chain-of-thought prompting is a form of planning through the model's internal simulation. The idea of learning a world model and then reasoning inside it may be the unifying principle behind the most capable AI systems we've built.
From Ha & Schmidhuber's original VAE+RNN to DreamerV3's universal agent to DIAMOND's diffusion dreams, the field has covered remarkable ground in just six years. The next frontier is clear: world models that can learn from a handful of interactions, generalize across environments, and plan over the kind of long horizons that humans navigate effortlessly. The dream is to build agents that truly understand their world — not just react to it.
References & Further Reading
- Ha & Schmidhuber 2018 — World Models — the original paper with interactive demos
- Hafner et al. 2020 — Dream to Control: Learning Behaviors by Latent Imagination — the Dreamer architecture
- Hafner et al. 2023 — Mastering Diverse Domains through World Models — DreamerV3 scaling to 150+ tasks
- Schrittwieser et al. 2020 — Mastering Atari, Go, Chess and Shogi by Planning with a Learned Model — MuZero
- Sutton 1991 — Dyna, an Integrated Architecture for Learning, Planning, and Reacting — the foundational Dyna framework
- Chua et al. 2018 — Deep RL in a Handful of Trials using Probabilistic Dynamics Models — PETS model ensembles
- Micheli et al. 2023 — Transformers are Sample-Efficient World Models — IRIS transformer world model
- Alonso et al. 2024 — Diffusion for World Modeling — DIAMOND diffusion-based world model