🥞

Just image Transformer: ピクセル空間で実画像を予測するフローマッチングモデル

に公開

概要

https://arxiv.org/abs/2511.13720

  • JiT (Just image Transformer) は VAE を使わず、ピクセル空間上で flow-matching を行う
  • モデルは速度 (velocity) v を予測するよりも実画像 x を予測した方が性能が良い (x-pred)
  • ただしロスは、実画像 x とノイズ画像 z から作れる速度 v でロスを計算すると良い (v-loss)

はじめに

拡散による画像生成モデルは Stable Diffusion を筆頭として、U-Net ベースのモデルが主流でした。 派生の SDXL は、その取り回しの良さから 2025 年 12 月現在でもいまだに使われるベースモデルだと思います。

最近ではそれらに加えて、DiT から始まった Transformer をメインに用いた拡散モデルである、 Flux.1Qwen-ImageZ-Image がその生成画像の品質の高さからよく使われている印象があります。

これらの拡散モデルに共通しているのは、VAE を用いて潜在空間上で(潜在)画像を生成しているという点があります。

一方で今回紹介する Just image Transformer は VAE を使用せず、ピクセル空間上で実画像を予測 するモデルとなっています。この記事では基本的な前提知識をおさらいしながら、なぜこのようなことができたのか、どういう仕組みなのか説明していきます。

前提知識のおさらい

軽く前提となる技術や関連する手法について確認します。

  • 画像生成

    • DDPM: Denoising Diffusion Probabilistic Models
      • → 拡散モデルの提案
      • U-Net を採用
      • 画像がノイズになる過程の逆(逆拡散過程)を学習 (diffusion)
      • モデルはピクセル空間上でノイズ \epsilon を予測する (eps-pred)
        • 課題: 高解像度生成の計算量が多い
      • 予測したノイズと正解ノイズ \epsilon でロスを取る (eps-loss)
    • LDM: Latent Diffusion Model
      • VAE で画像を潜在空間に圧縮することで計算量削減!
      • U-Net に加えて、VAE を採用
      • 画像がノイズになる過程の逆(逆拡散過程)を学習 (diffusion)
      • モデルは潜在空間上のノイズ \epsilon を予測する (eps-pred)
      • 予測したノイズと正解ノイズ \epsilon でロスを取る (eps-loss)
    • 速度予測: Imagen Video, SD2.1, NAI Diffusion V3
      • → ノイズの代わりに速度を予測
      • U-Net に加えて、VAE を採用
      • 画像がノイズになる過程の逆(逆拡散過程)を学習 (diffusion)
      • モデルは潜在空間上の速度 v を予測する (v-pred)
      • 予測したノイズと正解速度 v でロスを取る (v-loss)
    • DiT: Scalable Diffusion Models with Transformers
      • → LDM の U-Net を Transformer にした
      • 改造した Transformer を採用、VAE も続投
      • VAE を採用して潜在空間に圧縮
      • 画像がノイズになる過程の逆(逆拡散過程)を学習 (diffusion)
      • モデルは潜在空間上のノイズ \epsilon を予測する (eps-pred)
      • 予測したノイズと正解ノイズ \epsilon でロスを取る (eps-loss)
    • MMDiT登場以降: SD3, Flux.1, Qwen-Image, Z-Image
      • → 計算式シンプルな flow-matching が流行
      • 各々で改造した Transformer を採用、VAE 続投
      • VAE を採用して潜在空間に圧縮
      • ノイズから画像になるフローを学習 (flow-matching)
      • モデルは潜在空間上の速度 v を予測する (v-pred)
      • 予測した速度と正解速度 v でロスを取る (v-loss)
  • 画像認識

なぜ VAE が使われる?

LDM から始まった VAE の採用は、高解像度画像生成での計算量削減を目的として導入されました。VAE はピクセル空間上の画像を潜在空間に圧縮し、その後モデルは潜在空間上で予測、デノイズを行います。

どれくらい圧縮されるのかというと、例えば SD-VAE は、画像を圧縮するとチャンネル数が 3 から 4 に増えますが、縦横サイズは 1 / 8 になります。結果的に 1 / 8 \times 1/8 \times 4/3 = 1/48 で、元の計算量の約 2% に抑えることができます。
また、最近頻繁に採用される Flux.1 VAE は、同様に縦横は 1/8 になりますがチャンネル数は 16 になるので、1/8 \times 1/8 \times 16/3 = 1/12 となり約 8% となります。

このように VAE を用いることで計算量を削減できるので、VAE を使わない場合と比べると同じ計算量で高解像度を生成できるようになります。当初の目的は計算量を削減することにありましたが、そもそもの VAE の品質が悪いと潜在空間からピクセル画像にデコードする際にボトルネックになり、本体がどう頑張っても細かい部分が潰れてしまうことがあります。そのため、VAE 自体の性能も生成される画像の品質に大きく影響を与えます。Flux.2 では VAE の品質改善を頑張っているみたいです。

Just image Transformer

https://arxiv.org/abs/2511.13720

https://github.com/LTH14/JiT


JiT-H/32で生成された512x512画像

論文の著者について紹介すると、Tianhong Li は MAR やそのピクセル空間版である FractalMAR の筆頭著者で、Kaiming He は ResNet の著者です。

以下、JiT でやっていることや仕組みを説明します。

多様体仮説

モデルの根幹の設計に関わるので、まず多様体仮説について説明します。


JiT論文 Figure 1より

多様体仮説は、(画像の文脈で言えば)高次元のピクセルスペースの中で、自然画像は低次元である「多様体」上に存在すると主張する仮説です。ノイズのないクリーンな画像 x は多様体上に存在するとみなせますが、ノイズ \epsilon や速度場 v (v = x - \epsilon で表される) は多様体上にないと考えることができます。

要は、私たちが一般に目にするような普通の画像は、広大なピクセルの組み合わせが考えられる中のある一部分にしか分布してないのに対して、ランダムなノイズやノイズの関わる表現はそうではない、ということだと思います。ランダムにピクセルを生成しても意味ある絵が出てくることはないですしね。

参考:
https://xtech.nikkei.com/dm/atcl/mag/15/00144/00031/

なので、高次元なノイズを予測するよりも実は低次元しかない実画像を予測した方がモデルにとって簡単なのではないか、という考え方が大事になってきます。

JiT ではこの発想が活かされています。

アーキテクチャ


シンプルなアーキテクチャ図

JiT では VAE を使わないので、このようにシンプルな構造になっています。

入力されるノイジーな画像はまず複数のパッチに分割されます。その後 Linear Embed (いわゆる Patch Embedding レイヤー)を通りチャンネル数が増やされた後、Transformer Block を何回か通り、最後に Linear Predict 層で元々のチャンネル数に変換してから、予測されたパッチを画像の形に戻します。

まずパッチ化でどのようなことをしているのか説明します。

パッチ化 (patchify)

パッチ化は画像を Transformer で効率的に扱うために行う処理で、ViT (Vision Transformer) という、Transformer で画像分類を行うモデルで採用されました。この処理自体は DiT でも用いられています


DiT論文 Figure 4より

画像のように、I \times I サイズの画像を、一辺 p のパッチに分割し、パッチをそれぞれトークンとして扱います。

もしパッチ化せずに1ピクセルを1トークンとして扱った場合は I^2 トークンになり、Transformer (というより Attention) はシーケンス長 T の 2 乗の計算量がかかるので、これは非常に処理が重くなることがわかります。
例えば 256x256 の画像であれば、T = 256 \times 256 = 65,536 トークン、さらに計算量は 65,536^2 = 4,294,967,296 となります。

一方でパッチに分割すればシーケンス長を短くすることができます。パッチサイズ p = 16 とすれば、256x256 解像度の画像は 16x16 個のパッチとして扱えるので、T = 16 \times 16 = 256 で、計算量は 256^2 = 65,536 で済みます。

ViT ではパッチサイズ 16 や 14 がよく使われます。DiT ではパッチサイズ 8, 4, 2 を試しており、パッチサイズを小さくするほど性能が上がる (計算量は増える)ことを実験で示しました。
個人的な解釈としては、DiT は既に VAE で 1/8 に圧縮されているので、そこからさらにパッチ化すると学習の難易度が上がるのかもしれません。


DiT論文 Figure 7を改変

MMDiT 以降の SD3, Flux.1[1], Qwen-Image, CogView4 などもこれに倣い、パッチサイズ 2 を採用しています。

対して、VAE を使わない JiT では主に パッチサイズ 16 で実験が行われました。この場合、256x256 解像度画像では 16 \times 16 = 256 トークンとなります。(An Image is Worth 16x16 Words というわけですね)

ボトルネック層

パッチ化した後、Transformer に通すために隠れ次元を揃える必要がありますが、直接隠れ次元まで射影するのではなく、一度ボトルネックとなる小さい次元に射影してから隠れ次元まで射影します。


各パッチごとの次元の変化。RGBの3チャンネルを持つ256x256の画像は、まず3チャンネル16x16のパッチ256個に分けられる。各パッチは チャンネルxパッチサイズxパッチサイズ (3x16x16) に展開され、ボトルネック次元 128 に射影される。その後最終的に隠れ次元 768 に射影する。(展開時の次元と隠れ次元が一致しているのは偶然)

つまり Patch Embedding 層で 1 層のみの線型層で隠れ次元に射影するのではなく、2層用意したうち 1 層目で一度低ランクに射影してから 2 層目で隠れ層に戻します。最終的に隠れ次元になるのは同じですが、ボトルネックがあることでランクが制限されることになります。

ただし、公式のコードでは Conv 層を使って計算しています:

ボトルネックのコード

やってることは上で説明していることと同じですが、Gemini 曰く、Conv を使ったほうが GPU で計算するときに効率がいいそうです。

https://github.com/LTH14/JiT/blob/cbc743a2ada5e9762697da2c83f8c4f8379e8c17/model_jit.py#L17-L37

特に注目すべきは proj1proj2 の間に活性化関数がないことです。MLP であれば活性化関数を噛ませるのが普通ですが、ここでは挟んでいないので単に低ランクな行列で射影していることになります。
MobileNetV2 によると、活性化関数を挟まないことで多様体を破壊せずに済むという利点があるそうです。

以下のグラフはボトルネックのサイズを変更した際の品質の変化を表しています。(FIDは低いほど性能が意味します。)


JiT論文 Figure 4より。点線が引かれているのがボトルネックを使用しない場合のベースライン。青いグラフで表されているのは、ボトルネックのサイズを変更した際のFID。

JiT-B/16 をベースにしているので、16x16 パッチはそのまま扱うと 768 次元になります。この際、ボトルネックのサイズを 16 まで落とした時、多少の性能の劣化は発生したものの壊滅的な崩壊をすることはなく、また、ボトルネックサイズが 32~512 の時は、ボトルネックを使わない場合よりも性能が向上しました。

次元を低くした方が性能が良くなるというのは一見直感に反しますが、低次元表現を学習する際にはよく使われる手法だそうです。
個人的には LoRA が連想されて、暗記を防いで汎化しやすくなる効果がありそうな気がしました。また、MobileNetV2 で採用された線形ボトルネックの手法と共通していることからも、「実は自然データ表現の分解に必要な次元はそんなに多くない」ということが言えそうです。これも多様体仮説に沿っていて、実画像の分布が低次元で表し切れるものだということが示唆されます。

MobileNetV2 参考:

https://deepsquare.jp/2020/06/mobilenet-v2/#outline__2_2_2

また、これを支持する実験として、モデルサイズ base と画像シーケンス長を固定し、パッチサイズと画像解像度を変更した時の比較を以下に示します:


JiT論文 Table 5より。画像解像度とパッチサイズ (\text{B}/1616 の部分) のみを変えた時の FID の変化。

1024x1024 解像度でパッチサイズ 64 を採用する場合、ボトルネック (128 次元) を通る前の1パッチ自体の次元数は 3 \times 64 \times 64 = 12288 となります。パッチサイズ 1632 の時と比べて圧倒的に次元が高くなるわけですが、それにも関わらず FID は許容できる程度しか劣化してなく、\epsilon,v\text{-pred} のような崩壊を起こしていません

多少劣化するものの、同じような計算量でより大きい解像度の画像を生成できるわけですから、パッチサイズを大きくするのは有力な選択肢となるでしょう。

実画像を予測するフローマッチング

JiT はフローマッチングを行うモデルです。

通常のフローマッチング

フローマッチングでは、ノイズ \epsilon の分布から実画像 x の分布へのフローを学習します。フローの途中の時刻 t におけるノイジーな画像 z_t は、ノイズ \epsilon と実画像 x の線形補完で表せて、

z_t = tx + (1 - t) \epsilon

となります。つまり、時刻 t = 0z_t は完全なノイズ、時刻 t = 1z_t はクリーンな実画像になります。ここで、一般的なフローマッチングモデルでは、時刻 t におけるノイズから実画像への速度場 v = x - \epsilon を学習するのが主流です。速度場 v はノイジー画像 z_t を実画像 t で微分して、

\begin{aligned} v = z_t' &= t' x + (1 - t)' \epsilon \\ &= x - \epsilon \end{aligned}

としても求まります。

小学生にもわかるように説明して

ノイズ \epsilon から出発して、ゴール地点である画像 x まで 1 の時間 t をかけて移動することを考えます。その時の移動する距離は x - \epsilon です。
速度と距離、時間の関係式は \text{速度} = \text{距離}/\text{時間} であり、ここで距離は x - \epsilon、時間 t = 1 なので移動速度 v

\begin{aligned} v &= (x - \epsilon) / t \\ &= (x - \epsilon) / 1 \\ &= x - \epsilon \end{aligned}

となります。

現在の位置 z_t で速度 v が分かれば、x にたどり着くための向きと速さがわかるので z_t + v (1 - t) を計算すれば、

\begin{aligned} z_t + v (1 - t) &= \left[tx + (1 - t)\epsilon\right] + (1 - t)(x - \epsilon) \\ &= (t +1 - t) x + \left[1 - t - (1 - t)\right] \epsilon \\ &= x \end{aligned}

としてクリーンな画像 x を得ることができます。

そういうわけで、フローマッチングでは時刻 t 時点のノイジー画像 z_t から 速度 v を求めることができれば解くことができます。それに合わせてモデルが速度 v を予測、つまりロス関数 \mathcal{L}_{t,x,z} はモデルの予測速度 v_0 と正解の速度 v の二乗誤差をとった、

\mathcal{L}_{t,x,z} = \mathbb{E}\left\|v_\theta(z_t, t) - v\right\|^2

とするのが素直な考え方です。

しかし、これまでの議論からそうすべきではない、と言いたいことが伝わるのではないでしょうか。多様体仮説が本当なら、速度 v を予測するよりも実画像 x を予測すべきですが、本当に簡単になるのか確かめてみたいです。

モデルの予測対象の実験

モデルが実データ x を予測する方が本当に簡単なのかを確かめるため、JiTでは単純なトイデータセットを使った実験がされています。(ここで使っているロスは全て v\text{-loss} です (後述))

以下は ReLU を用いた 5 層の MLP (隠れ次元 256) を用いて、多様体を想定した D 次元の合成データを学習した際の生成結果を示しています。

データ設定
  • 多様体データ: \hat{x} \in \mathbb{R}^d (dD よりも小さく、常に低次元。実験では d = 2)
  • 観測できるデータ: x = P\hat{x} \in \mathbb{R}^D (D 次元であり、実験では D \in \{2, 8, 16, 512\})

P_{d \times D}d 次元から D 次元に射影する行列ですが、P^\top P = I_{d \times d} となる列直行の行列です。この行列は学習中はランダムに初期化されて固定されます。

このように設定することで、実際は低次元だが観測できるのは高次元空間、という多様体仮説の仮定を再現しています。


JiT論文 Figure 2より。D 次元が小さい時は、x\text{-pred}, \epsilon\text{-pred}, v\text{-pred} 関わらず成功しているが、D 次元が大きくなるにつれ \epsilon\text{-pred}, v\text{-pred} は崩壊し、x\text{-pred} のみが納得のいく生成ができている。

D = 512 の時、MLP の隠れ次元 256 よりも大きくなっており、表現力不足になりそうにも関わらず、x\text{-pred} はきちんと生成できており、これは実データ予測が多様体上の低次元な空間を予測することに長けていることを支持する結果になっています。

さまざまなロスの取り方

実データ x を予測すると良さそうなのがわかったのですが、必ずしも x でロスを取る必要はない ありません。具体的には実画像 x を予測して、それらからノイズ \epsilon や速度 vを計算することができます


JiT論文 Table 1より。モデル \text{net}_\theta が予測するものが実画像 x_\theta, \epsilon_\theta, v_\theta だった時のロスの組み合わせ(他のパラメータの作り方)を表している。

xθ から vθ を計算する例​

まず先ほど、

  • z_t = tx + (1 - t) \epsilon
  • v = x - \epsilon

であることを確認しています。まず不明なパラメータであるノイズ \epsilon の式を作ります。(これはモデルの入力にないため、推論時は知り得ない情報になります)
z_t の式を変形すると、

\begin{aligned} z_t &= tx + (1 - t) \epsilon \\ (1 - t) \epsilon &= z_t - tx \\ \epsilon &= (z_t - tx) / (1 - t) \end{aligned}

となります。これを2つ目の式に代入します。

\begin{aligned} v_\theta &= x - (z_t - tx) / (1 - t) \\ &= \frac{x(1-t)}{1 - t} - \frac{z_t - tx}{1 - t} \\ &= \frac{x(1 - t + t) - z_t}{1 - t} \\ &= (x - z_t) / (1 - t) \end{aligned}

xx_\theta に置き換えれば、

v_\theta = (x_\theta - z_t) / (1 - t)

を得ることができます。ちゃんと表の式と同じになりました。

このように、モデルが実画像、ノイズ、速度のいずれかを予測したとき、既知の情報のみから実画像、ノイズ、速度を導出することが可能 というわけです。ロスターゲットも同様に計算できるので、任意のターゲットとロスを計算することができます。そのため、モデルは必ずしも直接知りたいパラメータを直接予測しなくても、別のパラメータを予測して間接的に同じ計算を実現することができるのです。

しかし、どのような組み合わせが良いのでしょうか? 実画像を予測した方が簡単と何度も言ってきたので、実画像を予測して実画像でロスを取るのが良いのでしょうか?

実は少し違います。以下は、JiT-B をベースに ImageNet を使って、さまざまなモデル予測とロスの組み合わせで学習した時の FID スコアを比較したものになります。


JiT論文 Table 2より。(b) 64x64解像度 ImageNet をパッチサイズ 4 の JiT-B/4 で学習したところ、x\text{-pred}, \epsilon\text{-pred}, v\text{-pred} かかわらず一様に低いFID を獲得しており、特に \epsilon\text{-loss}, v\text{-loss} の FID が低くなっていることがわかる。(a) 一方、256x256解像度 ImageNet をパッチサイズ 16 のJiT-B/16 で学習した場合、\epsilon\text{-pred}, v\text{-pred}x\text{-pred}に比べて非常に FID が高くなっており、学習が壊滅的に失敗していることを示している。 x\text{-pred} では v\text{-loss} でさらに FID が低くなっているが、これは \epsilon\text{-pred}, v\text{-pred} の崩壊を完全に防ぐことはできていない。

この表を読むと、

  • モデルの予測はノイズ予測 \epsilon\text{-pred} や速度予測 v\text{-pred} よりも、実画像予測 x\text{-pred} の方が圧倒的に性能が良い
  • ロスのターゲットは、そのまま実画像 x やノイズ \epsilon を使うよりも、速度 v でロスをとった方が多少性能が良くなる

ということがわかります。また、v\text{-loss} は多少性能を上げるにしても \epsilon\text{-pred}, v\text{-pred} を崩壊から救うほどではない というのは、トイデータセットでの実験結果にも合致しています。

しかし、なぜこうなるのでしょうか?予測する分には実画像の方が多様体上にあって簡単、というのは何度も言ってきた通りでしたが、ロスは速度 v の時だけ他2つよりも良い 結果になっています。

先ほどのロスターゲットを計算する式を表した表を確認してみると、ロスのターゲットはそれぞれ以下のように表されます:

  • モデル予測(実画像): x_\theta = \text{net}_\theta(z_t, t) (t は時刻、z_t はその時のノイジー画像)
  • 実画像: x_\theta (そのまま)
  • ノイズ: \epsilon_\theta = (z_t - t x_\theta) / (1 - t)
  • 速度: v_\theta = (x_\theta - z_t) / (1 - t)

論文の著者によると、速度ロスの場合はモデルの入力であるノイジー画像 z_t と差分を計算している部分がモデルの入力から出力までの長い残差接続として見ることができ、それを時刻 t でスケーリングしていると考えることができるそうです。

そう言うならノイズロス \epsilon_\theta = (z_t - t x_\theta) / (1 - t) でも z_t が入ってるじゃないかと思ったんですが、こっちはモデル予測 x_\theta に時刻 t がかかっているせいで、t \sim 0 の時はどれだけテキトーな予測しても正解になっちゃって効率が悪いからなんじゃないかと個人的に考えています。

ロスの実装

以上から、モデルの予測は実画像 xロスは速度 v で計算すればよいことがわかりました。論文で示されているロス関数は最終的に以下のようになります:

\mathcal{L} = \mathbb{E}_{t,x,\epsilon}\left\|v_\theta(z_t, t) - v\right\|^2, \\ \text{where:} \quad v_\theta(z_t, t) = (\mathtt{net}_\theta(z_t, t) - z_t) / (1 - t)

PyTorch での実装は以下のようになります:

https://github.com/LTH14/JiT/blob/cbc743a2ada5e9762697da2c83f8c4f8379e8c17/denoiser.py#L49-L65

個人的に、めちゃくちゃシンプルな変更だけで済んでいて良いなと思いました。

ここで、v_pred を作っている部分に注目すると、

v_pred = (x_pred - z) / (1 - t).clamp_min(self.t_eps)

.clamp_min(self.t_eps) をしており、(1 - t) の値が最低でも self.t_eps となるようにされています。これは、t \sim 1 の時に分母が 0 近くなって計算が不安定になるのを防ぐためで、論文中では self.t_eps = 0.05 が使われていました。

学習中は clamp することで安定して学習することができますが、推論時は clamp しないほうが品質が良くなるそうです:

https://github.com/LTH14/JiT/issues/24

これは推論時のステップ数を多くした際に、一度にデノイズする時刻が 0.05 よりも小さい時に誤って 0.05clamp してしまうと、正しい距離のデノイズとならないからだと思います。(25ステップ生成であれば、1ステップで 1/25 = 0.04 タイムステップ分デノイズする)

アーキテクチャ内部の変更点

JiT では DiT をベースにしていますが、以下の要素が取り入れられています:

  • SwiGLU
  • RMSNorm
  • RoPE
  • QKNorm
  • CFG interval
  • in-context クラストークン

以下はそれぞれを採用した際の性能の比較です:


JiT論文 Table 4より。ベースライン(SwiGLU, RMSNorm)に対して、RoPE, QKNorm, in-context class tokens の採用で FID が下がっている。カッコ内は CFG interval の適用ありの場合。

あまり見かけない CFG interval と in-context クラストークン について説明します。

CFG interval

https://arxiv.org/abs/2404.07724

https://github.com/kynkaat/guidance-interval

簡潔に言えば、CFG を適用する時刻を限定することで、高い CFG scale を使っても色が飽和したり多様性が失われることがなく、より詳細に生成されるようになる という手法です。


公式GitHubより

元の研究は diffusion 向けでしたが、flow-matching でも使えるようで、JiT 論文中と公式実装では時刻 t \in [0.1, 1] のときのみ CFG を適用するようになっています。

https://github.com/LTH14/JiT/blob/cbc743a2ada5e9762697da2c83f8c4f8379e8c17/denoiser.py#L101-L102

in-context クラストークン

ViT ではパッチシーケンスと一緒に、画像のクラスを表すクラストークンを1つだけ追加して Transformer に入力します。(シーケンスに含めるのを in-context と言う)。また、DiT では同様に in-context でクラスや時刻情報を処理する方法も試されていましたが、最終的には AdaLayerNorm-Zero でのみ処理するようになっています。

これらに対して JiT では、クラストークンを何回か複製し、専用の位置埋め込みを足し合わせた上でシーケンスに含める ということを行なっています。論文中ではクラストークンを 32 個繰り返すようになっています。

https://github.com/LTH14/JiT/blob/cbc743a2ada5e9762697da2c83f8c4f8379e8c17/model_jit.py#L349-L351

これらに加えて、Transformer ブロックの最初のブロックからこれらの条件を加えるのではなく、途中から差し込むと効果が良かったそうです。(この方法を in-context start block と呼ぶ)。実験ではモデルのブロック数に応じておよそ 1/3 のブロック番目に差し込むようになっています。
が、論文中ではこのトリックの有無による比較は見つかりませんでした。

学習中の工夫

タイムステップサンプリング

JiT ではタイムステップ t を以下のようにサンプリングしています:

https://github.com/LTH14/JiT/blob/cbc743a2ada5e9762697da2c83f8c4f8379e8c17/denoiser.py#L45-L47

論文中では P_mean = -0.8P_std = 0.8 が使われています。

分布をグラフにしてみると以下のようになります:


ChatGPTが上のサンプリング関数を元に作成

可視化コード by ChatGPT
# Visualize the distribution of t = sigmoid(z) where z ~ N(P_mean, P_std^2)

import torch
import matplotlib.pyplot as plt

def sample_t(n: int, P_std: float, P_mean: float, device=None) -> torch.Tensor:
    z = torch.randn(n, device=device) * P_std + P_mean
    return torch.sigmoid(z)

# Parameters
P_std = 0.8
P_mean = -0.8
n = 100000

# Sample
t = sample_t(n, P_std, P_mean)

# Plot
plt.figure()
plt.hist(t.numpy(), bins=100, density=True)
plt.xlabel("t = sigmoid(z)")
plt.ylabel("Density")
plt.title("Distribution of t for P_std=0.8, P_mean=-0.8")
plt.show()

グラフを見るとわかるように t が左寄りになっている → ノイズに近い方を重点的に学習 する、となっています。実画像に近い方は予測が簡単そうなので、難しいノイズ寄りを優先して学習していると言うことだと思います。

ノイズスケール

学習や推論で使用するノイズは基本的に正規分布からサンプリングしますが、解像度に応じてスケールが行われています:

https://github.com/LTH14/JiT/blob/cbc743a2ada5e9762697da2c83f8c4f8379e8c17/denoiser.py#L53

self.noise_scale1.0 \times \text{image\_size} / 256 となっており、256x256 解像度であればそのまま 1.0 ですが、512x512 解像度なら 2.0 になります。

まとめ

JiT は多様体仮説に基づいて、以下を示しました:

  • 実データ予測 (x\text{-pred}) が最重要
  • 速度ロス (v\text{-loss}) は効果的 だが、クリティカルではない (\epsilon\text{-pred}v\text{-pred} を救うことはできない)
  • ボトルネックは効果的

感想

個人的に、今まで「数式に沿っていればモデルの予測やロスターゲットがどうであれ差が出ない」「フローマッチングは速度を予測して当然」「実画像予測は難しいからやってないだけ」「VAE ないと高解像度生成は計算量増えすぎて無理」と思い込んでたものが全部ひっくり返されたので、とても面白いです。
また、モダンな改善を取り入れながらもシンプルな変更だけでピクセル空間での生成を実現しているのも最高にクールですね。同時期に PixelDiT などのピクセル空間生成の手法がいくつか出ているのですが、それらに比べてもシンプルな手法でピクセル空間での生成が実現できているのはカッコいい。

脚注
  1. 公開コード等では patch_size = 1 のようにも見えるが、packing という別の名で patch_size = 2同等の処理をしている。 ↩︎

GitHubで編集を提案

Discussion