Graph Convolutional Networks で使われている計算の最小動作例 (PyTorch Geometric編)
グラフデータを深層学習で扱うアプローチとして、Graph Neural Networks (GNNs) がよく利用されています。論文などで提案されている数式をプログラムとして実装するためには、pytorch-geometric や deep graph library (dgl) などが利用されます。
前回(大昔)の記事はこちらです。
最近、いくつかのgithub上のレポジトリを読んでいたら、pytorch-geometricを使っている例を確認しました。
この記事では、GCN (グラフ畳み込みネットワーク; Graph Convolutional Network) の実装例を通じて、ライブラリの使い方や処理の流れを step-by-step で触れてみようと思います。
環境とデータ
主要なライブラリのバージョンは以下の通りです:
- torch: 1.12.1
- torch_geometric: 2.4.0
データの例は前回の画像で作ったものを再利用します。
以下で利用するために、スクリプトに起こしておきます。
import networkx as nx
import torch
import torch_geometric
def get_graph() -> nx.Graph:
"""
例題のグラフを返す
"""
n = 6
nxg = nx.Graph()
nxg.add_nodes_from(range(n))
E = [(0, 1), (0, 2), (1, 2), (2, 3), (2, 4), (4, 5)]
for u, v in E:
nxg.add_edge(u, v)
return nxg
def get_fixed_feature() -> torch.FloatTensor:
"""
例題のノード特徴ベクトル (各ノードが3次元の特徴ベクトルを持つ) を作り
Tensorにして返す
"""
X = torch.tensor(
[
[0, 0, 0], # 0
[0, 0, 1], # 1
[0, 1, 0], # 2
[0, 1, 1], # 3
[1, 0, 0], # 4
[1, 1, 0], # 5
]
).float()
return X
PyTorch Geometricを利用した簡単なGCNの例
ここからPyTorch Geometricの動作例を確認していきます。
データの用意
PyTorch Geometricを利用した処理 (NNモジュールのforward部分) の特徴として、頂点に対応した属性ベクトルだけではなく、辺の情報も同時に渡すことがあります。
上記の例では
from torch_geometric.utils import from_networkx
# networkx のデータから PyG のデータを作成する
G = get_graph()
data = from_networkx(G)
# data.edge_index は辺情報を格納している (12本分)
# 出力の例
# tensor([[0, 0, 1, 1, 2, 2, 2, 2, 3, 4, 4, 5],
# [1, 2, 0, 2, 0, 1, 3, 4, 2, 2, 5, 4]])
print(data.edge_index)
手計算
GCNでは、あるノード
- 頂点
は次元v (ここではd ) の特徴ベクトルを持つ。d=3 - 情報を伝播するとき、
次元のベクトルを行列で変換し、d 次元のベクトルに変換する。ここではd' とする。d'=2 - 行列で書くと、書く頂点が
型のベクトルで表現された特徴を持ちます (pytorchのnn.Embeddingをイメージする)。(1, d) -
という次元の変換は、d=3 \to d'=2 型の行列で書かれます (pytorchのnn.Linearをイメージする)。(d, d')
- 行列で書くと、書く頂点が
- 情報を集約するとき、単純に特徴ベクトルを足し算する。
このような処理を例題の上の頂点
# 例題
W = Linear(3, 2, bias=False)
W.weight = Parameter(torch.Tensor([[1, 2, 3], [4, 5, 6]]).float())
for i in G.nodes():
sum_i = W(X[i, :]) # 自分自身i
for j in G.neighbors(i):
sum_i += W(X[j, :]) # 隣接ノードj
print("ノード:", i, " 手計算:", sum_i.data)
# 出力
# ノード: 0 手計算: tensor([ 5., 11.])
# ノード: 1 手計算: tensor([ 5., 11.])
# ノード: 2 手計算: tensor([11., 26.])
# ノード: 3 手計算: tensor([ 7., 16.])
# ノード: 4 手計算: tensor([ 6., 18.])
# ノード: 5 手計算: tensor([ 4., 13.])
-
は頂点0 と1 に隣接しているので、2 の特徴ベクトルと、隣接した0 /1 の特徴ベクトルを集約して、新しい2 の特徴ベクトルをつくります。0 -
は0 の特徴ベクトルを持つので、変換後は[0, 0, 0] です。[0, 0] -
は1 の特徴ベクトルを持つので、変換後は[0, 0, 1] です。[3, 6] -
は2 の特徴ベクトルを持つので、変換後は[0, 1, 0] です。[2, 5] - これらを加えるため、新しい特徴ベクトルは
です。[5, 11]
-
- 他の頂点も同様です。
このように実際に行列計算や特徴ベクトルの入力を試すと、GCNの動作が掴めます (これは前回の記事で使っていたdglのGraphConvでも同様でした)。
PyTorch Geometricを利用した実装
ここから公式ドキュメントに少し戻り、上の実装をPyTorch Geometric上の実装に持っていきます。こちらのドキュメントを参考にしてもらうと、公式のGCNの実装チュートリアルを見ることができます。
上のGCN実装は細かい係数などが入っているため、先程の足し算するだけのメッセージパッシングを実装してみます。
from torch_geometric.nn import MessagePassing
class SimpleGNN(MessagePassing):
def __init__(self, in_channels, out_channels):
# in_channels: ノードの入力特徴量の次元 (d).
# out_channels: ノードの出力特徴量の次元 (d')
# メッセージの処理は加えるだけ (add)
super().__init__(aggr='add')
# 同じ初期化
self.lin = Linear(in_channels, out_channels, bias=False)
self.lin.weight = Parameter(torch.Tensor([[1, 2, 3], [4, 5, 6]]).float())
def forward(self, x, edge_index):
# 自己ループを加える (自分自身の特徴ベクトルも伝搬させる)
# - edge_index: は Dataで取得した(src, dst)の表現
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
# Wx の計算 (すべてのノードの特徴ベクトルがWで変換される)
x = self.lin(x)
# messageを呼ぶ
out = self.propagate(edge_index, x=x, norm=None)
return out
def message(self, x_j, norm):
# 何もしない (addに任せるだけ)
return x_j
実際に動作させ、手計算した場合と比較します。
# 例題 その2 (d=3, d'=2)
gnn = SimpleGNN(3, 2)
out = gnn(X, data.edge_index)
for i in G.nodes():
# 手計算
sum_i = gnn.lin(X[i, :])
for j in G.neighbors(i):
sum_i += gnn.lin(X[j, :])
# out[i].data はi番目のノードの計算結果 (addした後)
print("ノード:", i, " 手計算:", sum_i.data, " GNN計算:", out[i].data)
# 出力
# ノード: 0 手計算: tensor([ 5., 11.]) GNN計算: tensor([ 5., 11.])
# ノード: 1 手計算: tensor([ 5., 11.]) GNN計算: tensor([ 5., 11.])
# ノード: 2 手計算: tensor([11., 26.]) GNN計算: tensor([11., 26.])
# ノード: 3 手計算: tensor([ 7., 16.]) GNN計算: tensor([ 7., 16.])
# ノード: 4 手計算: tensor([ 6., 18.]) GNN計算: tensor([ 6., 18.])
# ノード: 5 手計算: tensor([ 4., 13.]) GNN計算: tensor([ 4., 13.])
計算結果が一致することが確認できました。
これぐらいの例ではわざわざMessagePassingを拡張したSimpleGCNクラスを実装する旨味がない気がしますが、edge_indexを渡してforwardを回す処理 (gnn(X, data.edge_index)
) や message()
の処理など、わかりやすくするための例として作成しました。
サンプルプログラム (全体) の書かれたgistはこちらです。
次回のstep-by-stepでは真面目なGCNの実装と、NGCF/LightGCNの比較についてまとめるつもりです。
Discussion