🌐

グラフニューラルネットワークって何?②GNNと他の深層学習モデル:Transformer

に公開

GNNって結局どんなもの?ほかの深層モデルとどう違う?

さて、上の質問の一つ目に答えるのは比較的簡単である。というのも、GNNというものは非常に抽象的な概念であり、実際にはグラフ畳み込みネットワーク(GCN)やグラフアテンションネットワーク(GAT)等の具体的なモデルになってはじめて何をやっているのかが具体的に分かるのである。一方で、二つ目の質問は、私も学生時代に投げかけられて興味深いと感じていたので、今回の内容で触れ始められたらいいなと考える。皆さんも自分のなじみのあるモデルとGNNの関連性を知ることができれば、「他人」のように遠かったGNNを「遠い親戚」くらいには親しく感じてくれるかもしれない。

※注意?前回の記事より理論というか数式が増えています。

GNNって何?概要と数式で理解してみよう

GNN(Graph Neural Network)は、グラフ構造をもつデータ(例:ソーシャルネットワーク、化学構造、知識グラフなど)を処理するためのニューラルネットワークである。このことについては、すでに前回の記事である程度触れたかもしれない。

では実際にどのようにグラフの各エッジやノードと関わってくるのだろうか?今回は話が複雑になるのでエッジについての説明を省き、ノードにフォーカスを当ててGNNのノード埋め込み表現獲得のプロセスを見ていこう。ちなみにエッジを情報として含めた埋め込み表現の獲得も可能です。

※急にノード埋め込み表現といわれて、反射的に「何それ?」と感じた人もいるかもしれない。このことは前に記事の「GNNのできること基本例」のところで少し触れたが、あれだけではわからないという人のために、以下に簡潔な説明を試みる。ノードの埋め込み表現を獲得するというのはつまり、今のノードの持つ表現(ノードの特徴などが詰め込まれたベクトルのような形であることが多い)を周りのノードとの関係性を考慮してより適切な表現(これを埋め込み表現と呼んでいる)に上書きしているというイメージである。


ノードの更新ルール(典型的なGNN)

各ノード v の特徴量 h_v を以下のように更新します:

h_v^{(l+1)} = \text{UPDATE} \left( h_v^{(l)}, \text{AGGREGATE} \left( \{ h_u^{(l)} : u \in \mathcal{N}(v) \} \right) \right)
  • \mathcal{N}(v):ノード ( v ) の隣接ノード集合(近傍)
  • h_v^{(l)}:レイヤー l 時点での特徴ベクトル(l+1は更新後のノードの特徴ベクトル)
  • AGGREGATE:近傍の情報をまとめる(例:平均、和、重み付き和)
  • UPDATE:まとめた情報を使って自身の状態を更新(例:MLP)

これはいわば、ノードが「隣の人の話を聞いて」自分をアップデートするプロセスです(ゆえに前の記事でも述べたがMessage Passing Neural Networkと呼ばれたりもする)。


GNNの簡単なPyTorch実装例

Pytorchにすでに慣れ親しんできた人たちのために、もっと直感的にコードを示そう。

import torch
import torch.nn as nn

class GNNLayer(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(GNNLayer, self).__init__()
        self.linear = nn.Linear(input_dim, output_dim)

    def forward(self, h, adj):
        h_agg = torch.matmul(adj, h)  # 隣接ノードの特徴量を集約
        h_updated = self.linear(h_agg)
        return h_updated

ちなみに、GNNはPyTorch Geometryというライブラリを導入することでいろいろできるようになるのでできる人はいろいろ遊んでみてもいいかもしれない。使い勝手はPyTorchとほぼ同じである。


GNNの直感的な例:GNNの仕組みを「部活動」で理解してみよう!

さて、一旦、ここまで読んできてくださった皆様、ありがとうございました。上の情報はある程度深層学習やPyTorchになじみのある人向けに書いたので、そこを理解して読み進めてくれた方たちに対しては、その前提知識を築くに至った努力に感謝し、何もわからずそれでも読み進めた方たちに対しては、その忍耐と勇気に感謝したい。

ちなみに、理解できなかった方たちは、ご心配なく。この記事はむしろこれからこの領域に踏み込みたいと考えている方たちのためのものであると考えて書いている。そこで、ここからは比較的フランクな例でGNNのやっていることを示してみようと思う。以下の例で、GNNのプロセスを通して各ノードが自分自身の情報のみならず、自分の周りのノードの情報も取り入れられることを示せれば幸いである。

例:ノードは部活動の所属を示す!

ノード 部活動 ベクトル(One-Hot)
A サッカー部 [1, 0, 0]
B バスケ部 [0, 1, 0]
C バスケ部 [0, 1, 0]
D バレー部 [0, 0, 1]

グラフ構造(隣接関係)

  • Aの友達 → B, C (つまりAの近傍 = BとC)
  • Bの友達 → A
  • Cの友達 → A
  • Dは誰ともつながっていない(孤立ノード)
  B     C
   \   /
     A       D

GNNレイヤーの構造

今回は以下のようにシンプルな感じにしようと思う。

  • AGGREGATE = 総和
  • UPDATE = AGGREGATEと足した後のベクトルをSoftmaxにかける
  • Softmax = 受け取ったベクトルの中の各エントリの合計が1になるようにする

今回は ノードA に注目して、周囲の影響を受けてどうベクトルが変わるかを見てみましょう。

※上記を見て気づくかもしれないがAGGREGATEもUPDATEも自分の好きな、目的に応じた任意の関数を使っていいのである。重要なのはAGGREGATEは平均や和などいくつかのノードの情報をまとめるものである必要があり、UPDATEは既存のノード情報をAGGREGATEの内容を使って更新できるものである必要がある。

ステップ1:AGGREGATE(隣接ノードのベクトルを足し合わせる)

Aの近傍は B と C。両方とも [0, 1, 0](バスケ部)

\text{AGGREGATE}(A) = [0,1,0] + [0,1,0] = [0,2,0]

ステップ2:UPDATE(自分 + AGGREGATE → Softmax)

A自身のベクトルは [1, 0, 0]

\mathbf{v} = [1, 0, 0],\quad \text{AGGREGATE} = [0,2,0]

和を取る:

\mathbf{v}_{\text{sum}} = [1, 2, 0]

Softmaxを適用:

\text{Softmax}([1,2,0]) = [0.236, 0.643, 0.087]

結果:サッカー部のAが“バスケっぽい確率分布”になる!

  • 元のA: [1, 0, 0](100%サッカー)
  • GNN後のA: 約 [0.236, 0.643, 0.087]
    サッカー: 23.6%、バスケ: 64.3%、バレー: 8.7% と周囲の影響を反映!

Transformerとは?Self-Attentionの仕組み

かの伝説の論文「Attention Is All You Need」が発表されてから、Transformerはさまざまな分野で活躍している。だが、実は、TransformerはGNNの「特殊なケース」であるという見方もあることはご存じだろうか。本記事ではその関係性を、数式やコードを交えてできるだけシンプルに示せたらいいと考える。

まずはTransformerについて簡単に説明しよう。Transformerは主に自然言語処理で活躍してきたモデルで、「Self-Attention」によって全入力トークン同士の関係を捉えるところがポイントとなっている。詳しく知りたい人は、上記の論文や、それを解説している記事がこの世には星の数ほど?あるはずなので調べてみると面白いかもしれない。

Self-Attentionの計算(1ヘッドのみ)

\text{Attention}(Q, K, V) = \text{Softmax}\left( \frac{QK^\top}{\sqrt{d_k}} \right)V

ここで:

  • Q = XW_Q:Query
  • K = XW_K:Key
  • V = XW_V:Value

すべて X \in \mathbb{R}^{n \times d_{\text{model}}} を線形変換して得ている。つまり、\text{Attention}(Q, K, V)というのはX(これは入力された文章のトークンのベクトルの集まり)にいろいろな線形変換を行い、上記の数式にあてはめたものとなっている。

\text{Attention}(X) = \text{Softmax}\left( \frac{(XW_Q)(XW_K)^\top}{\sqrt{d_k}} \right)(XW_V)

数式に頼った説明となってしまって申し訳ない。しかし、TransformerのSelf‐Attentionの説明をしようとするとそれだけで記事が一つ書けてしまう内容であり、また、ほかにもとても分かりやすい記事が多数存在しているため、ここでは詳細な説明には踏み入らない。

重要なのは、Self-Attentionは、全トークンがお互いを見合って影響を与え合う構造であるということである。言い換えれば、各トークンがほかのトークンとの関係性に着目している、トークン同士の全結合構造 = 完全グラフ構造をしているといえる。

Wikipedia:完全グラフより


実はTransformerはGNNの一種だった!?

ここが本章の核心であるが、Self-Attentionも、実は以下のように書き換えられる。急に以下の式が登場したのに混乱した方も多くいるかもしれないが、その場合は、Self‐Attention等の記事をほかで読んでみて、上の式が、トークン単位でみると下の式に変換できることを確認してほしい。本当はここの式の変換も私が説明するべきなのかもしれないが、線形代数等の知識を要することと、本記事で扱いたいテーマからはずれるので割愛させてもらった。申し訳ないです

token_i' = \sum_{j \in \mathcal{N}(i)} \alpha_{ij} W_V \> token_j
  • xをtokenに表示しなおしているのはわかりやすさのため。
  • \alpha_{ij} = \frac{\exp(q_i^\top k_j / \sqrt{d_k})}{\sum_{j'} \exp(q_i^\top k_{j'} / \sqrt{d_k})}学習された重み付きAGGREGATE。(\text{Softmax}\left( \frac{(XW_Q)(XW_K)^\top}{\sqrt{d_k}} \right)の部分)
  • \mathcal{N}(i) = \{1, 2, ..., n\}:全ノードを集計しているので完全グラフであるといえる。

ちなみにこの総和は自分自身であるtoken_iも含んでいるため、以下のように書き直せる。

token_i' = \sum_{j \in \mathcal{N}(i)} \alpha_{ij} W_V \> token_j = (\alpha_{ii} W_V \> token_i)+(\sum_{j \in \mathcal{N}(i)\>without\>i} \alpha_{ij} W_V \> token_j)

これを

  • \text{AGGREGATE} \left( \{ h_u^{(l)} : u \in \mathcal{N}(v) \} \right) = \sum_{j \in \mathcal{N}(i)\>without\>i} \alpha_{ij} W_V \> token_j
  • \text{UPDATE}(token_i) = (\alpha_{ii} W_V \> token_i)+\text{AGGREGATE}

と考えると、完全にGNNの枠組みに収まるといえるのではないだろうか。

つまり、TransformerはGNNのように近傍ノードを集約するようなプロセスを持ち、対象となるグラフはトークン同士の完全グラフである。というわけで、Transformerは完全グラフを使い、AGGREGATE関数やUPDATE関数がSelf-AttentionになったGNNとも言えるのではないだろうか、という話である。

※ちなみにUPDATE関数がどこにあるのかはある程度議論の余地がある。今回は強引にSelf‐Attention内に収めたが、Transformerにおける、Self‐Attention後のベクトルにかける線形変換こそがUPDATEであるとすることもできる。GNNは抽象的な概念であるため、様々なとらえ方ができるが、重要なのはSelf-Attentionと完全グラフ、そしてAGGREGATEのプロセスとの関連性である。

GNNとTransformerの比較

項目 GNN Transformer
入力構造 任意のグラフ 完全グラフ(全ノード接続)
隣接ノードの選び方 明示的に指定 全ノード対象(Softmaxで重み付け)
AGGREGATEの重み 手動 or 固定 学習される(Attention)

まとめ

さて、早いものでまとめの時間がやってきた。今回持ち帰ってほしい内容は以下のものである。

  • GNN概要:GNNはノードが近傍から情報を集めて自分を更新する構造
  • Transformerも、Self-Attentionを通じて「全ノードを近傍」とみなして重み付き集約を行う。よって、Transformerは「完全グラフ構造上のGNNの一種」と捉えることができる

これだけである。

後半のTransformerとGNNの関連性についてはTransformerの説明をしっかりできなかったためしっくりこない人たちも多いかもしれないが、申し訳ない。これは私の怠惰によるものであり、理解できなかった人たちは全く悪くないので心配しないでほしい。

※ただし、Transformerについて数式レベルで詳細な理解をしているそこのあなた!あなたが理解できなかったとしたらそれは、私の展開した論理が間違っているか、あなたのTransformerについての理解の道がまだ完結していないかのどちらかである可能性が高い(記事を書き終わった本人としては、後者であってほしい、じゃないとGNNとTransformerについてもう一度考察し直すこととなる)。


次回予告?

前回から内容を追っている方は、ふと思うだろう。「あれ?CNNとの関連性も解説してくれないのか?」と。心配せずとも次回の記事で取り上げようと考えている。本来はCNN関連の内容も載せようと考えていたのだが、記事があまりに長くなり、数式があふれて、数式アレルギーの人が二度と寄り付かなくなってしまったなんてことを避けたかったので、次章に譲ることとした。

ところで、CNNについてはすでに一部の人が感じている通り、今回のTransformerについての内容より簡単である。それはCNNがグリッド上のグラフ構造上で畳み込みという名のAGGRAGATEを行っていると示すのが比較的容易だからである。

さて、ネタバレはこの辺にして、今日は休もうと思う。ゴールデンウィークの最終日になぜこれを書き始めたのかはわからないが、多分雨が降っていて退屈だったのが原因だろう。次の記事がいつ出るのか、これもまた神のみぞ知る話となるが、楽しみにしている皆様(もしそんな人がいれば幸いだが)にはあまり期待せず気長に待っていてほしい。


GMOペパボ株式会社

Discussion