Just image Transformer: ピクセル空間で実画像を予測するフローマッチングモデル
概要
- JiT (Just image Transformer) は VAE を使わず、ピクセル空間上で flow-matching を行う
- モデルは速度 (velocity)
を予測するよりも実画像v を予測した方が性能が良い (x x-pred) - ただしロスは、実画像
とノイズ画像x から作れる速度z でロスを計算すると良い (v v-loss)
はじめに
拡散による画像生成モデルは Stable Diffusion を筆頭として、U-Net ベースのモデルが主流でした。 派生の SDXL は、その取り回しの良さから 2025 年 12 月現在でもいまだに使われるベースモデルだと思います。
最近ではそれらに加えて、DiT から始まった Transformer をメインに用いた拡散モデルである、 Flux.1 や Qwen-Image、 Z-Image がその生成画像の品質の高さからよく使われている印象があります。
これらの拡散モデルに共通しているのは、VAE を用いて潜在空間上で(潜在)画像を生成しているという点があります。
一方で今回紹介する Just image Transformer は VAE を使用せず、ピクセル空間上で実画像を予測 するモデルとなっています。この記事では基本的な前提知識をおさらいしながら、なぜこのようなことができたのか、どういう仕組みなのか説明していきます。
前提知識のおさらい
軽く前提となる技術や関連する手法について確認します。
-
画像生成
- DDPM: Denoising Diffusion Probabilistic Models
- → 拡散モデルの提案
- U-Net を採用
- 画像がノイズになる過程の逆(逆拡散過程)を学習 (diffusion)
- モデルはピクセル空間上でノイズ
を予測する (\epsilon eps-pred)- 課題: 高解像度生成の計算量が多い
- 予測したノイズと正解ノイズ
でロスを取る (\epsilon eps-loss)
- LDM: Latent Diffusion Model
- → VAE で画像を潜在空間に圧縮することで計算量削減!
- U-Net に加えて、VAE を採用
- 画像がノイズになる過程の逆(逆拡散過程)を学習 (diffusion)
- モデルは潜在空間上のノイズ
を予測する (\epsilon eps-pred) - 予測したノイズと正解ノイズ
でロスを取る (\epsilon eps-loss)
- 速度予測: Imagen Video, SD2.1, NAI Diffusion V3 等
- → ノイズの代わりに速度を予測
- U-Net に加えて、VAE を採用
- 画像がノイズになる過程の逆(逆拡散過程)を学習 (diffusion)
- モデルは潜在空間上の速度
を予測する (v v-pred) - 予測したノイズと正解速度
でロスを取る (v v-loss)
- DiT: Scalable Diffusion Models with Transformers
- → LDM の U-Net を Transformer にした
- 改造した Transformer を採用、VAE も続投
- VAE を採用して潜在空間に圧縮
- 画像がノイズになる過程の逆(逆拡散過程)を学習 (diffusion)
- モデルは潜在空間上のノイズ
を予測する (\epsilon eps-pred) - 予測したノイズと正解ノイズ
でロスを取る (\epsilon eps-loss)
- MMDiT登場以降: SD3, Flux.1, Qwen-Image, Z-Image 等
- → 計算式シンプルな flow-matching が流行
- 各々で改造した Transformer を採用、VAE 続投
- VAE を採用して潜在空間に圧縮
- ノイズから画像になるフローを学習 (flow-matching)
- モデルは潜在空間上の速度
を予測する (v v-pred) - 予測した速度と正解速度
でロスを取る (v v-loss)
- DDPM: Denoising Diffusion Probabilistic Models
-
画像認識
- ViT: An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale
- 画像認識タスクで Transformer を採用した
- ピクセル空間の画像を小さなパッチに分割してから処理
- ViT: An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale
なぜ VAE が使われる?
LDM から始まった VAE の採用は、高解像度画像生成での計算量削減を目的として導入されました。VAE はピクセル空間上の画像を潜在空間に圧縮し、その後モデルは潜在空間上で予測、デノイズを行います。
どれくらい圧縮されるのかというと、例えば SD-VAE は、画像を圧縮するとチャンネル数が
また、最近頻繁に採用される Flux.1 VAE は、同様に縦横は
このように VAE を用いることで計算量を削減できるので、VAE を使わない場合と比べると同じ計算量で高解像度を生成できるようになります。当初の目的は計算量を削減することにありましたが、そもそもの VAE の品質が悪いと潜在空間からピクセル画像にデコードする際にボトルネックになり、本体がどう頑張っても細かい部分が潰れてしまうことがあります。そのため、VAE 自体の性能も生成される画像の品質に大きく影響を与えます。Flux.2 では VAE の品質改善を頑張っているみたいです。
Just image Transformer
論文の著者について紹介すると、Tianhong Li は MAR やそのピクセル空間版である FractalMAR の筆頭著者で、Kaiming He は ResNet の著者です。
以下、JiT でやっていることや仕組みを説明します。
多様体仮説
モデルの根幹の設計に関わるので、まず多様体仮説について説明します。
多様体仮説は、(画像の文脈で言えば)高次元のピクセルスペースの中で、自然画像は低次元である「多様体」上に存在すると主張する仮説です。ノイズのないクリーンな画像
要は、私たちが一般に目にするような普通の画像は、広大なピクセルの組み合わせが考えられる中のある一部分にしか分布してないのに対して、ランダムなノイズやノイズの関わる表現はそうではない、ということだと思います。ランダムにピクセルを生成しても意味ある絵が出てくることはないですしね。
参考:
なので、高次元なノイズを予測するよりも実は低次元しかない実画像を予測した方がモデルにとって簡単なのではないか、という考え方が大事になってきます。
JiT ではこの発想が活かされています。
アーキテクチャ
JiT では VAE を使わないので、このようにシンプルな構造になっています。
入力されるノイジーな画像はまず複数のパッチに分割されます。その後 Linear Embed (いわゆる Patch Embedding レイヤー)を通りチャンネル数が増やされた後、Transformer Block を何回か通り、最後に Linear Predict 層で元々のチャンネル数に変換してから、予測されたパッチを画像の形に戻します。
まずパッチ化でどのようなことをしているのか説明します。
パッチ化 (patchify)
パッチ化は画像を Transformer で効率的に扱うために行う処理で、ViT (Vision Transformer) という、Transformer で画像分類を行うモデルで採用されました。この処理自体は DiT でも用いられています。
画像のように、
もしパッチ化せずに1ピクセルを1トークンとして扱った場合は
例えば 256x256 の画像であれば、
一方でパッチに分割すればシーケンス長を短くすることができます。パッチサイズ
ViT ではパッチサイズ 16 や 14 がよく使われます。DiT ではパッチサイズ
個人的な解釈としては、DiT は既に VAE で
MMDiT 以降の SD3, Flux.1[1], Qwen-Image, CogView4 などもこれに倣い、パッチサイズ 2 を採用しています。
対して、VAE を使わない JiT では主に パッチサイズ 16 で実験が行われました。この場合、256x256 解像度画像では
ボトルネック層
パッチ化した後、Transformer に通すために隠れ次元を揃える必要がありますが、直接隠れ次元まで射影するのではなく、一度ボトルネックとなる小さい次元に射影してから隠れ次元まで射影します。

各パッチごとの次元の変化。RGBの3チャンネルを持つ256x256の画像は、まず3チャンネル16x16のパッチ256個に分けられる。各パッチは チャンネルxパッチサイズxパッチサイズ (3x16x16) に展開され、ボトルネック次元 128 に射影される。その後最終的に隠れ次元 768 に射影する。(展開時の次元と隠れ次元が一致しているのは偶然)
つまり Patch Embedding 層で 1 層のみの線型層で隠れ次元に射影するのではなく、2層用意したうち 1 層目で一度低ランクに射影してから 2 層目で隠れ層に戻します。最終的に隠れ次元になるのは同じですが、ボトルネックがあることでランクが制限されることになります。
ただし、公式のコードでは Conv 層を使って計算しています:
ボトルネックのコード
やってることは上で説明していることと同じですが、Gemini 曰く、Conv を使ったほうが GPU で計算するときに効率がいいそうです。
特に注目すべきは proj1 と proj2 の間に活性化関数がないことです。MLP であれば活性化関数を噛ませるのが普通ですが、ここでは挟んでいないので単に低ランクな行列で射影していることになります。
MobileNetV2 によると、活性化関数を挟まないことで多様体を破壊せずに済むという利点があるそうです。
以下のグラフはボトルネックのサイズを変更した際の品質の変化を表しています。(FIDは低いほど性能が意味します。)

JiT論文 Figure 4より。点線が引かれているのがボトルネックを使用しない場合のベースライン。青いグラフで表されているのは、ボトルネックのサイズを変更した際のFID。
JiT-B/16 をベースにしているので、16x16 パッチはそのまま扱うと 768 次元になります。この際、ボトルネックのサイズを 16 まで落とした時、多少の性能の劣化は発生したものの壊滅的な崩壊をすることはなく、また、ボトルネックサイズが 32~512 の時は、ボトルネックを使わない場合よりも性能が向上しました。
次元を低くした方が性能が良くなるというのは一見直感に反しますが、低次元表現を学習する際にはよく使われる手法だそうです。
個人的には LoRA が連想されて、暗記を防いで汎化しやすくなる効果がありそうな気がしました。また、MobileNetV2 で採用された線形ボトルネックの手法と共通していることからも、「実は自然データ表現の分解に必要な次元はそんなに多くない」ということが言えそうです。これも多様体仮説に沿っていて、実画像の分布が低次元で表し切れるものだということが示唆されます。
MobileNetV2 参考:
また、これを支持する実験として、モデルサイズ base と画像シーケンス長を固定し、パッチサイズと画像解像度を変更した時の比較を以下に示します:

JiT論文 Table 5より。画像解像度とパッチサイズ (
1024x1024 解像度でパッチサイズ
多少劣化するものの、同じような計算量でより大きい解像度の画像を生成できるわけですから、パッチサイズを大きくするのは有力な選択肢となるでしょう。
実画像を予測するフローマッチング
JiT はフローマッチングを行うモデルです。
通常のフローマッチング
フローマッチングでは、ノイズ
となります。つまり、時刻
としても求まります。
小学生にもわかるように説明して
ノイズ
速度と距離、時間の関係式は
となります。
現在の位置
としてクリーンな画像
そういうわけで、フローマッチングでは時刻
とするのが素直な考え方です。
しかし、これまでの議論からそうすべきではない、と言いたいことが伝わるのではないでしょうか。多様体仮説が本当なら、速度
モデルの予測対象の実験
モデルが実データ
以下は ReLU を用いた 5 層の MLP (隠れ次元 256) を用いて、多様体を想定した
データ設定
- 多様体データ:
(\hat{x} \in \mathbb{R}^d はd よりも小さく、常に低次元。実験ではD )d = 2 - 観測できるデータ:
(x = P\hat{x} \in \mathbb{R}^D 次元であり、実験ではD )D \in \{2, 8, 16, 512\}
このように設定することで、実際は低次元だが観測できるのは高次元空間、という多様体仮説の仮定を再現しています。

JiT論文 Figure 2より。
さまざまなロスの取り方
実データ

JiT論文 Table 1より。モデル
xθ から vθ を計算する例
まず先ほど、
z_t = tx + (1 - t) \epsilon v = x - \epsilon
であることを確認しています。まず不明なパラメータであるノイズ
となります。これを2つ目の式に代入します。
を得ることができます。ちゃんと表の式と同じになりました。
このように、モデルが実画像、ノイズ、速度のいずれかを予測したとき、既知の情報のみから実画像、ノイズ、速度を導出することが可能 というわけです。ロスターゲットも同様に計算できるので、任意のターゲットとロスを計算することができます。そのため、モデルは必ずしも直接知りたいパラメータを直接予測しなくても、別のパラメータを予測して間接的に同じ計算を実現することができるのです。
しかし、どのような組み合わせが良いのでしょうか? 実画像を予測した方が簡単と何度も言ってきたので、実画像を予測して実画像でロスを取るのが良いのでしょうか?
実は少し違います。以下は、JiT-B をベースに ImageNet を使って、さまざまなモデル予測とロスの組み合わせで学習した時の FID スコアを比較したものになります。

JiT論文 Table 2より。(b) 64x64解像度 ImageNet をパッチサイズ 4 の JiT-B/4 で学習したところ、
この表を読むと、
-
モデルの予測はノイズ予測
や速度予測\epsilon\text{-pred} よりも、実画像予測v\text{-pred} の方が圧倒的に性能が良いx\text{-pred} -
ロスのターゲットは、そのまま実画像
やノイズx を使うよりも、速度\epsilon でロスをとった方が多少性能が良くなるv
ということがわかります。また、
しかし、なぜこうなるのでしょうか?予測する分には実画像の方が多様体上にあって簡単、というのは何度も言ってきた通りでしたが、ロスは速度
先ほどのロスターゲットを計算する式を表した表を確認してみると、ロスのターゲットはそれぞれ以下のように表されます:
-
モデル予測(実画像):
(x_\theta = \text{net}_\theta(z_t, t) は時刻、t はその時のノイジー画像)z_t -
実画像:
(そのまま)x_\theta -
ノイズ:
\epsilon_\theta = (z_t - t x_\theta) / (1 - t) -
速度:
v_\theta = (x_\theta - z_t) / (1 - t)
論文の著者によると、速度ロスの場合はモデルの入力であるノイジー画像
そう言うならノイズロス
ロスの実装
以上から、モデルの予測は実画像
PyTorch での実装は以下のようになります:
個人的に、めちゃくちゃシンプルな変更だけで済んでいて良いなと思いました。
ここで、v_pred を作っている部分に注目すると、
v_pred = (x_pred - z) / (1 - t).clamp_min(self.t_eps)
.clamp_min(self.t_eps) をしており、(1 - t) の値が最低でも self.t_eps となるようにされています。これは、self.t_eps = 0.05 が使われていました。
学習中は clamp することで安定して学習することができますが、推論時は clamp しないほうが品質が良くなるそうです:
これは推論時のステップ数を多くした際に、一度にデノイズする時刻が 0.05 よりも小さい時に誤って 0.05 に clamp してしまうと、正しい距離のデノイズとならないからだと思います。(25ステップ生成であれば、1ステップで
アーキテクチャ内部の変更点
JiT では DiT をベースにしていますが、以下の要素が取り入れられています:
- SwiGLU
- RMSNorm
- RoPE
- QKNorm
- CFG interval
- in-context クラストークン
以下はそれぞれを採用した際の性能の比較です:

JiT論文 Table 4より。ベースライン(SwiGLU, RMSNorm)に対して、RoPE, QKNorm, in-context class tokens の採用で FID が下がっている。カッコ内は CFG interval の適用ありの場合。
あまり見かけない CFG interval と in-context クラストークン について説明します。
CFG interval
簡潔に言えば、CFG を適用する時刻を限定することで、高い CFG scale を使っても色が飽和したり多様性が失われることがなく、より詳細に生成されるようになる という手法です。

公式GitHubより
元の研究は diffusion 向けでしたが、flow-matching でも使えるようで、JiT 論文中と公式実装では時刻
in-context クラストークン
ViT ではパッチシーケンスと一緒に、画像のクラスを表すクラストークンを1つだけ追加して Transformer に入力します。(シーケンスに含めるのを in-context と言う)。また、DiT では同様に in-context でクラスや時刻情報を処理する方法も試されていましたが、最終的には AdaLayerNorm-Zero でのみ処理するようになっています。
これらに対して JiT では、クラストークンを何回か複製し、専用の位置埋め込みを足し合わせた上でシーケンスに含める ということを行なっています。論文中ではクラストークンを 32 個繰り返すようになっています。
これらに加えて、Transformer ブロックの最初のブロックからこれらの条件を加えるのではなく、途中から差し込むと効果が良かったそうです。(この方法を in-context start block と呼ぶ)。実験ではモデルのブロック数に応じておよそ 1/3 のブロック番目に差し込むようになっています。
が、論文中ではこのトリックの有無による比較は見つかりませんでした。
学習中の工夫
タイムステップサンプリング
JiT ではタイムステップ
論文中では P_mean = -0.8、P_std = 0.8 が使われています。
分布をグラフにしてみると以下のようになります:

ChatGPTが上のサンプリング関数を元に作成
可視化コード by ChatGPT
# Visualize the distribution of t = sigmoid(z) where z ~ N(P_mean, P_std^2)
import torch
import matplotlib.pyplot as plt
def sample_t(n: int, P_std: float, P_mean: float, device=None) -> torch.Tensor:
z = torch.randn(n, device=device) * P_std + P_mean
return torch.sigmoid(z)
# Parameters
P_std = 0.8
P_mean = -0.8
n = 100000
# Sample
t = sample_t(n, P_std, P_mean)
# Plot
plt.figure()
plt.hist(t.numpy(), bins=100, density=True)
plt.xlabel("t = sigmoid(z)")
plt.ylabel("Density")
plt.title("Distribution of t for P_std=0.8, P_mean=-0.8")
plt.show()
グラフを見るとわかるように
ノイズスケール
学習や推論で使用するノイズは基本的に正規分布からサンプリングしますが、解像度に応じてスケールが行われています:
self.noise_scale は 1.0 ですが、512x512 解像度なら 2.0 になります。
まとめ
JiT は多様体仮説に基づいて、以下を示しました:
- 実データ予測 (
) が最重要x\text{-pred} -
速度ロス (
) は効果的 だが、クリティカルではない (v\text{-loss} や\epsilon\text{-pred} を救うことはできない)v\text{-pred} - ボトルネックは効果的
感想
個人的に、今まで「数式に沿っていればモデルの予測やロスターゲットがどうであれ差が出ない」「フローマッチングは速度を予測して当然」「実画像予測は難しいからやってないだけ」「VAE ないと高解像度生成は計算量増えすぎて無理」と思い込んでたものが全部ひっくり返されたので、とても面白いです。
また、モダンな改善を取り入れながらもシンプルな変更だけでピクセル空間での生成を実現しているのも最高にクールですね。同時期に PixelDiT などのピクセル空間生成の手法がいくつか出ているのですが、それらに比べてもシンプルな手法でピクセル空間での生成が実現できているのはカッコいい。





Discussion