🔥

[論文解説]Consistency Policy

2024/12/24に公開

Consistency Policyについて解説します。

元論文
Consistency Policy: Accelerated Visuomotor Policies via Consistency Distillation

公式のサイト
https://consistency-policy.github.io/

公式のgithub
https://github.com/Aaditya-Prasad/Consistency-Policy/

概要

Consistency Policyは、Diffusion PolicyをTeacher Modelとし、Diffusion Policyを蒸留することで、より高速なロボットの方策を学習しました。
学習ではConsistency Trajectory Model(CTM)を使うことで、精度を維持したまま、推論ステップ数を削減することで高速化を達成しました。

先行研究との違いは

Diffusion Policy: ロボティクスや制御の分野で注目を集める一方で、従来の拡散モデルは推論に多くのdeniseステップが必要で速度が遅いというデメリットがありました。

提案手法のConsistency Policyは、Diffusion Policyを蒸留することで性能を保ちつつ推論ステップ数を削減します。

技術のキモは?

CTMを使った蒸留

基本的なアイデア

  1. Teacher Modelは通常の拡散モデルとして学習する。
  2. Student Modelは大きな時間ステップでのdenoiseを学習する。
  3. 異なる経路で同じ中間点に到達することを保証する損失関数を導入する。

lossの全体像は以下のように定義されます。

\mathcal{L}_{CP} = \beta \mathcal{L}_{DSM} + \alpha \mathcal{L}_{CTM}

ここでDSM(Denoising Score Matching) lossはTeacher ModelをDiffusio Policyの学習と同じように学習します。
CTM(Consistency Trajectory Model) lossはStudent Modelを少ないステップでTeacher Modelから蒸留します。

学習では2つのlossを具体的に確認します。

DSM(Denoising Score Matching)

Teacher Model s_{\phi}は、Diffusion policyとして学習します。
真の値x_0、ノイズが付与されたTステップ後の値x_{T}があるとき。x_{T} → x_{0}に向かう勾配を推定することでdenoiseを行います。
以下のような式で表されます。

x_0 = x_T + \int_{T}^0 \frac{dx}{dt} dt

SonyAIさんの動画がとてもわかりやすいです。

DSMの途中式を全て省略しますが、この変化量を推定するネットワークをTeacherでは学習します。
tステップかけて真の値x_0を求められるようにDSM lossを求めます。

\mathcal{L}_{DSM}(\theta) = E[d(x_0, s_{\phi}(x_t, t, 0; o))]

CTM(Consistency Trajectory Model)の計算

Student Modelはg_{\theta}はTeacher Modelよりも少ないステップで学習するように蒸留する。
以下の2つの経路を考えます:
経路A: t → s → 0

  • t → s: Student Modelが予測
  • s → 0: Student Modelが予測(stop grad)
    x_{est} = g_{sg(\theta)}(g_{\phi}(x_t, t, s), s, 0)

経路B: t → u → s → 0

  • t → u: Teacher Modelが予測(stop grad)
  • u → s: Student Modelが予測(stop grad)
  • s → 0: Student Modelが予測(stop grad)
    x_{target} = g_{sg\theta}(g_{sg(\theta)}(s_{\phi}(x_t, t, u), u, s), s, 0)

図で表すと以下のようになります。
経路Aが上、経路Bが下になります。

青はstudentがgradを計算するステップ。
緑はteacherがstopgradで計算するステップ。
オレンジはstudentがstopgradで計算するステップ。
赤がCTMのlossとなります。

両経路で得られる予測値の誤差をCTM損失として最小化します。

\mathcal{L}_{CTM} = d( x_{target}, x_{est} )

* u → s の予測はTeacher Modelの方が精度が良いように思えますが、一貫性のためStudent Modelを使用。
s → 0 の計算の必要性: CTMの原著では点sまででCTM lossを計算しているので0まではいらないはず。CTMではGAN lossとして、x_{est}x_{0}のlossを計算しているが、なぜかこれがなくなっている。

結果

Consistency Policy (CP)の性能評価は、DDPMやDDiMの拡散モデルで学習したDiffusion PolicyをBaselineとして比較し、ロボットシミュレーション(Robomimic)上の5種類のタスクで検証します。

評価指標は、

  • 成功率 (Success Rate)
  • NFE (Number of Function Evaluations): 推論ステップ数を、並列化考慮で正規化した指標

DDPMは100ステップ、DDiMは15ステップで推論を行い、それぞれ並列化に応じてNFEを計算しました。一方、CPは1ステップ or 3ステップの推論を行い、NFE上でも大きなメリットを示しています。
(ParaDiGMSの論文からDDPMとDDiMは並列化を考慮すると3.7倍と1.6倍になるため、ステップ数から割ります)

P5000 GPU(16GB)での推論時間の結果。1ステップあたり1msぐらい。(TeacherもStudentもアーキテクチャはほぼ同じなので1ステップあたりの速度は変わってなさそう)

まとめ

Consistency Policyは、Diffusion Policy(Teacher Model)を、CTMを使ってStudent Modelに蒸留することで、少ないステップで高精度なpolicyを学習することができました。

感想

  • lossの式がCTMと違うところがあり、説明がないため実験で得られた結果なのか、間違いなのかの分からなかった。
  • 見逃した可能性があるが所々説明が少なかった。

Discussion