🎃

Graph Convolutional Networks で使われている計算の最小動作例 (PyTorch Geometric編)

2023/10/25に公開

グラフデータを深層学習で扱うアプローチとして、Graph Neural Networks (GNNs) がよく利用されています。論文などで提案されている数式をプログラムとして実装するためには、pytorch-geometricdeep graph library (dgl) などが利用されます。

前回(大昔)の記事はこちらです。

https://zenn.dev/takilog/articles/e54a45d6f7266229e367

最近、いくつかのgithub上のレポジトリを読んでいたら、pytorch-geometricを使っている例を確認しました。

この記事では、GCN (グラフ畳み込みネットワーク; Graph Convolutional Network) の実装例を通じて、ライブラリの使い方や処理の流れを step-by-step で触れてみようと思います。

環境とデータ

主要なライブラリのバージョンは以下の通りです:

  • torch: 1.12.1
  • torch_geometric: 2.4.0

データの例は前回の画像で作ったものを再利用します。

画像

以下で利用するために、スクリプトに起こしておきます。

import networkx as nx
import torch
import torch_geometric

def get_graph() -> nx.Graph:
    """
    例題のグラフを返す
    """
    n = 6
    nxg = nx.Graph()
    nxg.add_nodes_from(range(n))
    E = [(0, 1), (0, 2), (1, 2), (2, 3), (2, 4), (4, 5)]
    for u, v in E:
        nxg.add_edge(u, v)
    return nxg

def get_fixed_feature() -> torch.FloatTensor:
    """
    例題のノード特徴ベクトル (各ノードが3次元の特徴ベクトルを持つ) を作り
    Tensorにして返す
    """
    X = torch.tensor(
        [
            [0, 0, 0],  # 0
            [0, 0, 1],  # 1
            [0, 1, 0],  # 2
            [0, 1, 1],  # 3
            [1, 0, 0],  # 4
            [1, 1, 0],  # 5
        ]
    ).float()
    return X

PyTorch Geometricを利用した簡単なGCNの例

ここからPyTorch Geometricの動作例を確認していきます。

データの用意

PyTorch Geometricを利用した処理 (NNモジュールのforward部分) の特徴として、頂点に対応した属性ベクトルだけではなく、辺の情報も同時に渡すことがあります。

上記の例では \{01, 02, 12, 23, 24, 45\} の無向辺を持ちますが、これを効率的に管理するため(COO format/インデクスとデータを疎で格納する)、PyTorch Geometricにnetworkxのデータを変換させると、辺\{u, v\}に対してu\to vv\to uの両方を管理し、合計で12本の辺情報を持ちます。例を見てみます。

from torch_geometric.utils import from_networkx

# networkx のデータから PyG のデータを作成する
G = get_graph()
data = from_networkx(G)

# data.edge_index は辺情報を格納している (12本分)
# 出力の例
# tensor([[0, 0, 1, 1, 2, 2, 2, 2, 3, 4, 4, 5],
#         [1, 2, 0, 2, 0, 1, 3, 4, 2, 2, 5, 4]])
print(data.edge_index)

手計算

GCNでは、あるノード v に着目したとき、v の隣接ノード u\in \mathcal{N}_G(v) の特徴ベクトルを集約して、新しい特徴ベクトルを計算します。ここでは次の計算を考えます。

  • 頂点 v は次元 d (ここでは d=3) の特徴ベクトルを持つ。
  • 情報を伝播するとき、d 次元のベクトルを行列で変換し、d'次元のベクトルに変換する。ここではd'=2とする。
    • 行列で書くと、書く頂点が(1, d) 型のベクトルで表現された特徴を持ちます (pytorchのnn.Embeddingをイメージする)。
    • d=3 \to d'=2という次元の変換は、(d, d')型の行列で書かれます (pytorchのnn.Linearをイメージする)。
  • 情報を集約するとき、単純に特徴ベクトルを足し算する。

このような処理を例題の上の頂点0で確認してみましょう。行列は123/456という簡単な形で初期化し、バイアスもなくしてあります。行と列の説明がひっくり返っている (nn.Linearの都合) ので注意してください。

# 例題
W = Linear(3, 2, bias=False)
W.weight = Parameter(torch.Tensor([[1, 2, 3], [4, 5, 6]]).float())
for i in G.nodes():
    sum_i = W(X[i, :])  # 自分自身i
    for j in G.neighbors(i):
        sum_i += W(X[j, :])  # 隣接ノードj
    print("ノード:", i, " 手計算:", sum_i.data)

# 出力
# ノード: 0  手計算: tensor([ 5., 11.])
# ノード: 1  手計算: tensor([ 5., 11.])
# ノード: 2  手計算: tensor([11., 26.])
# ノード: 3  手計算: tensor([ 7., 16.])
# ノード: 4  手計算: tensor([ 6., 18.])
# ノード: 5  手計算: tensor([ 4., 13.])
  • 0は頂点12に隣接しているので、0の特徴ベクトルと、隣接した1/2の特徴ベクトルを集約して、新しい0の特徴ベクトルをつくります。
    • 0[0, 0, 0] の特徴ベクトルを持つので、変換後は [0, 0] です。
    • 1[0, 0, 1] の特徴ベクトルを持つので、変換後は [3, 6] です。
    • 2[0, 1, 0] の特徴ベクトルを持つので、変換後は [2, 5] です。
    • これらを加えるため、新しい特徴ベクトルは[5, 11] です。
  • 他の頂点も同様です。

このように実際に行列計算や特徴ベクトルの入力を試すと、GCNの動作が掴めます (これは前回の記事で使っていたdglのGraphConvでも同様でした)。

PyTorch Geometricを利用した実装

ここから公式ドキュメントに少し戻り、上の実装をPyTorch Geometric上の実装に持っていきます。こちらのドキュメントを参考にしてもらうと、公式のGCNの実装チュートリアルを見ることができます。

https://pytorch-geometric.readthedocs.io/en/latest/tutorial/create_gnn.html

上のGCN実装は細かい係数などが入っているため、先程の足し算するだけのメッセージパッシングを実装してみます。

from torch_geometric.nn import MessagePassing
class SimpleGNN(MessagePassing):
    def __init__(self, in_channels, out_channels):
        # in_channels: ノードの入力特徴量の次元 (d).
        # out_channels: ノードの出力特徴量の次元 (d')
        # メッセージの処理は加えるだけ (add)
        super().__init__(aggr='add')

        # 同じ初期化
        self.lin = Linear(in_channels, out_channels, bias=False)
        self.lin.weight = Parameter(torch.Tensor([[1, 2, 3], [4, 5, 6]]).float())
                
    def forward(self, x, edge_index):
        # 自己ループを加える (自分自身の特徴ベクトルも伝搬させる)
        # - edge_index: は Dataで取得した(src, dst)の表現
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
    
        # Wx の計算 (すべてのノードの特徴ベクトルがWで変換される)
        x = self.lin(x)

        # messageを呼ぶ
        out = self.propagate(edge_index, x=x, norm=None)
        return out
    
    def message(self, x_j, norm):
        # 何もしない (addに任せるだけ)
        return x_j

実際に動作させ、手計算した場合と比較します。

# 例題 その2 (d=3, d'=2)
gnn = SimpleGNN(3, 2)
out = gnn(X, data.edge_index)
for i in G.nodes():
    # 手計算
    sum_i = gnn.lin(X[i, :])
    for j in G.neighbors(i):
        sum_i += gnn.lin(X[j, :])

    # out[i].data はi番目のノードの計算結果 (addした後)
    print("ノード:", i, " 手計算:", sum_i.data, " GNN計算:", out[i].data)

# 出力
# ノード: 0  手計算: tensor([ 5., 11.])  GNN計算: tensor([ 5., 11.])
# ノード: 1  手計算: tensor([ 5., 11.])  GNN計算: tensor([ 5., 11.])
# ノード: 2  手計算: tensor([11., 26.])  GNN計算: tensor([11., 26.])
# ノード: 3  手計算: tensor([ 7., 16.])  GNN計算: tensor([ 7., 16.])
# ノード: 4  手計算: tensor([ 6., 18.])  GNN計算: tensor([ 6., 18.])
# ノード: 5  手計算: tensor([ 4., 13.])  GNN計算: tensor([ 4., 13.])

計算結果が一致することが確認できました。

これぐらいの例ではわざわざMessagePassingを拡張したSimpleGCNクラスを実装する旨味がない気がしますが、edge_indexを渡してforwardを回す処理 (gnn(X, data.edge_index)) や message() の処理など、わかりやすくするための例として作成しました。

サンプルプログラム (全体) の書かれたgistはこちらです。

https://gist.github.com/cocomoff/144b295a8ff2dd57482a9087cbc476d2

次回のstep-by-stepでは真面目なGCNの実装と、NGCF/LightGCNの比較についてまとめるつもりです。

Discussion