🔖
PydanticAIでエージェントを作る-4:番外編・構造化出力はGPT-4oが飛び抜けてすごいのかもしれない
TL;DR
前回の記事[1]でdeps_typeを使って入出力の制御と更新が簡単になるという話をしました。
モデルはgpt-4o-miniを使っていましたが、今回はollamaで試してみました。
ollamaならAPIコストを無くせるかも・・・という下心だけでお試ししましたが、なかなか厳しい結果でした。
- 試したモデル(@ollama)
- llama3.2:1b
- qwen2.5
- qwen2.5:1.5b
- シンプルな構造型なら一応対応できる。
- ネストすると失敗(難しいのか空オブジェクトを返してくる確率が高い。)
コード変更点
あまりうまく行かなかったのでコード全体の話は補足で。
前回作成したコード[1:1]の一部を変更するだけです。
(ollamaはOpenAIとAPIが互換らしく、OpenAIModelに、モデル名とbase_url(ollamaサーバ)を指定するだけで使えました。)
from pydantic_ai.models.openai import OpenAIModel
# modelを指定。ollamaからqwenを指定
model_llama = OpenAIModel(model_name="qwen2.5", base_url="http://localhost:11434/v1")
scheduler_agent = Agent(
model=model_qwen, # モデルを指定
deps_type=input_info, # 入力の型を指定
result_type=output_info, # レスポンスの型を指定
system_prompt="You are an AI assistant helping a user with organizing schedules.",
retries=10, # retries:デフォだと1
)
で・・・出力なんですが、色々問題が出ました。
- Agentで指定しているretriesがデフォは1なんですが、これだと失敗することが多かったです。
(実はgpt-4oや4o-miniでは構造ミスが出たことが無く、リトライ数を変更していませんでした。) - というわけでリトライを10回にしてみたところ、うまくデータを作成できないのか空オブジェクトで返すことが多かったです。
(まれに成功するので、プログラムの問題ではなさそう。) - 記事1[2]のような2つの項目だけの場合は一応出力が出ました。
というわけで、スケジュールのリスト、のようなネスト化した形式が難しいかもしれません。
ちなみに巷で噂のdeepseek-r1はそもそも推論特化のモデルのせいか、構造化出力自体対応していませんでした。
ポイント
- 他のモデルのAPIには登録してないのでわかりませんが、ネスト化した構造化出力を安定して得るというのは神業かもしれないと思いました。
- gpt-4o(-mini)、恐ろしい子・・・
補足(コード全体:検証用)
リトライが多すぎるのと、空オブジェクト回答が多すぎて使い物にはなりませんでした。
import difflib
import nest_asyncio
import pandas as pd
import streamlit as st
from dotenv import load_dotenv
from pydantic import BaseModel, Field
from pydantic_ai import Agent, RunContext
from typing import List, Optional, Literal
from datetime import datetime
from pydantic_ai.models.openai import OpenAIModel
nest_asyncio.apply()
load_dotenv()
# modelを指定。ollamaからqwenを指定
model_qwen = OpenAIModel(model_name="qwen2.5", base_url="http://localhost:11434/v1")
# 入力の型を定義
class schedule(BaseModel):
task: str # 入力の型を定義
Start: str = Field(
"",
description="The start time of the task. The format should be YYYY-MM-DD.",
) # 開始日時の型を定義
Finish: str = Field(
"",
description="The finish time of the task. The format should be YYYY-MM-DD.",
) # 終了日時の型を定義
class input_info(BaseModel):
prompts: str # 入力の型を定義
schedules: List[schedule]
# 入力の型を定義
class output_info(BaseModel):
schedules: List[schedule] # レスポンスの型を定義、入力の型と同じ
# レスポンスの型を定義
# Pydanticのエージェントを作成
# 環境変数にOPENAI_API_KEYは設定済み
scheduler_agent = Agent(
model = model_qwen, # モデルを指定
deps_type=input_info, # 入力の型を指定
result_type=output_info, # レスポンスの型を指定
system_prompt="You are an AI assistant helping a user with organizing schedules.",
retries=10,
)
@scheduler_agent.system_prompt
async def get_system_prompt(ctx: RunContext[input_info]) -> str:
"""
System prompt for the AI agent to handle with list_of_schedules.
"""
added_prompt = f"""
Toeday is {datetime.now().strftime("%Y-%m-%d")}.
User's current schedules:
{ctx.deps.schedules}.
"""
return added_prompt
usual_agent = Agent(
model = model_qwen, # モデルを指定
system_prompt="You are an AI assistant helping a user with organizing schedules.",
retries=10,
)
if "user_input" not in st.session_state:
temp_input = input_info(prompts="", schedules=[]) # 入力の初期値を設定
st.session_state.user_input = (
temp_input.model_dump()
) # 入力をdictでセッションステートに保存
if "result" not in st.session_state:
temp_result = output_info(schedules=[]) # レスポンスの初期値を設定
st.session_state.result = (
temp_result.model_dump()
) # レスポンスをdictでセッションステートに保存
st.write("更新前データ")
st.json(st.session_state.user_input) # 入力をJSON形式で表示
st.table(
pd.DataFrame.from_dict(st.session_state.user_input["schedules"])
) # データフレームを表として表示
st.write("更新予定データ")
st.json(st.session_state.result) # レスポンスをJSON形式で表示
st.table(
pd.DataFrame.from_dict(st.session_state.result["schedules"])
) # データフレームを表として表示
# データ更新ボタン
if st.button("データ更新"):
st.session_state.user_input.update(st.session_state.result)
st.rerun()
# streamlitのUI作成
st.title("Pydanticを使ったAIエージェントの作成")
prompts = st.text_area("プロンプト入力欄", "ここに入力してください")
button1, button2 = st.columns(2)
if button1.button("チャット"):
st.write("チャットボタンがクリックされました")
user_input_deps = input_info(**st.session_state.user_input) # 入力を取得
user_input_deps.prompts = prompts # ユーザー入力を更新
result = scheduler_agent.run_sync(user_input_deps.prompts, deps=user_input_deps)
st.session_state.user_input.update(
user_input_deps.model_dump()
) # レスポンスをセッションステート
st.session_state.result.update(
result.data.model_dump()
) # レスポンスをセッションステートに保存
print(result.data)
st.rerun()
if button2.button("通常チャット"):
st.write("通常チャットボタンがクリックされました")
result = usual_agent.run_sync(prompts)
print(result.data)
st.write(result.data)
st.rerun()
Discussion