## GNNを学ぶ Vol.1:グラフニューラルネットワーク入門
こんにちは、古閑です。
前回までで、GISデータの基本(ベクター・ラスター)や、QGISでのデータ選択方法について学びました。今回からは角度を変えてグラフニューラルネットワークについても学習ヲ進めていこうと思います。
今回は、**グラフニューラルネットワーク(GNN)**の基礎を学ぶため、
『Hands-On Graph Neural Networks Using Python: Practical techniques and architectures for building powerful graph and deep learning apps with PyTorch (English Edition)』
という書籍を読み始めました。
GISとGNNとの隣接点として、15章の「Forecasting Traffic Using A3T-GCN」から取り組んでいきます。今回読んだ内容は、交通渋滞の予測が現実世界でどんな効果があるのか?という説明。
これは、最近のスマートシティという流行に対してGNNを用いた交通予測がもたらす利益について説明しています。
また、この章で実際に行う内容として、
*生のCSVファイルを探して加工すること、
*T-GNNを用いて新しい個通予測を適用すること、
*結果を可視化して比較すること、
を実行します。
1. GNNは何を学んでいるのか?
GNNは、グラフ構造のデータを扱うための機械学習モデルです。ここでいうグラフとは、ノード(点)とエッジ(線)で構成される、ネットワークのようなデータのことです。
GNNが学習する際に重要な要素は、主に二つあります。
- ノードの特徴量(Feature): ノードが持つ個々の情報です。例えば、圃場のノードであれば「地表面温度」や「NDVI」などがこれに当たります。
- ノード間の関係性(トポロジー): ノードがどのように繋がっているか、というネットワーク全体の構造です。
GNNの最も興味深い点は、この二つの要素を同時に学習し、ノードの情報をアップデートしていく点にあります。
2. グラフをコンピュータに理解させる方法
グラフをGNNに学習させるには、その構造を数値データに変換する必要があります。そのための主要なデータ構造が以下の二つです。
- 隣接行列(Adjacency matrix): どのノードとどのノードが繋がっているか(エッジ)を表現する行列です。
- 特徴行列(Feature matrix): 各ノードが持つ特徴量を数値で表現する行列です。
この二つの行列を組み合わせることで、GNNはグラフの構造と個々のノードの情報を同時に理解できるようになります。
3. データセットのダウンロード
「Hands on GNN」 のGitHubより
dataset:the Caltrans Performance Measurement System(PeMS)
を google colab を用いてダウンロードしました。
* データセットについて:
引用文
"""
このデータセットは、PeMSD7データセットの中規模版です。
このデータセットの元データは、カリフォルニア州運輸局の**PeMS(Performance Measurement System)**というシステムから取得されました。このシステムは、カリフォルニア州の主要な高速道路に設置された約39,000台のセンサーから、リアルタイムの交通速度データを収集しています。
今回使用するデータセットでは、その中からカリフォルニア州第7区に位置する228箇所のセンサー局のデータのみを対象としています。データは2012年5月と6月の平日に収集されたものです。
元々は30秒ごとの速度計測データですが、このデータセットでは5分間隔に集約されています。
"""
- データフレームの概要を表示します。
speeds.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 12672 entries, 0 to 12671
Columns: 228 entries, 0 to 227
dtypes: float64(228)
memory usage: 22.0 MB
* データセットの最初の5行を表示します。
speeds.head(5)
index | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 71.1 | 66.0 | 64.6 | 65.6 | 67.1 | 71.9 | 68.6 | 67.7 | 65.8 | 40.9 | 67.6 | 69.8 | 65.5 | 69.9 | 63.3 | 69.8 | 65.0 | 64.2 | 71.9 | 70.5 |
1 | 68.1 | 66.8 | 61.7 | 66.7 | 64.5 | 71.6 | 72.29999999999998 | 64.9 | 65.6 | 40.1 | 68.1 | 70.8 | 65.7 | 70.2 | 67.0 | 68.9 | 64.9 | 65.8 | 72.7 | 72.29999999999998 |
2 | 68.0 | 64.3 | 66.6 | 68.7 | 68.1 | 70.5 | 70.2 | 61.7 | 63.4 | 39.6 | 69.4 | 67.4 | 65.5 | 70.8 | 69.2 | 66.4 | 64.9 | 68.1 | 71.7 | 71.8 |
3 | 68.3 | 67.8 | 65.9 | 66.6 | 67.9 | 70.3 | 69.8 | 67.6 | 63.2 | 37.6 | 67.0 | 67.6 | 65.2 | 71.1 | 66.0 | 66.1 | 64.9 | 65.6 | 70.9 | 71.5 |
4 | 68.9 | 69.5 | 61.2 | 67.4 | 64.0 | 68.1 | 67.0 | 66.7 | 64.2 | 36.8 | 66.6 | 68.5 | 65.2 | 69.8 | 70.0 | 68.6 | 64.9 | 62.7 | 71.1 | 69.0 |
- データフレームの概要を表示します。
distances.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 228 entries, 0 to 227
Columns: 228 entries, 0 to 227
dtypes: float64(228)
memory usage: 406.3 KB
- データセットの最初の5行を表示します。
distances.head(5)
index | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 0.0 | 3165.94 | 8731.54 | 11903.45 | 7757.18 | 19878.51 | 18436.35 | 2213.47 | 5887.74 | 16132.12 | 13615.5 | 4269.59 | 9099.25 | 25422.62 | 13520.72 | 5371.21 | 16011.3 | 15778.86 | 18061.91 | 17633.57 |
1 | 3165.94 | 0.0 | 5625.76 | 8749.35 | 4695.51 | 16716.96 | 15274.0 | 1037.65 | 2910.29 | 12966.19 | 10451.78 | 1330.37 | 5982.33 | 22258.06 | 10357.2 | 2391.12 | 13509.06 | 13160.73 | 14952.42 | 15294.95 |
2 | 8731.54 | 5625.76 | 0.0 | 3280.12 | 1035.65 | 11465.75 | 10027.42 | 6659.59 | 2905.29 | 7631.76 | 5065.2 | 4470.84 | 387.42 | 16912.7 | 4968.68 | 3392.76 | 8920.21 | 8335.059999999998 | 10189.58 | 10953.16 |
3 | 11903.45 | 8749.35 | 3280.12 | 0.0 | 4315.55 | 8193.1 | 6759.32 | 9761.19 | 6171.81 | 4355.25 | 1790.75 | 7688.279999999999 | 2892.84 | 13633.8 | 1694.58 | 6641.74 | 8214.35 | 7388.649999999999 | 7128.46 | 10322.56 |
4 | 7757.18 | 4695.51 | 1035.65 | 4315.55 | 0.0 | 12500.86 | 11062.04 | 5733.14 | 1889.16 | 8667.35 | 6100.79 | 3487.65 | 1423.07 | 17947.24 | 6004.27 | 2391.72 | 9351.43 | 8847.3 | 11188.26 | 11320.83 |
引用文
"""
このデータセットでやりたいことは、交通速度の変化を可視化することです。これは、季節性といった特徴が非常に役立つため、時系列予測では典型的な作業です。一方、非定常な時系列データは、使用する前にさらなる処理が必要になる場合があります。
"""
データの可視化:時系列グラフで時間の変化を見る
GNNの応用例として交通予測を学ぶ上で、まずはじめに行うべきは、データセットを可視化して、その中にどんなパターンが隠されているかを把握することです。
今回使用するデータセットは、5分間隔で交通速度を記録したものです。この交通速度が時間とともにどのように変化しているかを可視化するには、時系列グラフが最も適しています。
plt.figure(figsize=(10, 5))
plt.plot(speeds)
plt.grid(linestyle=':')
plt.xlabel('Time(5 min)')
plt.ylabel('Traffic speed')
このコードを実行すると、横軸が**「5分間隔の時間」、縦軸が「交通速度」**となる折れ線グラフが生成されます。
このグラフから、交通速度が朝や夕方のピーク時にどのように変動するのか、あるいは特定の曜日にどんな傾向があるのかといった、**「季節性」や「トレンド」**を直感的に読み取ることができます。
この可視化は、GNNにデータを学習させる前の、非常に重要な第一歩となります。データの全体像を把握することで、次にどのようなモデルや処理が必要になるのかが見えてくるのです。
今回は、matplotlib
というライブラリを使って、交通速度を時系列で可視化にも挑戦しました。
4. まとめと今後の展望
GNNの基礎を学ぶことで、GISデータが持つ「空間的な関係性」という強力な情報を、機械学習に活かせる可能性が少し見えてきた気がします。
本日はPeMS-M dataset のダウンロードとpandasへの変換。そしてmatplotliibでの視覚化をおこなっいました。英文の文章でなかなか取り組みが進まなかったので、ブログでアウトプットしながら読むと「ブログに書かなきゃ!」というモチベーションがでるので読む気力が湧きます。次回はdatasetをGNNで交通予測に取り組みたいと思います。
参考文献
- 『Hands-On Graph Neural Networks Using Python: Practical techniques and architectures for building powerful graph and deep learning apps with PyTorch (English Edition) 』第15章
Discussion