🏗️

【Adrian's作業メモ】Pydantic AIとwatsonx.aiの連携

に公開

TL;DR

Pydantic AIwatsonx.aiを連携させ、AI Agentを構築した。
Watsonx.aiだとツールコール(Tool Call)ができず、自前で実装してみた。
参考Repoはこらら

導入

Pydantic AIは、PythonのAgentフレームワークです。OpenAI、Anthropic、Geminiなどの主要なLLMプロバイダーをサポートしていますが、IBM watsonx.aiのような独自APIを持つプロバイダーとの連携も可能です。

本記事では、Watsonx.aiをPydanticAIで使用するためのカスタムモデルアダプターの実装方法を解説します。

Pydantic AIとは?

PydanticAIの特徴:

  • 型安全性: Pydanticの検証機能を活用したLLMの入出力管理
  • 統一インターフェース: 複数のLLMプロバイダーを同じAPIで扱える
  • エージェント機能: ツール呼び出しや会話履歴の管理を簡素化
  • 拡張性: カスタムモデルの実装が可能

LLMプロバイダー

PydanticAIは、以下のLLMプロバイダーをデフォルトでサポートしています:

デフォルトサポート対象

プロバイダー モデルクラス 主要モデル例
OpenAI OpenAIModel GPT-4, GPT-4 Turbo, GPT-3.5
Anthropic AnthropicModel Claude 3.5 Sonnet, Claude 3 Opus
Google Gemini GeminiModel Gemini 1.5 Pro, Gemini 1.5 Flash
Groq GroqModel Llama 3, Mixtral
Ollama OllamaModel ローカルモデル(Llama, Mistral等)

使用例

from pydantic_ai import Agent

roulette_agent = Agent(  
    'openai:gpt-4o',
    deps_type=int,
    output_type=bool,
    system_prompt=(
        'Use the `roulette_wheel` function to see if the '
        'customer has won based on the number they provide.'
    ),
)

Custom LLMの実装

上記以外のプロバイダー(Watsonx.aiなど)を使用する場合は、Model抽象クラスを継承してカスタムモデルアダプターを実装します:

from pydantic_ai.models import Model, StreamedResponse
from pydantic_ai.messages import ModelMessage, ModelResponse

@dataclass
class WatsonxAIModel(Model):
    """Watsonx.aiのカスタムモデルアダプター"""
    
    client: ModelInference  # Watsonx SDKのクライアント
    
    @property
    def system(self) -> str:
        return "watsonx"
    
    @property
    def model_name(self) -> str:
        return self._model_name
    
    async def request(
        self,
        message_history: list[ModelMessage],
        model_settings: Any,
        model_request_parameters: Any,
    ) -> ModelResponse:
        """非ストリーミングリクエストの実装"""
        # 実装詳細...

主要な実装要素

1. 初期化と認証

def __init__(
    self,
    model_id: str = "openai/gpt-oss-120b",
    params: Optional[TextGenParameters] = None,
) -> None:
    # 環境変数から認証情報を取得
    endpoint = os.getenv("WATSONX_ENDPOINT")
    api_key = os.getenv("IBM_CLOUD_API_KEY")
    project_id = os.getenv("WATSONX_PROJECT_ID")
    
    # Watsonxクライアントの初期化
    credentials = Credentials(url=endpoint, api_key=api_key)
    self.client = ModelInference(
        model_id=model_id,
        credentials=credentials,
        project_id=project_id,
        params=params or default_params,
    )

ポイント:

  • 環境変数ベースの設定でセキュリティを確保
  • デフォルトパラメータ(temperature、max_tokensなど)を設定

2. プロンプト構築

PydanticAIのメッセージ形式をWatsonx用のテキストプロンプトに変換:

def _build_prompt(self, messages: list[ModelMessage]) -> str:
    """PydanticAIのメッセージ → Watsonxプロンプトに変換"""
    lines: list[str] = []
    for msg in messages:
        if isinstance(msg, ModelRequest):
            for p in msg.parts:
                if isinstance(p, UserPromptPart):
                    lines.append(f"{p.content}")
        elif isinstance(msg, ModelResponse):
            for p in msg.parts:
                if isinstance(p, TextPart):
                    lines.append(f"{p.content}")
    return "\n".join(lines)

3. API呼び出しとレスポンス変換

async def request(
    self,
    message_history: list[ModelMessage],
    model_settings: Any,
    model_request_parameters: Any,
) -> ModelResponse:
    # プロンプト構築
    prompt = self._build_prompt(message_history)
    
    # Watsonx API呼び出し
    result = await self.client.agenerate(prompt, params=self._params.to_dict())
    generated_text = result["results"][0]["generated_text"]
    
    # メタデータ抽出
    input_tokens = result["results"][0]["input_token_count"]
    output_tokens = result["results"][0]["generated_token_count"]
    
    # PydanticAI形式のレスポンス作成
    usage = RequestUsage(input_tokens=input_tokens, output_tokens=output_tokens)
    response = ModelResponse(
        parts=[TextPart(content=generated_text)],
        model_name=result["model_id"],
        timestamp=datetime.fromisoformat(result["created_at"].replace("Z", "+00:00")),
    )
    
    # メタデータを付与
    response.usage = usage
    return response

重要な変換処理:

  • Watsonxの生レスポンス → PydanticAIの標準形式
  • トークン使用量、モデル情報、タイムスタンプの保持
  • メタデータを動的属性として追加

4. ツール呼び出しの実装

Watsonx.aiはOpenAIのFunction Callingのようなネイティブなツール呼び出し機能を持っていません。そのため、プロンプトエンジニアリングでツール呼び出しを実装します。

[補足:10/26]
Watsonx.aiはツール実行を行うモデルはありますが、この記事の時点ではPydantic AIの既存のツールコールはエラーが発生しています。

Watsonx.aiのツールコールについてはこちらを参考してください。

ツール呼び出しのPrompt

system_prompt="""
You help users select travel destinations.

IMPORTANT: When the user asks for a "random city", you MUST respond EXACTLY with:
CALL_TOOL: choose_random_city

Do NOT select a city yourself. Always use the tool when randomness is requested.

After I provide the tool result, format it as JSON:
{"destination": "City Name"}
"""

ツール呼び出しの検出と実行

def handle_tool_calls(agent_response: str, available_tools: dict) -> str | None:
    """モデルのレスポンスからツール呼び出しを検出して実行"""
    if "CALL_TOOL:" in agent_response:
        # ツール名を抽出
        tool_match = re.search(r"CALL_TOOL:\s*(\w+)", agent_response)
        if tool_match:
            tool_name = tool_match.group(1)
            
            # ツールが存在すれば実行
            if tool_name in available_tools:
                tool_func = available_tools[tool_name]
                result = tool_func()
                return result
    return None

実行例

ソースコード


@dataclass
class TravelDeps:
    user_name: str  # ユーザー名
    origin_city: str  # 出発地


destination_agent = Agent(
    model=watsonx_ai_model,
    deps_type=TravelDeps,
    # output_type=DestinationOutput,  # Watsonxだと自前で実装する必要があり、一旦JSONパースで対応
    system_prompt="""\
You help users select travel destinations.

IMPORTANT: When the user asks for a "random city", you MUST respond EXACTLY with:
CALL_TOOL: choose_random_city

Do NOT select a city yourself. Always use the tool when randomness is requested.

After I provide the tool result, format it as JSON:
{"destination": "City Name"}
""",
)

deps = TravelDeps(user_name="Maria", origin_city="Berlin")
user_request = "I want a random city within Europe."

result = destination_agent.run_sync(user_request, deps=deps)
agent_response = result.output if hasattr(result, "output") else ""
cleaned_response = extract_json_from_text(agent_response)
fallback_tool_result = handle_tool_calls(agent_response, AVAILABLE_TOOLS)

# Define available tools
def choose_random_city() -> str:
    """Choose a random city from a predefined list."""
    print("\n🔧 TOOL CALLED: choose_random_city", flush=True)
    cities = ["London", "Paris", "Berlin", "Dublin", "Madrid"]
    selected = random.choice(cities)
    print(f"   → Selected city: {selected}\n", flush=True)
    return selected

AVAILABLE_TOOLS = {
    "choose_random_city": choose_random_city,
}


fallback_tool_result = handle_tool_calls(agent_response, AVAILABLE_TOOLS)

if fallback_tool_result:
    print(f"📦 Fallback Tool Result: {fallback_tool_result}")
    print(f"✅ Creating structured response from tool result\n")
    destination_data = {"destination": fallback_tool_result}
    cleaned_response = json.dumps(destination_data, indent=2)

出力

> uv run watsonx_demo.py
============================================================
🚀 Starting Travel Planning Agent with Manual Tool Execution
============================================================

👤 User: I want a random city within Europe.
📍 Origin: Berlin
👋 Traveler: Maria

🔍 Detected tool call request from agent: choose_random_city

🔧 TOOL CALLED: choose_random_city
   → Selected city: Dublin

✅ Tool executed successfully
📦 Fallback Tool Result: Dublin
✅ Creating structured response from tool result

DESTINATION DATA: {'destination': 'Dublin'}

============================================================
🤖 Agent: Dublin
============================================================

============================================================
📊 USAGE STATISTICS
============================================================
Total Requests: 1
Input Tokens: 76
Output Tokens: 70
Total Tokens: 146

============================================================
✅ Demo completed successfully!
============================================================

watsonx.aiモデルクラス

@dataclass
class WatsonxAIModel(Model):
    """
    Pydantic AI Model for IBM Watsonx, supporting both non-streaming and streaming generation.
    """

    client: ModelInference
    _http_client: AsyncHTTPClient
    _model_name: str
    _endpoint: str
    _project_id: str
    _params: TextGenParameters

    def __init__(
        self,
        model_id: str = "openai/gpt-oss-120b",
        params: Optional[TextGenParameters] = None,
    ) -> None:
        """
        Initialize WatsonxAIModel using environment-based credentials.
        :param model_id: Watsonx model ID.
        :param params: Optional text generation parameters.
        :raises UserError: If required environment variables are missing.
        """
        check_allow_model_requests()
        endpoint = os.getenv("WATSONX_ENDPOINT")
        api_key = os.getenv("IBM_CLOUD_API_KEY")
        project_id = os.getenv("WATSONX_PROJECT_ID")
        if not endpoint or not api_key or not project_id:
            raise UserError(
                "Missing required environment variables: WATSONX_ENDPOINT, IBM_CLOUD_API_KEY, or WATSONX_PROJECT_ID."
            )
        credentials = Credentials(url=endpoint, api_key=api_key)
        default_params = TextGenParameters(
            temperature=0.0,
            max_new_tokens=1000,
            random_seed=42,
            decoding_method="greedy",
            min_new_tokens=1,
        )
        self.client = ModelInference(
            model_id=model_id,
            credentials=credentials,
            project_id=project_id,
            params=params or default_params,
        )
        self._http_client = cached_async_http_client()
        self._model_name = model_id
        self._endpoint = endpoint
        self._project_id = project_id
        self._params = params or default_params

    @property
    def system(self) -> str:
        """
        System identifier for Watsonx.
        """
        return "watsonx"

    @property
    def model_name(self) -> str:
        """
        Returns the Watsonx model name.
        """
        return self._model_name

    async def request(
        self,
        message_history: list[ModelMessage],
        model_settings: Any,
        model_request_parameters: Any,
    ) -> ModelResponse:
        """
        Perform non-streaming generation using Watsonx's agenerate.
        """
        check_allow_model_requests()
        prompt = self._build_prompt(message_history)
        result = await self.client.agenerate(prompt, params=self._params.to_dict())
        generated_text = result["results"][0]["generated_text"]

        # Extract JSON from the response to handle cases where model includes extra text
        cleaned_text = extract_json_from_text(generated_text)

        # Extract comprehensive metadata from Watsonx response
        input_tokens = 0
        output_tokens = 0
        stop_reason = None
        seed = None
        model_id = None
        created_at = None
        warnings = []

        # Extract from result metadata
        if "model_id" in result:
            model_id = result["model_id"]
        if "created_at" in result:
            created_at = result["created_at"]
        if "system" in result and "warnings" in result["system"]:
            warnings = result["system"]["warnings"]

        # Extract from results array
        if "results" in result and len(result["results"]) > 0:
            result_data = result["results"][0]

            # Token counts
            if "input_token_count" in result_data:
                input_tokens = result_data["input_token_count"]
            if "generated_token_count" in result_data:
                output_tokens = result_data["generated_token_count"]

            # Generation metadata
            if "stop_reason" in result_data:
                stop_reason = result_data["stop_reason"]
            if "seed" in result_data:
                seed = result_data["seed"]

        # Create RequestUsage with token counts
        usage = RequestUsage(input_tokens=input_tokens, output_tokens=output_tokens)

        # Create ModelResponse with all metadata
        response = ModelResponse(
            parts=[TextPart(content=cleaned_text)],
            model_name=model_id or self._model_name,
            timestamp=datetime.fromisoformat(created_at.replace("Z", "+00:00"))
            if created_at
            else datetime.utcnow(),
        )

        # Add usage attribute to response for pydantic_ai compatibility
        response.usage = usage  # type: ignore[attr-defined]

        # Store additional Watsonx-specific metadata as attributes
        response.stop_reason = stop_reason  # type: ignore[attr-defined]
        response.seed = seed  # type: ignore[attr-defined]
        response.warnings = warnings  # type: ignore[attr-defined]

        return response

    @asynccontextmanager
    async def request_stream(
        self,
        message_history: list[ModelMessage],
        model_settings: Any,
        model_request_parameters: Any,
    ) -> AsyncIterator[StreamedResponse]:
        """
        Execute streaming generation via Watsonx's agenerate_stream.
        """
        check_allow_model_requests()
        prompt = self._build_prompt(message_history)
        stream_generator = await self.client.agenerate_stream(
            prompt, params=self._params.to_dict()
        )
        try:
            yield WatsonxStreamedResponse(
                _model_name=self._model_name, _async_generator=stream_generator
            )
        finally:
            pass

    def _build_prompt(self, messages: list[ModelMessage]) -> str:
        """
        Constructs a Watsonx prompt from pydantic_ai messages.
        """
        lines: list[str] = []
        for msg in messages:
            if isinstance(msg, ModelRequest):
                for p in msg.parts:
                    if isinstance(p, UserPromptPart):
                        lines.append(f"{p.content}")
                    elif hasattr(p, "content") and isinstance(p.content, str):
                        lines.append(f"{p.content}")
            elif isinstance(msg, ModelResponse):
                # Handle system messages and previous responses
                for p in msg.parts:
                    if isinstance(p, TextPart):
                        lines.append(f"{p.content}")
        return "\n".join(lines)

まとめ

本記事では、PydanticAIとWatsonx.aiを連携させる方法を解説しました。

重要なポイント:

  1. カスタムモデルアダプター: Modelクラスを継承して実装
  2. プロトコル変換: Watsonx API ↔ PydanticAIメッセージ形式
  3. メタデータ管理: トークン使用量やモデル情報の保持
  4. ツール呼び出し: プロンプトエンジニアリングで手動実装

PydanticAIの抽象化により、Watsonx.aiのような独自APIを持つプロバイダーでも、統一されたインターフェースで利用できるようになります。

この実装パターンは、他のカスタムLLMプロバイダーにも応用可能です。

サンプルコード

完全なサンプルコードは、以下のリポジトリで公開しています:

Discussion