😽

拡散モデルの説明における『ノイズ除去』に対する誤解について

2023/12/20に公開

本記事は、筆者が初めてzenn.devで書いた記事であり、ぶっちゃけ練習みたいな側面を多分に含みます。ご容赦ください。

この記事が目指すもの

拡散モデルがどのようにデータを生成しているのかという原理を説明するとき、だいたい以下のような表現をすることが多いと思います。

完全なノイズからスタートし、少しずつノイズを除去することによって最終的にデータを生成する。

この説明は全く正しくて、私自身もこのように説明することが多いのですが、ここで述べている『ノイズ除去』が指すものの意味を誤解されることが非常に多いです。具体的には、

「なるほど、ノイズのついたデータが与えられて、ノイズがのる前のデータに相当する唯一の正解データを推定するんだな」

と解釈されてしまう場合が多々あります。が、残念ながらこちらは誤解で正しくありません。もう少し詳しく言うと、この解釈は拡散モデルの学習時の処理としては正しいと言えなくもないですが、生成時の処理としては正しくありません。したがって、拡散モデルによるデータ生成の原理を説明する上では誤りということになります。

この記事では、拡散モデルの原理を簡単なたとえ話で解説しながら、上記のような誤解がなぜ起こるのかについて説明します。

本記事で採用するたとえ話

本記事では、拡散モデルが前提としている拡散過程や逆拡散過程を『人生ゲーム』に例えて説明します。例えば拡散過程は、

  1. ノイズをランダムに決める
  2. 決めたノイズをデータに足す

という2段階の処理を繰り返すことによってデータに少しずつノイズを付加しますが、これは人生ゲームに例えると、

  1. 駒が進む数を(ルーレットを回して)ランダムに決める
  2. 決めた数を駒の位置に足す(つまり、駒が進む)

という手順に対応しています。ただし、この人生ゲームは通常のものとは異なり、より現実味を持たせるために以下のようなルールが採用されています。

  • 人は生まれを選べない。スタート地点がたくさんあり、ランダムに選んだところから始める
  • 人生にゴールはない。ゴール地点が無く、代わりに一定のターン数が経過したらゲーム終了

これらのルールは後ほど重要になりますが、とりあえず今は忘れてもらってもかまいません。とにかく重要なことは、「駒の場所がデータに対応していること」と「ルーレットを回した結果がノイズに対応していること」です。

拡散過程はノイズの付加、逆拡散過程はノイズの除去に対応するので、人生ゲームの例えで言うと、

  • 拡散過程について考える = 次のターンに駒がどこにあるかを考える
  • 逆拡散過程について考える = 前のターンに駒がどこにあったかを考える

という風に対応します。

拡散過程:次のターンに駒はどこにあるか?

上で述べたように、データがノイズへと崩壊していく拡散過程は、人生ゲームでは駒が進んでいく過程に対応します。あなたの駒はボード上の適当な場所に存在していて、今まさにルーレットを回して駒を進めようとしています。さて、ここで問題です。

「次のターンに自分の駒がどこにいるか」を予測することはできるでしょうか?

すぐにわかりますが、これは「できる」とも「できない」とも言えます。ルーレットには1から10までの数字しか書かれていないので、次のターンに今の位置から10マス以内の位置にいることは確実です。したがって、この意味では予測できると言ってよいでしょう。一方で、その10マスのうちどの位置にいるかは、これから回すルーレットによってランダムに決まるので今はわかりません。その意味では予測できないとも言えます。

まとめると、次のターンの駒の位置(つまりノイズを付加した後のデータ)について以下のようなことがわかります。

  • 今の位置から10マス以内のどこかにいることはわかる
    (ノイズを乗せても元のデータからはそんなに離れることはできないことに対応)
  • その10マスのうちどこにいるかはわからないので、適当に予測するしかない
    (どんなノイズが乗るかはわからないことに対応)

逆拡散過程:前のターンに駒はどこにあったか?

さて、本題の逆拡散過程を考えます。あなたが人生ゲームで楽しんでいたところ、友人が現れたとしましょう。友人は今来たところなので、今のボード上の駒の位置はわかりますが、それまでの駒の動きは一切知りません。さて、ここで問題です。

友人は「前のターンにあなたの駒がどこにいたか」を予測することができるでしょうか?

少し考えるとわかりますが、こちらの問題も先ほどとほとんど同じような回答になります。つまり、ルーレットには1から10までの数字しか書かれていないので、前のターンに今の位置から10マス以内の位置にいたことは予測できます。一方で、その10マスのうちどの位置にいたかは予測することは難しそうなので、適当に予測するしかなさそうです(注:若干歯切れが悪い理由は、本記事の最後でわかります)

先ほど考えた場合と比べると、10マス以内というのが前方向か後ろ方向かという違いだけで対称な答えが得られていることがわかるでしょうか?

ノイズ除去の意味

次のターンの駒の位置の予測(=ノイズの付加)と、前のターンの駒の位置の予測(=ノイズの除去)には、対称的な難しさがあることがわかりました。つまり、『ノイズの除去』は本質的に予測できない問題であり、ノイズの付加と同じようなレベルで適当に決めるしかない部分が存在します。この意味において、唯一の正解を予測しているわけではなく、正解と考えられる候補(人生ゲームの例では「10マス以内」)から適当に選ぶことで予測を行うしかありません

これが、冒頭に述べた「唯一の正解データを予測」が誤解である理由です。正解は唯一ではなく無数に存在しており、その範囲内で適当に決めてよいのです

「いや、でも、前のターンに自分の駒が実際にどこにいたかは正解があるわけだし…」と思われるかもしれません。しかし、今来たばかりの友人にとっては、目の前にあるボードの情報が全てです。あなたが前のターンに1を出していようが10を出していようが、移動した後の駒の位置が同じであれば、友人はその2つの過去を見分けることはできません。前のターンでの駒の位置が1こ前であることと10こ前であることは、友人にとっては同じぐらいありえそうなことであり、その意味でどちらも正解と考えられる候補なのです。

たとえ話から元の話に戻すと、同じデータにノイズを乗せると様々なノイズ付きデータが得られるように、同じノイズ付きデータからノイズを除去しても様々なデータが得られます。「いや、でも、実際にノイズを乗せる前のデータには正解があるわけだし…」と思うのは幻想で、ノイズが付加された後のノイズ付きデータ(今の駒の位置)しか見ていなければ、元のデータ(前の駒の位置)として妥当と考えられる候補は無数に存在するのです。

誤解を招く原因

さて、逆拡散過程において『正解がただ一つ存在する』と誤解してしまう原因についてですが、大きく2つあるのではないかなと思います。

ノイズ除去という表現が招く先入観

1つ目は、『ノイズ除去』という言葉が持つニュアンスのせいではないかなと思っています。ノイズが乗っているということは、元々きれいなデータが存在していたということであり、つまりノイズを載せる前のデータが唯一の正解だろうなという思い込みが発生しがちです。したがって、「モデルがノイズの除去を行っているのであれば、その唯一の正解を求めているんだな」という先入観が自然と生まれてしまうのではないでしょうか。

学習時に行うモデルの動作

2つ目は、拡散モデルの学習時の動作によるものです。拡散モデルの学習では、まず学習データに適当にノイズを乗せ、そのノイズ付きデータから元のデータ(あるいはノイズ)を推定するようにモデルを訓練します。つまり、学習時の動作としては、モデルが唯一の正解を予測しようとしていると解釈してしまってもあまり問題が生じません

ただし、この解釈もあまり正しいとは言えません。人生ゲームの例に戻って、学習時の動作について考えてみましょう。あなたのもとに現れた友人が尋常ならぬ執念をもって、前のターンのあなたの駒の位置について「何とかしてうまく予測できるように頑張ろう」と思ったとします。その友人はあなたが人生ゲームに興じているところにたびたび現れ、前のターンの駒の位置を予測してはあなたに正解を教えてもらいます。正解を教えてもらうたびに、友人は予測方法を修正して精度を高めていくでしょう。この意味では、唯一の正解をうまく予測しようとすることで学習を行っているように見えます。

しかし、よく考えてみるとこれはなかなか理不尽なゲームで、全く同じ駒の位置だったとしても、正解が1だったり10だったりします。友人はこのような経験を経て、唯一の正解を正しく予測しようとすることは徒労だと思い知ります。友人が最大限できることは、駒がいたと考えられる位置の候補を絞ることだけであり、唯一の正解を当てることではありません。

ただ、実は拡散モデルでよく使われる過程(ガウシアンノイズの付加)では、唯一の正解を予測しようとすることと正解の候補を絞ることが等価になります。つまり、実際に生成時に行うことは候補を絞ることなのですが、学習時にはひたすら唯一の正解を当てにいけばよいわけです。これが、今回問題にしている誤解を生んでいる1つの原因なのではないかなと思います。

拡散過程と逆拡散過程の非対称性はどこから来るか?

お気づきかと思いますが、このままでは友人の努力は完全に徒労に終わります。なぜなら、10マス以内というところまでは絞れても、それ以上の予測はどうしようもないからです。でも、実際には拡散モデルは何かしらの学習ができていますし、そもそもノイズの付加とノイズの除去が完全に対称だなんてどう考えても変です。では、何が原因で対称性が壊れているのでしょうか?

結論から言うと、対称性が壊れる原因はスタートの存在にあります。最初に述べたように、この人生ゲームにはスタートはたくさんありますが、ゴールはありません。これが対称性を壊す原因であり、友人がうまく予測を立てられるようになる手がかりになります。

もし、友人が来た時にあなたがまだ2回しか駒を動かしておらず、あなたの駒の近くにはスタート地点が1つしかないとしましょう。あなたは友人に対して2回しか駒を動かしていないことと、あなたの駒がスタート地点から18マス離れたところにあると教えてあげたとします。

これは非常に重大な手がかりです。

なぜなら、2回ルーレットを回して18マス進むには「8⇒10」「9⇒9」「10⇒8」の3通りしか考えられません。したがって、前のターンの駒の位置として考えられる候補は「8マス前」「9マス前」「10マス前」の3つしかありません。これは「10マス以内」としかわからなかった場合と比べると、格段に精度の高い予測と言えるでしょう。

拡散モデルは非対称性を学んでいる

このように、「最初はスタート地点にいた」という事実から推理することで、前のターンの駒の位置の予測を精度よく行うことができる場合があります。上の例からわかるように、この推理のためには「今、何ターン目かという情報」と「スタート地点がどのあたりにあるかの知識」の2つになります。この2つの情報は、拡散モデルの文脈ではそれぞれ「時刻(タイムステップ)」と「データ分布」と呼ばれており、両者の情報をうまく使って学習を行うことによって、ノイズ除去(=前のターンの駒の位置の予測)を高精度に行うことができるようになります。

拡散モデルによるデータ生成は完全なノイズからスタートするので、人生ゲームの例で言うとゲーム終了時から推理が始まります。最初のうちはスタート地点が遠い上にルーレットを回した回数も多いので、直前のターンの駒の位置をうまく予測することはできず、1から10マス前までをほとんど適当に選ぶしかありません。推理が進んでゲーム初期のターンに近づくほど、今の駒の位置とターン数の情報、スタート地点に関する知識などから、駒の位置の候補を絞ることができるようになり、高精度な予測ができるようになります。

まとめ

拡散モデルは友だちです。拡散モデルの説明における「ノイズ除去」は、ノイズが乗る前のデータに相当する唯一の正解データを当てにいっているわけではなく、ノイズが乗る前のデータとして考えられる無数の正解データの候補から1つを適当に選ぶという処理です。

Discussion