Open13

Shortcut Models (One Step Diffusion via Shortcut Models) について調べる

PlatPlat

背景

Diffusion や Flow matching (Diffusion の特殊ケース) ではモデルはノイズからデータへの変換を学習する。モデルは 0 ~ 1 のタイムステップの中で、それぞれのタイムステップで尤もらしい・期待値の高い予測をするように学習を行うが、これはデータの平均を予測していることになる。そのため、完璧に軌跡を学習できたとしても、少ないステップ数で生成を行うとき、特にステップ数が1(=タイムステップが0)の時の予測は本来のデータ分布の 平均 を目指す予測を行うことになり、正しい生成が行えない。


https://kvfrans.com/shortcut-models/

PlatPlat

提案手法

タイムステップ条件 t だけでなく、デノイズしたいステップサイズ d も条件に追加してデノイズを行う Shortcut モデルを提案する。普通に学習すると、非常に細かいステップサイズでの予測をすることになり、これは少ないステップ数の時にうまく予測ができなくなるので、飛びたいステップ先の条件 d を考慮することによって、いい感じに次の点が予測できるようになる。
タイムステップ t 時点でのノイジーなデータ x_t から、ステップサイズ d 先のショートカット x'_{t+d} を予測する関数 s(x_t, t, d) は、

x'_{t+d} = x_t + s(x_t, t, d) d

となる。ショートカットモデル s_\theta(x_t, t, d) は全ての x_t, t, d の組み合わせからショートカットを学習することが目標。 これは、ステップ数の多い Flow-matching の一般化として考えることができ、d = 0 の時は通常の Flow-matching と等価である。

ただし、愚直に s_\theta を学習しようとするとステップサイズを小さく取ることになって、非常に計算量が必要になる。なので、代わりに「ステップサイズ 2d の 1 回のショートカットが、 ステップサイズd のショートカット 2 回分になる」ように学習を行う。つまり、

s(x_t, t, 2d) * 2d = s(x_t, t, d) * d + s(x_{t+d}, t, d) * d

を学習することであり、これは両辺を 2d で割って、

s(x_t, t, 2d) = (s(x_t, t, d) + s(x_{t+d}, t, d))/2

となる。(x_t からの 2d のショートカットが、x_t からの d のショートカット先(x_{t+d})x_{t+d} からの d のショートカットが等しくなるようにする。)


こういう感じ...?

プロジェクトページにはこういう感じに書いてある:

https://kvfrans.com/shortcut-models/

PlatPlat

ロス関数

d が 0 の時は普通に Flow-matching を学習するので、 Flow-matching のロスとショートカットのロスを組み合わせたものになる。それぞれ平均二乗誤差 MSE で計算するので、論文中の式をそのまま載せるとこうなる:

PlatPlat

学習の詳細

実際のサンプルの利用

d が 0 の時は普通に Flow-matching をやると言ったので、普通にやる。ランダムにデータとノイズのペアを作り、速度(方向を含んだ速さ)を予測する。タイムステップは 0~1 の間で均等にランダムにサンプルして問題なかったらしい。

ショートカットの決め方

ショートカットのサイズ d は事前に決めたサイズを使う。論文中では 0, 2^0, 2^1, ... 2^7 (=128) の計8種類を用いた。それぞれがショートカットするタイムステップの距離は 1, 1/2, 1/4, ... 1/128 となる。(前提として、タイムステップは 0~1 の間で取られる)。2の累乗で距離を取ってるので、d1 未満のとき (1/2 以下のとき)は、 2d で2倍を取ってもその 8 種類の内に入ることになる。

ロス関数中に出てくる、x'_{t+d} だが、これはモデル自身の予測で作られたもので、実データとノイズの補間で作れるものではないことに注意が必要。
あと、d1/128 とか小さい時は d = 0 として扱ったらしい。

PlatPlat

学習の効率化

ロス関数には、実データを用いた Flow-matching 項と自己一貫項があるが、自己一貫の項は決定論的(自身の予測を元にさらに予測するため)なので、バッチ内で Flow-matching 項の割合を高くする方が良いとしている。

特に、自己一貫項は計算するために2回の forward を行う必要があるため、なるべく少なく済むと嬉しい。論文だとバッチ内の 1/4 を自己一貫項に割り当てたそう。最終的な計算量の増加は ~16% くらいだったらしい。

PlatPlat

離散的なタイムステップ

ショートカットの学習の負荷を減らすために、ショートカットの開始地点のタイムステップ t は連続的にとらず、特定のタイムステップからのみ取るようにした。具体的には、まず d を最初にサンプルしてから td の何倍かになるように取った。

PlatPlat

Consistency Models との比較

https://arxiv.org/abs/2303.01469

Consistency Models と似てるが、一部異なる。

  1. Consistency は、経験的な x_tx_d ペアのみで学習を行う (自身を目標に自身を改善する?一方 Shortcut では x'_{t+d} を使い、これは backward パスに入らない) ため、それぞれの離れたステップで取り返しのつかない誤差が積み重なっていく (小さいステップでの学習もない)
  2. Shortcut では log_2(T) 種類のみのステップサイズを学習する (実験では T = 128)が、Consistency は T 種類学習することになる。
  3. Shortcut ではショートカットしない普通の生成も可能だが、Concsistency はそうではない(らしい)
  4. Shortcut の方がやってることがシンプル。Concistency はいろんなトリックが必要らしい?

Consistency の参考:
https://zenn.dev/discus0434/articles/484be111f7862d

PlatPlat

Classifier-Free Guidance (CFG) との組み合わせ

蒸留モデルにありがちな CFG どうするのか問題も少し言及されていた。CFG は条件なし生成と条件あり生成を比較して線形近似することで実現される。

Shotcut では小さいステップサイズ (小さい d)ではよく動作するが、大きいステップサイズ (= 少ない生成ステップ数) だと、近似に失敗した時にエラーが発生しやすかったらしい。
そのため、検証では d = 0 (= 通常の Flow-matching と等価) のときのみ CFG を適用し、それ以外では CFG を適用せずに生成を行ったらしい。

Shortcut の CFG の課題は CFG scale を学習前に指定する必要があることとされている。

PlatPlat

↑ CFG の問題で連想するのが、

1. LCM-LoRA (Latent Consistency Model LoRA)

元論文 は CFG を操作するためのモジュールが追加されているけど、diffusers による LoRA は学習時に固定して、追加モジュールなくても CFG 適用状態になるようにしているらしい
参考:
https://note.com/gcem156/n/nb23aff723431

2. Flux.1-dev の guidance 蒸留

dev は ガイダンス蒸留(内容不明) をされているのだが、その実装が Shortcut と似ている気がする...

Shortcut はタイムステップ t に加えてステップサイズ d を受け取るわけだが、そのため d を処理するための埋め込みレイヤー(タイムステップ埋め込み層; sin, cos の絶対位置エンコーディング)が一つ増えて、通常のタイムステップ埋め込みに足し合わされている。

https://github.com/smileyenot983/shortcut_pytorch/blob/7fee517af0b15dd8bd6ec89371f2e62e3947b4f0/model.py#L631-L650

一方の Flux はタイムステップ t に加えてガイダンス量 guidance を受け取り、同様にタイムステップ埋め込みレイヤーに通されて、通常のタイムステップ埋め込みに足し合わされている。

https://github.com/black-forest-labs/flux/blob/d06f82803f5727a91b0cf93fcbb09d920761fba1/src/flux/model.py#L84-L104

構造がめっちゃ似てるので、Flux の構造を真似して、ガイダンスを受け取りながら少ないステップ数で生成できるような学習も似た感じにできるんじゃないだろうか...?

PlatPlat

タイムステップ t、ノイジーデータ x_t、ステップ距離 d、CFGスケール c から、ガイダンスされたショートカット先 x'_{t+d,g} を予測する g(x_t, t, d, c) は単純に追加でガイダンスを掛けて、

x'_{t+d,g} = x_t + g(x_t, t, d, g) d c

となる。Shortcut 同様の式を作るとしたら、2回分の 1 倍ガイダンス付き d 距離ショートカットを c 倍したものが、1回分の c 倍ガイダンス付き 2d 距離を 1 倍したものに等しくなればいいので、

g(x_t, t, 2d, c) * 2d = g(x_t, t, d, 1) * d * c + g(x_{t+d,c}, t, d, 1) * d * c

なので、両辺 2d で割ると、

g(x_t, t, 2d, c) = (g(x_t, t, d, c) + g(x_{t+d,c}, t, d, c)) * c / 2

になる。

なので、単にガイダンス量 c を入力に追加して、ロス関数の自己一貫項で c を掛ければいい...?

実装するとしたら、Shortcut と Flux の混ぜて、タイムステップ埋め込みに通した距離 d の埋め込みと、タイムステップ埋め込みに通したガイダンス guidance を足せばいいのか...?

PlatPlat

いや、CFGは無条件(uncond)と条件あり(cond)を比較する必要があるから、さらに forward 回数を増やさないとダメか...?

x'_{t+d,uncond} = g(x_t, t, d, cfg=1, uncond) * d \\ x'_{t+d,cond} = g(x_t, t, d, cfg=1, cond) * d \\ g(x_t, t, d, cfg, cond) * d = (x'_{t+d,uncond} + (x'_{t+d,cond} - x'_{t+d,uncond}) * cfg) * d

となってしまうが、d が消えてしまうので、Shortcut とは別の新しいガイダンス項になるのかな

計算量が増えてしまう感じだろうか...

PlatPlat

ちょっと色々認識が違ってたようなので修正

d は距離というよりは時間で、delta timeduration で呼ぶ方が適切っぽい。公式コードだと dt で表記されていて、そのまま微小時間を表してる。

なので、dt → 0 っていうのは、「ショートカット距離がめちゃくちゃ短い」なので、つまり「ショートカットせずに普通に flow matching してるのと同じ」って感じっぽい。

実際の実装で dt を0にするときはショートカット回数が 128 回(128ステップ生成) の時を想定していて、Flow-matching のロスの時は、128 回ショートカット時を想定している感じっぽい。(連続値からのサンプルではなく、1/128 の倍数からランダムに取得する。つまり、torch.rand(low=0, high=1, size=(batch_size,)) ではなく、 torch.randint(low=0, high=128, size=(batch_size,)).float()/128 となる。)

タイムステップの取り方は ノイズ→画像1→0 の実装と、0→1 になっている実装があるっぽくて、公式実装や公式 DiT は 0→1 だったが AuraFlow は 1→0 だったので実装時は符号を変える必要があった。つまり、x_{t+d} = x_t + d(...) * d となってる場合は、0→1 でデノイズだけど、逆方向なら x_t = x_{t+d} - d(...) (移項しただけ) にする必要があると思う。多分。

なぜか知らないけど、Shortcut モデルが受け取るショートカット情報は、ショートカットする時間の長さじゃなくて、ショートカット回数に log2 を取った指数(多分、指数だけど、公式コードとかだと base= 底 = 2 で表されていてちょっとよくわからない。理解が間違っているのだろうか...) を入力しているので、ショートカット回数 1 なら log2(1) = 0 で 0 が入力され、128 ステップ生成なら log2(128) = 7 が入力されることになる。

タイムステップ埋め込みなのに指数を使うのって、なんかあんまり直感的じゃない... 普通にショートカットする時間 dt を入れてしまってはじゃダメなのか...?