今更だけどTransformerが強い理由とマルチヘッドの重要性について考えてみた
Transformerは未だにほぼ一強の構造で、他のモデルが様々登場してはいるものの、全て蹴散らしていると言っても過言ではないと思う
Transformerが何故こんなにも強いのか、何故今の形のこのTransformerでなければならなかったのか、皆が使ってるから最適化が進んでいるとかそういう歴史的経緯抜きで再度考えてみた
まず、Transformerのコアを担っているAttention機構について
こいつはデータの注目部分を抽出して利用する代物だと説明される事が多いが、私の感覚で言えばそれは嘘である
Attention機構の正体は、あらゆるデータ間の関係を体現した汎用グラフとその上でのメッセージパッシングだ
厳密に言えば、途中softmax(QK^T)が作っている物がデータxとyの接続を示す重み付きの隣接行列であり、Vとの行列積を取ることで1hop分のメッセージパッシングを行っている
yの各要素がxの指定箇所に情報を与える構造になっているので、これは正にメッセージの伝達である(自己Attentionの場合、y=x)
通常、TransformerはAttentionを数レイヤー重ねて作るが、これはAttention一つが1hop分のメッセージ伝達しか担わないのを補うためだろう
ところで、通常のsoftmaxはmaxの確率的で連続的な表現なので、シングルヘッドAttentionは確率的に分散されてはいるものの、1つ分の接続エッジしか現していない
グラフの各要素が常に1つしか接続先を持たない状態でメッセージパッシングを重ねても、重ねたhop分の到達先程度にしか届かないだろう
話は少し逸れるが、スモールワールド現象という事象をご存知だろうか
知り合いの知り合いを辿っていくと、6人程度の伝達で世界中の誰にでも到達するという実験結果がある(正確には米国内で米国の誰かを対象にした実験)
ここで全員紹介できる知り合いが1人しか居ない場合、世界中の誰にでも到達するということは到底不可能なはずだ
1hop辺りの接続先を増やすことで、より広い範囲と情報交換が可能になる
つまり、Attentionのマルチヘッド化は気休めやオプションではなく、必須の重要な要素である
実際、接続先を1から2に増やすと、それだけでN-hop後の最終到達先は2^N要素にもなり、途中の経路の混ざり具合を考えると更に幅広い範囲に影響が及ぶ
ここまでで、Attentionがデータの関係を体現したグラフ上でのメッセージパッシングであり、その1hop分の接続先を広げるためにマルチヘッド化が必須であったこと、そのhop数を増やすために多層化=Transformer化が必要なことを話した
これまでの話を踏まえると、これでTransformerが表現しているものは関係を表したグラフ上での完全な情報伝達であるということになる
それができると何故強いのか、についてまだ触れていないのでこれから触れていく
まず、関係のグラフという概念がそもそも強い
ここでいうグラフというのはチャート図のことではなく、関係を表す有向グラフのことである
例えば、リストは関係のグラフで表す事ができる
インデックス順に前から後ろへ一つずつ繋げた関係のグラフはリストである
また、ツリーも関係のグラフで表す事ができる
葉から親のノード、更に親のノードと繋げた関係のグラフはツリーである
つまり、任意の関係のグラフ上で完全な情報伝達ができれば、リストやツリー上での情報伝達をエミュレートできるということだ
リストやツリーとの違いは、構造的な局所の偏りが強く出ざるを得ないか否かの違いでしかない
例えばリスト上を順に舐めると直近のデータの影響が遠方より必然的に強く残る
ツリーにしても親子関係の局所性が強く出る
グラフとして俯瞰する場合、それらは接続の仕方次第である
Attention重みをニューラルネット的に決めるということは、この関係のグラフの接続の仕方を動的に決めている事に他ならない
そしてTransformer化によりN-hopの情報伝達が完全化する
計算量こそO(N^2)になってしまうし、リストやツリーが得意な構造はそれらをわざわざエミュレートするために収束も遅くなってしまうが、グラフで表すことのできるあらゆる関係を処理できる
それがTransformerなのである
ある集合とある集合の全要素ずつの関係、あるいは一つ特定のある集合の中での全要素ずつの関係まで扱えれば、現実的な問題では十分強い
だからTransformerは強い
以上
Discussion