Shortcut Models (One Step Diffusion via Shortcut Models) について調べる
arXiv:
プロジェクトページ:
公式GitHub (Jax):
非公式 PyTorch 実装:
背景
Diffusion や Flow matching (Diffusion の特殊ケース) ではモデルはノイズからデータへの変換を学習する。モデルは 0 ~ 1 のタイムステップの中で、それぞれのタイムステップで尤もらしい・期待値の高い予測をするように学習を行うが、これはデータの平均を予測していることになる。そのため、完璧に軌跡を学習できたとしても、少ないステップ数で生成を行うとき、特にステップ数が1(=タイムステップが0)の時の予測は本来のデータ分布の 平均 を目指す予測を行うことになり、正しい生成が行えない。
提案手法
タイムステップ条件
タイムステップ
となる。ショートカットモデル
ただし、愚直に
を学習することであり、これは両辺を 2d で割って、
となる。(
こういう感じ...?
プロジェクトページにはこういう感じに書いてある:
https://kvfrans.com/shortcut-models/
ロス関数
学習の詳細
実際のサンプルの利用
ショートカットの決め方
ショートカットのサイズ
ロス関数中に出てくる、
あと、
学習の効率化
ロス関数には、実データを用いた Flow-matching 項と自己一貫項があるが、自己一貫の項は決定論的(自身の予測を元にさらに予測するため)なので、バッチ内で Flow-matching 項の割合を高くする方が良いとしている。
特に、自己一貫項は計算するために2回の forward
を行う必要があるため、なるべく少なく済むと嬉しい。論文だとバッチ内の
離散的なタイムステップ
ショートカットの学習の負荷を減らすために、ショートカットの開始地点のタイムステップ
Consistency Models との比較
Consistency Models と似てるが、一部異なる。
- Consistency は、経験的な
とx_t ペアのみで学習を行う (自身を目標に自身を改善する?一方 Shortcut ではx_d を使い、これはx'_{t+d} backward
パスに入らない) ため、それぞれの離れたステップで取り返しのつかない誤差が積み重なっていく (小さいステップでの学習もない) - Shortcut では
種類のみのステップサイズを学習する (実験ではlog_2(T) )が、Consistency はT = 128 種類学習することになる。T - Shortcut ではショートカットしない普通の生成も可能だが、Concsistency はそうではない(らしい)
- Shortcut の方がやってることがシンプル。Concistency はいろんなトリックが必要らしい?
Consistency の参考:
Classifier-Free Guidance (CFG) との組み合わせ
蒸留モデルにありがちな CFG どうするのか問題も少し言及されていた。CFG は条件なし生成と条件あり生成を比較して線形近似することで実現される。
Shotcut では小さいステップサイズ (小さい
そのため、検証では
Shortcut の CFG の課題は CFG scale を学習前に指定する必要があることとされている。
↑ CFG の問題で連想するのが、
1. LCM-LoRA (Latent Consistency Model LoRA)
元論文 は CFG を操作するためのモジュールが追加されているけど、diffusers による LoRA は学習時に固定して、追加モジュールなくても CFG 適用状態になるようにしているらしい
参考:
2. Flux.1-dev の guidance 蒸留
dev は ガイダンス蒸留(内容不明) をされているのだが、その実装が Shortcut と似ている気がする...
Shortcut はタイムステップ
一方の Flux はタイムステップ guidance
を受け取り、同様にタイムステップ埋め込みレイヤーに通されて、通常のタイムステップ埋め込みに足し合わされている。
構造がめっちゃ似てるので、Flux の構造を真似して、ガイダンスを受け取りながら少ないステップ数で生成できるような学習も似た感じにできるんじゃないだろうか...?
タイムステップ
となる。Shortcut 同様の式を作るとしたら、2回分の
なので、両辺
になる。
なので、単にガイダンス量
実装するとしたら、Shortcut と Flux の混ぜて、タイムステップ埋め込みに通した距離 guidance
を足せばいいのか...?
いや、CFGは無条件(
となってしまうが、
計算量が増えてしまう感じだろうか...
ちょっと色々認識が違ってたようなので修正
delta time
か duration
で呼ぶ方が適切っぽい。公式コードだと
なので、
実際の実装で torch.rand(low=0, high=1, size=(batch_size,))
ではなく、 torch.randint(low=0, high=128, size=(batch_size,)).float()/128
となる。)
タイムステップの取り方は ノイズ→画像
で 1→0
の実装と、0→1
になっている実装があるっぽくて、公式実装や公式 DiT は 0→1
だったが AuraFlow は 1→0
だったので実装時は符号を変える必要があった。つまり、0→1
でデノイズだけど、逆方向なら
なぜか知らないけど、Shortcut モデルが受け取るショートカット情報は、ショートカットする時間の長さじゃなくて、ショートカット回数に log2 を取った指数(多分、指数だけど、公式コードとかだと base
= 底 = 2 で表されていてちょっとよくわからない。理解が間違っているのだろうか...) を入力しているので、ショートカット回数 1 なら log2(1) = 0
で 0 が入力され、128 ステップ生成なら log2(128) = 7
が入力されることになる。
タイムステップ埋め込みなのに指数を使うのって、なんかあんまり直感的じゃない... 普通にショートカットする時間