SwarmでテキストからSQLを生成するWebアプリを作ってみた
はじめに
先日OpenAIのAIエージェントSwarmが発表され、複雑なタスクを遂行させるらしいので、
さっそく試してみることにしました。
今回開発したいのは、テキストからSQLを生成し、観光地案内をするようなアプリです。
アプリのソースコードはこちらで公開しています。
Swarmとは
Swarmは、OpenAIが開発した実験的な教育用フレームワークです。このフレームワークの主な目的は、複数のAIエージェントを軽量かつ人間工学的に連携させる方法を探求することです。
主な特徴は以下の通りです。
- 「Agent」と「handoff」という2つの基本的な概念を使用している
- Agentは指示とツールを持ち、必要に応じて他のAgentに会話を引き継ぐことができる
- 完全にクライアント側で動作し、呼び出し間で状態を保持しない
- Chat Completions APIを使用している
- 開発者が複数のエージェントの調整や実行を容易に制御できるように設計されている
- 大規模で独立した機能や、単一のプロンプトに組み込むのが難しい指示を扱う状況に適している
- Swarmは現在実験的なフレームワークであり、本番環境での使用は想定されていない
Agent
Agentは、特定のタスクや機能を実行するための独立した単位です。
各Agentは独自の指示セットとツールセットを持っています。
Agentは特定の目的のために設計され、その目的に関連するタスクを効率的に実行します。
handoff
handoffは、あるAgentから別のAgentに会話や制御を移す過程を指します。
これにより、複雑なタスクを複数の専門化されたAgentで分担して処理することができます。
handoffは、現在のAgentが自身の専門外の問題に直面したときや、別のAgentがより適切に対応できると判断したときに発生します。
環境構築
環境構築はとてもシンプルで、まずSwarmをインストールします。
pip install git+ssh://git@github.com/openai/swarm.git
そしてOPENAI_API_KEYを設定します。
export OPENAI_API_KEY="sk-..."
アプリケーションの作成
観光地案内アプリを作成するために、以下の内容を実装します。
- データベースの準備
- Agentの作成(テキストからSQLの生成AgentとSQLの説明Agent)
- SQLを実行する機能
- FastAPIでWebページで表示
データベースの準備
ガッツリ開発ではないので、適当にSQLiteを使ってデータベースを準備します。
class Database:
def __init__(self):
self.conn = None
@contextmanager
def get_connection(self):
if self.conn is None:
self.conn = sqlite3.connect(':memory:', check_same_thread=False)
self.init_db()
try:
yield self.conn
finally:
pass
@contextmanager
def get_cursor(self):
with self.get_connection() as conn:
cursor = conn.cursor()
try:
yield cursor
finally:
cursor.close()
def init_db(self):
with self.get_cursor() as cursor:
# 観光地テーブルの作成
cursor.execute('''
CREATE TABLE attractions (
id INTEGER PRIMARY KEY,
name TEXT,
city_id INTEGER,
type TEXT,
admission_fee INTEGER
)
''')
# 都市テーブルの作成
cursor.execute('''
CREATE TABLE cities (
id INTEGER PRIMARY KEY,
name TEXT,
prefecture TEXT
)
''')
# 観光地データの挿入
attractions = [
(1, '東京スカイツリー', 1, 'タワー', 2100),
(2, '浅草寺', 1, '寺院', 0),
(3, '大阪城', 2, '城', 600),
(4, '道頓堀', 2, '繁華街', 0),
(5, '金閣寺', 3, '寺院', 400),
(6, '伏見稲荷大社', 3, '神社', 0),
(7, '横浜中華街', 4, '繁華街', 0),
(8, '箱根温泉', 5, '温泉', 500),
(9, '富士山', 6, '山', 1000),
(10, '札幌時計台', 7, '歴史的建造物', 200)
]
cursor.executemany('INSERT INTO attractions VALUES (?,?,?,?,?)', attractions)
# 都市データの挿入
cities = [
(1, '東京', '東京都'),
(2, '大阪', '大阪府'),
(3, '京都', '京都府'),
(4, '横浜', '神奈川県'),
(5, '箱根', '神奈川県'),
(6, '富士宮', '静岡県'),
(7, '札幌', '北海道')
]
cursor.executemany('INSERT INTO cities VALUES (?,?,?)', cities)
# レビューテーブルの作成
cursor.execute('''
CREATE TABLE reviews (
id INTEGER PRIMARY KEY,
attraction_id INTEGER,
rating INTEGER,
comment TEXT,
review_date DATE,
FOREIGN KEY (attraction_id) REFERENCES attractions (id)
)
''')
# レビューデータの挿入
reviews = [
(1, 1, 5, '景色が素晴らしい!', '2023-05-01'),
(2, 1, 4, '混んでいたが、価値がある', '2023-05-02'),
(3, 2, 5, '歴史を感じる素晴らしい寺院', '2023-05-03'),
(4, 3, 4, '大阪の象徴、素晴らしい', '2023-05-04'),
(5, 4, 3, '賑やかで面白い', '2023-05-05'),
(6, 5, 5, '美しい金閣寺、必見', '2023-05-06'),
(7, 6, 4, '鳥居の並びが印象的', '2023-05-07'),
(8, 7, 4, '美味しい中華料理がたくさん', '2023-05-08'),
(9, 8, 5, 'リラックスできる温泉', '2023-05-09'),
(10, 9, 5, '日本の象徴、絶景', '2023-05-10'),
(11, 10, 3, '歴史的な建物だが、期待ほどではない', '2023-05-11')
]
cursor.executemany('INSERT INTO reviews VALUES (?,?,?,?,?)', reviews)
self.conn.commit()
db = Database()
def get_db():
with db.get_cursor() as cursor:
yield cursor
Agentの作成
まずはSwarmのclientの初期化を行います。
from swarm import Swarm, Agent
client = Swarm()
続いて二つのAgentを作成します。
# テキストからSQLを生成するAgent
agent = Agent(
name="SQLAgent",
instructions="""
あなたは日本語の自然言語クエリをSQLクエリに変換できるAIアシスタントです。
データベースには3つのテーブルがあります:
1. 'attractions'テーブル以下の列を含む:id, name, city_id, type, admission_fee
2. 'cities'テーブル、以下の列を含む:id, name, prefecture
3. 'reviews'テーブル、以下の列を含む:id, attraction_id, rating, comment, review_date
SQLクエリのみを返し、他のテキストや説明を含めないでください。複雑なクエリをサポートし、複数テーブルの結合、比較、ソート、集計関数を含みます。
"""
)
# SQLの説明をするAgent
explanation_agent = Agent(
name="ExplanationAgent",
instructions="""
あなたはSQL専門家のAIアシスタントです。与えられたSQLクエリを非技術者にも分かりやすく説明することが任務です。
クエリの目的、関連するテーブル、使用される操作(結合、フィルタリング、ソートなど)、期待される結果を含む、簡潔かつ包括的な説明をMarkdown形式で提供してください。
説明のみを返し、他の内容は含めないでください。
"""
)
SQLを実行する機能
続いて、SQLを実行する機能を作成します。
他にSQLの整形や結果がわかりやすく表示するようにするための関数も作成します。
# SQLクエリをクリーンアップし、可能なMarkdown形式と余分な空白を削除する
def clean_sql_query(sql_query: str) -> str:
"""SQLクエリをクリーンアップし、可能なMarkdown形式と余分な空白を削除する"""
cleaned = re.sub(r'```sql\s*|\s*```', '', sql_query).strip()
return cleaned
# SQLクエリを実行し、結果を返す
def execute_sql(sql_query: str, cursor) -> tuple:
try:
cursor.execute(sql_query)
results = cursor.fetchall()
return results, cursor.description
except sqlite3.Error as e:
return f"SQLエラー: {e}", None
# AIを使用してSQLクエリの説明を生成する
def explain_query(sql_query: str) -> str:
"""AIを使用してSQLクエリの説明を生成する"""
response = client.run(
agent=explanation_agent,
messages=[{"role": "user", "content": f"以下のSQLクエリを説明してください:\n\n{sql_query}"}],
)
print(response.messages[-1]["content"])
return response.messages[-1]["content"]
# クエリ結果をフォーマットし、コンテキストと単位を追加する
def format_results(results, description) -> str:
"""クエリ結果をフォーマットし、コンテキストと単位を追加する"""
if isinstance(results, str):
return results
if not results or not description:
return "一致する結果が見つかりませんでした。"
headers = [desc[0] for desc in description]
# 入場料に円単位を追加
if 'admission_fee' in headers:
fee_index = headers.index('admission_fee')
results = [list(row) for row in results]
for row in results:
row[fee_index] = f"{row[fee_index]}円" if row[fee_index] > 0 else "無料"
formatted_results = tabulate(results, headers=headers, tablefmt="grid")
return formatted_results
続いて、上記機能を統合して結果をフロントエンドで表示するようにします。
def process_query(natural_language_query: str, cursor) -> dict:
"""自然言語クエリを処理し、SQLに変換し、実行して結果を返す"""
# Swarmを使用して自然言語をSQLに変換
response = client.run(
messages=[{"role": "user", "content": natural_language_query}],
agent=agent,
)
sql_query = clean_sql_query(response.messages[-1]["content"])
# SQLクエリを実行
results, description = execute_sql(sql_query, cursor)
# クエリの説明を取得
explanation = explain_query(sql_query)
# 結果をフォーマット
formatted_results = format_results(results, description)
return {
"sql_query": sql_query,
"explanation": explanation,
"results": formatted_results
}
FastAPIでWebページで表示
FastAPIでWebページで表示するようにします。
実装は以下の通りです。
サーバー側:
from fastapi import FastAPI, Request, Form, Depends
from fastapi.templating import Jinja2Templates
from fastapi.responses import HTMLResponse, JSONResponse
from contextlib import contextmanager, asynccontextmanager
app = FastAPI()
templates = Jinja2Templates(directory="templates")
@app.get("/", response_class=HTMLResponse)
async def read_root(request: Request):
return templates.TemplateResponse("index.html", {"request": request})
@app.post("/query")
async def query(user_input: str = Form(...), cursor = Depends(get_db)):
result = process_query(user_input, cursor)
return JSONResponse(content=result)
@asynccontextmanager
async def lifespan(app: FastAPI):
await shutdown_event()
yield
async def shutdown_event():
if db.conn:
db.conn.close()
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
テンプレートファイル:
<!DOCTYPE html>
<html lang="ja">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>日本の観光案内データベース</title>
<style>
body {
font-family: Arial, sans-serif;
max-width: 800px;
margin: 0 auto;
padding: 20px;
}
form {
margin-bottom: 20px;
}
input[type="text"] {
width: 70%;
padding: 10px;
}
input[type="submit"] {
padding: 10px 20px;
}
.card {
background-color: #f0f0f0;
border-radius: 5px;
padding: 15px;
margin-bottom: 20px;
}
h3 {
margin-bottom: 5px;
}
pre {
white-space: pre-wrap;
word-wrap: break-word;
margin: 0;
}
#loading {
display: none;
text-align: center;
margin-top: 20px;
}
</style>
<script src="https://cdn.jsdelivr.net/npm/marked/marked.min.js"></script>
</head>
<body>
<h1>日本の観光案内データベース</h1>
<p>自然言語で観光地や都市に関する質問ができます。</p>
<form id="queryForm">
<input type="text" id="userInput" name="user_input" placeholder="質問を入力してください"
value="{{ user_input if user_input else '' }}">
<input type="submit" value="質問する">
</form>
<div id="loading">検索中...</div>
<div id="result"></div>
<script>
document.getElementById('queryForm').addEventListener('submit', function (e) {
e.preventDefault();
const userInput = document.getElementById('userInput').value;
const loadingDiv = document.getElementById('loading');
const resultDiv = document.getElementById('result');
loadingDiv.style.display = 'block';
resultDiv.innerHTML = '';
fetch('/query', {
method: 'POST',
headers: {
'Content-Type': 'application/x-www-form-urlencoded',
},
body: 'user_input=' + encodeURIComponent(userInput)
})
.then(response => response.json())
.then(data => {
loadingDiv.style.display = 'none';
resultDiv.innerHTML = ``;
document.getElementById('explanation').innerHTML = marked.parse(data.explanation);
})
.catch(error => {
loadingDiv.style.display = 'none';
resultDiv.innerHTML = '<p>エラーが発生しました。もう一度お試しください。</p>';
console.error('Error:', error);
});
});
</script>
</body>
</html>
以上で必要な機能は実装できました。
動作確認
まずはOpenAIのAPIキーを設定します。
export OPENAI_API_KEY="sk-..."
続いてサーバーを起動します。
uvicorn app:app --reload
ブラウザでhttp://localhost:8000/
にアクセスして、動作確認を行います。
これでシンプルな検索フォームが表示されます。
一番評価が高い観光地と入場料金という質問をしてみます。
ちゃんと結果が表示されました。
しかも複数のテーブルも結合して検索しているのでとても優秀です。
他の質問もしてみます。
存在していない情報を検索してみます。
おわりに
今回はSwarmを使ってテキストからSQLを生成し、観光地案内をするようなアプリを作成しました。
個人的な感想としてはとても使いやすく、AgentとAgentとの連携も実装しやすいイメージでした。
また、Swarmのフレームワーク自体がまだ実験的なフレームワークで本番環境での使用は想定されていないので、
外部提供ではなく、社内での運用や学習向けのツールとして使うのが良いのではないかと思いました。
Discussion