🐳
【Julia】グラフ畳み込みネットワークの推論:スクラッチ
『グラフニューラルネットワーク(講談社)』の学習記事です。本書73〜76ページのグラフ畳み込みネットワーク(GCN)の推論アルゴリズムをスクラッチで実装して、理解することが目標です。
1. アルゴリズム
-
入力の準備
各ノードが持つ特徴ベクトルをまとめた「ノード特徴行列」 (ノード数 × 特徴数)、グラフの構造を表す「隣接行列」\mathbf{X} (ノード数 × ノード数)、各層ごとの「重み行列」\mathbf{A} (特徴数 × 出力特徴数)を用意する。\mathbf{W}^{(l)} -
自己ループの追加
各ノードが自分自身の情報も参照できるように、隣接行列 に単位行列\mathbf{A} を足して、自己ループ付き隣接行列\mathbf{I} を作成する。\tilde{\mathbf{A}} \tilde{\mathbf{A}} = \mathbf{A} + \mathbf{I} -
次数行列の作成
自己ループ付き隣接行列 の各行の合計を対角成分とする「次数行列」\tilde{\mathbf{A}} を作成する。\tilde{\mathbf{D}} \tilde{\mathbf{D}}_{uu} = \sum_{v} \tilde{\mathbf{A}_{uv}} -
正規化隣接行列の作成
ノードごとの影響を均等にするため、隣接行列を正規化する。\mathbf{A}_{\text{norm}} = \tilde{\mathbf{D}}^{-1/2} \tilde{\mathbf{A}} \tilde{\mathbf{D}}^{-1/2} -
初期値の設定
最初の層のノード特徴を、ノード特徴行列 で初期化する。\mathbf{X} \mathbf{H}^{(0)} = \mathbf{X} -
各層での更新処理
各層ごとに、正規化隣接行列を作成して「隣接ノードの特徴の重み付き和」を計算し、重み行列で線形変換した後、活性化関数(ReLUなど)を適用して、新しいノード特徴を取得する。\mathbf{H}^{(l+1)} = \sigma \left(\mathbf{A}_{\text{norm}} \mathbf{H}^{(l)} \mathbf{W}^{(l+1)} \right) -
最終層の出力
L層まで繰り返した後、最終層のノード特徴行列を出力する。\mathbf{H}^{(L)}
2. Julia実装
利用するライブラリを読み込む。
using LinearAlgebra
2.1. 入力の準備
# ノード特徴行列(3ノード、2特徴)
𝑿 = [
1 2; # ノード0
2 3; # ノード1
3 4 # ノード2
]
# 隣接行列(3ノード)
𝑨 = [
0 1 0; # ノード0はノード1と隣接
1 0 1; # ノード1はノード0、2と隣接
0 1 0 # ノード2はノード1と隣接
]
# 重み行列(2層)
# 1層目:入力2次元→中間3次元
𝑾₁ = [
0.5 0.1 0.3;
0.4 0.6 0.2
]
# 2層目:中間3次元→出力2次元
𝑾₂ = [
0.2 0.7;
0.5 0.1;
0.6 0.3
]
𝑾ₛ = [𝑾₁, 𝑾₂]
2.2. 自己ループの追加
𝑰 = I(3)
𝑨̂ = 𝑨 + 𝑰
𝑰
3×3 Diagonal{Bool, Vector{Bool}}:
1 ⋅ ⋅
⋅ 1 ⋅
⋅ ⋅ 1
𝑨̂
3×3 Matrix{Int64}:
1 1 0
1 1 1
0 1 1
2.3. 次数行列の作成
𝑫̂ = Diagonal(sum(𝑨̂, dims=2)[:])
𝑫̂_inv_sqrt = Diagonal(1 ./ sqrt.(sum(𝑨̂, dims=2)[:]));
𝑫̂
3×3 Diagonal{Int64, Vector{Int64}}:
2 ⋅ ⋅
⋅ 3 ⋅
⋅ ⋅ 2
𝑫̂_inv_sqrt
3×3 Diagonal{Float64, Vector{Float64}}:
0.707107 ⋅ ⋅
⋅ 0.57735 ⋅
⋅ ⋅ 0.707107
2.4. 正規化隣接行列の作成
𝑨̂ = 𝑫̂_inv_sqrt * 𝑨̂ * 𝑫̂_inv_sqrt
3×3 Matrix{Float64}:
0.125 0.0680414 0.0
0.0680414 0.037037 0.0680414
0.0 0.0680414 0.125
2.5.初期値の設定 〜 各層での更新処理
# 活性化関数
relu(x) = max.(0, x)
# 推論
function GCN_inference(𝑿, 𝑨̂, 𝑾ₛ)
𝑯 = 𝑿
for 𝑾 in 𝑾ₛ
𝑯 = relu(𝑨̂ * 𝑯 * 𝑾)
end
return 𝑯
end
2.6. 最終層の出力
𝑯ᴸ = GCN_inference(𝑿, 𝑨̂, 𝑾ₛ)
3×2 Matrix{Float64}:
1.78735 1.72693
2.42969 2.3622
2.20235 2.15193
2.7. 全体まとめ
using LinearAlgebra
# 活性化関数
relu(x) = max.(0, x)
# GCNレイヤー
struct GCNLayer
𝑾::Matrix{Float64}
end
function (layer::GCNLayer)(𝑨̂, 𝑯)
relu(𝑨̂ * 𝑯 * layer.𝑾)
end
# GCNモデル
struct GCNModel
layers::Vector{GCNLayer}
end
function (model::GCNModel)(𝑨̂, 𝑿)
𝑯 = 𝑿
for layer in model.layers
𝑯 = layer(𝑨̂, 𝑯)
end
return 𝑯
end
# 入力の準備
𝑿 = [
1 2;
2 3;
3 4
]
𝑨 = [
0 1 0;
1 0 1;
0 1 0
]
𝑾₁ = [
0.5 0.1 0.3;
0.4 0.6 0.2
]
𝑾₂ = [
0.2 0.7;
0.5 0.1;
0.6 0.3
]
# 自己ループの追加
𝑰 = I(3)
𝑨̂ = 𝑨 + 𝑰
# 次数行列を作成
𝑫̂ = Diagonal(sum(𝑨̂, dims=2)[:])
𝑫̂_inv_sqrt = Diagonal(1 ./ sqrt.(sum(𝑨̂, dims=2)[:]))
# 正規化隣接行列
𝑨̂ = 𝑫̂_inv_sqrt * 𝑨̂ * 𝑫̂_inv_sqrt
# モデル構築
model = GCNModel([
GCNLayer(𝑾₁),
GCNLayer(𝑾₂)
])
# 推論
𝑯 = model(𝑨̂, 𝑿)
3. 終わりに
間違いなどありましたら遠慮なくご指摘いただけますと幸いです。
Discussion