🐳

【Julia】グラフ畳み込みネットワークの推論:スクラッチ

に公開

『グラフニューラルネットワーク(講談社)』の学習記事です。本書73〜76ページのグラフ畳み込みネットワーク(GCN)の推論アルゴリズムをスクラッチで実装して、理解することが目標です。

1. アルゴリズム

  1. 入力の準備
    各ノードが持つ特徴ベクトルをまとめた「ノード特徴行列」\mathbf{X}(ノード数 × 特徴数)、グラフの構造を表す「隣接行列」\mathbf{A}(ノード数 × ノード数)、各層ごとの「重み行列」\mathbf{W}^{(l)}(特徴数 × 出力特徴数)を用意する。

  2. 自己ループの追加
    各ノードが自分自身の情報も参照できるように、隣接行列\mathbf{A}に単位行列\mathbf{I}を足して、自己ループ付き隣接行列\tilde{\mathbf{A}}を作成する。

    \tilde{\mathbf{A}} = \mathbf{A} + \mathbf{I}
  3. 次数行列の作成
    自己ループ付き隣接行列\tilde{\mathbf{A}}の各行の合計を対角成分とする「次数行列」\tilde{\mathbf{D}}を作成する。

    \tilde{\mathbf{D}}_{uu} = \sum_{v} \tilde{\mathbf{A}_{uv}}
  4. 正規化隣接行列の作成
    ノードごとの影響を均等にするため、隣接行列を正規化する。

    \mathbf{A}_{\text{norm}} = \tilde{\mathbf{D}}^{-1/2} \tilde{\mathbf{A}} \tilde{\mathbf{D}}^{-1/2}
  5. 初期値の設定
    最初の層のノード特徴を、ノード特徴行列\mathbf{X}で初期化する。

    \mathbf{H}^{(0)} = \mathbf{X}
  6. 各層での更新処理
    各層ごとに、正規化隣接行列を作成して「隣接ノードの特徴の重み付き和」を計算し、重み行列で線形変換した後、活性化関数(ReLUなど)を適用して、新しいノード特徴を取得する。

    \mathbf{H}^{(l+1)} = \sigma \left(\mathbf{A}_{\text{norm}} \mathbf{H}^{(l)} \mathbf{W}^{(l+1)} \right)
  7. 最終層の出力
    L層まで繰り返した後、最終層のノード特徴行列を出力する。

    \mathbf{H}^{(L)}

2. Julia実装

利用するライブラリを読み込む。

using LinearAlgebra

2.1. 入力の準備

# ノード特徴行列(3ノード、2特徴)
𝑿 = [
    1 2;  # ノード0
    2 3;  # ノード1
    3 4   # ノード2
]

# 隣接行列(3ノード)
𝑨 = [
    0 1 0; # ノード0はノード1と隣接
    1 0 1; # ノード1はノード0、2と隣接
    0 1 0  # ノード2はノード1と隣接
]

# 重み行列(2層)
# 1層目:入力2次元→中間3次元
𝑾₁ = [
    0.5 0.1 0.3;
    0.4 0.6 0.2
]
# 2層目:中間3次元→出力2次元
𝑾₂ = [
    0.2 0.7;
    0.5 0.1;
    0.6 0.3
]
𝑾ₛ = [𝑾₁, 𝑾₂]

2.2. 自己ループの追加

𝑰 = I(3)                                     
𝑨̂ = 𝑨 + 𝑰
𝑰
3×3 Diagonal{Bool, Vector{Bool}}:
 1  ⋅  ⋅
 ⋅  1  ⋅
 ⋅  ⋅  1
𝑨̂ 
3×3 Matrix{Int64}:
 1  1  0
 1  1  1
 0  1  1

2.3. 次数行列の作成

𝑫̂ = Diagonal(sum(𝑨̂, dims=2)[:])
𝑫̂_inv_sqrt = Diagonal(1 ./ sqrt.(sum(𝑨̂, dims=2)[:]));
𝑫̂
3×3 Diagonal{Int64, Vector{Int64}}:
 2  ⋅  ⋅
 ⋅  3  ⋅
 ⋅  ⋅  2
𝑫̂_inv_sqrt
3×3 Diagonal{Float64, Vector{Float64}}:
 0.707107   ⋅        ⋅ 
  ⋅        0.57735   ⋅ 
  ⋅         ⋅       0.707107

2.4. 正規化隣接行列の作成

𝑨̂ = 𝑫̂_inv_sqrt * 𝑨̂ * 𝑫̂_inv_sqrt
3×3 Matrix{Float64}:
 0.125      0.0680414  0.0
 0.0680414  0.037037   0.0680414
 0.0        0.0680414  0.125

2.5.初期値の設定 〜 各層での更新処理

# 活性化関数
relu(x) = max.(0, x)

# 推論
function GCN_inference(𝑿, 𝑨̂, 𝑾ₛ)
    𝑯 = 𝑿
    for 𝑾 in 𝑾ₛ
        𝑯 = relu(𝑨̂ * 𝑯 * 𝑾)
    end
    return 𝑯
end

2.6. 最終層の出力

𝑯ᴸ = GCN_inference(𝑿, 𝑨̂, 𝑾ₛ)
3×2 Matrix{Float64}:
 1.78735  1.72693
 2.42969  2.3622
 2.20235  2.15193

2.7. 全体まとめ

using LinearAlgebra

# 活性化関数
relu(x) = max.(0, x)

# GCNレイヤー
struct GCNLayer
    𝑾::Matrix{Float64}
end
function (layer::GCNLayer)(𝑨̂, 𝑯)
    relu(𝑨̂ * 𝑯 * layer.𝑾)
end

# GCNモデル
struct GCNModel
    layers::Vector{GCNLayer}
end
function (model::GCNModel)(𝑨̂, 𝑿)
    𝑯 = 𝑿
    for layer in model.layers
        𝑯 = layer(𝑨̂, 𝑯)
    end
    return 𝑯
end


# 入力の準備
𝑿 = [
    1 2;
    2 3;
    3 4
]
𝑨 = [
    0 1 0;
    1 0 1;
    0 1 0
]
𝑾₁ = [
    0.5 0.1 0.3;
    0.4 0.6 0.2
]
𝑾₂ = [
    0.2 0.7;
    0.5 0.1;
    0.6 0.3
]

# 自己ループの追加
𝑰 = I(3)                                     
𝑨̂ = 𝑨 + 𝑰

# 次数行列を作成
𝑫̂ = Diagonal(sum(𝑨̂, dims=2)[:])
𝑫̂_inv_sqrt = Diagonal(1 ./ sqrt.(sum(𝑨̂, dims=2)[:]))

# 正規化隣接行列
𝑨̂ = 𝑫̂_inv_sqrt * 𝑨̂ * 𝑫̂_inv_sqrt

# モデル構築
model = GCNModel([
    GCNLayer(𝑾₁),
    GCNLayer(𝑾₂)
])

# 推論
𝑯 = model(𝑨̂, 𝑿)

3. 終わりに

間違いなどありましたら遠慮なくご指摘いただけますと幸いです。

Discussion