🔥

PyTorch実装で理解するST-GCN

2023/02/14に公開

Spatial Temporal Graph Convolutional Network

このnotebookではSpatial Temporal Graph Convolional Network (ST-GCN)の基礎の部分を噛み砕いて理解していきます.
PyTorch実装で理解するGraph Convolutional NetworkでGCNの基礎知識について理解しておくことを推奨します.

この内容はipynb形式で書かれたものをmarkdownに変換したものです.
originialのnotebookはgithubで公開しています.

時系列グラフ(3次元グラフ)

spatial temporal(時空間)とあるように,GCNの畳み込みを時間方向に拡張したモデルです.
GCNで扱っていたグラフは2次元グラフでしたが,ST-GCNでは時系列情報も持った3次元グラフです(Fig.1).
この時フレーム間のノードの接続方法はいくつか考えられますが,とりあえずは簡単のため同一のノードを結ぶエッジとします.
3次元グラフも2次元グラフ同様\mathcal{G}=\{\mathcal{V,E}\}と表現できます.
この時ノードの集合\mathcal{V}=\{v_{it}\mid i = 1, \ldots,V;i = 1, \ldots,T\}であり,エッジに関しては一旦空間方向だけの集合\mathcal{E}=\{e_{ij}\}で定義しておきます.
理由としては時間方向のエッジは同一ノードを繋ぐだけなので,隣接行列を使わなくても普通の畳み込みで対応できるからです.この後の実装などで理解
できると思うので一旦スルーしましょう.

Fig.1 グラフ構造

ここで,プログラミングで扱えるようにするためFig.1の3次元グラフをtensorに置き換えます.
2次元の特徴量は\mathbf{X}_t=[\mathbf{x}_1,\mathbf{x}_2,\mathbf{x}_3,\mathbf{x}_4,\mathbf{x}_5]^\topとして表現できることは前回理解しましたね.
これをさらに3次元方向に重ねて\mathbf{X}=\{\mathbf{X}_t\}(\in\mathbb{R}^{T\times V\times C})と表現できます(Fig.2).

Fig.2 グラフ特徴と隣接行列

データ作成

せっかくなので,APIを利用した姿勢推定からデータを作成してみましょう.
ここではグーグル先生が公開してくれているMediaPipeを利用してみることとします.
裏で色々手こずりました.
Apple siliconは公式には対応していないみたいです.
Apple silicon環境の方はpip install mediapipe-siliconそれ以外の方はpip install mediapipeでインストールしてみてください!
動画はvideoACから拝借しました(https://video-ac.com/video/45).

https://video-ac.com/video/45
手元で試す際は適当な動画をasset/spatial-temporal-graph-convolutional-network/video.mp4に保存して使用してください.

事前準備が長くなりましたが,早速姿勢推定をしていきます.
基本的には公式のSolutionsに基づいて実装するので詳しい説明は割愛します.

import cv2
import numpy as np
import mediapipe as mp

# 今回使うグラフのセッティング
num_node = 17
E = [[15,13],[13,11],[16,14],[14,12],[11,12],[5,11],
     [6,12],[5,6],[5,7], [6,8],[7,9],[8,10],[1,2],
     [0,1],[0,2],[1,3],[2,4],[3,5],[4,6]]
reduced_keypoints = [0,2,5,7,8,11,12,13,14,15,16,23,24,25,26,27,28] # 関節数が多すぎてごちゃつくので必要な分だけピックアップ

# APIで推定した座標をnumpy配列に変換
def make_spatial_feature_mx(hand_landmarks, width, height):
    spatial_mx = []
    for i, lm in enumerate(hand_landmarks.landmark):
        if i in reduced_keypoints:
            x = lm.x * width
            y = lm.y * height
            spatial_mx.append([x, y])
    spatial_mx = np.array(spatial_mx)
    return spatial_mx   # shape (V, C)
    
# videoの読み込み
cap = cv2.VideoCapture("asset/spatial-temporal-graph-convolutional-network/video.mp4")
fps = cap.get(cv2.CAP_PROP_FPS)
width = cap.get(cv2.CAP_PROP_FRAME_WIDTH)
height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT)

X = []
#pose推定の設定
mp_pose = mp.solutions.pose
with mp_pose.Pose(
    min_detection_confidence=0.5,   
    min_tracking_confidence=0.5) as pose:
    
    idx = 0
    while True:
        cap.set(cv2.CAP_PROP_POS_FRAMES, idx * fps) # 動きがわかりやすいように1秒ごと取得
        success, image = cap.read()
        if not success:
            break

        image.flags.writeable = False
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        results = pose.process(image)
        
        X_t = make_spatial_feature_mx(results.pose_landmarks, width, height)    # 各フレームでの特徴行列を作成
        X.append(X_t)   # 時間方向に追加する
        idx += 1
        
X = np.array(X)
X.shape # T V C
(6, 17, 2)

ポーズ推定が終わったので,可視化してみましょう.

%matplotlib inline
from matplotlib import pyplot as plt
        
fig, ax = plt.subplots(2, 3, figsize=(16.0, 6.0) )
for i in range(6):
    r = i // 3
    c = i % 3
    # エッジを描画
    for e in E:
        ax[r,c].plot([X[i, e[0], 0], X[i, e[1], 0]], [X[i, e[0], 1], X[i, e[1], 1]], "c")
    # ノードを描画
    ax[r,c].scatter(X[i, :, 0], X[i, :, 1], s=7)
    # その他設定
    ax[r,c].set_xlim([0, width])
    ax[r,c].set_ylim([0, height])
    ax[r,c].set_aspect('equal', 'box')
    ax[r,c].set_xticks([])
    ax[r,c].set_yticks([])
    ax[r,c].invert_yaxis()

どうでしょうか?ヨガのポーズをしているグラフが描画できましたね.
このようにST-GCNなどのGraphベースの行動認識では前段の姿勢推定が必須です.
今回のようにAPIを活用すると比較的簡単に実装できますし,より精度や速度を求めてGitHubなどに上がっているソースから実装してもいいですね.

ついでに前回同様,隣接行列\mathbf{A}と正規化した隣接行列\mathbf{D}^{-\frac{1}{2}}\mathbf{A}\mathbf{D}^{-\frac{1}{2}}も定義してしまいます.

def edge2mat(E, num_node):
    A = np.zeros((num_node, num_node))
    for i, j in E:
        A[j, i] = 1
    return A

def get_D(A, pow=-1):
    d_ii = np.sum(A, 0)
    D = np.zeros_like(A)
    for i in range(len(A)):
        D[i, i] = d_ii[i]**(pow)
    return D

reversed_E = [[j, i] for [i, j] in E]
I = [[i, i] for i in range(num_node)] 
new_E = E + reversed_E + I

# 隣接行列の作成
A = edge2mat(new_E, num_node)  # ノードの数は5

# DAD
D = get_D(A,pow=-0.5)
DAD = D @ A @ D

fig, ax = plt.subplots(1, 2)
ax[0].pcolor(A, cmap=plt.cm.Blues)
ax[0].set_aspect('equal', 'box')
ax[0].set_title('A', fontsize=10)
ax[0].invert_yaxis()

ax[1].pcolor(DAD, cmap=plt.cm.Blues)
ax[1].set_aspect('equal', 'box')
ax[1].set_title('DAD', fontsize=10)
ax[1].invert_yaxis()

今回はこのような6フレーム分の時系列姿勢データ\mathbf{X}とその隣接行列\mathbf{D}^{-\frac{1}{2}}\mathbf{A}\mathbf{D}^{-\frac{1}{2}}を例に,ST-GCNを理解していきましょう!

ST-GCN

ただのGCNでもノードとか隣接行列とかいっぱいあって難しかったのに,時空間に畳み込むなんて複雑そう..と思ったかもしれませんが,GCNさえ理解できていればST-GCNは意外と単純です.
まずは図を使いながら理解していきましょう!
時空間の畳み込みと言いつつ,実は空間と時間は別の処理で畳み込みます.

空間畳み込み

まず,空間方向の畳み込みは前回のGCNを使用します.
このとき,時間方向を一気に扱いたいため,実装上は1x1Convを用いるので,それに準じて図示したものがFig.3となります.
理解が難しかった人は前回の記事に戻ってFig.8あたりを確認してみてください.
どうですか?処理自体は全く同じで,入力に時間方向の情報が増えただけですね!

Fig.3 空間方向の畳み込み

ではここまでの処理を実装していきます.

import torch
from torch import nn
from einops import rearrange

X_tensor = torch.tensor(X, dtype=torch.float32)
X_tensor = rearrange(X_tensor, "T V C -> C T V")    # チャネルファースト
A_tensor = torch.tensor(A, dtype=torch.float32)
D_tensor = torch.tensor(D, dtype=torch.float32)

# 重みをConvとして設定
in_channels = 2      # x座標とy座標なので
out_channels = 8     # 適当です

spatial_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)  # 1x1conv

# 行列計算
XW = spatial_conv(X_tensor)
# einsumを使います.
# einsum使いたくないのであれば,for文で回すか,隣接行列をフレーム分拡張するなどの実装もできると思います.
spatial_X_tensor = torch.einsum('CTV,VW->CTW', XW, D_tensor @ A_tensor @ D_tensor)

print(f"入力のサイズ:{X_tensor.size()}")
print(f"出力のサイズ:{spatial_X_tensor.size()}")
入力のサイズ:torch.Size([2, 6, 17])
出力のサイズ:torch.Size([8, 6, 17])

今回は3次元データのため可視化は省略しましたが,特徴量次元が2→8に増えていることが確認できます.
これは畳み込みと隣接行列によってフレームごと隣接ノードを畳み込めたことを示しています.

時間畳み込み

次にフレーム間(時間方向)の畳み込みをしていきます.
空間方向ではノードによって隣接ノードの数が2つだったり3つだったりと異なっていましたよね?
では時間方向はどうでしょう?どのノードも前後のフレームとの接続しかないので2つだけですね.
つまりGCNのように隣接行列を使わなくても,Fig.4のように画像みたいに畳み込むことができます.
これが最初に空間方向のエッジしか定義しなかった理由です(エッジはあるんだけど,定義しなくても自明的なイメージですね).
少し噛み砕くと,ノード方向に1,時間方向に\tauのカーネルを用いてノード方向のストライドは1にして畳み込むと,ノードごとに時間方向が集約できます.
実際の実装ではpaddingを使ったり,時間方向のstrideを2にしてpoolingを兼ねた畳み込みのようにしたりなどの工夫がされています.

Fig.4 空間方向の畳み込み

では先ほどの空間畳み込みで集約された出力を時間方向に畳み込んでみましょう.

t_kernel = 3    # 何フレーム分畳み込むかの設定
t_stride = 1    # 時間方向のストライド
padding = ((t_kernel - 1) // 2, 0)  # 入力と同じサイズの出力にするために
temporal_conv = nn.Conv2d(out_channels, out_channels, kernel_size=(t_kernel, 1), stride=(t_stride, 1), padding=padding)

temporal_X_tensor = temporal_conv(spatial_X_tensor)

print(f"入力サイズ:{temporal_X_tensor.size()}")
print(f"出力サイズ:{temporal_X_tensor.size()}")
入力サイズ:torch.Size([8, 6, 17])
出力サイズ:torch.Size([8, 6, 17])

実装でみてみると時空間の畳み込みも思ったよりシンプルだったと思います.
ST-GCNではこの時間方向と空間方向の畳み込みを繰り返すことで時系列情報を抽出しています.
使い方的には画像のCNNと同じですね!

PyTroch Layerを作ってみる

STBlockとしてレイヤーを作成してみましょう!
(公式の実装とは異なります)

class STBlock(nn.Module):
    def __init__(self, in_features, out_features, t_kernel=9, t_stride=1):
        super(STBlock, self).__init__()
        # args
        self.in_features = in_features
        self.out_features = out_features
        padding = ((t_kernel - 1) // 2, 0)
        
        # layer
        self.spatial_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        self.temporal_conv = nn.Conv2d(out_channels, out_channels, kernel_size=(t_kernel, 1), stride=(t_stride, 1), padding=padding)
        
    def forward(self, input, adj):
        """
        Args:
            input (Tensor): graph feature
                            input.size() = (N, C, V, T)
            adj (Tensor): normalized adjacency matrix.
                          e.g. DAD or DA
                          input.size() = (V, V)
        Returns:
            Tensor: out.size() = (N, C_out, V, T)
        """
        XW = self.spatial_conv(input)
        spatial_feature = torch.einsum('NCTV,VW->NCTW', XW, adj)
        temporal_feature = self.temporal_conv(spatial_feature)
        return temporal_feature
from torch.utils.data import TensorDataset, DataLoader

# フェイクデータ
X = np.random.randn(10, 2, 10, 17) # num_data, node_feature, time, num_node
y = [0] * 5 + [1] * 5
X_tensor = torch.tensor(X, dtype=torch.float32)
y_tensor = torch.tensor(y, dtype=torch.int64)
data = TensorDataset(X_tensor, y_tensor)
loader = DataLoader(data, batch_size=10)

# DADの定義
A_tensor = torch.tensor(A, dtype=torch.float32)
D_tensor = torch.tensor(D, dtype=torch.float32)
DAD = D_tensor @ A_tensor @ D_tensor

# モデルの定義
model = STBlock(in_channels, out_channels)

# バッチ処理
for input, label in loader:
    new_X_tensor = model(input, DAD)

new_X_tensor.size()
torch.Size([10, 8, 10, 17])

隣接行列の発展

行動認識をする上でより効率的な隣接行列の設定方法が提案されています.
ST-GCNの論文の5ページの図に何やら色々な集約方法が書いてありそうですね.
この章ではこれのそれぞれの意味を理解していきます.

ここからとりあえず文字だけで説明して,そのあと実装を用いてそれぞれの隣接行列を可視化します.

Uni-labeling

論文中では最も単純な分割戦略と説明されていますが,分割は行いません.
これは今までの説明でも使ってきた\mathbf{X}とその隣接行列\mathbf{D}^{-\frac{1}{2}}\mathbf{A}\mathbf{D}^{-\frac{1}{2}}と同じで,隣接ノードを全て等価と見なす方法です.

Distance partitioning

ここからは隣接行列を分割して複数定義する方法となります.
Distance pertitioningはその中でも最も単純な方法となり\mathbf{A}\mathbf{A}_0\mathbf{A}_1の2つに分割します.

  • \mathbf{A}_0はノード自身への接続(自己ループ)を指します.
  • \mathbf{A}_1はそれ以外の接続を指します.

Spatial configuration partitioning

ここでは\mathbf{A}\mathbf{A}_0\mathbf{A}_1\mathbf{A}_2の2つに分割します.
論文中では数式などを用いて難しそうに書いてありますが,実はそんなに大したことはしていません.

  • \mathbf{A}_0はノード自身への接続(自己ループ)を指します.
  • \mathbf{A}_1は手首→肘のような体の外から内に向けた接続.
  • \mathbf{A}_2は肘→手首のような体の内から外に向けた接続.

ではこれらを実装してみましょう!
と,実はその前に今まで何も言わずに使ってきエッジEですが,実は内向きの接続で書かれていました.
そこで,区別のため変数名を変えて定義します.

# 内向きエッジ
inward = E
# 外向きエッジ
outward = reversed_E
# 自己ループ
self_loop = I
print(f"inward:{inward}")
print(f"outward:{outward}")
print(f"self-loop:{self_loop}")
inward:[[15, 13], [13, 11], [16, 14], [14, 12], [11, 12], [5, 11], [6, 12], [5, 6], [7, 5], [8, 6], [9, 7], [10, 8], [1, 0], [2, 0], [3, 1], [4, 2], [3, 5], [4, 6]]
outward:[[13, 15], [11, 13], [14, 16], [12, 14], [12, 11], [11, 5], [12, 6], [6, 5], [5, 7], [6, 8], [7, 9], [8, 10], [0, 1], [0, 2], [1, 3], [2, 4], [5, 3], [6, 4]]
self-loop:[[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7], [8, 8], [9, 9], [10, 10], [11, 11], [12, 12], [13, 13], [14, 14], [15, 15], [16, 16]]

では,本題に戻ります.
公式実装は論文の数式とやや挙動が違うようなので,自前で実装してみようと思います!
(具体的には\mathbf{D}^{-\frac{1}{2}}\mathbf{A}\mathbf{D}^{-\frac{1}{2}}ではなく\mathbf{D}^{-1}\mathbf{A}で正規化されていたり,\mathbf{A}を分割して正規化せずに正規化した\mathbf{D}^{-1}\mathbf{A}を分割してたり)

# 今まではDADを手動で計算していましたがここからは多用するので関数化します
def get_DAD(A):
    d_ii = np.sum(A, 0) + 0.001 # 0を避けるため0.001を足す
    D = np.zeros_like(A)
    for i in range(len(A)):
        D[i, i] = d_ii[i]**(-0.5)
    DAD = np.dot(np.dot(D, A), D)
    return DAD

# 公式では条件分岐で3つの戦略が1つの関数にまとめてありますが.
# ここではわかりやすさ優先で別で定義します
# uni-labelingの関数
def uni_labeling(inward, outward, self_loop, num_node):
    A = np.zeros((1, num_node, num_node))   # 分割しないので(1, num_node, num_node)の配列を作成
    # A_0について
    A[0] = get_DAD(A=edge2mat(inward+outward+self_loop, num_node))
    return A

# distance partitioningの関数
def distance(inward, outward, self_loop, num_node):
    A = np.zeros((2, num_node, num_node))   # A0とA1に分割されるので(2, num_node, num_node)の配列を作成
    # A_0について
    A[0] = get_DAD(A=edge2mat(self_loop, num_node)) # self-loopのみ
    # A_1について
    A[1] = get_DAD(A=edge2mat(inward+outward, num_node))   # self-loop以外
    return A

# Spatial configuration partitioningの関数
def spatial(inward, outward, self_loop, num_node):
    A = np.zeros((3, num_node, num_node))   # A0,A1,A2に分割されるので(3, num_node, num_node)の配列を作成
    # A_0について
    A[0] = get_DAD(A=edge2mat(self_loop, num_node)) # self-loopのみ
    # A_1について
    A[1] = get_DAD(A=edge2mat(inward, num_node))   # inwardのみ
    # A_2について
    A[2] = get_DAD(A=edge2mat(outward, num_node))   # outwardのみ
    return A

とりあえず一通り関数の定義ができたので,実際に可視化してみましょう!

# uni-labelingについて
print("Uni-labeling")
A = uni_labeling(inward, outward, self_loop, num_node)
fig, ax = plt.subplots(1, 3)
ax[0].pcolor(A[0], cmap=plt.cm.Blues, vmin=0, vmax=1)
ax[0].set_aspect('equal', 'box')
ax[0].set_title('A', fontsize=10)
ax[0].invert_yaxis()
ax[1].set_aspect('equal', 'box')
ax[1].invert_yaxis()
ax[2].set_aspect('equal', 'box')
ax[2].invert_yaxis()
plt.show()

# distance partitioningについて
print("Distance partitioning")
A = distance(inward, outward, self_loop, num_node)
fig, ax = plt.subplots(1, 3)
ax[0].pcolor(A[0], cmap=plt.cm.Blues, vmin=0, vmax=1)
ax[0].set_aspect('equal', 'box')
ax[0].set_title('self-loop', fontsize=10)
ax[0].invert_yaxis()
ax[1].pcolor(A[1], cmap=plt.cm.Blues, vmin=0, vmax=1)
ax[1].set_aspect('equal', 'box')
ax[1].set_title('other', fontsize=10)
ax[1].invert_yaxis()
ax[2].set_aspect('equal', 'box')
ax[2].invert_yaxis()
plt.show()

# Spatial configuration partitioningについて
print("Spatial configuration partitioning")
A = spatial(inward, outward, self_loop, num_node)
fig, ax = plt.subplots(1, 3)
ax[0].pcolor(A[0], cmap=plt.cm.Blues, vmin=0, vmax=1)
ax[0].set_aspect('equal', 'box')
ax[0].set_title('self-loop', fontsize=10)
ax[0].invert_yaxis()

ax[1].pcolor(A[1], cmap=plt.cm.Blues, vmin=0, vmax=1)
ax[1].set_aspect('equal', 'box')
ax[1].set_title('inward', fontsize=10)
ax[1].invert_yaxis()

ax[2].pcolor(A[2], cmap=plt.cm.Blues, vmin=0, vmax=1)
ax[2].set_aspect('equal', 'box')
ax[2].set_title('outword', fontsize=10)
ax[2].invert_yaxis()
plt.show()
Uni-labeling

Distance partitioning

Spatial configuration partitioning


どうでしょうか?なんとなく違いは分かりましたか?
下になるほどエッジの分割を細分化しています.
実装してみてみると大したことはしてない(著者に怒られろ)のですが,Spatialを使うとUni(今までの定義)を使うより精度が10%以上上がるらしいですよ.

ただ,今まで\mathbf{D}^{-\frac{1}{2}}\mathbf{A}\mathbf{D}^{-\frac{1}{2}}\mathbf{X}\mathbf{W}を使って畳み込んできたのに\mathbf{A}が増えたら計算できないですよね?
そこで提案されている方法が,それぞれの\mathbf{A}_jの隣接行列で畳み込んだ結果を足しちゃうって方法です.

\mathbf{X}_\mathrm{out}=\sum_j \boldsymbol{\Lambda}_j^{-\frac{1}{2}} \mathbf{A}_j \boldsymbol{\Lambda}_j^{-\frac{1}{2}}\mathbf{X}\mathbf{W}_j

\boldsymbol{\Lambda}は今まで\mathbf{D}で表してたような次数行列で,\mathbf{A}_jごとに定義されています(このへんは先ほどの実装ですでにget_DAD()として組み込まれています).
ここまで理解できているみなさんなら足すだけなので実装できるかもしれませんが,一応PyTorchで実装してみます.
まず,シンプルにさっきのレイヤーを拡張するパターン

class STBlock2(nn.Module):
    def __init__(self, in_features, out_features, len_A, t_kernel=9, t_stride=1): # len_Aとして隣接行列の分割数を追加 
        super(STBlock2, self).__init__()
        # args
        self.in_features = in_features
        self.out_features = out_features
        padding = ((t_kernel - 1) // 2, 0)
        
        # layer
        self.spatial_conv = nn.ModuleList((nn.Conv2d(in_channels, out_channels, kernel_size=1) for _ in range(len_A)))  # Aの数だけconvを定義
        self.temporal_conv = nn.Conv2d(out_channels, out_channels, kernel_size=(t_kernel, 1), stride=(t_stride, 1), padding=padding)
        
    def forward(self, input, adj):
        """
        Args:
            input (Tensor): graph feature
                            input.size() = (N, C, V, T)
            adj (Tensor): normalized adjacency matrix.
                          e.g. DAD or DA
                          input.size() = (K, V, V)
        Returns:
            Tensor: out.size() = (N, C_out, V, T)
        """
        spatial_feature = None
        for module, a in zip(self.spatial_conv, adj):
            XW = module(input)
            DADXW = torch.einsum('NCTV,VW->NCTW', XW, a)
            if spatial_feature is not None:
                spatial_feature = spatial_feature + DADXW 
            else:
                spatial_feature = DADXW
        temporal_feature = self.temporal_conv(spatial_feature)
        return temporal_feature
# フェイクデータ
X = np.random.randn(10, 2, 10, 17)
y = [0] * 5 + [1] * 5
X_tensor = torch.tensor(X, dtype=torch.float32)
y_tensor = torch.tensor(y, dtype=torch.int64)
data = TensorDataset(X_tensor, y_tensor)
loader = DataLoader(data, batch_size=10)

# Aの定義
A = spatial(inward, outward, self_loop, num_node)
A = torch.tensor(A, dtype=torch.float32)

# モデルの定義
model = STBlock2(in_channels, out_channels, len_A=len(A))

# バッチ処理
for input, label in loader:
    new_X_tensor = model(input, A)

new_X_tensor.size()
torch.Size([10, 8, 10, 17])

これが素直に実装する方法ですが,公式実装では少し工夫されており,convのchnnel方向に拡張することで実装しています.
こっちの方がfor文を使わないのでスマートですね!

class STBlock2(nn.Module):
    def __init__(self, in_features, out_features, len_A, t_kernel=9, t_stride=1): # len_Aとして隣接行列の分割数を追加 
        super(STBlock2, self).__init__()
        # args
        self.in_features = in_features
        self.out_features = out_features
        padding = ((t_kernel - 1) // 2, 0)
        
        # layer
        self.spatial_conv = nn.Conv2d(in_channels, out_channels * len_A, kernel_size=1)# Aの数倍にchannelを拡張
        self.temporal_conv = nn.Conv2d(out_channels, out_channels, kernel_size=(t_kernel, 1), stride=(t_stride, 1), padding=padding)
        
    def forward(self, input, adj):
        """
        Args:
            input (Tensor): graph feature
                            input.size() = (N, C, V, T)
            adj (Tensor): normalized adjacency matrix.
                          e.g. DAD or DA
                          input.size() = (K, V, V)
        Returns:
            Tensor: out.size() = (N, C_out, V, T)
        """
        XW = spatial_conv(X)
        XW = rearrange(XW, "N (C K) T V -> N C K T V", C=self.out_features)
        spatial_feature = torch.einsum('NCKTV,KVW->NCTW', XW, adj)

        temporal_feature = self.temporal_conv(spatial_feature)
        return temporal_feature
torch.Size([10, 8, 10, 17])

終わりに

前回と今回でGCNについての記事は一旦終わりにしようと思います.
かなり独自の実装で解説しているので,公式の実装の解説は要望があればしようかなと思います.
他にもこんなモデルの解説をして欲しいとかあればコメントいただければ勉強しようと思います.

Discussion