LangGraphによる「AIエージェントWebアプリ」を作成する【Next.js】
はじめに
今回はLangGraphで実装したAIエージェントを利用したWebアプリを作ろうと思います。
下記の画像のような、よく見るチャット型のWebアプリを作ることが目標です。本記事を読めば誰でも作れるようになっているはずです!
LangGraphを利用したAIエージェントのロジックの作り方は、過去の記事で解説しているのでそちらをご参照ください。
本記事のアプリで実装しているAIエージェントの上記の記事の内容のものになります。
コードの解説部分の分量がだいぶ長くなってしまいましたが、コードをクローンして使うだけなら、簡単にできるので、ぜひお試しください!
参考文献
LangChainとLangGraphによるRAG・AIエージェント[実践]入門
いつも提示させていただいておりますが、本当にLangGraphを使う上で、この本があれば大体十分なんですよね。
いつもいつもありがとうございます。
作成するAIエージェント
基本設計
今回作成するのは、「今日、明日、明後日」の日付の範囲内で、「天気」もしくは「日付」に関してのユーザ質問を回答するAIエージェントです。
より具体的には、まず、下記の3つのワークフローを作成します。
- tool01
- ユーザの質問が「天気に関する質問」か「日付に関する質問」かを前段のLLMが判定して、条件分岐を行い、後段の「今日の天気回答専門LLM」と『今日の日付回答専門LLM」が設定されているSYSTEMプロンプトの情報(今日のダミー日付と今日のダミー天気)に則って回答を行います。
- tool02
- tool01とほぼ同様ですが、「明日」の天気・日付を回答します。
- tool03
- tool01とほぼ同様ですが、「明後日」の天気・日付を回答します。
続いて、上記のワークフローを束ねるAIエージェントを構築します。
エージェントはユーザの質問が
「今日に関する質問」なら「tool01」ワークフローを
「明日に関する質問」なら「tool02」ワークフローを
「明後日に関する質問」なら「tool03」ワークフローを起動して、
ユーザの質問内容を、起動したワークフローに流し込んで回答を取得し、その結果をユーザに回答します。
ここまでの内容は下記の記事をご覧ください。
追加機能
そして、天気においては、システムプロンプトには午前の天気と午後の天気の情報を記載しております。
したがって、ユーザから天気を聞かれた時に、午前の天気を答えるべきか、午後の天気を答えるべきかがわからないという問題があります。
したがって、ユーザが「今日の天気を教えて?」という感じに、時間情報を指定せずに質問してきたら、Aiエージェント側からユーザに質問をする必要があります。
(例えば、「時間の情報を入力してください。」など)
その結果、ユーザから「15時の天気を教えて?」のように情報を得たら、「今日の午後の天気」の情報をシステムプロンプトの通りに出力します。
このように、状況に応じて動的にユーザ入力を受け付けるような処理も実装します。
この内容は下記の記事をご覧ください。
Webアプリ化する上で
さらにWebアプリとして設計する上で、3つ追加で考慮するべきことがあります。
- お手軽無料chatGPTとして使われないようにする工夫
- サーバに同時に接続するユーザごとに履歴を区別する工夫
- ユーザの会話履歴を永続的に保持する工夫
もちろん、他にも考慮するべきことはありますが、一旦上記を考えます。
なぜなら、上記3つは全て
LangGraphを使うことで簡単に実装できるからです。
お手軽無料chatGPTとして使われないようにする工夫
ここでは、処理の一番最初に質問が不適切でないかを判定するLLMを導入します。
つまり、「天気か日付」の質問以外がユーザから入力された場合に、それを検出してブロックする機能です。
これはlangGraphにて、一番最初のノードに検出機構を実装してみます
ユーザごとに履歴を区別する工夫
LangGraphでは、下記のような形でthread_id
を指定することができます。
thread_config = {"configurable": {"thread_id": session_id}}
for event in graph.stream(state, thread_config):
・・・
ここで、thread_id
をユーザごとに異なる値を入れることで、同時にサーバに接続されたとしても、会話履歴がごっちゃにならずに処理を解決することが可能になります。
ユーザの会話履歴を永続的に保持する工夫
langGraphでは、ワークフローの実行中に、特定の地点でのState
を「チェックポイント」として保存する機能があります。
その上で、チェックポイントをSQLite
などのDBに保存して、永続化することもできます。
チェックポイントの保存して利用できるのは、
- インメモリチェックポインター
- SQLiteチェックポインター
- PostgreSQLチェックポインター
が利用できます。
このうち、一番上は非常に使いやすいですが、セッション終了時に消去されてしまいます。
一方で下二つはチェックポイントをDBとして保存できるため、永続化が可能です。
今回は、簡単にローカルで保存できるSQLiteの方を利用しようと思います。
成果物
できたもの
まず、作ったものの動画を下記に提示します。
動画を見ていただけるとわかりますが、下記の機能が達成されていると思います。
- 無関係な質問が来たらブロックする
- 日付の質問が来たらシステムプロンプトに合わせて回答する
- だから、いまだに10月の日付を回答しますw
- 天気の質問が来たら、時間の情報の有無で分岐する
- 時間の情報があればシステムプロンプトの情報に合わせて回答する
- 時間の情報がなければ、ユーザに追加で質問する
- 質問した結果、時間の情報が得られれば、それ込みで回答する
- 質問した結果、時間の情報が得られなければ、再度確認する
また、動画ではわからないですが、会話履歴などはDBとして、サーバ側のフォルダに保存されています。
ちなみにフロントエンドは基本的にv0に作ってもらいました。
非常に簡単にフロントエンドが作れるので楽ですね。
実装コード
下記のリポジトリに置いてあります
下記のコマンドでクローンをしてください。
git clone https://github.com/personabb/chatbot_langgraph_sample_local.git
動作方法
著者の実行環境は下記です。
OS:M2 Mac
RAM:64GB
Python:3.10.14
Node.js:v19.6.1
pnpm:8.6.2
環境構築
フロントエンドとして、Next.js+pnpm、
バックエンドとしてPythonが利用できる環境(もしくはGoogle Colabが利用できる環境)になっていれば問題ありません。
既に、利用できる方は下記のステップは飛ばしてください。
環境構築
フロントエンド、バックエンド両方
gitを導入している方は下記のコマンドで、リポジトリをクローンしてください
git clone https://github.com/personabb/chatbot_langgraph_sample_local.git
gitを導入していない場合は、下記のページで緑色の「Code」ボタンをクリックして、下の方にある「Download ZIP」をクリックすることでダウンロードできます。
この後は、このリポジトリをカレントディレクトリとして、コマンドなどの実行をしてください。
このリポジトリ自体は、Desktopにおいても良いですし、そのままDownloadにおいておいても良いですが、ターミナルのカレントディレクトリだけ、このディレクトリを指定しておいてください。
フロントエンド (Next.jsの環境構築)
フロントエンドではNext.js(+ pnpm)を利用しています。
下記の通り、環境構築をしてください。(Macを想定しています。Windowsの方申し訳ございません)
(2023年の環境構築に使った資料やスクリーンショットを引っ張ってきているので、古いかもしれません。最終的にNext.js(+ pnpm)を利用できるようになっていればいいので、他の方法で環境構築しても構いません)
Brewの導入(導入済みの方はスキップ)
バッケージインストーラであるbrewを導入します。
ターミナルを起動して、下記コマンドを実行してください。パスワードを求められるので入力してください。
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)"
途中でEnterキーを押すことを求められます。押してください。
下記のコマンドで環境変数にパスを通します(一行ずつ実行してください)
echo 'eval "$(/opt/homebrew/bin/brew shellenv)"' >> ~/.zshrc
eval "$(/opt/homebrew/bin/brew shellenv)"
PATH=/usr/local/bin:$PATH
export PATH
下記のコマンドを実行して、brewのバーションが表示されれば成功です。
brew -v
Node.jsの導入(nodebrewの導入)
下記コマンドでnodebrewをインストールします
brew install nodebrew
下記の画像の赤枠部分の記載されているコマンドを実行します(人によっておそらく違うのでよく見てください)
上記の画像なら下記コマンドです
/opt/homebrew/opt/nodebrew/bin/nodebrew setup_dirs
続いては、上記の画像の、赤枠の部分に記載されているPathを環境変数に追加します。
上記の画像なら、下記のコマンドで環境変数を追加できます。
echo export PATH=$HOME/.nodebrew/current/bin:$PATH >> ~/.zshrc
ただし、export PATH=$HOME/.nodebrew/current/bin:$PATH
の部分は、皆様の実行画面(の赤枠の部分)に合わせて変更してください。
その後、設定した環境変数を適用するために下記コマンドで更新します。
source ~/.zshrc
Node.jsの導入(Node.jsの導入)
導入したnodebrewを使ってnodeのvarsion 19.6.1をインストールします。
(インストールするバージョンは、おそらくなんでもいいです。私が上記の環境のため書いていますが、基本は最新のバージョンをお勧めします)
下記のコマンドを実行します
nodebrew install v19.6.1
Installed successfully
と表示されればOKです。
続いて、インストールしたnodeのvarsion 19.6.1を利用できるように下記コマンドを実行します。
nodebrew use v19.6.1
use v19.6.1
と表示されればOKです。
nodebrewを利用することで、複数のバージョンをインストールして、用途ごとにバージョンを使い分けることができるようになっています。
(もちろん、最終的にNext.js(+ pnpm)が利用できればいいので、他の方法でインストールしても良いです)
pnpmのインストール
下記コマンドでpnpmをインストールします
npm install --global pnpm@8.6.0
以上で、Next.jsとpnpmが利用できるようになりました。
バックエンド(ローカルPCを利用)
カレントディレクトリを./AdvancedLivePortrait_Nextjs/alp_backend
に設定するために、下記のコマンドを実行します。
cd alp_backend
pythonの導入
前提として、pythonのバージョンは3.10もしくは3.11を利用します。
pythonはpyenvを利用して、バージョンを指定しながら導入します。
pythonのバージョンはpyenvで指定します。
pyenv自体の導入については下記をご覧ください。
pyenvが導入できていれば、下記のコマンドでpythonのバージョンを指定できます。
pyenv install 3.10.14 #もしくは3.11.9など
pyenv local 3.10.14 #もしくはpyenv global 3.10.14
これでpythonのバージョンが指定できます。
pyenv global
はシステム全体に、このバージョンを反映させたい時に利用してください。
pyenv local
は現在のカレントディレクトリでのみ、このバージョンを反映させたい場合に利用します。
下記コマンドを実行して、pythonのバージョンが変更されているかを確認してください。
python -V
# Python 3.10.14
pythonの仮装環境の設定
続いて、必要なパッケージをインストールするために仮想環境を構築します
venvで仮想環境を構築します。
venvはpython公式の仮装環境のため、pythonが利用可能であれば導入の必要なく利用できます。
python -m venv env
source env/bin/activate
以降、バックエンドを実行する場合は、この仮装環境に毎回入って実行してください。
次回以降、仮装環境に入るだけなら下記コマンドだけで大丈夫です
source env/bin/activate
実行準備
フロントエンド
カレントディレクトリは./chatbot_langgraph_sample_local
を想定しています。
バックエンドのURLを設定
.env.local
をリポジトリ直下(./
)に作成して、下記の通り設定してください。
URLはバックエンドサーバのURLです。(以下は例です)
NEXT_PUBLIC_BASE_URL=http://127.0.0.1:8002
Next.jsで必要なパッケージのインストールとビルド
下記のコマンドを一つずつ実行してください。
pnpm i
pnpm build
バックエンド
バックエンドの準備では、./chatbot_langgraph_sample_local/backend_python
をカレントディレクトリとして、それ以降のコマンドを実行してください。
pythonパッケージのインストール
必要なpythonのパッケージをインストールします。下記コマンドを実行してください。
pip install -r requirement.txt
実行
フロントエンド
カレントディレクトリは./chatbot_langgraph_sample_local
を想定しています。
フロントエンドサーバの起動
下記のコマンドを実行してください
pnpm dev
ターミナルに接続用IPアドレス(同じPCからでしか、このアドレスには接続できません)が表示されたら、フロントエンドサーバの準備は完了です。
端末からの接続
フロントエンドサーバと同じPCからであれば、http://localhost:3000/
にブラウザから接続することで、Webアプリに接続できます。
バックエンド(ローカルPCの場合)
バックエンドでは、./chatbot_langgraph_sample_local/backend_python
をカレントディレクトリとして、それ以降のコマンドを実行してください。
バックエンドサーバの起動
下記コマンドを実行して、バックエンドサーバを立ち上げてください。
python LangGraph_server.py
ターミナルに緑の文字で接続用IPアドレス(このアドレスに接続しても接続できません)が表示されたら、バックエンドサーバの準備は完了です。
使い方
タブレットやPCなどから、フロントエンドサーバに接続すると下記のような画面が表示されます。
あとは、画面下部の入力欄から質問内容を入力して、「Enter」か送信アイコンで送信してください。
SQLiteデータベースの中身の確認
Webアプリで会話をするたびに、セッションIDが発行され、そのIDごとに会話履歴がSQLiteデータベースとして、サーバ型のフォルダに保存されます。
保存場所は、./chatbot_langgraph_sample_local/backend_python/sqlite_db/
です。
この中にある.sqlite
ファイルの中にデータベースの中身が格納されています。しかし簡単に中が見れないので、ちゃんと会話履歴が保存されていることを確認したいと思います。
そのためには下記のpythonコードを利用します。
#本コードはlangGraphのSqliteSaverで保存されたチェックポイントの中身を見るためのコードです。
import sqlite3
import pandas as pd
import os
import json
import msgpack
def decode_binary_data(df):
"""
データフレーム内のバイナリデータをデコードして文字列に変換する。
:param df: Pandas DataFrame
:return: デコードされたDataFrame
"""
for col in df.columns:
if df[col].dtype == object:
df[col] = df[col].apply(lambda x: decode_if_binary(x))
return df
def decode_if_binary(value):
"""
値がバイナリデータの場合にデコードし、それ以外の場合はそのまま返す。
:param value: バイナリデータまたはその他の値
:return: デコードされた値または元の値
"""
if isinstance(value, bytes):
try:
# msgpackとしてデコード
return msgpack.unpackb(value, raw=False)
except (msgpack.exceptions.ExtraData, msgpack.exceptions.FormatError):
try:
# UTF-8としてデコード
return value.decode('utf-8')
except UnicodeDecodeError:
# JSONとしてデコード
try:
return json.loads(value.decode('utf-8'))
except Exception:
return value # デコードできない場合はそのまま返す
return value
def save_sqlite_tables_to_csv(db_path, output_dir):
"""
SQLiteデータベース内の各テーブルをCSVファイルとして保存する。
バイナリデータはデコードして保存。
:param db_path: SQLiteデータベースファイルのパス
:param output_dir: CSVファイルを保存するディレクトリ
"""
if not os.path.exists(output_dir):
os.makedirs(output_dir)
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
tables = cursor.fetchall()
print("=== テーブル一覧 ===")
for table in tables:
print(f"- {table[0]}")
print("\n=== 各テーブルをCSVに保存 ===")
for table in tables:
table_name = table[0]
print(f"テーブル名: {table_name}")
try:
df = pd.read_sql_query(f"SELECT * FROM {table_name}", conn)
df = decode_binary_data(df)
output_path = os.path.join(output_dir, f"{table_name}.csv")
df.to_csv(output_path, index=False, encoding="utf-8")
print(f"テーブル '{table_name}' をCSVファイルに保存しました: {output_path}")
except Exception as e:
print(f"エラー: テーブル '{table_name}' の処理中に問題が発生しました: {e}")
conn.close()
# データベースのパスを指定して実行
db_path = "./sqlite_db/xxxxx.sqlite"
output_dir = "./sqlite_output"
save_sqlite_tables_to_csv(db_path, output_dir)
下記の部分において、中身を確認したいデータベースのファイル名を書き換えてください。
db_path = "./sqlite_db/xxxxx.sqlite"
その上で、このコードを実行すると、./chatbot_langgraph_sample_local/backend_python/sqlite_output
にcsv
ファイルが吐き出されます。
その中で特にcheckpoints.csv
の中身を見ると、下記の画像のように会話履歴が保存されていることがわかります。
コードの解説
忘備録的に記載します。
フロントエンド
基本的にv0に作ってもらいましたが、忘備録としてどういうふうに動いているのかを記載します。
コード全文
page.tsx 全文
'use client'
import { useState, useRef, useEffect } from 'react'
import { Send } from 'lucide-react'
// UI コンポーネント
const Button = ({ children, ...props }: React.ButtonHTMLAttributes<HTMLButtonElement> & { children: React.ReactNode }) => (
<button
className="inline-flex items-center justify-center rounded-md text-sm font-medium ring-offset-background transition-colors focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:pointer-events-none disabled:opacity-50 bg-primary text-primary-foreground hover:bg-primary/90 h-10 px-4 py-2"
{...props}
>
{children}
</button>
)
const Input = ({ ...props }) => (
<input
className="flex h-10 w-full rounded-md border border-input bg-background px-3 py-2 text-sm ring-offset-background file:border-0 file:bg-transparent file:text-sm file:font-medium placeholder:text-muted-foreground focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:cursor-not-allowed disabled:opacity-50"
{...props}
/>
)
const Card = ({ children, className, ...props }: { children: React.ReactNode, className?: string }) => (
<div className={`rounded-lg border bg-card text-card-foreground shadow-sm ${className}`} {...props}>
{children}
</div>
)
const CardHeader = ({ children, ...props }: { children: React.ReactNode }) => (
<div className="flex flex-col space-y-1.5 p-6" {...props}>
{children}
</div>
)
const CardTitle = ({ children, ...props }: { children: React.ReactNode }) => (
<h3 className="text-2xl font-semibold leading-none tracking-tight" {...props}>
{children}
</h3>
)
const CardContent = ({ children, className, ...props }: { children: React.ReactNode, className?: string }) => (
<div className={`p-6 pt-0 ${className}`} {...props}>
{children}
</div>
)
const CardFooter = ({ children, ...props }: { children: React.ReactNode }) => (
<div className="flex items-center p-6 pt-0" {...props}>
{children}
</div>
)
const ChatBubble = ({ children, role }: { children: React.ReactNode, role: 'user' | 'assistant' }) => (
<div className={`flex ${role === 'user' ? 'justify-end' : 'justify-start'} mb-4`}>
<div className={`relative max-w-[70%] p-2 rounded-lg shadow-md ${role === 'user' ? 'bg-green-500 text-white' : 'bg-white text-black'}`}>
{children}
</div>
</div>
)
type Message = {
role: 'user' | 'assistant'
content: string
}
export default function Chatbot() {
const [messages, setMessages] = useState<Message[]>([])
const [input, setInput] = useState('')
const [sessionId, setSessionId] = useState<string>("None")
//const [sessionId, setSessionId] = useState<string>("asap2650")
const [interruptSet, setInterruptSet] = useState<boolean>(false)
const messagesEndRef = useRef<HTMLDivElement>(null)
const scrollToBottom = () => {
messagesEndRef.current?.scrollIntoView({ behavior: "smooth" })
}
useEffect(scrollToBottom, [messages])
const handleSubmit = async (e: React.FormEvent) => {
//ユーザがボタンを押すたびに、ページがリロードされるのを防ぐ
e.preventDefault();
if (!input.trim()) return;
setMessages(prev => [...prev, { role: 'user', content: input }]);
setInput('');
try {
const endpoint = interruptSet ? '/continue' : '/ask';
const response = await fetch(`${process.env.NEXT_PUBLIC_BASE_URL}${endpoint}`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify(interruptSet ? { session_id: sessionId, additional_input: input } : {session_id: sessionId, user_input: input }),
});
if (!response.ok) throw new Error('ネットワークエラーが発生しました');
const data = await response.json();
if (data.response !== null) {
setMessages(prev => [...prev, { role: 'assistant', content: data.response }]);
}
console.log('handleSubmit data.interrupt:', data.interrupt);
console.log('handleSubmit sessionId:', data.session_id);
console.log('handleSubmit interrupt_event:', data.interrupt_event);
setSessionId(data.session_id);
if (data.interrupt) {
setInput(''); // 入力欄をクリア
setSessionId(data.session_id);
setInterruptSet(true);
console.log('handleInterrupt messages:', messages);
} else {
// セッションが中断されていない場合、フラグを折る
setInterruptSet(false);
}
} catch (error) {
console.error('エラー:', error);
setMessages(prev => [...prev, { role: 'assistant', content: 'エラーが発生しました。もう一度お試しください。' }]);
}
}
return (
<Card className="w-full max-w-2xl mx-auto">
<CardHeader>
<CardTitle>AIエージェントチャットボット</CardTitle>
</CardHeader>
<CardContent className="h-[75vh] overflow-y-auto">
{messages.map((message, index) => (
<ChatBubble key={index} role={message.role}>
{message.content}
</ChatBubble>
))}
<div ref={messagesEndRef} />
</CardContent>
<CardFooter>
<form onSubmit={handleSubmit} className="flex w-full space-x-2">
<Input
value={input}
onChange={(e: React.ChangeEvent<HTMLInputElement>) => setInput(e.target.value)}
placeholder="メッセージを入力..."
aria-label="メッセージを入力"
/>
<Button type="submit" aria-label="送信">
<Send className="h-4 w-4" />
</Button>
</form>
</CardFooter>
</Card>
)
}
解説
パーツコンポーネント
まず、一番上部には、パーツのコンポーネントを定義しています。
基本UI部分はv0に作ってもらいましたが、その中でも少しこだわって変更したのが下記の部分です。
特に、吹き出しの位置や色などを定義しているのが下記の部分です。
const ChatBubble = ({ children, role }: { children: React.ReactNode, role: 'user' | 'assistant' }) => (
<div className={`flex ${role === 'user' ? 'justify-end' : 'justify-start'} mb-4`}>
<div className={`relative max-w-[70%] p-2 rounded-lg shadow-md ${role === 'user' ? 'bg-green-500 text-white' : 'bg-white text-black'}`}>
{children}
</div>
</div>
)
-
Props
-
children
:
メッセージのテキストを受け取るプロパティ -
role
:
'user'
または'assistant'
の値のどちらかを受け取り、それに応じて表示が変化
-
-
外側の
<div>
-
className="flex ${role === 'user' ? 'justify-end' : 'justify-start'} mb-4"
:-
flex
: 横並びの配置。 -
justify-end
orjustify-start
: 吹き出しを右寄せ (user
) または左寄せ (assistant
) に設定。 -
mb-4
: 各吹き出し間の余白。
-
- したがって、ユーザの
role
に応じて、吹き出しの位置が変わるLineのような実装になります
-
-
内側の
<div>
-
className="relative max-w-[70%] p-2 rounded-lg shadow-md ${role === 'user' ? 'bg-green-500 text-white' : 'bg-white text-black'}"
:-
relative
: レイアウトの基点。 -
max-w-[70%]
: 吹き出しの最大幅を画面幅の70%に制限。 -
p-2
: 内部余白。 -
rounded-lg
: 角を丸くする。 -
shadow-md
: 影を追加。 -
背景色と文字色:
-
user
:bg-green-500 text-white
(緑背景に白文字)。 -
assistant
:bg-white text-black
(白背景に黒文字)。
-
-
- ここもロールに合わせて、色を変えたり影をつけたりして、Lineのような実装を目指しました。
-
API呼び出し
フロントに入力された内容をバックエンド側にAPIで投げる処理をexport default function Chatbot()
の中で記載します。
初期状態の設定
export default function Chatbot() {
const [messages, setMessages] = useState<Message[]>([])
const [input, setInput] = useState('')
const [sessionId, setSessionId] = useState<string>("None")
//const [sessionId, setSessionId] = useState<string>("asap2650")
const [interruptSet, setInterruptSet] = useState<boolean>(false)
const messagesEndRef = useRef<HTMLDivElement>(null)
-
messages
:- 人間とAIの会話履歴を管理する配列
-
input
:- ユーザが入力・送信したメッセージ内容を入れる変数。
str
- ユーザが入力・送信したメッセージ内容を入れる変数。
-
sessionId
:- フロント側を一意に識別する
id
- 現在は、初期値
None
をバックエンドに送ると、一意のid
を付与して、返してくれる- その
id
を利用し続けることで、バックエンドが会話履歴を保持し続けることができる。
- その
- 一意であれば、こちらで決めて送っても良い。テストとして私の
id
を入れたコードをコメントアウトしている
- フロント側を一意に識別する
-
interruptSet
:- バックエンドのAiエージェントの処理が中断されたかどうかを占めるフラグ
- このフラグが立っているか折れているかで、叩くAPIのURLが変わる
-
messagesEndRef
:- チャットの最下部にスクロールするための参照(
useRef
)
- チャットの最下部にスクロールするための参照(
スクロールの管理
const scrollToBottom = () => {
messagesEndRef.current?.scrollIntoView({ behavior: "smooth" })
}
useEffect(scrollToBottom, [messages])
scrollToBottom
はmessagesEndRef
を参照して、チャットの最下部にスムーズにスクロールする関数です。
そして。useEffect(scrollToBottom, [messages])
により、メッセージが更新されるたびに、scrollToBottom
が実行されます。
メッセージ送信&APIリクエスト
const handleSubmit = async (e: React.FormEvent) => {
//ユーザがボタンを押すたびに、ページがリロードされるのを防ぐ
e.preventDefault();
if (!input.trim()) return;
setMessages(prev => [...prev, { role: 'user', content: input }]);
setInput('');
try {
const endpoint = interruptSet ? '/continue' : '/ask';
const response = await fetch(`${process.env.NEXT_PUBLIC_BASE_URL}${endpoint}`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify(interruptSet ? { session_id: sessionId, additional_input: input } : {session_id: sessionId, user_input: input }),
});
if (!response.ok) throw new Error('ネットワークエラーが発生しました');
const data = await response.json();
if (data.response !== null) {
setMessages(prev => [...prev, { role: 'assistant', content: data.response }]);
}
console.log('handleSubmit data.interrupt:', data.interrupt);
console.log('handleSubmit sessionId:', data.session_id);
console.log('handleSubmit interrupt_event:', data.interrupt_event);
setSessionId(data.session_id);
if (data.interrupt) {
setInput(''); // 入力欄をクリア
setSessionId(data.session_id);
setInterruptSet(true);
console.log('handleInterrupt messages:', messages);
} else {
// セッションが中断されていない場合、フラグを折る
setInterruptSet(false);
}
} catch (error) {
console.error('エラー:', error);
setMessages(prev => [...prev, { role: 'assistant', content: 'エラーが発生しました。もう一度お試しください。' }]);
}
}
上記の関数は、ユーザがフロントエンドにテキストを入力し、送信を決定したら呼ばれます。
まず最初に、フォーム送信時にページリロードを防ぐため、handleSubmit
内でe.preventDefault()
を呼び出します。
続いて、下記から、ユーザの入力にrole
を付与して、会話履歴を更新し、入力欄を空にします。
setMessages(prev => [...prev, { role: 'user', content: input }]);
setInput('');
try {
const endpoint = interruptSet ? '/continue' : '/ask';
const response = await fetch(`${process.env.NEXT_PUBLIC_BASE_URL}${endpoint}`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify(interruptSet ? { session_id: sessionId, additional_input: input } : {session_id: sessionId, user_input: input }),
});
続いて、上記からinterruptSet
のフラグに応じて、呼び出すAPIのURLを変更し、APIへの呼び出し結果をresponse
に取得します。
当然、呼び出されるAPIが変わるので、入力変数に関しても、interruptSet
のフラグに応じて変更されます。
const data = await response.json();
if (data.response !== null) {
setMessages(prev => [...prev, { role: 'assistant', content: data.response }]);
}
・・・
setSessionId(data.session_id);
まず、一番最後の行では、APIからの返信があった場合、バックエンドから送られてきたid
を取得して格納します。
このid
を以降の会話でも利用することで、会話履歴を保持し続けることができます。
また、このid
はこちらから指定した場合は、そのid
が利用されますが、None
で送った場合は、バックエンド側が一意に決定して送信してきます。
続いてconst data = await response.json();
により、APIからAIのメッセージが返ってきたら、APIで返却されたデータを取得して、AIからの発言を、messages
に格納しています。
このとき、LangGraph側で追加の質問が必要と判断した場合は、そのための質問文もdata.response
で取得されます。
LangGraphの中断がある場合
LangGraphの中断がある場合はdata.interrupt
がTrue
になるため、下記の通り処理が発生します。
if (data.interrupt) {
setInput(''); // 入力欄をクリア
setSessionId(data.session_id);
setInterruptSet(true);
console.log('handleInterrupt messages:', messages);
} else {
// セッションが中断されていない場合、フラグを折る
setInterruptSet(false);
}
中断がある場合は、バックエンドが発行したセッションIDを控えて、中断フラグを立てます。
一方もし、中断なく処理が完了し、APIからメッセージを受け取れた場合は、interruptSet
のフラグが折れるため、次のAPIは/ask/
が叩かれることになります。
UI
export default function Chatbot() {
・・・
return (
<Card className="w-full max-w-2xl mx-auto">
<CardHeader>
<CardTitle>AIエージェントチャットボット</CardTitle>
</CardHeader>
<CardContent className="h-[75vh] overflow-y-auto">
{messages.map((message, index) => (
<ChatBubble key={index} role={message.role}>
{message.content}
</ChatBubble>
))}
<div ref={messagesEndRef} />
</CardContent>
<CardFooter>
<form onSubmit={handleSubmit} className="flex w-full space-x-2">
<Input
//ref={inputRef}
value={input}
onChange={(e: React.ChangeEvent<HTMLInputElement>) => setInput(e.target.value)}
placeholder="メッセージを入力..."
aria-label="メッセージを入力"
/>
<Button type="submit" aria-label="送信">
<Send className="h-4 w-4" />
</Button>
</form>
</CardFooter>
</Card>
)
}
前述したコンポーネントを利用してUIを構築しています。
{messages.map((message, index) => (
<ChatBubble key={index} role={message.role}>
{message.content}
</ChatBubble>
))}
特に上記の部分では、messages.map((message, index))
にて、messages
配列を繰り返し処理し、各メッセージmessage
をチャットバブル (<ChatBubble>
) としてレンダリングしています。
<ChatBubble>
では、role
を指定して、配置場所と吹き出し・文字の色を決定し、文章の内容としてmessage.content
を表示しています。
続いて、
<CardFooter>
<form onSubmit={handleSubmit} className="flex w-full space-x-2">
<Input
//ref={inputRef}
value={input}
onChange={(e: React.ChangeEvent<HTMLInputElement>) => setInput(e.target.value)}
placeholder="メッセージを入力..."
aria-label="メッセージを入力"
/>
<Button type="submit" aria-label="送信">
<Send className="h-4 w-4" />
</Button>
</form>
</CardFooter>
上記の部分が、メッセージの送信処理を司っています。
<form onSubmit={handleSubmit} ・・・>
にて、フォームの送信イベントをキャッチし、handleSubmit
関数を実行します。
<Input
value={input}
onChange={(e: React.ChangeEvent<HTMLInputElement>) => setInput(e.target.value)}
placeholder="メッセージを入力..."
aria-label="メッセージを入力"
/>
続いて、上記の部分で、入力欄に書いた文字を取得しています。
文字を入力すると、値value
が変化するため、onChange={(e: React.ChangeEvent<HTMLInputElement>) => setInput(e.target.value)}
により、変化が発生するたびに、input
変数に格納されていきます。
<Button type="submit" aria-label="送信">
<Send className="h-4 w-4" />
</Button>
ボタンを押すことで、submit
が入ります。その結果、onSubmit={handleSubmit}
が起動することになります。
バックエンド
基本的なロジックのコードは、これまでの解説記事に記載しておりますので、差分だけ解説しようと思います。
記事は下記をご覧ください。
コード全文
LangGraph_server.py 全文
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
import uuid
import os
from langgraph.graph import StateGraph, END
from langchain_openai import AzureChatOpenAI
from langgraph.errors import NodeInterrupt
import sqlite3
from langgraph.checkpoint.sqlite import SqliteSaver
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
#os.environ["OPENAI_API_VERSION"] = "2024-xx-xx-preview"
#os.environ["AZURE_OPENAI_ENDPOINT"] = "https://xxxx.openai.azure.com"
#os.environ["AZURE_OPENAI_API_KEY"] = "xxxx"
def get_local_db_path(session_id):
return f'./sqlite_db/{session_id}.sqlite'
# SQLiteデータベースの初期化またはロード
def load_or_create_db(session_id):
"""
ローカルフォルダでSQLiteデータベースをロードまたは新規作成します。
:param session_id: セッションID(データベース名に使用)
:param db_directory: SQLiteデータベースを保存するディレクトリ
:return: SQLiteデータベース接続
"""
local_db_path = get_local_db_path(session_id)
print(f"Local database path: {local_db_path}")
# データベース保存ディレクトリを確認または作成
print(f"Checking database directory: {os.path.dirname(local_db_path)}")
if not os.path.exists(os.path.dirname(local_db_path)):
os.makedirs(os.path.dirname(local_db_path), exist_ok=True)
if not os.path.exists(local_db_path):
# 初回セッションの場合、DBファイルを新規作成
print(f"No existing database for session {session_id}, creating new one.")
conn = sqlite3.connect(local_db_path)
# 必要なら初期化処理を実行
conn.execute("CREATE TABLE IF NOT EXISTS example_table (id INTEGER PRIMARY KEY, data TEXT)")
conn.commit()
conn.close()
else:
print(f"Loaded existing database for session: {session_id}")
# SQLiteデータベース接続を返す
return sqlite3.connect(local_db_path, check_same_thread=False)
app = FastAPI()
# CORSの設定
origins = [
"*" # フロントエンドのURLを指定
# 必要に応じて他のオリジンを追加
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"], # すべてのHTTPメソッドを許可
allow_headers=["*"], # すべてのヘッダーを許可
)
# 各セッションの状態を保存
sessions = {}
# グラフを流れるStateの方の定義
class State(BaseModel):
question_bool: bool = Field(default = False, description="不適切質問の抽出結果")
message_type: str = Field(default = "", description="ユーザからの質問の分類結果")
query: str = Field(default = "", description="これまでのプロンプト内容")
AI_messages: str = Field(default = "", description="AIからのメッセージ内容")
bool_time: bool = Field(default = False, description="時刻情報が含まれているかどうかの判定結果")
advance_messages: str = Field(default = "", description="追加質問のユーザ回答")
# 日付か天気かの質問分類器の出力
class MessageType(BaseModel):
message_type: str = Field(description="ユーザからの質問の分類結果", example="search")
# 天気の質問において、必要情報がすでに埋まっているかを判定する判定器の出力
class TimeType(BaseModel):
message_type: bool = Field(description="ユーザ質問に時刻の情報が含めれているかどうかの真偽値(bool)", example=True)
# 今日、明日、明後日のどのワークフローを選択するかの分類器の出力
class ToolType(BaseModel):
message_type: str = Field(description="どのツールを利用するかの判定結果", example="tool01")
# 日付か天気に関する質問をしているかどうかを判定
class date_weather_Type(BaseModel):
message_type: bool = Field(description="天気か日付の質問をユーザがしているかどうかの真偽値(bool)", example=True)
# APIリクエストの定義
class AskRequest(BaseModel):
session_id: str = None
user_input: str
class ContinueRequest(BaseModel):
session_id: str
additional_input: str
# StateGraphの作成 (既存コードの関数やエッジを設定)
def initialize_graph(sqlite_db):
# モデルの初期化
model = AzureChatOpenAI(azure_deployment="gpt-4o", temperature=0)
output_parser = StrOutputParser()
def not_date_weather_interrupt(State):
print("--not_date_weather_interrupt--")
raise NodeInterrupt("天気か日付に関する質問をしてください")
def interrupt(State):
print("--interrupt--")
if not State.bool_time:
raise NodeInterrupt("天気を知りたい時間を入力してください")
return State
def chat_w1(State):
print("--chat_w1--")
if State.query:
sys_prompt = "あなたはユーザからの質問を繰り返してください。その後、質問に回答してください。ただし今日の午前は雨で、午後は雪です"
prompt = None
if not State.advance_messages:
prompt = ChatPromptTemplate.from_messages(
[
("system",sys_prompt),
("human", "{user_input}")
]
)
else:
prompt = ChatPromptTemplate.from_messages(
[
("system",sys_prompt),
("human", "{user_input}"),
("assistant", "天気を知りたい時間を入力してください(例:「午前中」「20時」など): "),
("human",State.advance_messages)
]
)
chain = prompt | model | output_parser
dict = {
"query":State.query,
"AI_messages": chain.invoke({"user_input": State.query})
}
return dict
return {
"AI_messages": "No user input provided"
}
def chat_d1(State):
print("--chat_d1--")
if State.query:
sys_prompt = "あなたはユーザの質問内容を繰り返し発言した後、それに対して回答してください。ただし今日は10/23です"
prompt = ChatPromptTemplate.from_messages(
[
("system",sys_prompt),
("human", "{user_input}")
]
)
chain = prompt | model | output_parser
return {
"query":State.query,
"AI_messages": chain.invoke({"user_input": State.query})
}
return {
"AI_messages": "No user input provided"
}
def chat_w2(State):
print("--chat_w2--")
if State.query:
sys_prompt = "あなたはユーザからの質問を繰り返してください。その後、質問に回答してください。ただし明日の午前は曇りで、午後は霰です"
prompt = None
if not State.advance_messages:
prompt = ChatPromptTemplate.from_messages(
[
("system",sys_prompt),
("human", "{user_input}")
]
)
else:
prompt = ChatPromptTemplate.from_messages(
[
("system",sys_prompt),
("human", "{user_input}"),
("assistant", "天気を知りたい時間を入力してください(例:「午前中」「20時」など): "),
("human",State.advance_messages)
]
)
chain = prompt | model | output_parser
dict = {
"query":State.query,
"AI_messages": chain.invoke({"user_input": State.query})
}
return dict
return {
"AI_messages": "No user input provided"
}
def chat_d2(State):
print("--chat_d2--")
if State.query:
sys_prompt = "あなたはユーザの質問内容を繰り返し発言した後、それに対して回答してください。ただし明日は10/24です"
prompt = ChatPromptTemplate.from_messages(
[
("system",sys_prompt),
("human", "{user_input}")
]
)
chain = prompt | model | output_parser
return {
"query":State.query,
"AI_messages": chain.invoke({"user_input": State.query})
}
return {
"AI_messages": "No user input provided"
}
def chat_w3(State):
print("--chat_w3--")
if State.query:
sys_prompt = "あなたはユーザからの質問を繰り返してください。その後、質問に回答してください。ただし明後日の午前は晴れで、午後は霧です"
prompt = None
if not State.advance_messages:
prompt = ChatPromptTemplate.from_messages(
[
("system",sys_prompt),
("human", "{user_input}")
]
)
else:
prompt = ChatPromptTemplate.from_messages(
[
("system",sys_prompt),
("human", "{user_input}"),
("assistant", "天気を知りたい時間を入力してください(例:「午前中」「20時」など): "),
("human",State.advance_messages)
]
)
chain = prompt | model | output_parser
dict = {
"query":State.query,
"AI_messages": chain.invoke({"user_input": State.query})
}
return dict
return {
"AI_messages": "No user input provided"
}
def chat_d3(State):
print("--chat_d3--")
if State.query:
sys_prompt = "あなたはユーザの質問内容を繰り返し発言した後、それに対して回答してください。ただし明後日は10/25です"
prompt = ChatPromptTemplate.from_messages(
[
("system",sys_prompt),
("human", "{user_input}")
]
)
chain = prompt | model | output_parser
return {
"query":State.query,
"AI_messages": chain.invoke({"user_input": State.query})
}
return {
"AI_messages": "No user input provided"
}
def response(State):
print("--response--")
return State
# 日付か天気かの質問分類器の出力
def classify(State):
print("--classify--")
classifier = model.with_structured_output(MessageType)
# プロンプトの作成
classification_prompt = """
## You are a message classifier.
## ユーザが天気に関しての質問をしていたら"weather"と返答してください。
## それ以外の質問をしていたら、"day"と返答してください。
"""
if State.query:
prompt = ChatPromptTemplate.from_messages(
[
("system",classification_prompt),
("human", "{user_input}")
]
)
chain = prompt | classifier
return {
"message_type": chain.invoke({"user_input": State.query}).message_type,
"query": State.query
}
else:
return {"AI_messages": "No user input provided"}
# 天気の質問において、必要情報がすでに埋まっているかを判定する判定器の出力
def classify_time(State):
print("--classify_time--")
classifier_time = model.with_structured_output(TimeType)
# プロンプトの作成
classification_prompt = """
## You are a message classifier.
## ユーザが、日付以外の時間を指定して質問している場合(例えば、「午前」「午後」「12時」「5:20」などがある場合)はTrueと返答してください。
## そうでない場合はFalseと返答してください。
TrueかFalse以外では回答しないでください。
"""
if State.query:
if State.advance_messages:
prompt = ChatPromptTemplate.from_messages(
[
("system",classification_prompt),
("human", "{user_input}ただし、{advance_messages}")
]
)
else:
prompt = ChatPromptTemplate.from_messages(
[
("system",classification_prompt),
("human", "{user_input}")
]
)
chain = prompt | classifier_time
if State.advance_messages:
dicts = {
"bool_time": chain.invoke({"user_input": State.query, "advance_messages": State.advance_messages}).message_type,
}
return dicts
else:
dicts = {
"bool_time": chain.invoke({"user_input": State.query}).message_type,
}
return dicts
else:
return {"AI_messages": "No user input provided"}
# 今日、明日、明後日のどのワークフローを選択するかの分類器の出力
def select_tool(State):
print("--select_tool--")
tools = model.with_structured_output(ToolType)
# プロンプトの作成
classification_prompt = """
## You are a message classifier.
## 今日についての質問の場合は"tool01"と返答してください。
## 明日についての質問の場合は"tool02"と返答してください。
## 明後日についての質問の場合は"tool03"と返答してください。
"""
user_prompt = """
# ユーザからの質問内容
{user_input}
"""
if State.query:
prompt = ChatPromptTemplate.from_messages(
[
("system",classification_prompt),
("human", user_prompt)
]
)
chain = prompt | tools
return {
"message_type": chain.invoke({"user_input": State.query}).message_type,
}
else:
return {"AI_messages": "No user input provided"}
# 日付か天気に関する質問をしているかどうかを判定
def date_weather(State):
tools = model.with_structured_output(date_weather_Type)
print("--date_weather--")
# プロンプトの作成
classification_prompt = """
## You are a message classifier.
## このチャットボットは日付か天気に関する質問しか答えることはできません。
## それ以外の質問には答えることができません。
## そのため、それ以外の質問をしていた場合は"False"と返答してください。
## 日付か天気に関する質問の場合は"True"と返答してください。
## ユーザからの質問内容に日付や天気の情報が入っていたとしても、最終的な質問内容が天気や日付を回答するものでない場合は"False"と返答してください。
"""
user_prompt = """
# ユーザからの質問内容
{user_input}
"""
if State.query:
prompt = ChatPromptTemplate.from_messages(
[
("system",classification_prompt),
("human", user_prompt)
]
)
chain = prompt | tools
return {
"question_bool": chain.invoke({"user_input": State.query}).message_type,
}
else:
return {"AI_messages": "No user input provided"}
# tool1
#ノードの追加
graph_builder = StateGraph(State)
graph_builder.add_node("date_weather", date_weather)
graph_builder.add_node("not_date_weather_interrupt", not_date_weather_interrupt)
graph_builder.add_node("select_tool", select_tool)
graph_builder.add_node("classify1", classify)
graph_builder.add_node("classify_time_1", classify_time)
graph_builder.add_node("interrupt_1", interrupt)
graph_builder.add_node("chat_w1", chat_w1)
graph_builder.add_node("chat_d1", chat_d1)
graph_builder.add_node("response1", response)
# エッジの追加
graph_builder.add_edge("classify_time_1", "interrupt_1")
graph_builder.add_edge("chat_d1", "response1")
graph_builder.add_edge("chat_w1", "response1")
# 条件分岐
graph_builder.add_conditional_edges("date_weather", lambda state: state.question_bool, {True: "select_tool", False: "not_date_weather_interrupt"})
graph_builder.add_conditional_edges("classify1", lambda state: state.message_type, {"weather": "classify_time_1", "day": "chat_d1"})
graph_builder.add_conditional_edges("interrupt_1", lambda state: state.bool_time, {True: "chat_w1", False: "classify_time_1"})
# tool2
#ノードの追加
graph_builder.add_node("classify2", classify)
graph_builder.add_node("classify_time_2", classify_time)
graph_builder.add_node("interrupt_2", interrupt)
graph_builder.add_node("chat_w2", chat_w2)
graph_builder.add_node("chat_d2", chat_d2)
graph_builder.add_node("response2", response)
# エッジの追加
graph_builder.add_edge("classify_time_2", "interrupt_2")
graph_builder.add_edge("chat_d2", "response2")
graph_builder.add_edge("chat_w2", "response2")
# 条件分岐
graph_builder.add_conditional_edges("classify2", lambda state: state.message_type, {"weather": "classify_time_2", "day": "chat_d2"})
graph_builder.add_conditional_edges("interrupt_2", lambda state: state.bool_time, {True: "chat_w2", False: "classify_time_2"})
# tool3
#ノードの追加
graph_builder.add_node("classify3", classify)
graph_builder.add_node("classify_time_3", classify_time)
graph_builder.add_node("interrupt_3", interrupt)
graph_builder.add_node("chat_w3", chat_w3)
graph_builder.add_node("chat_d3", chat_d3)
graph_builder.add_node("response3", response)
# エッジの追加
graph_builder.add_edge("classify_time_3", "interrupt_3")
graph_builder.add_edge("chat_d3", "response3")
graph_builder.add_edge("chat_w3", "response3")
# 条件分岐
graph_builder.add_conditional_edges("classify3", lambda state: state.message_type, {"weather": "classify_time_3", "day": "chat_d3"})
graph_builder.add_conditional_edges("interrupt_3", lambda state: state.bool_time, {True: "chat_w3", False: "classify_time_3"})
# All
#ノードの追加
graph_builder.add_node("response", response)
# エッジの追加
graph_builder.add_edge("response1", "response")
graph_builder.add_edge("response2", "response")
graph_builder.add_edge("response3", "response")
# 条件分岐
graph_builder.add_conditional_edges("select_tool", lambda state: state.message_type, {"tool01": "classify1", "tool02": "classify2", "tool03": "classify3"})
# 開始位置、終了位置の指定
graph_builder.set_entry_point("date_weather")
graph_builder.set_finish_point("response")
# グラフ構築
#memory = MemorySaver()
memory = SqliteSaver(sqlite_db)
graph = graph_builder.compile(checkpointer=memory)
return graph
@app.post("/ask")
async def ask(request: AskRequest):
# 新しいセッションIDを生成
if request.session_id == "None":
session_id = str(uuid.uuid4())
else:
session_id = request.session_id
user_input = request.user_input
sqlite_db = load_or_create_db(session_id)
graph = initialize_graph(sqlite_db)
print("Session ID:", session_id)
print("User Input:", user_input)
# Stateの初期化
state = {
"question_bool": False,
"message_type": "",
"query": user_input,
"AI_messages": "",
"bool_time": False,
"advance_messages": ""
}
thread_config = {"configurable": {"thread_id": session_id}}
# イベントのリストと中断フラグ
event_list = []
interrupt = False
last_content = None
# LangGraphからのイベントを取得し、中断チェック
for event in graph.stream(state, thread_config):
#グラフ途中の中断を検出
event_list.append(event)
if "__interrupt__" in event:
interrupt = True
break
# 最後の 'response' から 'messages' の content を取得
if "response" in event and "AI_messages" in event["response"]:
last_content = event["response"]["AI_messages"]
if interrupt:
for key in event_list[-2].keys():
#KEYを取り出す処理。key = next(iter(event_list[-2].keys()))やkey = list(event_list[-2].keys())[0]でも良いし、そちらの方がいいかも
#中断した処理の直前のノードの名前によって処理を変える。(ただし、今回は一つだけ)
if "classify_time" in key:
last_content = "天気を知りたい時間を入力してください(例:「午前中」「20時」など)"
elif key == "date_weather":
last_content = "天気か日付に関する質問をしてください"
# セッションの状態を保存
sessions[session_id] = {
"initial_input": user_input,
"state": state,
"interrupt": interrupt,
"event_list": event_list,
"interrupt_event": list(event_list[-2].keys()),
}
sqlite_db.commit()
sqlite_db.close()
# 応答と中断フラグを返す
return {
"session_id": session_id,
"response": last_content,
"interrupt": interrupt,
"interrupt_event": list(event_list[-2].keys())
}
@app.post("/continue")
async def continue_conversation(request: ContinueRequest):
session_id = request.session_id
additional_input = request.additional_input
sqlite_db = load_or_create_db(session_id)
graph = initialize_graph(sqlite_db)
print("Session ID:", session_id)
print("Additional Input:", additional_input)
# セッションが存在するか確認
if session_id not in sessions:
raise HTTPException(status_code=404, detail="Session not found")
session_data = sessions[session_id]
interrupt = session_data["interrupt"]
add_state = {}
if session_data["interrupt_event"][0] == "date_weather":
#assistant_message = "天気か日付に関する質問をしてください"
# Stateの更新
add_state = {
"question_bool": False,
"message_type": "",
"query": additional_input,
"AI_messages": "",
"bool_time": False,
"advance_messages": ""
}
elif "classify_time" in session_data["interrupt_event"][0]:
#assistant_message = "天気を知りたい時間を入力してください(例:「午前中」「20時」など)"
# Stateの更新
add_state = {
"query": session_data["initial_input"],
"advance_messages":additional_input,
}
# 中断がない場合のエラー処理
if not interrupt:
return {"response": "No interrupt in this session."}
# 直前の状態を取得して分岐
all_states = []
for state in graph.get_state_history({"configurable": {"thread_id": session_id}}):
all_states.append(state)
to_replay = all_states[1] if len(all_states) > 1 else all_states[0]
branch_config = graph.update_state(config=to_replay.config, values=add_state)
# LangGraphの再実行
last_content = None
event_list = []
for event in graph.stream(None, branch_config):
event_list.append(event)
if "__interrupt__" in event:
interrupt = True
break
# 最後の 'response' から 'messages' の content を取得
if "response" in event and "AI_messages" in event["response"]:
last_content = event["response"]["AI_messages"]
if last_content:
interrupt = False
if interrupt:
for key in event_list[-2].keys():
#KEYを取り出す処理。key = next(iter(event_list[-2].keys()))やkey = list(event_list[-2].keys())[0]でも良いし、そちらの方がいいかも
#中断した処理の直前のノードの名前によって処理を変える。(ただし、今回は一つだけ)
if "classify_time" in key:
last_content = "天気を知りたい時間を入力してください(例:「午前中」「20時」など)"
elif key == "date_weather":
last_content = "天気か日付に関する質問をしてください"
if session_data["interrupt_event"][0] == "date_weather":
# セッションの状態を更新
sessions[session_id] = {
"initial_input": additional_input,
"state": state,
"interrupt": interrupt,
"event_list": event_list,
"interrupt_event": list(event_list[-2].keys()),
}
elif "classify_time" in session_data["interrupt_event"][0]:
sessions[session_id]["interrupt"] = interrupt
sessions[session_id]["event_list"] = event_list
sessions[session_id]["interrupt_event"] = list(event_list[-2].keys())
sqlite_db.commit()
sqlite_db.close()
# 応答を返却
return {
"session_id": session_id,
"response": last_content,
"interrupt": interrupt,
"interrupt_event": list(event_list[-2].keys())
}
if __name__ == "__main__":
import uvicorn
uvicorn.run("LangGraph_server:app", host="0.0.0.0", port=8002, reload=True)
解説
会話履歴を保存するSQLiteデータベース
def get_local_db_path(session_id):
return f'./sqlite_db/{session_id}.sqlite'
# SQLiteデータベースの初期化またはロード
def load_or_create_db(session_id):
"""
ローカルフォルダでSQLiteデータベースをロードまたは新規作成します。
:param session_id: セッションID(データベース名に使用)
:param db_directory: SQLiteデータベースを保存するディレクトリ
:return: SQLiteデータベース接続
"""
local_db_path = get_local_db_path(session_id)
print(f"Local database path: {local_db_path}")
# データベース保存ディレクトリを確認または作成
print(f"Checking database directory: {os.path.dirname(local_db_path)}")
if not os.path.exists(os.path.dirname(local_db_path)):
os.makedirs(os.path.dirname(local_db_path), exist_ok=True)
if not os.path.exists(local_db_path):
# 初回セッションの場合、DBファイルを新規作成
print(f"No existing database for session {session_id}, creating new one.")
conn = sqlite3.connect(local_db_path)
# 必要なら初期化処理を実行
conn.execute("CREATE TABLE IF NOT EXISTS example_table (id INTEGER PRIMARY KEY, data TEXT)")
conn.commit()
conn.close()
else:
print(f"Loaded existing database for session: {session_id}")
# SQLiteデータベース接続を返す
return sqlite3.connect(local_db_path, check_same_thread=False)
上記は、会話履歴を格納するSQLiteデータベースをreturnする関数です。
この関数を実行すると、引数のsession_id
に応じて、ローカルに、該当のデータベースがあればそれに接続し、なければ新しくデータベースを作成した上で初期化して、接続する関数です。
APIとして利用するために
pp = FastAPI()
# CORSの設定
origins = [
"*" # フロントエンドのURLを指定
# 必要に応じて他のオリジンを追加
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"], # すべてのHTTPメソッドを許可
allow_headers=["*"], # すべてのヘッダーを許可
)
・・・
if __name__ == "__main__":
import uvicorn
uvicorn.run("LangGraph_server:app", host="0.0.0.0", port=8002, reload=True)
APIサーバとして利用するためにFastAPIを利用しています。
フロントエンドのURLに対して接続に制限をなくすために設定しています。
利用するフロントエンドのURLを許可しないと、接続できないエラーが発生するので、毎回記述が必要になります。
langGraphのmemoryにSQLite DBを利用する
def initialize_graph(sqlite_db):
・・・
memory = SqliteSaver(sqlite_db)
graph = graph_builder.compile(checkpointer=memory)
return graph
LangGraphのグラフを構築するためのノードやエッジの定義や、ノードに利用する関数の定義などは全て、initialize_graph
関数にまとめています。
この関数の引数に、接続したSQLiteデータベースを設定することで、LangGraphのmemoryにSQLiteを利用することができます。
一般的な一般的なインメモリに保存する場合は、memory = MemorySaver()
のように書いていたと思いますが、SQLiteを利用する場合は、memory = SqliteSaver(sqlite_db)
とかけます。
/ask/ API
@app.post("/ask")
async def ask(request: AskRequest):
# 新しいセッションIDを生成
if request.session_id == "None":
session_id = str(uuid.uuid4())
else:
session_id = request.session_id
user_input = request.user_input
sqlite_db = load_or_create_db(session_id)
graph = initialize_graph(sqlite_db)
print("Session ID:", session_id)
print("User Input:", user_input)
# Stateの初期化
state = {
"question_bool": False,
"message_type": "",
"query": user_input,
"AI_messages": "",
"bool_time": False,
"advance_messages": ""
}
thread_config = {"configurable": {"thread_id": session_id}}
# イベントのリストと中断フラグ
event_list = []
interrupt = False
last_content = None
# LangGraphからのイベントを取得し、中断チェック
for event in graph.stream(state, thread_config):
#グラフ途中の中断を検出
event_list.append(event)
if "__interrupt__" in event:
interrupt = True
break
# 最後の 'response' から 'messages' の content を取得
if "response" in event and "AI_messages" in event["response"]:
last_content = event["response"]["AI_messages"]
if interrupt:
for key in event_list[-2].keys():
#KEYを取り出す処理。key = next(iter(event_list[-2].keys()))やkey = list(event_list[-2].keys())[0]でも良いし、そちらの方がいいかも
#中断した処理の直前のノードの名前によって処理を変える。(ただし、今回は一つだけ)
if "classify_time" in key:
last_content = "天気を知りたい時間を入力してください(例:「午前中」「20時」など)"
elif key == "date_weather":
last_content = "天気か日付に関する質問をしてください"
# セッションの状態を保存
sessions[session_id] = {
"initial_input": user_input,
"state": state,
"interrupt": interrupt,
"event_list": event_list,
"interrupt_event": list(event_list[-2].keys()),
}
sqlite_db.commit()
sqlite_db.close()
# 応答と中断フラグを返す
return {
"session_id": session_id,
"response": last_content,
"interrupt": interrupt,
}
基本的に、直前の会話で中断が発生していない時に呼ばれるAPIです。
つまり、質問の一番最初に呼ばれるAPIになります。
セッションIDの識別・新規作成
引数のrequest
には、ユーザからの質問request.user_input
とセッションIDrequest.session_id
がフロントから送られます。
フロントから送られてくるrequest.session_id
がNone
の場合は、こちらで一意にIDを付与しますが、もしIDがフロントから送られている場合は、そのIDを今後利用します。
SQLiteデータベースとLangGraphの初期化
sqlite_db = load_or_create_db(session_id)
graph = initialize_graph(sqlite_db)
続いて、上記で、セッションIDに応じて、SQLiteのデータベースに接続(場合によっては新規作成・初期化・接続)を行い、それを利用して、LangGraphのグラフを構築します。
Stateの設定とLangGraphのチェックポイントにid指定
# Stateの初期化
state = {
"question_bool": False,
"message_type": "",
"query": user_input,
"AI_messages": "",
"bool_time": False,
"advance_messages": ""
}
thread_config = {"configurable": {"thread_id": session_id}}
for event in graph.stream(state, thread_config):
・・・
得られた、ユーザ入力を用いてState
を設定し、作成したセッションIDをもとに、LangGraphのチェックポイントを一意に決定します。
これにより、別々のユーザが同時に接続してきても、混乱することなく、質問応答が可能になります。
langGraph中断時の処理
if interrupt:
for key in event_list[-2].keys():
#KEYを取り出す処理。key = next(iter(event_list[-2].keys()))やkey = list(event_list[-2].keys())[0]でも良いし、そちらの方がいいかも
#中断した処理の直前のノードの名前によって処理を変える。(ただし、今回は一つだけ)
if "classify_time" in key:
last_content = "天気を知りたい時間を入力してください(例:「午前中」「20時」など)"
elif key == "date_weather":
last_content = "天気か日付に関する質問をしてください"
中断フラグが立っている時は、フロント側に返信を返す前に、追加質問の文章をフロント側に返して、それを表示してもらう必要があります。
今回は、LangGraphの処理のどの場所で中断したかで、条件分岐し、追加質問文章を決定し、フロントエンドにAIからのメッセージとして返却しています。
上記の処理では、LangGraphにおける中断箇所の手前のノードのkey
をAPIの返答で受け取っており、そのノードのkey
の値に応じて、画面に表示するメッセージが変わります。
-
if "classify_time" in key:
- ユーザが天気の質問をした際に「時間情報」が不足している場合に、中断されたことを示します。
- この場合は、「天気を知りたい時間を入力してください(例:「午前中」「20時」など)」と画面に表示します。
- また、この場合のノードの
key
は「classify_time_1
」、「classify_time_2
」
「classify_time_3
」、の3パターンあるため、どれでも対応できるようにif文を組んでいます。
- ユーザが天気の質問をした際に「時間情報」が不足している場合に、中断されたことを示します。
-
elif key == "date_weather":
- ユーザが「天気か日付の質問」以外の内容について入力しているため、ブロックされたことを示します。
- したがって、「天気か日付に関する質問をしてください」と画面に表示します。
- ユーザが「天気か日付の質問」以外の内容について入力しているため、ブロックされたことを示します。
APIの返り値の設定とユーザごとにセッションの保存
# セッションの状態を保存
sessions[session_id] = {
"initial_input": user_input,
"state": state,
"interrupt": interrupt,
"event_list": event_list,
"interrupt_event": list(event_list[-2].keys()),
}
sqlite_db.commit()
sqlite_db.close()
# 応答と中断フラグを返す
return {
"session_id": session_id,
"response": last_content,
"interrupt": interrupt,
}
ここでは、セッションIDごとに必要な情報をsessions
として保存しています。
そして、この後記載しますが、/continue
のAPIが呼ばれたときに、どのノードで中断したのかや、最初のユーザ質問はなんだったかなどの情報を、このsessions
から取得します。
また、APIの返り値としては、セッションIDsession_id
とAIからのメッセージresponse
、中断したかどうかのフラグinterrupt
の情報が返ります。
その全てが、フロント側で処理をするのに必要な情報であることは、ここまで読んでいただけた方なら理解いただけると思います。
/continue/ API
@app.post("/continue")
async def continue_conversation(request: ContinueRequest):
session_id = request.session_id
additional_input = request.additional_input
sqlite_db = load_or_create_db(session_id)
graph = initialize_graph(sqlite_db)
print("Session ID:", session_id)
print("Additional Input:", additional_input)
# セッションが存在するか確認
if session_id not in sessions:
raise HTTPException(status_code=404, detail="Session not found")
session_data = sessions[session_id]
interrupt = session_data["interrupt"]
add_state = {}
if session_data["interrupt_event"][0] == "date_weather":
# Stateの更新
add_state = {
"question_bool": False,
"message_type": "",
"query": additional_input,
"AI_messages": "",
"bool_time": False,
"advance_messages": ""
}
elif "classify_time" in session_data["interrupt_event"][0]:
# Stateの更新
add_state = {
"query": session_data["initial_input"],
"advance_messages":additional_input,
}
# 中断がない場合のエラー処理
if not interrupt:
return {"response": "No interrupt in this session."}
# 直前の状態を取得して分岐
all_states = []
for state in graph.get_state_history({"configurable": {"thread_id": session_id}}):
all_states.append(state)
to_replay = all_states[1] if len(all_states) > 1 else all_states[0]
branch_config = graph.update_state(config=to_replay.config, values=add_state)
# LangGraphの再実行
last_content = None
event_list = []
for event in graph.stream(None, branch_config):
event_list.append(event)
if "__interrupt__" in event:
interrupt = True
break
# 最後の 'response' から 'messages' の content を取得
if "response" in event and "AI_messages" in event["response"]:
last_content = event["response"]["AI_messages"]
if last_content:
interrupt = False
if interrupt:
for key in event_list[-2].keys():
#KEYを取り出す処理。key = next(iter(event_list[-2].keys()))やkey = list(event_list[-2].keys())[0]でも良いし、そちらの方がいいかも
#中断した処理の直前のノードの名前によって処理を変える。(ただし、今回は一つだけ)
if "classify_time" in key:
last_content = "天気を知りたい時間を入力してください(例:「午前中」「20時」など)"
elif key == "date_weather":
last_content = "天気か日付に関する質問をしてください"
if session_data["interrupt_event"][0] == "date_weather":
# セッションの状態を更新
sessions[session_id]["initial_input"] = additional_input
sessions[session_id]["interrupt"] = interrupt
sessions[session_id]["event_list"] = event_list
sessions[session_id]["interrupt_event"] = list(event_list[-2].keys())
elif "classify_time" in session_data["interrupt_event"][0]:
sessions[session_id]["interrupt"] = interrupt
sessions[session_id]["event_list"] = event_list
sessions[session_id]["interrupt_event"] = list(event_list[-2].keys())
sqlite_db.commit()
sqlite_db.close()
# 応答を返却
return {
"session_id": session_id,
"response": last_content,
"interrupt": interrupt,
}
上記は、中断フラグが立っているときに、フロントからの入力を処理する際に呼ばれる関数です。
したがって、中断されているLangGraphの途中から処理を開始する必要があるため、APIを分けて実装しています。
セッションIDの識別
session_id = request.session_id
additional_input = request.additional_input
sqlite_db = load_or_create_db(session_id)
graph = initialize_graph(sqlite_db)
中断フラグが立っている場合に呼ばれるAPIのため、フロント側から必ずセッションIDが送られてきますので、それに合わせてLangGraphのチェックポイントを格納しているSQliteのデータベースに接続し、LangGraphを初期化します。
sessionsの取得
session_data = sessions[session_id]
session_data
は、/ask/
APIの最後に保存した、sessions[session_id]
のことです。
これにより、前のAPIで保存したデータ(前の質問文など)を今回のAPI処理でも利用することができます。
また、session_id
ごとに分けて保存しているため、他のユーザの情報と混ざってしまうこともありません。
Stateの更新
if session_data["interrupt_event"][0] == "date_weather":
# Stateの更新
add_state = {
"question_bool": False,
"message_type": "",
"query": additional_input,
"AI_messages": "",
"bool_time": False,
"advance_messages": ""
}
elif "classify_time" in session_data["interrupt_event"][0]:
# Stateの更新
add_state = {
"query": session_data["initial_input"],
"advance_messages":additional_input,
}
中断したノードに合わせて、State
を更新します。
最初のノードで中断した場合は、このAPIの入力として「新規の質問文」が入ってくるはずなので、上のように、query
にadditional_input
が入ります。
一方で、天気の質問の途中で中断した場合は、このAPIの入力として「追加の時間情報」が入ってくるはずなので、advance_messages
にadditional_input
が入ります。
LangGraphのチェックポイントにid指定
# 直前の状態を取得して分岐
all_states = []
for state in graph.get_state_history({"configurable": {"thread_id": session_id}}):
all_states.append(state)
to_replay = all_states[1] if len(all_states) > 1 else all_states[0]
branch_config = graph.update_state(config=to_replay.config, values=add_state)
・・・
# LangGraphの再実行
for event in graph.stream(None, branch_config):
・・・
LangGraphにおいて中断後の処理で、特徴的なのはgraph.stream(None, branch_config)
のようにState
をNone
として実行する必要があることです。
None
を指定することで、中断箇所から実行開始することができます。
では、どのノード位置から、どのState
で実行すれば良いかは、それより上のコードで定義しています。
for state in graph.get_state_history({"configurable": {"thread_id": session_id}}):
all_states.append(state)
to_replay = all_states[1] if len(all_states) > 1 else all_states[0]
上記の部分において、フロントから送られてくるセッションIDをもとに、チェックポイントからState
の履歴を取得します。
State
の履歴は最新のものから逆順に格納されているので、all_states[1]
を取得することで、中断したノードの直前のノードの情報を取得できます。
branch_config = graph.update_state(config=to_replay.config, values=add_state)
そして、上記により、中断したノードの直前のノードでのState
が持っていたconfig
を指定し、State
の中身として、前述したadd_state
に更新をして、それをもとにgraph.stream(None, branch_config)
をしているので、更新したState
を中断したノードの直前のノードから再開できるというわけです。
langGraph中断時の処理
if interrupt:
for key in event_list[-2].keys():
#KEYを取り出す処理。key = next(iter(event_list[-2].keys()))やkey = list(event_list[-2].keys())[0]でも良いし、そちらの方がいいかも
#中断した処理の直前のノードの名前によって処理を変える。(ただし、今回は一つだけ)
if "classify_time" in key:
last_content = "天気を知りたい時間を入力してください(例:「午前中」「20時」など)"
elif key == "date_weather":
last_content = "天気か日付に関する質問をしてください"
中断フラグが立っている時は、フロント側に返信を返す前に、追加質問の文章をフロント側に返して、それを表示してもらう必要があります。
今回は、LangGraphの処理のどの場所で中断したかで、条件分岐し、追加質問文章を決定し、フロントエンドにAIからのメッセージとして返却しています。
上記の処理では、LangGraphにおける中断箇所の手前のノードのkey
をAPIの返答で受け取っており、そのノードのkey
の値に応じて、画面に表示するメッセージが変わります。
-
if "classify_time" in key:
- ユーザが天気の質問をした際に「時間情報」が不足している場合に、中断されたことを示します。
- この場合は、「天気を知りたい時間を入力してください(例:「午前中」「20時」など)」と画面に表示します。
- また、この場合のノードの
key
は「classify_time_1
」、「classify_time_2
」
「classify_time_3
」、の3パターンあるため、どれでも対応できるようにif文を組んでいます。
- ユーザが天気の質問をした際に「時間情報」が不足している場合に、中断されたことを示します。
-
elif key == "date_weather":
- ユーザが「天気か日付の質問」以外の内容について入力しているため、ブロックされたことを示します。
- したがって、「天気か日付に関する質問をしてください」と画面に表示します。
- ユーザが「天気か日付の質問」以外の内容について入力しているため、ブロックされたことを示します。
APIの返り値の設定とユーザごとにセッションの保存
if session_data["interrupt_event"][0] == "date_weather":
# セッションの状態を更新
sessions[session_id]["initial_input"] = additional_input
sessions[session_id]["interrupt"] = interrupt
sessions[session_id]["event_list"] = event_list
sessions[session_id]["interrupt_event"] = list(event_list[-2].keys())
elif "classify_time" in session_data["interrupt_event"][0]:
sessions[session_id]["interrupt"] = interrupt
sessions[session_id]["event_list"] = event_list
sessions[session_id]["interrupt_event"] = list(event_list[-2].keys())
sqlite_db.commit()
sqlite_db.close()
# 応答を返却
return {
"session_id": session_id,
"response": last_content,
"interrupt": interrupt,
}
ここでは、セッションIDごとに、必要な情報をsessions
として保存しています。
まず、前回の質問で、中断が発生したノードの手前のノードのkey
に応じて、何を更新するべきかが変わります。
例えば、前回の質問が最初のノードで中断(不適切な質問による中断)した場合は、前回の質問の内容initial_input
を今回の入力内容additional_input
で更新する必要があります。
一方で、前回の質問が、天気の質問の時間情報の不足により中断した場合は、前回の質問内容自体は生きているため、更新してはいけません。
ということを考慮して、sessions[session_id]
を更新しています。
また、APIの返り値としては、セッションIDsession_id
とAIからのメッセージresponse
、中断したかどうかのフラグinterrupt
の情報が返ります。
これは、/ask/
APIと同じです。
まとめ
ここまで読んでくださってありがとうございました!
さて、実は、今回あえてAPIの関数の中にgraph = initialize_graph(sqlite_db)
を記述しています。
もし、バックエンドのサーバを用意して、常にサーバを立てているのであれば、APIの関数の外でグラフを初期化して、初期化したものを呼ばれるたびに使いまわした方が、処理時間は短くなります。
今回のコードでは、APIが呼ばれるたびにグラフが再度初期化されるため、処理時間に無駄があります。
なぜ、このような処理にしたかというと、lambdaなどのサーバレスで実装しようと思った場合に、この初期化処理が毎回必要になるからです。
そのため、FastAPI関数の中にLangGraphの初期化を入れてみました。
したがって、続きは、今回の内容をVercelとlambdaにて、デプロイしてみようと思います!
Discussion