🐷

GNN(グラフニューラルネットワーク)についてと、簡単な見える化アプリ

2024/09/18に公開

https://github.com/atarum/GraphNeuralNetworks
今回の参照記事は、こちら。
学んでる途中なので、機械学習のところは今回は飛ばします。
GNNの概要と、それが実装できそうな見える化アプリを簡単に作ってみます。

GNNとは?

GNNの定義

現代の機械学習では、画像や音声などのデータだけでなく、グラフデータと呼ばれる複雑な関係性を持つデータを扱う必要性が高まっています。例えば、SNS上のユーザーのつながりや、化学分子の構造などがこれに該当します。これらのデータを効率的に処理するために登場したのが、
GNN(グラフニューラルネットワーク)
です。

2. GNNの概要

2.1 GNNとは何か?

GNNは、グラフ構造を持つデータを処理するためのニューラルネットワークの一種です。従来のニューラルネットワークは、画像(2次元グリッド)やテキスト(1次元配列)などの固定的なデータ構造を扱うのに適していますが、ノード(点)とエッジ(線)で構成されるグラフデータを扱うことは困難でした。GNNはこの課題を解決し、グラフ上での情報伝達やパターン認識を可能にします。

用語解説

  • ノード(Node):グラフの頂点。オブジェクトや個体を表す。
  • エッジ(Edge):ノード間の関係性を表す線。
  • メッセージパッシング:ノード間で情報をやり取りするプロセス。
  • アテンション機構:重要度に応じて情報に重み付けを行う仕組み。

3. GNNの特徴・メリット・デメリット

3.1 特徴

  • 関係性の学習:各ノードが隣接するノードから情報を受け取り、自身の状態を更新します。
  • 汎用性:様々な分野のグラフデータに適用可能。

3.2 メリット

  • 複雑な関係性のモデリング:グラフ構造を直接扱うため、ノード間の複雑な関係性を学習できます。
  • 高い表現力:従来の手法では難しかったパターンや特徴を捉えることが可能です。

3.3 デメリット

  • 計算コストが高い:グラフのサイズが大きくなると計算量が増加します。
  • データの前処理が複雑:グラフデータの準備や処理には専門的な知識が必要です。

4. GNNの種類と機械学習タスクの例

4.1 機械学習タスクの例

  • ノード分類
    • 入力:グラフ構造と各ノードの特徴量。
    • 出力:各ノードのクラス(カテゴリ)ラベル。
    • :SNSでのユーザー属性の推定。
  • リンク予測
    • 入力:部分的なグラフ構造。
    • 出力:エッジ(リンク)の有無や強さの予測。
    • :友達推薦機能、化学結合の予測。
  • グラフ分類
    • 入力:複数のグラフデータ。
    • 出力:グラフ全体のクラスラベル。
    • :化合物の毒性判定、タンパク質の機能分類。

4.2.1 Graph Convolutional Network (GCN)

GCNは、グラフ上での畳み込み操作を実現したモデルです。畳み込みとは、画像処理で使われる技術で、周辺のピクセル情報を利用して特徴を抽出します。GCNはこれをグラフ構造に拡張したものです。低次元

4.2.2 Graph Attention Network (GAT)

GATは、各エッジに注意(アテンション)機構を導入したモデルです。これにより、重要な隣接ノードからの情報に重みを置いて学習できます。

4.2.3 GraphSAGE

GraphSAGEは、大規模なグラフデータにも適用可能なサンプリング手法を導入したモデルです。隣接ノードの一部をサンプリングして計算することで、効率化を図ります。

4.3 機械学習タスクの具体例

  • ソーシャルネットワーク分析:友人関係の予測やコミュニティ検出
  • 化学物質の特性予測:新薬の開発や材料の特性評価
  • 推薦システム:ユーザーと商品間の関係性を利用した推薦

5. まとめ

GNNは、グラフデータを効率的に扱うための強力なツールであり、多くの分野で活用が期待されています。しかし、その学習には計算資源や専門知識が必要です。今後の研究や技術の進歩により、これらの課題が解決され、より広く普及することが期待されます。

可視化アプリを作ってみる。

勉強してきましたが、想像以上にまだまだ理解が足りないので、今回はグラフ構造を見える化するアプリでお茶を濁します。

理想の形はこれ、
https://neo4j.com/product/auradb/

今回は、Nextjsを用いて、これの機能面を実装してみます。

環境準備

nextjsの環境をインストールします。

npx create-next-app@latest my-gnn-app --typescript
Need to install the following packages:
  create-next-app@14.2.11
Ok to proceed? (y)Would you like to use ESLint?No / YesWould you like to use Tailwind CSS?No / YesWould you like to use `src/` directory?No / YesWould you like to use App Router? (recommended)No / YesWould you like to customize the default import alias (@/*)? … No / Yes

移動して、グラフ描画に必要なライブラリをインストールします。

cd my-gnn-app
npm install react-force-graph-2d

ライブラリはこちらです。
https://github.com/vasturiano/react-force-graph?tab=readme-ov-file

app/globals.cssを下記の形に書き直します。

globals.css
@tailwind base;
@tailwind components;
@tailwind utilities;

念の為、下記コードで、動く確認してください。

npm run dev

実装

フォルダー階層をいじっていきます。下記の画像のように、フォルダーとファイルを作っていきましょう。

具体的には、

  • components
    • GraphVisualization.tsx
    • InputForm.tsx

を作成します。

グラフ情報を可視化

src/components/GraphVisualization.tsx
import React, { useEffect, useState } from 'react';

interface Node {
  id: string;
  name: string;
}

interface Link {
  source: string;
  target: string;
  name: string;
}

interface GraphVisualizationProps {
  nodes: Node[];
  links: Link[];
}

const GraphVisualization: React.FC<GraphVisualizationProps> = ({ nodes, links }) => {
  const [ForceGraph2D, setForceGraph2D] = useState<any>(null);

  useEffect(() => {
    import('react-force-graph-2d').then((module) => {
      setForceGraph2D(() => module.default);
    });
  }, []);

  if (!ForceGraph2D) {
    return <div>Loading...</div>;
  }

  const data = {
    nodes: nodes,
    links: links,
  };

  return (
    <div className="w-full h-screen">
      <ForceGraph2D
        graphData={data}
        nodeAutoColorBy="id"
        linkDirectionalArrowLength={6}
        linkDirectionalArrowRelPos={1}
        nodeCanvasObject={(node: any, ctx: CanvasRenderingContext2D, globalScale: number) => {
          const label = node.name;
          const fontSize = 12 / globalScale;
          ctx.font = `${fontSize}px Sans-Serif`;
          ctx.fillStyle = 'black';
          ctx.textAlign = 'center';
          ctx.textBaseline = 'middle';
          ctx.fillText(label, node.x, node.y);
        }}
        linkCanvasObjectMode={() => 'after'}
        linkCanvasObject={(link: any, ctx: CanvasRenderingContext2D, globalScale: number) => {
          const label = link.name;
          if (!label) return;

          const start = link.source;
          const end = link.target;

          // 中間点を計算
          const x = (start.x + end.x) / 2;
          const y = (start.y + end.y) / 2;

          const fontSize = 12 / globalScale;
          ctx.font = `${fontSize}px Sans-Serif`;
          ctx.fillStyle = 'red';
          ctx.textAlign = 'center';
          ctx.textBaseline = 'middle';
          ctx.fillText(label, x, y);
        }}
      />
    </div>
  );
};

export default GraphVisualization;

グラフ情報を入力・削除

src/components/InputForm.tsx
"use client";

import React, { useState } from 'react';

interface Node {
  id: string;
  name: string;
}

interface Link {
  source: string;
  target: string;
  name: string;
}

interface InputFormProps {
  nodes: Node[];
  links: Link[];
  onDataSubmit: (nodes: Node[], links: Link[]) => void;
  onDeleteNode: (nodeId: string) => void;
}

const InputForm: React.FC<InputFormProps> = ({ nodes, links, onDataSubmit, onDeleteNode }) => {
  // ノードの状態
  const [nodeId, setNodeId] = useState('');
  const [nodeName, setNodeName] = useState('');

  // リンクの状態
  const [sourceId, setSourceId] = useState('');
  const [targetId, setTargetId] = useState('');
  const [linkName, setLinkName] = useState('');

  const handleAddNode = () => {
    if (nodeId.trim() === '' || nodeName.trim() === '') return;
    // 既存のノードと重複チェック
    if (nodes.find(node => node.id === nodeId.trim())) {
      alert('同じIDのノードが既に存在します。');
      return;
    }
    const newNode = { id: nodeId.trim(), name: nodeName.trim() };
    onDataSubmit([...nodes, newNode], links);
    setNodeId('');
    setNodeName('');
  };

  const handleAddLink = () => {
    if (sourceId.trim() === '' || targetId.trim() === '' || linkName.trim() === '') return;
    // ノードが存在するかチェック
    if (!nodes.find(node => node.id === sourceId.trim()) || !nodes.find(node => node.id === targetId.trim())) {
      alert('指定されたノードIDが存在しません。');
      return;
    }
    const newLink = { source: sourceId.trim(), target: targetId.trim(), name: linkName.trim() };
    onDataSubmit(nodes, [...links, newLink]);
    setSourceId('');
    setTargetId('');
    setLinkName('');
  };

  const handleSubmit = () => {
    onDataSubmit(nodes, links);
  };

  return (
    <div className="p-4 bg-white rounded shadow mb-4 w-full max-w-md">
      <h2 className="text-2xl font-bold mb-4">ノードの追加</h2>
      <div className="flex flex-col mb-4">
        <input
          type="text"
          placeholder="ノードID"
          value={nodeId}
          onChange={(e) => setNodeId(e.target.value)}
          className="border p-2 mb-2"
        />
        <input
          type="text"
          placeholder="ノード名"
          value={nodeName}
          onChange={(e) => setNodeName(e.target.value)}
          className="border p-2 mb-2"
        />
        <button onClick={handleAddNode} className="bg-blue-500 text-white px-4 py-2 rounded">
          ノードを追加
        </button>
      </div>
      <h2 className="text-2xl font-bold mb-4">ノード一覧</h2>
      <ul className="mb-4">
        {nodes.map((node) => (
          <li key={node.id} className="flex justify-between items-center border-b p-2">
            <span>{node.name} ({node.id})</span>
            <button onClick={() => onDeleteNode(node.id)} className="bg-red-500 text-white px-2 py-1 rounded">
              削除
            </button>
          </li>
        ))}
      </ul>
      <h2 className="text-2xl font-bold mb-4">リンクの追加</h2>
      <div className="flex flex-col mb-4">
        <input
          type="text"
          placeholder="ソースノードID"
          value={sourceId}
          onChange={(e) => setSourceId(e.target.value)}
          className="border p-2 mb-2"
        />
        <input
          type="text"
          placeholder="ターゲットノードID"
          value={targetId}
          onChange={(e) => setTargetId(e.target.value)}
          className="border p-2 mb-2"
        />
        <input
          type="text"
          placeholder="リンク名(関係性)"
          value={linkName}
          onChange={(e) => setLinkName(e.target.value)}
          className="border p-2 mb-2"
        />
        <button onClick={handleAddLink} className="bg-blue-500 text-white px-4 py-2 rounded">
          リンクを追加
        </button>
      </div>
      <button onClick={handleSubmit} className="bg-green-500 text-white px-4 py-2 rounded w-full">
        グラフを更新
      </button>
    </div>
  );
};

export default InputForm;

メインページを編集

src/app/page.tsx
"use client";

import { NextPage } from 'next';
import Head from 'next/head';
import { useState } from 'react';
import InputForm from '../components/InputForm';
import GraphVisualization from '../components/GraphVisualization';

interface Node {
  id: string;
  name: string;
}

interface Link {
  source: string;
  target: string;
  name: string;
}

const Home: NextPage = () => {
  const [nodes, setNodes] = useState<Node[]>([]);
  const [links, setLinks] = useState<Link[]>([]);

  const handleDataSubmit = (newNodes: Node[], newLinks: Link[]) => {
    setNodes(newNodes);
    setLinks(newLinks);
  };

  const handleDeleteNode = (nodeId: string) => {
    // ノードを削除
    const updatedNodes = nodes.filter(node => node.id !== nodeId);
    // 削除されたノードに関連するリンクを削除
    const updatedLinks = links.filter(link => link.source !== nodeId && link.target !== nodeId);
    setNodes(updatedNodes);
    setLinks(updatedLinks);
  };

  return (
    <div>
      <Head>
        <title>GNNグラフ可視化アプリ</title>
        <meta name="description" content="Next.jsとTypeScriptで作るGNNグラフ可視化アプリ" />
      </Head>
      <main className="flex flex-col items-center justify-start min-h-screen bg-background p-4">
        <h1 className="text-4xl font-bold my-8 text-foreground">
          GNNグラフ可視化アプリ
        </h1>
        <InputForm
          nodes={nodes}
          links={links}
          onDataSubmit={handleDataSubmit}
          onDeleteNode={handleDeleteNode}
        />
        <GraphVisualization nodes={nodes} links={links} />
      </main>
    </div>
  );
};

export default Home;

下記のような状態に習ったら、成功です。何かを入力してみて、ください。

Discussion