【連載】pydantic-ai徹底解説 (2) ツール呼び出しと依存性注入
前回は pydantic-ai を用いてシンプルにLLMへ問い合わせる方法を紹介しました。
今回は、pydantic-ai の強力な機能である「ツール呼び出し(Function Tools)」と「依存性注入(DI)」を用いて、LLM に外部データやビジネスロジックを利用させる方法を説明します。
インストール方法
pydantic-aiとその例を実行するには、以下の手順でインストールを行います:
基本的なインストール
pipまたはuvを使用して、以下のコマンドでインストールできます:
# pipを使用する場合
! pip install 'pydantic-ai[examples]' loguru
# uvを使用する場合(リポジトリをクローンした場合)
# uv sync --extra examples
import os
from google.colab import userdata
os.environ['GEMINI_API_KEY'] = userdata.get('GEMINI_API_KEY')
os.environ['OPENAI_API_KEY'] = userdata.get('OPENAI_API_KEY')
import nest_asyncio
nest_asyncio.apply()
環境変数の設定
LLMを使用するために、以下のいずれかの環境変数を設定する必要があります:
# OpenAIを使用する場合
# export OPENAI_API_KEY=your-api-key
# Google Geminiを使用する場合は別途設定が必要です
例の実行方法
インストール後、以下のコマンドで例を実行できます:
# 基本的な実行方法
python -m pydantic_ai_examples.<example_module_name>
# 例:pydantic_modelの例を実行
python -m pydantic_ai_examples.pydantic_model
# uvを使用したワンライナーでの実行
OPENAI_API_KEY='your-api-key' \
uv run --with 'pydantic-ai[examples]' \
-m pydantic_ai_examples.pydantic_model
例を編集して実行したい場合は、以下のコマンドで例をコピーできます:
# python -m pydantic_ai_examples --copy-to examples/
ツール(Function Tools)とは?
LLMは基本的にテキストを入力しテキストを出力する存在です。しかし、実際のアプリケーションでは、データベースへの問い合わせやAPIコールなど、外部情報が必要な場合があります。
pydantic-aiでは、LLMが「関数ツール」を呼び出せる仕組みを提供します。
関数を@agent.tool
デコレータで登録することで、LLMはJSONスキーマを理解し、適切な引数を与えて関数を呼び出すよう求められます。
依存性注入(Dependencies)
さらに、ツールやシステムプロンプトで外部リソース(DB接続、APIクライアント)を扱いたい場合は、deps_type
を指定して依存性注入が可能です。
これにより、テスト時に依存性を差し替えたり、実行環境に応じて異なる接続情報を渡したりできます。
サンプル例:銀行サポートエージェント
以下に、銀行サポート用エージェントの例を示します。
シナリオは、顧客IDとDBコネクションを依存として受け取り、顧客名を取得してシステムプロンプトに付与し、customer_balance
というツールで口座残高を取得する、といった流れです。また、最終結果はSupportResult
というpydanticモデルで構造化します。
from dataclasses import dataclass
from pydantic import BaseModel, Field
from pydantic_ai import Agent, RunContext
class DatabaseConn:
async def customer_name(self, id: int) -> str:
if id == 123:
return "John"
async def customer_balance(self, id: int, include_pending: bool) -> float:
return 123.45
@dataclass
class SupportDependencies:
customer_id: int
db: DatabaseConn
class SupportResult(BaseModel):
support_advice: str = Field(description='Advice to the customer')
block_card: bool = Field(description='Whether to block the customer\'s card')
risk: int = Field(description='Risk level', ge=0, le=10)
support_agent = Agent(
'openai:gpt-4o',
# 'gemini:gemini-1.5-pro-latest',
deps_type=SupportDependencies,
result_type=SupportResult,
system_prompt=(
'You are a support agent in our bank. Provide customer support and risk assessment.'
),
)
# システムプロンプトへの依存性活用
@support_agent.system_prompt
async def add_customer_name(ctx: RunContext[SupportDependencies]) -> str:
name = await ctx.deps.db.customer_name(id=ctx.deps.customer_id)
return f"The customer's name is {name!r}"
# ツール定義:LLMが呼び出せる関数
@support_agent.tool
async def customer_balance(ctx: RunContext[SupportDependencies], include_pending: bool) -> float:
"""Returns the customer's account balance."""
return await ctx.deps.db.customer_balance(id=ctx.deps.customer_id, include_pending=include_pending)
# 実行例
deps = SupportDependencies(customer_id=123, db=DatabaseConn())
result = await support_agent.run('What is my balance?', deps=deps)
print(result.data)
サンプル例:銀行サポートエージェント(日本語 Ver)
from dataclasses import dataclass
from pydantic import BaseModel, Field
from pydantic_ai import Agent, RunContext
from loguru import logger
import sys
class DatabaseConn:
async def customer_name(self, id: int) -> str:
"""顧客名を取得するメソッド"""
logger.debug(f"顧客名を取得: id={id}")
if id == 123:
return "John"
async def customer_balance(self, id: int, include_pending: bool) -> float:
"""口座残高を取得するメソッド"""
logger.debug(f"残高を取得: id={id}, include_pending={include_pending}")
return 123.45
@dataclass
class SupportDependencies:
customer_id: int
db: DatabaseConn
class SupportResult(BaseModel):
support_advice: str = Field(description='顧客へのアドバイス')
block_card: bool = Field(description='カードをブロックするかどうか')
risk: int = Field(description='リスクレベル(0-10)', ge=0, le=10)
def __str__(self) -> str:
"""結果を見やすく整形"""
return (
f"サポート結果:\n"
f" アドバイス: {self.support_advice}\n"
f" カードブロック: {'はい' if self.block_card else 'いいえ'}\n"
f" リスクレベル: {self.risk}/10"
)
support_agent = Agent(
'openai:gpt-4o',
deps_type=SupportDependencies,
result_type=SupportResult,
system_prompt=(
'あなたは当行のサポートエージェントです。顧客サポートとリスク評価を提供してください。'
),
)
@support_agent.system_prompt
async def add_customer_name(ctx: RunContext[SupportDependencies]) -> str:
"""顧客名をシステムプロンプトに追加"""
name = await ctx.deps.db.customer_name(id=ctx.deps.customer_id)
logger.info(f"システムプロンプトに顧客名を追加: {name}")
return f"お客様の名前は {name!r} です"
@support_agent.tool
async def customer_balance(ctx: RunContext[SupportDependencies], include_pending: bool) -> float:
"""顧客の口座残高を返す"""
balance = await ctx.deps.db.customer_balance(
id=ctx.deps.customer_id,
include_pending=include_pending
)
logger.info(f"口座残高を取得: {balance:,.2f}円")
return balance
async def main():
try:
logger.info("サポートエージェントを開始")
deps = SupportDependencies(customer_id=123, db=DatabaseConn())
# エージェントの実行
logger.info("顧客の残高照会を実行")
result = await support_agent.run('残高を教えてください', deps=deps)
# 結果の表示
logger.success("処理が正常に完了")
logger.info(str(result.data))
except Exception as e:
logger.error(f"エラーが発生: {str(e)}")
raise
if __name__ == "__main__":
import asyncio
asyncio.run(main())
サンプル例:高度な銀行AIエージェントシステム(日本語 Ver)
from dataclasses import dataclass
from datetime import datetime, timedelta
from decimal import Decimal
from enum import Enum
from typing import List, Dict, Optional, Any
from pydantic import BaseModel, Field, validator
from pydantic_ai import Agent, RunContext
from loguru import logger
import sys
import asyncio
from abc import ABC, abstractmethod
# ロガー設定
logger.remove()
logger.add(
sys.stdout,
format="<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>",
level="INFO"
)
class Currency(str, Enum):
"""対応通貨の定義"""
JPY = "JPY"
USD = "USD"
EUR = "EUR"
GBP = "GBP"
class TransactionType(str, Enum):
"""取引種別の定義"""
DEPOSIT = "入金"
WITHDRAWAL = "出金"
TRANSFER = "振込"
PAYMENT = "支払い"
INVESTMENT = "投資"
class CustomerSegment(str, Enum):
"""顧客セグメントの定義"""
STANDARD = "一般"
PREMIUM = "プレミアム"
PRIVATE = "プライベート"
BUSINESS = "ビジネス"
class Transaction(BaseModel):
"""取引情報モデル"""
id: str
timestamp: datetime
type: TransactionType
amount: Decimal
currency: Currency
description: str
location: Optional[str]
merchant: Optional[str]
risk_score: float = Field(ge=0, le=1)
@validator('amount')
def validate_amount(cls, v):
"""金額の検証"""
if v <= 0:
raise ValueError("金額は0より大きい必要があります")
return v
class CustomerProfile(BaseModel):
"""顧客プロファイルモデル"""
id: int
name: str
segment: CustomerSegment
risk_tolerance: int = Field(ge=1, le=5)
preferred_currency: Currency
investment_goals: List[str]
annual_income: Optional[Decimal]
class Config:
arbitrary_types_allowed = True
class InvestmentAdvice(BaseModel):
"""投資アドバイスモデル"""
recommendation: str
risk_level: int = Field(ge=1, le=5)
expected_return: float
time_horizon: str
suitable_products: List[str]
class FraudDetectionResult(BaseModel):
"""不正検知結果モデル"""
is_suspicious: bool
confidence: float = Field(ge=0, le=1)
risk_factors: List[str]
recommended_actions: List[str]
class MarketData(BaseModel):
"""市場データモデル"""
currency_pair: str
current_rate: float
trend: str
volatility: float
@dataclass
class AdvancedBankDependencies:
"""拡張された依存関係"""
customer_id: int
db: 'AdvancedDatabaseConn'
fraud_detector: 'FraudDetector'
investment_advisor: 'InvestmentAdvisor'
exchange_service: 'CurrencyExchange'
class AdvancedBankResult(BaseModel):
"""拡張された結果モデル"""
support_advice: str = Field(description='顧客へのアドバイス')
block_card: bool = Field(description='カードブロック要否')
risk: int = Field(description='リスクレベル', ge=0, le=10)
fraud_detection: Optional[FraudDetectionResult]
investment_advice: Optional[InvestmentAdvice]
transaction_analysis: Optional[Dict[str, float]]
def __str__(self) -> str:
"""結果の整形表示"""
result = [
"=== 銀行エージェント分析結果 ===",
f"アドバイス: {self.support_advice}",
f"カードブロック: {'要' if self.block_card else '不要'}",
f"リスクレベル: {self.risk}/10"
]
if self.fraud_detection:
result.extend([
"--- 不正検知結果 ---",
f"不審な活動: {'あり' if self.fraud_detection.is_suspicious else 'なし'}",
f"信頼度: {self.fraud_detection.confidence:.2%}",
"リスク要因:",
*[f"- {factor}" for factor in self.fraud_detection.risk_factors]
])
if self.investment_advice:
result.extend([
"--- 投資アドバイス ---",
f"推奨: {self.investment_advice.recommendation}",
f"想定リターン: {self.investment_advice.expected_return:.1%}",
f"投資期間: {self.investment_advice.time_horizon}",
"推奨商品:",
*[f"- {product}" for product in self.investment_advice.suitable_products]
])
if self.transaction_analysis:
result.extend([
"--- 取引分析 ---",
*[f"{k}: {v:.2f}" for k, v in self.transaction_analysis.items()]
])
return "\n".join(result)
class InvestmentAdvisor:
"""投資アドバイザーシステム"""
async def get_market_data(self) -> List[MarketData]:
"""市場データを取得"""
await asyncio.sleep(0.1) # 外部APIコールのシミュレート
return [
MarketData(
currency_pair="USD/JPY",
current_rate=110.0,
trend="上昇",
volatility=0.12
),
MarketData(
currency_pair="EUR/JPY",
current_rate=130.0,
trend="横ばい",
volatility=0.15
)
]
async def get_advice(self, profile: CustomerProfile) -> InvestmentAdvice:
"""投資アドバイスを生成"""
logger.info(f"投資アドバイスを生成: customer_id={profile.id}")
market_data = await self.get_market_data()
# リスク許容度に基づく商品選定
if profile.risk_tolerance <= 2:
products = ["安全性重視型投資信託", "国債", "定期預金"]
expected_return = 0.02
elif profile.risk_tolerance <= 4:
products = ["バランス型投資信託", "優良株式", "社債"]
expected_return = 0.05
else:
products = ["成長株式ファンド", "新興国株式", "ハイイールド債"]
expected_return = 0.08
return InvestmentAdvice(
recommendation=f"{profile.segment.value}のお客様向けポートフォリオをご提案いたします",
risk_level=profile.risk_tolerance,
expected_return=expected_return,
time_horizon="中期(3-5年)",
suitable_products=products
)
class FraudDetector:
"""不正検知システム"""
async def analyze_transactions(self, transactions: List[Transaction]) -> FraudDetectionResult:
"""取引の不正検知を実行"""
logger.info("取引の不正検知を実行")
await asyncio.sleep(0.1) # 分析処理のシミュレート
suspicious_patterns = []
for tx in transactions:
if tx.risk_score > 0.7:
suspicious_patterns.append(f"高リスクスコア: {tx.description}")
if tx.amount > Decimal("500000"):
suspicious_patterns.append(f"大口取引: {tx.description}")
if tx.location == "海外ATM":
suspicious_patterns.append(f"海外ATM利用: {tx.description}")
confidence = 0.85 if suspicious_patterns else 0.15
recommended_actions = [
"カード一時停止",
"本人確認の実施",
"セキュリティ警告の送信"
] if suspicious_patterns else ["通常監視の継続"]
return FraudDetectionResult(
is_suspicious=bool(suspicious_patterns),
confidence=confidence,
risk_factors=suspicious_patterns,
recommended_actions=recommended_actions
)
class CurrencyExchange:
"""為替レート管理システム"""
async def get_rate(self, from_currency: Currency, to_currency: Currency) -> float:
"""為替レートを取得"""
logger.debug(f"為替レート取得: {from_currency} -> {to_currency}")
await asyncio.sleep(0.1) # 外部APIコールのシミュレート
rates = {
(Currency.USD, Currency.JPY): 110.0,
(Currency.EUR, Currency.JPY): 130.0,
(Currency.GBP, Currency.JPY): 150.0,
}
return rates.get((from_currency, to_currency), 1.0)
class AdvancedDatabaseConn:
"""拡張されたデータベース接続"""
def __init__(self):
self._transactions = [] # 取引履歴のキャッシュ
async def customer_profile(self, id: int) -> CustomerProfile:
"""顧客プロファイルを取得"""
logger.debug(f"顧客プロファイル取得: id={id}")
await asyncio.sleep(0.1) # DB遅延のシミュレート
return CustomerProfile(
id=id,
name="John Doe",
segment=CustomerSegment.PREMIUM,
risk_tolerance=3,
preferred_currency=Currency.JPY,
investment_goals=["資産形成", "老後資金"],
annual_income=Decimal("5000000")
)
async def recent_transactions(self, id: int, days: int = 30) -> List[Transaction]:
"""最近の取引履歴を取得"""
logger.debug(f"最近の取引履歴取得: id={id}, days={days}")
await asyncio.sleep(0.1) # DB遅延のシミュレート
if self._transactions:
return self._transactions
return [
Transaction(
id=f"tx_{i}",
timestamp=datetime.now() - timedelta(days=i),
type=TransactionType.PAYMENT,
amount=Decimal("10000"),
currency=Currency.JPY,
description="通常取引",
location="国内",
merchant="一般店舗",
risk_score=0.1
)
for i in range(5)
]
async def account_balance(self, id: int, currency: Currency) -> Decimal:
"""口座残高を取得"""
logger.debug(f"口座残高取得: id={id}, currency={currency}")
await asyncio.sleep(0.1) # DB遅延のシミュレート
return Decimal("1000000")
def set_transactions(self, transactions: List[Transaction]):
"""テスト用:取引履歴を設定"""
self._transactions = transactions
# エージェントの定義
advanced_agent = Agent(
'openai:gpt-4o',
deps_type=AdvancedBankDependencies,
result_type=AdvancedBankResult,
system_prompt=(
'あなたは当行の高度なAIアシスタントです。'
'顧客サポート、リスク評価、不正検知、投資アドバイスを提供してください。'
),
)
@advanced_agent.system_prompt
async def add_customer_context(ctx: RunContext[AdvancedBankDependencies]) -> str:
"""システムプロンプトにコンテキストを追加"""
profile = await ctx.deps.db.customer_profile(ctx.deps.customer_id)
return (
f"顧客プロファイル:\n"
f"- 名前: {profile.name}\n"
f"- セグメント: {profile.segment.value}\n"
f"- 投資リスク許容度: {profile.risk_tolerance}/5\n"
f"- 優先通貨: {profile.preferred_currency}"
)
@advanced_agent.tool
async def analyze_account(
ctx: RunContext[AdvancedBankDependencies],
include_transactions: bool = True,
check_fraud: bool = True
) -> Dict[str, Any]:
"""総合的なアカウント分析を実行"""
logger.info(f"アカウント分析開始: customer_id={ctx.deps.customer_id}")
profile = await ctx.deps.db.customer_profile(ctx.deps.customer_id)
balance = await ctx.deps.db.account_balance(ctx.deps.customer_id, profile.preferred_currency)
result = {
"balance": float(balance),
"currency": profile.preferred_currency,
"segment": profile.segment
}
if include_transactions:
try:
transactions = await ctx.deps.db.recent_transactions(ctx.deps.customer_id)
result["transaction_count"] = len(transactions)
if check_fraud and transactions:
try:
fraud_result = await ctx.deps.fraud_detector.analyze_transactions(transactions)
result["fraud_detection"] = fraud_result
except Exception as e:
logger.error(f"不正検知分析でエラー: {str(e)}")
result["fraud_detection_error"] = str(e)
except Exception as e:
logger.error(f"取引履歴の取得でエラー: {str(e)}")
result["transaction_error"] = str(e)
return result
# -------------------------------
@advanced_agent.tool
async def get_investment_recommendation(
ctx: RunContext[AdvancedBankDependencies]
) -> InvestmentAdvice:
"""投資アドバイスを取得"""
profile = await ctx.deps.db.customer_profile(ctx.deps.customer_id)
return await ctx.deps.investment_advisor.get_advice(profile)
@advanced_agent.tool
async def check_fraud(
ctx: RunContext[AdvancedBankDependencies],
transaction_days: int = 30
) -> FraudDetectionResult:
"""不正検知チェックを実行"""
transactions = await ctx.deps.db.recent_transactions(ctx.deps.customer_id, transaction_days)
return await ctx.deps.fraud_detector.analyze_transactions(transactions)
@advanced_agent.tool
async def convert_currency(
ctx: RunContext[AdvancedBankDependencies],
amount: float,
from_currency: Currency,
to_currency: Currency
) -> float:
"""通貨換算を実行"""
rate = await ctx.deps.exchange_service.get_rate(from_currency, to_currency)
return amount * rate
async def main():
"""メインの実行処理"""
try:
logger.info("=== 高度な銀行エージェントデモを開始 ===")
# 依存関係のインスタンス化
db = AdvancedDatabaseConn()
deps = AdvancedBankDependencies(
customer_id=123,
db=db,
fraud_detector=FraudDetector(),
investment_advisor=InvestmentAdvisor(),
exchange_service=CurrencyExchange()
)
# シナリオ1: 総合的な資産分析
logger.info("\n=== シナリオ1: 総合的な資産分析 ===")
result = await advanced_agent.run(
"""
以下の項目を含む総合的な資産分析をお願いします:
1. 現在の資産状況
2. 最近の取引パターン
3. リスク評価
4. 投資提案
""",
deps=deps
)
logger.info(str(result.data))
# シナリオ2: 不正検知アラート対応
logger.info("\n=== シナリオ2: 不正検知アラート対応 ===")
# 不正な取引をシミュレート
suspicious_transaction = Transaction(
id="tx_123",
timestamp=datetime.now(),
type=TransactionType.WITHDRAWAL,
amount=Decimal("1000000"),
currency=Currency.JPY,
description="大口出金",
location="海外ATM",
merchant=None,
risk_score=0.85
)
db.set_transactions([suspicious_transaction])
result = await advanced_agent.run(
"""
大口の海外ATM出金が検出されました。
1. リスク分析
2. 推奨アクション
3. 緊急対応の必要性
を評価してください。
""",
deps=deps
)
logger.info(str(result.data))
# シナリオ3: 投資ポートフォリオ最適化
logger.info("\n=== シナリオ3: 投資ポートフォリオ最適化 ===")
db.set_transactions([]) # 取引履歴をリセット
result = await advanced_agent.run(
"""
以下を考慮した投資ポートフォリオの提案をお願いします:
- リスク許容度
- 市場動向
- 長期的な資産形成目標
""",
deps=deps
)
logger.info(str(result.data))
# シナリオ4: マルチ通貨取引分析
logger.info("\n=== シナリオ4: マルチ通貨取引分析 ===")
# 複数通貨の取引をシミュレート
multi_currency_transactions = [
Transaction(
id=f"tx_{i}",
timestamp=datetime.now() - timedelta(days=i),
type=TransactionType.PAYMENT,
amount=Decimal("100"),
currency=currency,
description=f"{currency.value}での支払い",
location="オンライン",
merchant="EC Store",
risk_score=0.1
)
for i, currency in enumerate([Currency.USD, Currency.EUR, Currency.JPY])
]
db.set_transactions(multi_currency_transactions)
result = await advanced_agent.run(
"""
複数通貨での取引履歴を分析し、以下を報告してください:
1. 通貨ごとの取引傾向
2. 為替変動の影響
3. 通貨最適化の提案
""",
deps=deps
)
logger.info(str(result.data))
# シナリオ5: 顧客セグメント別サービス提案
logger.info("\n=== シナリオ5: 顧客セグメント別サービス提案 ===")
result = await advanced_agent.run(
"""
プレミアム顧客向けの以下のサービス提案をお願いします:
1. 資産管理サービス
2. 優遇金利商品
3. コンシェルジュサービス
""",
deps=deps
)
logger.info(str(result.data))
logger.success("全デモシナリオが正常に完了しました")
except Exception as e:
logger.error(f"エラーが発生: {str(e)}")
raise
if __name__ == "__main__":
import asyncio
asyncio.run(main())
ポイント:
-
deps_type=SupportDependencies
としているため、RunContext[SupportDependencies]
を通してツールやプロンプト関数内でctx.deps
を参照可能。 - LLMは
customer_balance
ツールを呼び出してから最終的なSupportResult
を返します。 -
SupportResult
という構造化された戻り値を得るため、LLMはJSONとして正しい形式で回答することが求められます。
pydantic-aiはこのJSONレスポンスをバリデートし、失敗すれば再試行を促します。
まとめ
今回の記事で、pydantic-aiを用いてLLMにツール関数を呼び出させ、依存性注入で外部データベース等のリソースにアクセスさせる方法を学びました。
これにより、単なるテキスト変換器だったLLMが、外部環境とのやり取りを行い、アプリケーションロジックの一部として組み込むことが可能になります。
次回は、さらに複雑な構造化レスポンスやストリーミング、再試行機構、そしてユニットテスト・評価(Evals)のためのTestModel
などについて掘り下げます。
📒ノートブック
参考サイト
<script async src="https://platform.twitter.com/widgets.js" charset="utf-8"></script>
Discussion