【Adrian's作業メモ】Pydantic AIとwatsonx.aiの連携
TL;DR
Pydantic AIとwatsonx.aiを連携させ、AI Agentを構築した。
Watsonx.aiだとツールコール(Tool Call)ができず、自前で実装してみた。
参考Repoはこらら!
導入
Pydantic AIは、PythonのAgentフレームワークです。OpenAI、Anthropic、Geminiなどの主要なLLMプロバイダーをサポートしていますが、IBM watsonx.aiのような独自APIを持つプロバイダーとの連携も可能です。
本記事では、Watsonx.aiをPydanticAIで使用するためのカスタムモデルアダプターの実装方法を解説します。
Pydantic AIとは?
PydanticAIの特徴:
- 型安全性: Pydanticの検証機能を活用したLLMの入出力管理
- 統一インターフェース: 複数のLLMプロバイダーを同じAPIで扱える
- エージェント機能: ツール呼び出しや会話履歴の管理を簡素化
- 拡張性: カスタムモデルの実装が可能
LLMプロバイダー
PydanticAIは、以下のLLMプロバイダーをデフォルトでサポートしています:
デフォルトサポート対象
| プロバイダー | モデルクラス | 主要モデル例 |
|---|---|---|
| OpenAI | OpenAIModel |
GPT-4, GPT-4 Turbo, GPT-3.5 |
| Anthropic | AnthropicModel |
Claude 3.5 Sonnet, Claude 3 Opus |
| Google Gemini | GeminiModel |
Gemini 1.5 Pro, Gemini 1.5 Flash |
| Groq | GroqModel |
Llama 3, Mixtral |
| Ollama | OllamaModel |
ローカルモデル(Llama, Mistral等) |
使用例
from pydantic_ai import Agent
roulette_agent = Agent(
'openai:gpt-4o',
deps_type=int,
output_type=bool,
system_prompt=(
'Use the `roulette_wheel` function to see if the '
'customer has won based on the number they provide.'
),
)
Custom LLMの実装
上記以外のプロバイダー(Watsonx.aiなど)を使用する場合は、Model抽象クラスを継承してカスタムモデルアダプターを実装します:
from pydantic_ai.models import Model, StreamedResponse
from pydantic_ai.messages import ModelMessage, ModelResponse
@dataclass
class WatsonxAIModel(Model):
"""Watsonx.aiのカスタムモデルアダプター"""
client: ModelInference # Watsonx SDKのクライアント
@property
def system(self) -> str:
return "watsonx"
@property
def model_name(self) -> str:
return self._model_name
async def request(
self,
message_history: list[ModelMessage],
model_settings: Any,
model_request_parameters: Any,
) -> ModelResponse:
"""非ストリーミングリクエストの実装"""
# 実装詳細...
主要な実装要素
1. 初期化と認証
def __init__(
self,
model_id: str = "openai/gpt-oss-120b",
params: Optional[TextGenParameters] = None,
) -> None:
# 環境変数から認証情報を取得
endpoint = os.getenv("WATSONX_ENDPOINT")
api_key = os.getenv("IBM_CLOUD_API_KEY")
project_id = os.getenv("WATSONX_PROJECT_ID")
# Watsonxクライアントの初期化
credentials = Credentials(url=endpoint, api_key=api_key)
self.client = ModelInference(
model_id=model_id,
credentials=credentials,
project_id=project_id,
params=params or default_params,
)
ポイント:
- 環境変数ベースの設定でセキュリティを確保
- デフォルトパラメータ(temperature、max_tokensなど)を設定
2. プロンプト構築
PydanticAIのメッセージ形式をWatsonx用のテキストプロンプトに変換:
def _build_prompt(self, messages: list[ModelMessage]) -> str:
"""PydanticAIのメッセージ → Watsonxプロンプトに変換"""
lines: list[str] = []
for msg in messages:
if isinstance(msg, ModelRequest):
for p in msg.parts:
if isinstance(p, UserPromptPart):
lines.append(f"{p.content}")
elif isinstance(msg, ModelResponse):
for p in msg.parts:
if isinstance(p, TextPart):
lines.append(f"{p.content}")
return "\n".join(lines)
3. API呼び出しとレスポンス変換
async def request(
self,
message_history: list[ModelMessage],
model_settings: Any,
model_request_parameters: Any,
) -> ModelResponse:
# プロンプト構築
prompt = self._build_prompt(message_history)
# Watsonx API呼び出し
result = await self.client.agenerate(prompt, params=self._params.to_dict())
generated_text = result["results"][0]["generated_text"]
# メタデータ抽出
input_tokens = result["results"][0]["input_token_count"]
output_tokens = result["results"][0]["generated_token_count"]
# PydanticAI形式のレスポンス作成
usage = RequestUsage(input_tokens=input_tokens, output_tokens=output_tokens)
response = ModelResponse(
parts=[TextPart(content=generated_text)],
model_name=result["model_id"],
timestamp=datetime.fromisoformat(result["created_at"].replace("Z", "+00:00")),
)
# メタデータを付与
response.usage = usage
return response
重要な変換処理:
- Watsonxの生レスポンス → PydanticAIの標準形式
- トークン使用量、モデル情報、タイムスタンプの保持
- メタデータを動的属性として追加
4. ツール呼び出しの実装
Watsonx.aiはOpenAIのFunction Callingのようなネイティブなツール呼び出し機能を持っていません。そのため、プロンプトエンジニアリングでツール呼び出しを実装します。
[補足:10/26]
Watsonx.aiはツール実行を行うモデルはありますが、この記事の時点ではPydantic AIの既存のツールコールはエラーが発生しています。
Watsonx.aiのツールコールについてはこちらを参考してください。
ツール呼び出しのPrompt
system_prompt="""
You help users select travel destinations.
IMPORTANT: When the user asks for a "random city", you MUST respond EXACTLY with:
CALL_TOOL: choose_random_city
Do NOT select a city yourself. Always use the tool when randomness is requested.
After I provide the tool result, format it as JSON:
{"destination": "City Name"}
"""
ツール呼び出しの検出と実行
def handle_tool_calls(agent_response: str, available_tools: dict) -> str | None:
"""モデルのレスポンスからツール呼び出しを検出して実行"""
if "CALL_TOOL:" in agent_response:
# ツール名を抽出
tool_match = re.search(r"CALL_TOOL:\s*(\w+)", agent_response)
if tool_match:
tool_name = tool_match.group(1)
# ツールが存在すれば実行
if tool_name in available_tools:
tool_func = available_tools[tool_name]
result = tool_func()
return result
return None
実行例
ソースコード
@dataclass
class TravelDeps:
user_name: str # ユーザー名
origin_city: str # 出発地
destination_agent = Agent(
model=watsonx_ai_model,
deps_type=TravelDeps,
# output_type=DestinationOutput, # Watsonxだと自前で実装する必要があり、一旦JSONパースで対応
system_prompt="""\
You help users select travel destinations.
IMPORTANT: When the user asks for a "random city", you MUST respond EXACTLY with:
CALL_TOOL: choose_random_city
Do NOT select a city yourself. Always use the tool when randomness is requested.
After I provide the tool result, format it as JSON:
{"destination": "City Name"}
""",
)
deps = TravelDeps(user_name="Maria", origin_city="Berlin")
user_request = "I want a random city within Europe."
result = destination_agent.run_sync(user_request, deps=deps)
agent_response = result.output if hasattr(result, "output") else ""
cleaned_response = extract_json_from_text(agent_response)
fallback_tool_result = handle_tool_calls(agent_response, AVAILABLE_TOOLS)
# Define available tools
def choose_random_city() -> str:
"""Choose a random city from a predefined list."""
print("\n🔧 TOOL CALLED: choose_random_city", flush=True)
cities = ["London", "Paris", "Berlin", "Dublin", "Madrid"]
selected = random.choice(cities)
print(f" → Selected city: {selected}\n", flush=True)
return selected
AVAILABLE_TOOLS = {
"choose_random_city": choose_random_city,
}
fallback_tool_result = handle_tool_calls(agent_response, AVAILABLE_TOOLS)
if fallback_tool_result:
print(f"📦 Fallback Tool Result: {fallback_tool_result}")
print(f"✅ Creating structured response from tool result\n")
destination_data = {"destination": fallback_tool_result}
cleaned_response = json.dumps(destination_data, indent=2)
出力
> uv run watsonx_demo.py
============================================================
🚀 Starting Travel Planning Agent with Manual Tool Execution
============================================================
👤 User: I want a random city within Europe.
📍 Origin: Berlin
👋 Traveler: Maria
🔍 Detected tool call request from agent: choose_random_city
🔧 TOOL CALLED: choose_random_city
→ Selected city: Dublin
✅ Tool executed successfully
📦 Fallback Tool Result: Dublin
✅ Creating structured response from tool result
DESTINATION DATA: {'destination': 'Dublin'}
============================================================
🤖 Agent: Dublin
============================================================
============================================================
📊 USAGE STATISTICS
============================================================
Total Requests: 1
Input Tokens: 76
Output Tokens: 70
Total Tokens: 146
============================================================
✅ Demo completed successfully!
============================================================
watsonx.aiモデルクラス
@dataclass
class WatsonxAIModel(Model):
"""
Pydantic AI Model for IBM Watsonx, supporting both non-streaming and streaming generation.
"""
client: ModelInference
_http_client: AsyncHTTPClient
_model_name: str
_endpoint: str
_project_id: str
_params: TextGenParameters
def __init__(
self,
model_id: str = "openai/gpt-oss-120b",
params: Optional[TextGenParameters] = None,
) -> None:
"""
Initialize WatsonxAIModel using environment-based credentials.
:param model_id: Watsonx model ID.
:param params: Optional text generation parameters.
:raises UserError: If required environment variables are missing.
"""
check_allow_model_requests()
endpoint = os.getenv("WATSONX_ENDPOINT")
api_key = os.getenv("IBM_CLOUD_API_KEY")
project_id = os.getenv("WATSONX_PROJECT_ID")
if not endpoint or not api_key or not project_id:
raise UserError(
"Missing required environment variables: WATSONX_ENDPOINT, IBM_CLOUD_API_KEY, or WATSONX_PROJECT_ID."
)
credentials = Credentials(url=endpoint, api_key=api_key)
default_params = TextGenParameters(
temperature=0.0,
max_new_tokens=1000,
random_seed=42,
decoding_method="greedy",
min_new_tokens=1,
)
self.client = ModelInference(
model_id=model_id,
credentials=credentials,
project_id=project_id,
params=params or default_params,
)
self._http_client = cached_async_http_client()
self._model_name = model_id
self._endpoint = endpoint
self._project_id = project_id
self._params = params or default_params
@property
def system(self) -> str:
"""
System identifier for Watsonx.
"""
return "watsonx"
@property
def model_name(self) -> str:
"""
Returns the Watsonx model name.
"""
return self._model_name
async def request(
self,
message_history: list[ModelMessage],
model_settings: Any,
model_request_parameters: Any,
) -> ModelResponse:
"""
Perform non-streaming generation using Watsonx's agenerate.
"""
check_allow_model_requests()
prompt = self._build_prompt(message_history)
result = await self.client.agenerate(prompt, params=self._params.to_dict())
generated_text = result["results"][0]["generated_text"]
# Extract JSON from the response to handle cases where model includes extra text
cleaned_text = extract_json_from_text(generated_text)
# Extract comprehensive metadata from Watsonx response
input_tokens = 0
output_tokens = 0
stop_reason = None
seed = None
model_id = None
created_at = None
warnings = []
# Extract from result metadata
if "model_id" in result:
model_id = result["model_id"]
if "created_at" in result:
created_at = result["created_at"]
if "system" in result and "warnings" in result["system"]:
warnings = result["system"]["warnings"]
# Extract from results array
if "results" in result and len(result["results"]) > 0:
result_data = result["results"][0]
# Token counts
if "input_token_count" in result_data:
input_tokens = result_data["input_token_count"]
if "generated_token_count" in result_data:
output_tokens = result_data["generated_token_count"]
# Generation metadata
if "stop_reason" in result_data:
stop_reason = result_data["stop_reason"]
if "seed" in result_data:
seed = result_data["seed"]
# Create RequestUsage with token counts
usage = RequestUsage(input_tokens=input_tokens, output_tokens=output_tokens)
# Create ModelResponse with all metadata
response = ModelResponse(
parts=[TextPart(content=cleaned_text)],
model_name=model_id or self._model_name,
timestamp=datetime.fromisoformat(created_at.replace("Z", "+00:00"))
if created_at
else datetime.utcnow(),
)
# Add usage attribute to response for pydantic_ai compatibility
response.usage = usage # type: ignore[attr-defined]
# Store additional Watsonx-specific metadata as attributes
response.stop_reason = stop_reason # type: ignore[attr-defined]
response.seed = seed # type: ignore[attr-defined]
response.warnings = warnings # type: ignore[attr-defined]
return response
@asynccontextmanager
async def request_stream(
self,
message_history: list[ModelMessage],
model_settings: Any,
model_request_parameters: Any,
) -> AsyncIterator[StreamedResponse]:
"""
Execute streaming generation via Watsonx's agenerate_stream.
"""
check_allow_model_requests()
prompt = self._build_prompt(message_history)
stream_generator = await self.client.agenerate_stream(
prompt, params=self._params.to_dict()
)
try:
yield WatsonxStreamedResponse(
_model_name=self._model_name, _async_generator=stream_generator
)
finally:
pass
def _build_prompt(self, messages: list[ModelMessage]) -> str:
"""
Constructs a Watsonx prompt from pydantic_ai messages.
"""
lines: list[str] = []
for msg in messages:
if isinstance(msg, ModelRequest):
for p in msg.parts:
if isinstance(p, UserPromptPart):
lines.append(f"{p.content}")
elif hasattr(p, "content") and isinstance(p.content, str):
lines.append(f"{p.content}")
elif isinstance(msg, ModelResponse):
# Handle system messages and previous responses
for p in msg.parts:
if isinstance(p, TextPart):
lines.append(f"{p.content}")
return "\n".join(lines)
まとめ
本記事では、PydanticAIとWatsonx.aiを連携させる方法を解説しました。
重要なポイント:
-
カスタムモデルアダプター:
Modelクラスを継承して実装 - プロトコル変換: Watsonx API ↔ PydanticAIメッセージ形式
- メタデータ管理: トークン使用量やモデル情報の保持
- ツール呼び出し: プロンプトエンジニアリングで手動実装
PydanticAIの抽象化により、Watsonx.aiのような独自APIを持つプロバイダーでも、統一されたインターフェースで利用できるようになります。
この実装パターンは、他のカスタムLLMプロバイダーにも応用可能です。
サンプルコード
完全なサンプルコードは、以下のリポジトリで公開しています:
- GitHub: REPO
Discussion