🤖

LangGraphでつくる!プロンプトに応じてDBデータから自由にグラフ画像を生成するAIエージェント

に公開

モチベーション

仕様を決める際、本番データにはいっているお客様のデータを元に意思決定するときってありますよね。例えば電子カルテのプロダクトであれば、「DBを見たところ、クリニックに入っている予約は基本的に半年先までだから、最大で半年分のデータを一覧表示するときのパフォーマンスを確認すれば十分そう」みたいな具合です。
こういったケースで、さくっとChatGPTに問い合わせる感覚でDBのデータを俯瞰できたら嬉しいなと思い作ってみました。

作ったもの

大学での履修状況を想定し、学部・学生・教員・講義・履修成績の5つのテーブルを持つDBを作成しています。
今回はこのような構造のDBに対して、

$ python src/main.py --prompt 学部ごとの学生数を教えて

のようにプロンプトを投げると、DBへ問い合わせてデータを取得しグラフを図示するコードを実装しました。プロンプトはDBのデータに即したものであれば何でもOKです。

各テーブルのスキーマ

departments (学部)

カラム名 データ型 制約 説明
id SERIAL PRIMARY KEY 学部ID
name VARCHAR(100) NOT NULL 学部名

students (学生)

カラム名 データ型 制約 説明
id SERIAL PRIMARY KEY 学生ID
name VARCHAR(100) NOT NULL 氏名
year INT NOT NULL 学年(1〜4)
department_id INT NOT NULL, FOREIGN KEY → departments(id) 所属学部ID

professors (教員)

カラム名 データ型 制約 説明
id SERIAL PRIMARY KEY 教員ID
name VARCHAR(100) NOT NULL 氏名
department_id INT NOT NULL, FOREIGN KEY → departments(id) 所属学部ID

courses (講義)

カラム名 データ型 制約 説明
id SERIAL PRIMARY KEY 講義ID
name VARCHAR(100) NOT NULL 講義名
professor_id INT NOT NULL, FOREIGN KEY → professors(id) 担当教員ID
credits INT NOT NULL 単位数

enrollments (履修・成績)

カラム名 データ型 制約 説明
id SERIAL PRIMARY KEY 履修ID
student_id INT NOT NULL, FOREIGN KEY → students(id) 学生ID
course_id INT NOT NULL, FOREIGN KEY → courses(id) 講義ID
term VARCHAR(10) NOT NULL 学期(Spring等)
grade VARCHAR(2) 成績(A〜F等)

技術的な内容

今回実装したコードはこちらです。
https://github.com/puertocampo/prompt-to-query-to-graph

↓LangGraphのワークフロー図はこんな感じ↓

主な構成要素としては以下の4つです。

  1. SQLクエリを生成 (generate_sql_query)
  2. SQLクエリ実行 (exec_sql_query)
  3. グラフ表示コード生成 (generate_plot_code)
  4. グラフ表示コード実行 (execute_plot_code)

ここからは各構成要素の詳細について、プロンプト「学部ごとの学生数を教えて」を投げた場合の具体的なデータと合わせて見ていきます。

1. SQLクエリ生成 (generate_sql_query)

SQLクエリ生成時のプロンプト

以下のスキーマ構造を持つpostgreSQLにより構築されたrelationalデータベースに対して、「{user_prompt}」を実現するSQLクエリを生成してください。

スキーマ情報:
{schema}

※スキーマ情報については後述します

生成されたクエリ例は以下の通りです。

SELECT d.name AS department_name, COUNT(s.id) AS student_count
FROM departments d
LEFT JOIN students s ON d.id = s.department_id
GROUP BY d.id, d.name;

なお、ワークフロー図ではgenerate_sql_queryからcontinue, loop, endという3つの分岐が発生していますが、これは条件付きエッジと呼ばれています。
今回は、生成されたSQLがDBに副作用を与えることがないよう、以下のように分岐させています。

  • SELECT文のみ ... continueへ進み、そのままSQLクエリが実行されます。
  • INSERTなど破壊的なクエリを含む ... loopへ進み、再度クエリ生成を試みます。
  • クエリ生成を3回繰り返しても破壊的なクエリを含む ... endへ進み、実行を終了します。

2. SQLクエリ実行 (exec_sql_query)

文字列として受け取ったクエリをDBにそのまま問い合わせるだけです。
実行結果は以下の通りです。

[['Law', 24], ['Economics', 25], ['Science', 16], ['Literature', 20], ['Engineering', 15]]

3. グラフ表示コード生成 (generate_plot_code)

グラフ表示コード生成時のプロンプト

以下のデータを可視化するmatplotlibのPythonコードを生成してください。

データ:
{data}

可視化の要件:
{user_prompt}

以下の要件を満たすコードを生成してください:

  • dataをハードコーディングによって書き下し、引数の不要なコードとすること
  • dataの他にマスタテーブルのデータを用いる必要がある場合は、data内の値からマスタテーブルのデータを推論すること
  • 関数内で利用する変数が正しく定義済みかどうかを確認すること
  • Pythonによって記述されていること
  • matplotlibを使用すること
  • 日本語のラベルやタイトルが正しく表示されるように設定すること
  • グラフは見やすく、色分けを用いて理解しやすいものにすること
  • 必要に応じて適切な色やスタイルを設定すること
  • X軸・Y軸のラベルとグラフタイトルは英語で記述すること
  • 最終的にグラフをplt.show()で表示すること

前節で取得された実行結果の先頭にカラム名(['department_name', 'student_count'])も加えたデータと、グラフ化の解釈がぶれないよう元のプロンプトをあわせて投げてコードを生成します。

生成されたコード例は以下のとおりです。

import matplotlib.pyplot as plt
import numpy as np

# データをハードコーディング
data = [['department_name', 'student_count'], ['Law', 24], ['Economics', 25], ['Science', 16], ['Literature', 20], ['Engineering', 15]]

# データの抽出
departments = [row[0] for row in data[1:]]  # 学部名
student_counts = [row[1] for row in data[1:]]  # 学生数

# グラフの設定
plt.figure(figsize=(10, 6))  # グラフのサイズ
colors = ['#ff9999','#66b3ff','#99ff99','#ffcc99','#c2c2f0']  # 色の設定

# 棒グラフの作成
bars = plt.bar(departments, student_counts, color=colors)

# 各バーにラベルを追加
for bar in bars:
    yval = bar.get_height()
    plt.text(bar.get_x() + bar.get_width()/2, yval, int(yval), ha='center', va='bottom')

# タイトルとラベルの設定
plt.title('Student Count by Department')  # グラフタイトル
plt.xlabel('Department')  # X軸ラベル
plt.ylabel('Student Count')  # Y軸ラベル

# 日本語のラベルを設定
plt.xticks(rotation=45)  # X軸ラベルを45度回転
plt.grid(axis='y')  # Y軸にグリッドを追加

# グラフの表示
plt.tight_layout()  # レイアウトの調整
plt.show()

4. グラフ表示コード実行 (execute_plot_code)

pythonに組み込まれている関数 exec(code: str) を呼び出すだけです!

AIにドメイン知識をどう持たせるか問題

ここに関してはRAGやfine-tuningなど色んな方法がありますが、今回は最も簡単な、プロンプトにスキーマ情報を直接仕込む形で実装しました。
pg_descriptionなどのカタログデータからテーブルとそのカラムに付属しているコメントをjson形式に抽出しておき、これをそのままstringとしてSQLクエリ生成時のプロンプトに含めています。
カタログデータの抽出は上述したGithubリポジトリのoutput_db_schema.pyファイルで実装しています。

[
    {
        "table": "departments",
        "description": "学部情報を格納するテーブル",
        "columns": [
            {
                "column_name": "id",
                "data_type": "integer",
                "description": "学部ID (主キー)"
            },
            {
                "column_name": "name",
                "data_type": "character varying",
                "description": "学部名"
            }
        ]
    },
    {
        "table": "students",
        "description": "学生情報を格納するテーブル",
        "columns": [
            {
                "column_name": "id",
                "data_type": "integer",
                "description": "学生ID (主キー)"
            },
            ...
         ]
    }
    ...
]

ただ、実務で利用しているようなテーブル・カラムの多いDBの場合、トークン数が膨れ上がるため、RAGを整備する・プロンプトキャッシングを利用するなどの省トークンのための工夫が必要そうです。

まとめ

実装にあたって30回ほどグラフ生成しましたが、生成されたクエリやコードの実行に失敗することは1,2回しかありませんでした。すごい精度ですね...!
現状、社内チャットボットはかなり話題にあがっていますが、今回のような社内データベースの超フレキシブルなデータの可視化についても今後流行していきそうな気がしています!
各関数ノードの中身も基本的にOpenAIのAPIにプロンプトを投げているだけなので、めっちゃ簡単ですし、LangGraphのチュートリアルとしても良い題材になると思いますので、興味がある方はぜひご自身でも実装してみてください☺️

Discussion