World Models 中面向更长时域的基于梯度规划
Gradient-based Planning for World Models at Longer Horizons
Michael Psenka、Mike Rabbat、Aditi Krishnapriyan、Yann LeCun、Amir Bar提出GRASP,用collocation提升virtual states、对state iterates加Gaussian noise,并阻断state gradients保留action Jacobians。Push-T中H=60成功率26.2%、中位用时49.1s,高于CEM、GD、LatCo。
.grasp-results-table table { font-size: 0.875rem; line-height: 1.35; width: 100%; } .grasp-results-table th, .grasp-results-table td { padding: 0.35rem 0.5rem; } /* 主要章节之间保持一致的空白(本文较长且大量使用 hr) */ article.post-content h2 { margin-top: 2.75rem; margin-bottom: 0.75rem; } article.post-content h2:first-of-type { margin-top: 2.25rem; } article.post-content h3 { margin-top: 1.65rem; margin-bottom: 0.5rem; } article.post-content hr { margin-top: 2.5rem; margin-bottom: 2.5rem; } GRASP 是一种新的基于 gradient(梯度)的 learned dynamics(学习到的动力学,也称“world model”)planner,它通过以下方式让 long-horizon planning 变得可行:(1) 将 trajectory 提升到 virtual states,使 optimization 可以在时间维度上并行;(2) 直接向 state iterates 加入 stochasticity 用于 exploration;(3) 重塑 gradients,使 actions 获得清晰信号,同时避免通过高维视觉模型传播脆弱的“state-input” gradients。
大型 learned world models 的能力正在持续增强。它们可以在高维视觉空间中预测很长的未来 observation 序列,并以几年前难以想象的方式跨任务泛化。随着这些模型扩大规模,它们看起来越来越不像任务特定的 predictor,而更像通用 simulator。
但拥有强大的 predictive model,并不等于能够有效地将其用于 control/learning/planning。在实践中,使用现代 world models 进行 long-horizon planning 仍然很脆弱:optimization 会变得 ill-conditioned,non-greedy 结构会产生糟糕的 local minima,高维 latent spaces 会引入微妙的 failure modes。
在这篇 blog post 中,我会介绍这个项目的动机以及我们的解决思路:为什么用现代 world models 做 planning 会出人意料地脆弱,为什么 long horizons 才是真正的压力测试,以及我们做了哪些修改来让 gradient-based planning 更稳健。
这篇 blog post 讨论的是我与 Mike Rabbat、Aditi Krishnapriyan、Yann LeCun 和 Amir Bar(* 表示共同指导)合作完成的工作,我们在其中提出了 GRASP。
What is a world model?
如今,“world model”这个术语的含义相当宽泛。根据上下文,它既可以指显式的 dynamics model,也可以指某个 generative model 所依赖的隐式、可靠的 internal state(例如,当 LLM 生成国际象棋走法时,内部是否存在棋盘的某种表示)。下面给出我们宽松的工作定义。
假设你采取 actions $a_t \in \mathcal{A}$,并观察 states $s_t \in \mathcal{S}$(图像、latent vectors、本体感知)。world model 是一个 learned model:给定当前 state 和一段未来 actions,它预测接下来会发生什么。形式上,它在一段 observed states $s_{t-h:t}$ 和当前 action $a_t$ 上定义 predictive distribution:
[P_\theta(s_{t+1} \mid s_{t-h:t},; a_t)]
它近似环境的真实条件分布 $P(s_{t+1} \mid s_{t-h:t},; a_t)$。为简单起见,本文假设一个 Markovian model $P(s_{t+1} \mid s_{t-h:t},; a_t)$(这里的所有结果都可以扩展到更一般的情形);当模型是 deterministic 时,它退化为 states 上的映射:
[s_{t+1} = F_\theta(s_t, a_t).]
实践中,state $s_t$ 通常是 learned latent representation(例如由 pixels 编码而来),因此模型运行在一个(理论上)紧凑、可微的空间中。关键在于,world model 给了你一个 differentiable simulator;你可以在假设的 action sequences 下向前 rollout,并通过预测进行 backpropagate。
Planning: choosing actions by optimizing through the model
给定起点 $s_0$ 和 goal $g$,最简单的 planner 会通过 rollout 模型并最小化 terminal error 来选择 action sequence $\mathbf{a}=(a_0,\dots,a_{T-1})$:
[\min_{\mathbf{a}} ; | s_T(\mathbf{a}) - g |2^2, \quad \text{where } s_T(\mathbf{a}) = \mathcal{F}{\theta}^{T}(s_0,\mathbf{a}).]
这里我们用 $\mathcal{F}^T$ 作为 world model 完整 rollout 的简写(对模型参数 $\theta$ 的依赖是隐含的):
[\mathcal{F}{\theta}^{T}(s_0, \mathbf{a}) = F\theta(F_\theta(\cdots F_\theta(s_0, a_0), \cdots, a_{T-2}), a_{T-1}).]
在短 horizon 和低维系统中,这种方法通常效果不错。但随着 horizon 变长、模型变得更大且表达能力更强,它的弱点会被放大。那么,为什么它不能直接在 scale 上工作?
Why long-horizon planning is hard (even when everything is differentiable)
对于更一般的 world model,有两个独立的痛点,外加一个 learned、deep learning-based models 特有的问题。
- Long-horizon rollouts create deep, ill-conditioned computation graphs
熟悉 backprop through time(BPTT)的人可能会注意到,我们是在对一个反复应用到自身的模型求导,这会导致 exploding/vanishing gradients 问题。也就是说,如果我们对较早的 actions(例如 $a_0$)求导(注意我们是在对 vector-valued functions 求导,因此得到的是 Jacobians,记作 $D_x (\cdots)$):
[D_{a_0} \mathcal{F}{\theta}^{T}(s_0, \mathbf{a}) = \Bigl(\prod{t=1}^T D_s F_\theta(s_t, a_t)\Bigr) D_{a_0}F_\theta(s_0, a_0).]
可以看到,Jacobian 的 conditioning 会随时间 $T$ 指数级缩放:
[\sigma_{\text{max/min}}(D_{a_0}\mathcal{F}{\theta}^{T}) \sim \sigma{\text{max/min}}(D_s F_\theta)^{T-1},]
从而导致 exploding 或 vanishing gradients。
- The landscape is non-greedy and full of traps
在短 horizon 下,greedy solution 通常已经足够好,也就是每一步都直接朝 goal 移动。如果你只需要向前规划几步,最优 trajectory 通常不会明显偏离每一步“朝 $g$ 前进”的策略。
随着 horizon 变长,会发生两件事。第一,更长的任务更可能需要 non-greedy 行为:绕过一堵墙、在推动前重新定位、后退以选择更好的路径。并且随着 horizon 变长,通常需要更多这样的 non-greedy steps。第二,optimization space 本身也随 horizon 扩大:$\mathrm{dim}(\mathcal{A} \times \cdots \times \mathcal{A}) = T\mathrm{dim}(\mathcal{A})$,进一步扩大了 optimization problem 的 local minima 空间。
沿最优路径到 goal 的距离并不单调,由此产生的 loss landscape 可能很粗糙。
A long-horizon fix: lifting the dynamics constraint
假设我们把 dynamics constraint $s_{t+1} = F_{\theta}(s_t, a_t)$ 作为 soft constraint,并改为同时对 actions $(a_0,\ldots,a_{T-1})$ 和 states $(s_0,\ldots,s_T)$ 优化下面的 penalty function:
[\min_{\mathbf{s},\mathbf{a}} \mathcal{L}(\mathbf{s}, \mathbf{a}) = \sum_{t=0}^{T-1} \big|F_\theta(s_t,a_t) - s_{t+1}\big|_2^2, \quad \text{with } s_0 \text{ fixed and } s_T=g.]
在 planning/robotics 文献中,这有时也被称为 collocation。注意,lifted formulation 与原始 rollout objective 共享相同的 global minimizers(两者恰好在 trajectory dynamically feasible 时为零)。但它们的 optimization landscapes 非常不同,并且我们立刻获得两个好处:
每次 world model evaluation $F_{\theta}(s_t,a_t)$ 只依赖局部变量,因此所有 $T$ 项都可以在时间维度上并行计算,这对更长的 horizons 带来很大的加速,并且
你不再需要通过单个很深的 $T$-step composition 进行 backpropagate 来获得 learning signal,因为之前的 Jacobians 乘积现在拆成了求和,例如:
[D_{a_0} \mathcal{L} = 2(F_\theta(s_0, a_0) - s_1).]
能够直接优化 states 也有助于 exploration,因为我们可以暂时穿过不符合物理的区域来寻找最优 plan:
基于 collocation 的 planning 允许我们直接扰动 states,并更有效地探索中间点。
不过,世上没有免费的午餐。事实上,尤其对于 deep learning-based world models,上述 optimization 在实践中有一个关键问题,会让它变得相当困难。
An issue for deep learning-based world models: sensitivity of state-input gradients
本节的 tl;dr 是:通过 deep learning-based $F_{\theta}$ 直接优化 states 极其脆弱,类似 adversarial robustness 问题。即使你在较低维的 state space 中训练 world model,world model 的训练过程也会让未见过的 state landscapes 变得非常尖锐,无论这是一个未见过的 state 本身,还是 data manifold 的 normal/orthogonal 方向。
Adversarial robustness and the “dimpled manifold” model
Adversarial robustness 最初研究 classification models $f_\theta : \mathbb{R}^{w\times h \times c} \to \mathbb{R}^K$,并表明如果从 base image $x$(不属于类别 $k$)出发,沿某个 logit 的 gradient $\nabla f_\theta^k$ 移动,那么不需要在 $x’ = x + \epsilon\nabla f_\theta^k$ 上移动很远,就可以让 $f_\theta$ 将 $x’$ 分类为 $k$(Szegedy et al., 2014;Goodfellow et al., 2015):
(Goodfellow et al., 2015)中经典示例的示意图。
后续工作给出了一个几何图景来解释正在发生什么:对于接近低维 manifold $\mathcal{M}$ 的数据,训练过程会控制 tangential directions 上的行为,但不会 regularize orthogonal directions 上的行为,因此导致敏感行为(Stutz et al., 2019)。换句话说:如果只考虑 data manifold $\mathcal{M}$ 的 tangential directions,$f_\theta$ 具有合理的 Lipschitz constant;但在 normal directions 上,它可能具有非常高的 Lipschitz constants。事实上,模型通常会从这些 normal directions 上更尖锐的行为中受益,因为这样可以更精确地拟合更复杂的函数。
因此,即便对于单个给定模型,这类 adversarial examples 也极其常见。此外,这并不只是 computer vision 现象;adversarial examples 也出现在 LLMs(Wallace et al., 2019)和 RL(Gleave et al., 2019)中。虽然有一些方法可以训练出 adversarially robust models,但 model performance 与 adversarial robustness 之间存在已知的 trade-off(Tsipras et al., 2019):尤其是在存在许多弱相关变量时,模型必须更尖锐才能达到更高性能。事实上,无论是在 computer vision 还是 LLMs 中,大多数现代 training algorithms 都不会把 adversarial robustness 训练掉。因此,至少在 deep learning 出现重大范式变化之前,这是一个我们不得不面对的问题。
Why is adversarial robustness an issue for world model planning?
考虑我们在 lifted state approach 中优化的 dynamics loss 的一个单独分量:
[\min_{s_t, a_t, s_{t+1}} |F_\theta(s_t, a_t) - s_{t+1}|_2^2]
进一步只关注 base state:
[\min_{s_t} |F_\theta(s_t, a_t) - s_{t+1}|_2^2.]
由于 world models 通常是在 state/action trajectories $(s_1, a_1, s_2, a_2, \ldots)$ 上训练的,$F_{\theta}$ 的 state-data manifold 维度受到 action space 的限制:
[\mathrm{dim}(\mathcal{M}_s) \le \mathrm{dim}(\mathcal{A}) + 1 + \mathrm{dim}(\mathcal{R}),]
其中 $\mathcal{R}$ 是某个可选的 augmentations 空间(例如 translations/rotations)。因此,我们通常可以预期 $\mathrm{dim}(\mathcal{M}_s)$ 远低于 $\mathrm{dim}(\mathcal{S})$,于是:
非常容易找到 adversarial examples,把任意 state hack 成任意期望 state。
因此,dynamics optimization
[\sum_{t=0}^{T-1} \big|F_\theta(s_t,a_t) - s_{t+1}\big|_2^2]
会显得极其“sticky”,因为 base points $s_t$ 很容易欺骗 $F_{\theta}$,让它以为已经达到了局部目标。1
- 这个 adversarial robustness 问题在 lifted-state approaches 中尤其严重,但并非它们独有。即使对于通过完整 rollout map $\mathcal{F}^T$ 进行优化的 serial optimization methods,也可能进入未见过的 states,在那里很容易让 normal component 输入到 $D_s F_{\theta}$ 的敏感 normal components 中。action Jacobian 的 chain rule expansion 是
[\Bigl(\prod_{t=1}^T D_s F_\theta(s_t, a_t)\Bigr) D_{a_0}F_\theta(s_0, a_0).]
看看如果乘积的任意阶段包含任意 data manifold 的 normal component,会发生什么。↩
Our fix
这就是我们的新 planner GRASP 登场的地方。主要观察是:虽然 $D_s F_{\theta}$ 不可靠且具有 adversarial 性质,但 action space 通常是低维且被充分训练的,因此 $D_a F_{\theta}$ 实际上适合用于优化,并且不会遭遇 adversarial robustness 问题!action input 通常维度更低、训练更密集(模型见过每个 action direction),因此 action gradients 的行为要好得多。
GRASP 的核心是构建一个 first-order lifted state / collocation-based planner,它只依赖通过 world model 得到的 action Jacobians。因此,我们利用 learned world models $F_{\theta}$ 的 differentiability,同时不受 state Jacobians $D_s F_{\theta}$ 固有敏感性的影响。
GRASP: Gradient RelAxed S tochastic P lanner
如前所述,我们从 collocation planning objective 开始,将 states 提升,并把 dynamics 放松为一个 penalty:
[\min_{\mathbf{s},\mathbf{a}} \mathcal{L}(\mathbf{s}, \mathbf{a}) = \sum_{t=0}^{T-1} \big|F_\theta(s_t,a_t) - s_{t+1}\big|_2^2, \quad \text{with } s_0 \text{ fixed and } s_T=g.]
然后我们加入两个关键组件。
Ingredient 1: Exploration by noising the state iterates
即使 objective 更平滑,planning 仍然是 nonconvex。我们通过在 optimization 期间向 virtual state updates 注入 Gaussian noise 来引入 exploration。一个简单版本是:
[s_t \leftarrow s_t - \eta_s \nabla_{s_t}\mathcal{L} + \sigma_{\text{state}} \xi, \qquad \xi\sim\mathcal{N}(0,I).]
Actions 仍然通过 non-stochastic descent 更新:
[a_t \leftarrow a_t - \eta_a \nabla_{a_t}\mathcal{L}.]
state noise 帮助你在 lifted space 中的 basin 之间“跳跃”,而 actions 仍由 gradients 引导。我们发现,在这里专门对 states 加噪(而不是 actions)能在 exploration 和寻找更尖锐 minima 的能力之间取得良好平衡。2
- 因为我们只对 states 加噪(而不是 actions),对应的 dynamics 并不是真正的 Langevin dynamics。↩
Ingredient 2: Reshape gradients: stop brittle state-input gradients, keep action gradients
如前所述,脆弱的路径是流入 world model 的 state input 的 gradient,即 (D_s F_{\theta})。最直接的初始做法就是直接阻断进入 (F_{\theta}) 的 state gradients:
令 $\bar{s}_t$ 与 $s_t$ 取值相同,但停止 gradients。定义 stop-gradient dynamics loss:
[\mathcal{L}{\text{dyn}}^{\text{sg}}(\mathbf{s},\mathbf{a}) = \sum{t=0}^{T-1} \big|F_\theta(\bar{s}t, a_t) - s{t+1}\big|_2^2.]
仅靠这一点行不通。注意,现在 states 只会跟随前一个 state 的 step,而没有任何东西迫使 base states 去追逐后面的 states。因此,会出现一些 trivial minima:先停在 origin,然后只有最后一个 action 试图一步到达 goal。
Dense goal shaping
我们可以把上述问题看作 goal 的信号被完全从之前的 states 中切断。一种修复方法是在整个 prediction 中简单加入 dense goal term:
[\mathcal{L}{\text{goal}}^{\text{sg}}(\mathbf{s},\mathbf{a}) = \sum{t=0}^{T-1} \big|F_\theta(\bar{s}_t, a_t) - g\big|_2^2.]
在常规设置中,这会过度偏向直接追逐 goal 的 greedy solution;但在我们的设置中,这会被 stop-gradient dynamics loss 对 feasible dynamics 的偏置所平衡。
最终 objective 如下:
[\mathcal{L}(\mathbf{s},\mathbf{a}) = \mathcal{L}{\text{dyn}}^{\text{sg}}(\mathbf{s},\mathbf{a}) + \gamma , \mathcal{L}{\text{goal}}^{\text{sg}}(\mathbf{s},\mathbf{a}).]
得到的 planning optimization objective 不再依赖 state gradients。
Periodic “sync”: briefly return to true rollout gradients
lifted stop-gradient objective 非常适合快速、有引导的 exploration,但它仍然是对原始 serial rollout objective 的近似。因此,每隔 $K_{\text{sync}}$ 次 iteration,GRASP 会执行一个短暂的 refinement phase:
使用当前 actions $\mathbf{a}$ 从 $s_0$ 开始 rollout,并在原始 serial loss 上做几步小的 gradient steps:
[\mathbf{a} \leftarrow \mathbf{a} - \eta_{\text{sync}},\nabla_{\mathbf{a}},|s_T(\mathbf{a})-g|_2^2.]
lifted-state optimization 仍然提供 optimization 的核心,而这个 refinement step 提供一些辅助,使 states 和 actions 更贴近真实 trajectories。
当然,这个 refinement step 可以替换为你选择的 serial planner(例如 CEM);核心思想是在主要使用 lifted-state planning 优势的同时,仍然获得 serial planners 的 full-path synchronization 的部分好处。
How GRASP addresses long-range planning
Collocation-based planners 为 long-horizon planning 提供了一个自然的修复方向,但由于 adversarial robustness 问题,通过现代 world models 进行这种 optimization 相当困难。GRASP 为更平滑的 collocation-based planner 提出了一种简单解法,并结合稳定的 stochasticity 用于 exploration。因此,longer-horizon planning 不仅更容易成功,而且成功所需时间也更短:
Push-T demo:使用 GRASP 进行 longer-horizon planning。
Horizon CEM GD LatCo GRASP
H=40 61.4% / 35.3s 51.0% / 18.0s 15.0% / 598.0s 59.0% / 8.5s
H=50 30.2% / 96.2s 37.6% / 76.3s 4.2% / 1114.7s 43.4% / 15.2s
H=60 7.2% / 83.1s 16.4% / 146.5s 2.0% / 231.5s 26.2% / 49.1s
H=70 7.8% / 156.1s 12.0% / 103.1s 0.0% / — 16.0% / 79.9s
H=80 2.8% / 132.2s 6.4% / 161.3s 0.0% / — 10.4% / 58.9s
Push-T 结果。Success rate (%) / median time to success。Bold = 每行最佳。注意,median success time 在 success rate 更高时会偏高;尽管 success rate 更高,GRASP 仍然更快。
What’s next?
现代 world model planners 仍有大量工作要做。我们希望利用 learned world models 的 gradient structure,而 collocation(lifted-state optimization)是 long-horizon planning 的自然方法;但理解这里典型的 gradient structure 至关重要:action gradients 平滑且信息丰富,state gradients 脆弱。
我们把 GRASP 看作这类 planners 的一个初始版本。扩展到 diffusion-based world models(更深的 latent timesteps 可以看作 world model 本身的平滑版本)、更复杂的 optimizers 和 noising strategies,以及将 GRASP 集成到 closed-loop system 或 RL policy learning 中以实现 adaptive long-horizon planning,都是自然且有意思的下一步。
我确实认为,现在是研究 world model planners 的好时机。这是一个有趣的交汇点:背景文献(planning 和 control 总体上)已经非常成熟和完善,但当前设置(在现代大规模 world models 上进行纯 planning optimization)仍然远未被充分探索。不过,一旦我们找到所有正确思路,world model planners 很可能会像 RL 一样常见。
更多细节请阅读完整论文或访问项目网站。
Citation
@article { psenka2026grasp , title = {Parallel Stochastic Gradient-Based Planning for World Models} , author = {Michael Psenka and Michael Rabbat and Aditi Krishnapriyan and Yann LeCun and Amir Bar} , year = {2026} , eprint = {2602.00475} , archivePrefix = {arXiv} , primaryClass = {cs.LG} , url = {https://arxiv.org/abs/2602.00475} }