🤖

Attention, MLP, すべての答え

2024/01/30に公開

はじめに

Attention機構も、MLPも、一般的な機械学習のあらゆるネットワークは「生きた柔軟な辞書」という同一のモノである。
(これは全く厳密性のないポエムです)

そもそもNNの行列積は何を計算しているか

一般的な機械学習のネットワークにおける行列積が計算しているものは、データ(ベクトル)とフィルター(ベクトル)の内積(相関値)のフィルター数分のバリエーションである。

A(1, 784) @ B(784, 784) => C (1, 784)

Aは784次元のデータひとつ。Bは784次元のデータが784フィルター個。AのデータをBのフィルターにかけると、データひとつ(784次元ベクトル)とフィルターひとつ(784次元ベクトル)の内積(相関値)がスカラー値としてひとつ、得られる。それがフィルター数個(784個)の明滅パターンを形成し、今回の場合は元のデータと同じ形に再生成される。

つまり、重み行列との行列積は、入力値とフィルター群を元に、入力に応じた新しい値を作る。

一般的に、重みが十分にランダムなら、AとCはほぼ似ていない値になる。
ただし、Aに十分似た入力A'について、その出力C'はCにある程度似る。

これを入力に応じて、重みから値を引き出している。と捉える。

行列積が多段化されたときの解釈

一般的な三層MLPにおいて、ネットワーク内の計算は、簡略化すると

activation(x @ w1) @ w2

のようになる。

上で示したように、一般的な機械学習のネットワークにおける入力値と重みとの行列積は、入力値と、重みというフィルター群を元に新しい値を作る操作である。

つまり、x @ w1 は x と w1 を元にした新しい入力である。

MLPは新しい入力 x @ w1 に activationを掛けたもの k = activation(x @ w1) と w2 との間で更に行列積を取る。

activationについては一旦保留するとして、x と w1 を元にした新しい値である k を用いて、更にw2 から新しい値を引き出している。

構造としては、xを元にkeyを得、keyを元にvalueを得ている。入力(=問題)と重みからヒントごとの倍率を得て、ヒントごとの倍率と更なる重み(=ヒントフィルター群)から最終的な解答を合成して引き出している。

直接的にxからvalueを得るより効果的なのは、activationにより負の相関(=役に立たないヒントの適用倍率)を切って0にできるから。

また、activationが無ければ、x @ w1 @ w2 は w1 @ w2 が単純なひとつの行列に集約できてしまうので、多段化する意味が無く、重み w12 = w1 @ w2 を一回掛けるのと同じ結果になる。

何故三層MLPが単層パーセプトロンより効果的なのか。また、何故四層以上のMLPが劇的な効果をもたらさないかについては、これで説明がつく。

三層MLPの時点で辞書として完全な機能を持っているため、単層では足らず、四層にする意味もあまりない。ということ。

Attention機構の解釈

Attention機構の肝は softmax(Q @ K.T) @ V の計算にある。

Attention機構の行列積についても、行列積は行列積であり、行列積自体の解釈は変わらない。

入力に応じて、重みから値を引き出している。

Attention機構はそれ自体がMLP的に働く。違いは時系列方向に次元が伸びているかどうかだけ。

つまり、Q @ K.T の解釈は後に置くとして、最終的には softmax(Q @ K.T) を入力として V から値を引き出している。

x => Q, K, V => key = softmax(Q @ K.T), value = V => key @ value

先程の多段化された行列積、三層のMLPと構造は同じ。

違いは、 Q @ K.T が時系列間の LxL (Lは系列長) 行列を作るという点。

これについては、グラフニューラルネットワーク的な「隣接行列」として解釈できる。

valueの値を「すべてそのまま」引き出すような key は、「単位行列」である。

I @ value => value

隣接行列的に考えると、単位行列は「自身から自身への1.0の重みノードの集合」である。

Q @ K.T の作る Attention weight map は、通常は単位行列よりももっと乱れている。

1番目から2番目への重みや3番目から6番目への重みなどが入り混じっているため、次の値を得るときに内部で混合が起きる。ただし、未来から過去方向への重みは Transfomer Decoder では通常 causal attention mask として下三角形状にマスクされる。

乱れた Q @ K.T により、引き出し重み key が value を混合して引き出すため、入力から(Q, K, Vへの投影時の学習重みを経て適度に変化した)新しい値が得られる。

MLPと違って多層化が有効なのは、Attentionは(計算量的な問題もあり)時系列方向の全結合的なMixを直接的に行わないから。

時系列方向の混合は一度に全て行われるのでなく、Q @ K.T の単位行列からの外れ方に応じて徐々にリークするように広がっていく。

「生きた柔軟な辞書」としてのMLP, Attention

ここまでで述べたように、MLPもAttention機構も、入力xからインデックスkeyを作り、valueから値を引き出す構造が共通している。

keyが常にone-hotになるならそれはハードな辞書であるが、通常のネットワークが形成するのはkeyが分散した重み配分の値となるような、ソフトな辞書である。加えて、テーブルから値を引くだけの固定的な辞書とは違い、ネットワークは学習を通じてkeyもvalueも変化・最適化するので、生きた辞書であると言える。

ゆえに、MLPもAttentionも生きた柔軟な辞書である。と。

Discussion