🌟

OpenAI Streaming + Function callingの並列実行に対応する

2024/02/05に公開

概要

OpenAI Python library を使ってFunction callingの並列実行に対応した実装を紹介する記事です。

ユーザーへの返答はStreamingで返答します。

私は最近LLMを使ったアプリケーション開発案件に関わっていますが、BtoB・BtoCに関わらずユーザー体験の観点からStreaming対応が必要になることが多いです。

その為、この記事のサンプルコードは実際のプロダクトにも応用できると思います。

対象読者

OpenAIを使ったアプリケーション開発の概要を理解している方が対象となります。

筆者のバックグラウンド

普段はTypeScript(Next.jsを主に利用)を用いたフロントエンド開発者です。

以前はバックエンドエンジニアで主にGoを使ってAWS上でAPIの開発などを行なっていました。

最近はLLMを用いたアプリケーション開発に関わっています。

Pythonにも少しは慣れてきました。

並列実行に対応したFunction callingの設定方法

最初にコードの全体を紹介します。以下の通りです。

generate_message_for_guest_user メソッドがメインの処理となります。(このコードを見て実装方法を理解できた人はこれ以降の章を読む必要はありません。)

import os
import math
import httpx
import json
from datetime import datetime
from zoneinfo import ZoneInfo
from typing import cast, List, TypedDict, Union
from collections.abc import AsyncIterator
from openai import AsyncOpenAI, AsyncStream
from openai.types.chat import (
    ChatCompletionMessageParam,
    ChatCompletionChunk,
    ChatCompletionToolParam,
    ChatCompletionMessageToolCall,
)
from domain.repository.cat_message_repository_interface import (
    CatMessageRepositoryInterface,
    GenerateMessageForGuestUserDto,
    GenerateMessageForGuestUserResult,
)


class FetchCurrentWeatherResponse(TypedDict):
    city_name: str
    description: str
    temperature: int


class GetCurrentDatetimeResponse(TypedDict):
    current_datetime: str


class OpenAiCatMessageRepository(CatMessageRepositoryInterface):
    def __init__(self) -> None:
        self.OPENAI_API_KEY = os.environ["OPENAI_API_KEY"]
        self.OPEN_WEATHER_API_KEY = os.environ["OPEN_WEATHER_API_KEY"]
        self.client = AsyncOpenAI(api_key=self.OPENAI_API_KEY)

    async def generate_message_for_guest_user(
        self, dto: GenerateMessageForGuestUserDto
    ) -> AsyncIterator[GenerateMessageForGuestUserResult]:
        messages = cast(List[ChatCompletionMessageParam], dto.get("chat_messages"))
        user = str(dto.get("user_id"))

        regenerated_messages = (
            await self._might_regenerate_messages_contain_tools_results_exec(
                dto,
                messages,
            )
        )

        response = await self.client.chat.completions.create(
            model="gpt-3.5-turbo-1106",
            messages=regenerated_messages,
            stream=True,
            temperature=0.7,
            user=user,
        )

        async for generated_response in self._extract_chat_chunks(response):
            yield generated_response

    # 必要に応じてtoolsを実行してメッセージのリストにtoolsの実行結果を含めて再生成する
    async def _might_regenerate_messages_contain_tools_results_exec(
        self,
        dto: GenerateMessageForGuestUserDto,
        messages: List[ChatCompletionMessageParam],
    ) -> List[ChatCompletionMessageParam]:
        tools = [
            {
                "type": "function",
                "function": {
                    "name": "fetch_current_weather",
                    "description": "指定された都市の現在の天気を取得する。(日本の都市の天気しか取得出来ない)",
                    "parameters": {
                        "type": "object",
                        "properties": {
                            "city_name": {
                                "type": "string",
                                "description": "英語表記の日本の都市名",
                            }
                        },
                        "required": ["city_name"],
                    },
                },
            },
            {
                "type": "function",
                "function": {
                    "name": "get_current_datetime_in_iso_format",
                    "description": "指定されたタイムゾーンの現在日時をISO 8601形式で返す。",
                    "parameters": {
                        "type": "object",
                        "properties": {
                            "timezone": {
                                "type": "string",
                                "description": "タイムゾーン名: 例: Asia/Tokyo, UTC, America/New_York",
                            }
                        },
                        "required": ["timezone"],
                    },
                },
            },
        ]
        tools_params = cast(List[ChatCompletionToolParam], tools)

        copied_messages = messages.copy()

        system_prompt = """
        あなたの役割は与えられた会話履歴からtoolsの利用が必要かどうか判断する事です。
        JSONのキーはuse_toolsとしてください。
        toolsの利用が必要な場合はtrue,不要な場合はfalseを返します。
        """

        copied_messages[0] = {
            "role": "system",
            "content": system_prompt,
        }

        response = await self.client.chat.completions.create(
            model="gpt-3.5-turbo-1106",
            messages=copied_messages,
            temperature=0.7,
            user=str(dto.get("user_id")),
            tools=tools_params,
            tool_choice="auto",
            response_format={"type": "json_object"},
        )

        tool_response_messages = []
        if response.choices[0].finish_reason == "tool_calls":
            tool_calls = response.choices[0].message.tool_calls

            if tool_calls is None:
                return messages

            for tool_call in tool_calls:
                tool_call_response = await self._might_call_tool(tool_call)
                if tool_call_response is not None:
                    tool_response_messages.append(
                        {
                            "tool_call_id": tool_call.id,
                            "role": "tool",
                            "content": json.dumps(
                                tool_call_response, ensure_ascii=False
                            ),
                        }
                    )
            # tools(Function calling等)の実行結果を含めて再生成したメッセージのリストを返す
            regenerated_messages = [
                *messages,
                response.choices[0].message,
                *tool_response_messages,
            ]

            return cast(List[ChatCompletionMessageParam], regenerated_messages)

        # ここに来たという事はtoolsの実行が必要ないという事なので、引数で渡されたmessagesをそのまま返す
        return messages

    async def _might_call_tool(
        self, tool_call: ChatCompletionMessageToolCall
    ) -> Union[None, FetchCurrentWeatherResponse, GetCurrentDatetimeResponse]:
        if tool_call.type == "function":
            return await self._might_call_function(tool_call)

    async def _might_call_function(
        self,
        tool_call: ChatCompletionMessageToolCall,
    ) -> Union[None, FetchCurrentWeatherResponse, GetCurrentDatetimeResponse]:
        if tool_call.function.name == "fetch_current_weather":
            function_arguments = json.loads(tool_call.function.arguments)
            city_name = function_arguments["city_name"]
            return await self._fetch_current_weather(city_name)

        if tool_call.function.name == "get_current_datetime_in_iso_format":
            function_arguments = json.loads(tool_call.function.arguments)
            timezone = function_arguments["timezone"]
            return await self._get_current_datetime_in_iso_format(timezone)

        return None

    async def _fetch_current_weather(
        self, city_name: str = "Tokyo"
    ) -> FetchCurrentWeatherResponse:
        async with httpx.AsyncClient() as client:
            geocoding_response = await client.get(
                "http://api.openweathermap.org/geo/1.0/direct",
                params={
                    "q": city_name + ",jp",
                    "limit": 1,
                    "appid": self.OPEN_WEATHER_API_KEY,
                },
            )
            geocoding_list = geocoding_response.json()
            geocoding = geocoding_list[0]
            lat, lon = geocoding["lat"], geocoding["lon"]

            current_weather_response = await client.get(
                "https://api.openweathermap.org/data/2.5/weather",
                params={
                    "lat": lat,
                    "lon": lon,
                    "units": "metric",
                    "lang": "ja",
                    "appid": self.OPEN_WEATHER_API_KEY,
                },
            )
            current_weather = current_weather_response.json()

            return {
                "city_name": city_name,
                "description": current_weather["weather"][0]["description"],
                "temperature": math.floor(current_weather["main"]["temp"]),
            }

    @staticmethod
    async def _get_current_datetime_in_iso_format(
        timezone: str,
    ) -> GetCurrentDatetimeResponse:
        current_datetime = datetime.now(ZoneInfo(timezone))

        return {
            "current_datetime": current_datetime.isoformat(),
        }

    @staticmethod
    async def _extract_chat_chunks(
        async_stream: AsyncStream[ChatCompletionChunk],
    ) -> AsyncIterator[GenerateMessageForGuestUserResult]:
        ai_response_id = ""
        async for chunk in async_stream:
            chunk_message: str = (
                chunk.choices[0].delta.content
                if chunk.choices[0].delta.content is not None
                else ""
            )

            if ai_response_id == "":
                ai_response_id = chunk.id

            if chunk_message == "":
                continue

            chunk_body: GenerateMessageForGuestUserResult = {
                "ai_response_id": ai_response_id,
                "message": chunk_message,
            }

            yield chunk_body

コードはGitHubで公開しています。

https://github.com/nekochans/ai-cat-api/blob/main/src/infrastructure/repository/openai/openai_cat_message_repository.py

以下はフロントエンド側のコードです。(Next.jsのAppRouterで作っています)

StreamingでPythonとFastAPIで作成したAPIサーバーからの返答をStreamingで表示させています。

https://github.com/nekochans/ai-cat-frontend/blob/main/src/app/chat/_components/ChatContent/ChatContent.tsx

これらのコードは私の個人開発サービスである AI Meow Cat の物です。

以下のページから動作確認が可能です。

天気や今の時刻を尋ねるような質問をするとFunction callingを使って答えを返してくれます。

「東京と大阪の天気を教えて!あと今の時刻も教えて欲しい!」などの複数の関数実行が必要な質問にも一回で答えられるようになっています。

https://www.ai-meow-cat.com/chat/moko

処理全体の流れ

以下の通りです。後ほど詳しく解説します。

  1. クライアントが generate_message_for_guest_userメソッドをコール。
  2. generate_message_for_guest_userメソッドは必要に応じてツールの実行やツールの実行結果を含めたメッセージを生成する
    _might_regenerate_messages_contain_tools_results_execメソッドをコール。
  3. _might_regenerate_messages_contain_tools_results_execメソッドは、OpenAIのチャットAPIを使用して、システムプロンプトに基づいたツールの呼び出し判定を実施。
  4. 必要に応じて、ツールの呼び出し(天気取得や現在時刻取得など)が _might_call_toolメソッドと _might_call_functionメソッドを通じて処理される。
  5. 最終的なメッセージが再構成され、OpenAIのチャットAPIを通じてユーザーメッセージとともに送信。
  6. 返されたチャットストリームは _extract_chat_chunksメソッドによって処理され、クライアントにチャットメッセージが提供される。

toolsの設定箇所の解説

toolsの設定箇所を解説します。対象箇所は以下の通りです。

Function callingの設定より1階層ネストが深くなっています。

ここで設定しているのは以下の2つです。

  • 指定された都市の天気を取得する関数
  • 現在時刻をISO 8601形式で返す関数

関数名の命名はとても重要です、主に以下の点を明確にしないとLLMが期待したとおりに関数を使ってくれません。

  • 関数名は明確に何をやっているかわかりやすい名前にする
  • descriptionには制約条件などがある場合は記載する
  • 必要に応じて引数の例等も記載する
        tools = [
            {
                "type": "function",
                "function": {
                    "name": "fetch_current_weather",
                    "description": "指定された都市の現在の天気を取得する。(日本の都市の天気しか取得出来ない)",
                    "parameters": {
                        "type": "object",
                        "properties": {
                            "city_name": {
                                "type": "string",
                                "description": "英語表記の日本の都市名",
                            }
                        },
                        "required": ["city_name"],
                    },
                },
            },
            {
                "type": "function",
                "function": {
                    "name": "get_current_datetime_in_iso_format",
                    "description": "指定されたタイムゾーンの現在日時をISO 8601形式で返す。",
                    "parameters": {
                        "type": "object",
                        "properties": {
                            "timezone": {
                                "type": "string",
                                "description": "タイムゾーン名: 例: Asia/Tokyo, UTC, America/New_York",
                            }
                        },
                        "required": ["timezone"],
                    },
                },
            },
        ]
        tools_params = cast(List[ChatCompletionToolParam], tools)

        copied_messages = messages.copy()

        system_prompt = """
        あなたの役割は与えられた会話履歴からtoolsの利用が必要かどうか判断する事です。
        JSONのキーはuse_toolsとしてください。
        toolsの利用が必要な場合はtrue,不要な場合はfalseを返します。
        """

        copied_messages[0] = {
            "role": "system",
            "content": system_prompt,
        }

        response = await self.client.chat.completions.create(
            model="gpt-3.5-turbo-1106",
            messages=copied_messages,
            temperature=0.7,
            user=str(dto.get("user_id")),
            tools=tools_params,
            tool_choice="auto",
            response_format={"type": "json_object"},
        )

余談ですが、今後は "type": "function" 以外にも別のツールが利用可能となっていくと予想されます。

Beta版の Assistants API だとKnowledge Retrieval(独自文書の追加)などが利用できるのでChatAPIでも利用可能になるかもしれません。

toolsの利用可否を判定する

この部分の処理を解説します。

  1. クライアントが generate_message_for_guest_userメソッドをコール。
  2. generate_message_for_guest_userメソッドは必要に応じてツールの実行やツールの実行結果を含めたメッセージを生成する
    _might_regenerate_messages_contain_tools_results_execメソッドをコール。
  3. _might_regenerate_messages_contain_tools_results_execメソッドは、OpenAIのチャットAPIを使用して、システムプロンプトに基づいたツールの呼び出し判定を実施。

さきほど紹介したtoolsの設定を行ないLLMにリクエストを実施しています。

最初のリクエストの目的はFunction callingを含めたツールの実行が必要かどうかを判定して必要ならツール(関数)の実行を行ない実行結果とともに再生成したメッセージを返します。

該当箇所は以下の通りです。

ちなみにエンドユーザーが送信したメッセージは「もこちゃん🐱こんにちは東京の天気を教えて欲しいのだ🐱」です。

その結果 tool_calls の中身は以下のようになっています。

エンドユーザーの質問に答えるために複数の関数の実行を要求していることがわかります。

[
  {
    "id": "call_xxxxxxxxxxxxxxxxxxxxxxxx",
    "function": {
      "arguments": "{\"city_name\": \"Tokyo\"}",
      "name": "fetch_current_weather"
    },
    "type": "function"
  },
  {
    "id": "call_yyyyyyyyyyyyyyyyyyyyyyyy",
    "function": {
      "arguments": "{\"city_name\": \"Yokohama\"}",
      "name": "fetch_current_weather"
    },
    "type": "function"
  },
  {
    "id": "call_zzzzzzzzzzzzzzzzzzzzzzzz",
    "function": {
      "arguments": "{\"timezone\": \"Asia/Tokyo\"}",
      "name": "get_current_datetime_in_iso_format"
    },
    "type": "function"
  }
]

この要求を元に関数の実行を行ない、元々渡ってきた messages のリストに response.choices[0].message とツールの実行結果が含まれる tool_response_messages を含めます。

この regenerated_messages を用いて再度LLMにリクエストを送信することでLLMがツールの実行結果を見て回答内容を生成してくれます。(この部分は後ほど解説します)

        tool_response_messages = []
        if response.choices[0].finish_reason == "tool_calls":
            tool_calls = response.choices[0].message.tool_calls

            if tool_calls is None:
                return messages

            for tool_call in tool_calls:
                tool_call_response = await self._might_call_tool(tool_call)
                if tool_call_response is not None:
                    tool_response_messages.append(
                        {
                            "tool_call_id": tool_call.id,
                            "role": "tool",
                            "content": json.dumps(
                                tool_call_response, ensure_ascii=False
                            ),
                        }
                    )
            # tools(Function calling等)の実行結果を含めて再生成したメッセージのリストを返す
            regenerated_messages = [
                *messages,
                response.choices[0].message,
                *tool_response_messages,
            ]

            return cast(List[ChatCompletionMessageParam], regenerated_messages)

        # ここに来たという事はtoolsの実行が必要ないという事なので、引数で渡されたmessagesをそのまま返す
        return messages

必要な場合toolsの実行を行なう

この部分の処理の解説です。

  1. 必要に応じて、ツールの呼び出し(天気取得や現在時刻取得など)が _might_call_toolメソッドと _might_call_functionメソッドを通じて処理される

今回OpenAIにtoolsとして渡している関数は以下の2つです。

それぞれ天気を取得する関数と現在時刻を取得する関数です。

天気を取得する関数は OpenWeather のAPIを使って実際の天気情報を取得しています。

    async def _fetch_current_weather(
        self, city_name: str = "Tokyo"
    ) -> FetchCurrentWeatherResponse:
        async with httpx.AsyncClient() as client:
            geocoding_response = await client.get(
                "http://api.openweathermap.org/geo/1.0/direct",
                params={
                    "q": city_name + ",jp",
                    "limit": 1,
                    "appid": self.OPEN_WEATHER_API_KEY,
                },
            )
            geocoding_list = geocoding_response.json()
            geocoding = geocoding_list[0]
            lat, lon = geocoding["lat"], geocoding["lon"]

            current_weather_response = await client.get(
                "https://api.openweathermap.org/data/2.5/weather",
                params={
                    "lat": lat,
                    "lon": lon,
                    "units": "metric",
                    "lang": "ja",
                    "appid": self.OPEN_WEATHER_API_KEY,
                },
            )
            current_weather = current_weather_response.json()

            return {
                "city_name": city_name,
                "description": current_weather["weather"][0]["description"],
                "temperature": math.floor(current_weather["main"]["temp"]),
            }

    @staticmethod
    async def _get_current_datetime_in_iso_format(
        timezone: str,
    ) -> GetCurrentDatetimeResponse:
        current_datetime = datetime.now(ZoneInfo(timezone))

        return {
            "current_datetime": current_datetime.isoformat(),
        }

toolsを呼び出している箇所は以下の通りです。

_might_call_tool から _might_call_function をコールするようにしています。

理由は今後toolsにFunction calling以外の機能が利用可能になった際に対応しやすくするためです。

    async def _might_call_tool(
        self, tool_call: ChatCompletionMessageToolCall
    ) -> Union[None, FetchCurrentWeatherResponse, GetCurrentDatetimeResponse]:
        if tool_call.type == "function":
            return await self._might_call_function(tool_call)

    async def _might_call_function(
        self,
        tool_call: ChatCompletionMessageToolCall,
    ) -> Union[None, FetchCurrentWeatherResponse, GetCurrentDatetimeResponse]:
        if tool_call.function.name == "fetch_current_weather":
            function_arguments = json.loads(tool_call.function.arguments)
            city_name = function_arguments["city_name"]
            return await self._fetch_current_weather(city_name)

        if tool_call.function.name == "get_current_datetime_in_iso_format":
            function_arguments = json.loads(tool_call.function.arguments)
            timezone = function_arguments["timezone"]
            return await self._get_current_datetime_in_iso_format(timezone)

        return None

toolsの実行結果を会話履歴に含めて再度リクエストを送信する

この部分の解説になります。

  1. 最終的なメッセージが再構成され、OpenAIのチャットAPIを通じてユーザーメッセージとともに送信。
  2. 返されたチャットストリームは _extract_chat_chunksメソッドによって処理され、クライアントにチャットメッセージが提供される。

関数の実行結果を含めた regenerated_messages を利用して再度LLMにリクエストを行ないます。

ここはStreamingで結果を生成して欲しいので stream=True を指定します。

_extract_chat_chunks でデータの中身を解析して処理しています。

        regenerated_messages = (
            await self._might_regenerate_messages_contain_tools_results_exec(
                dto,
                messages,
            )
        )

        response = await self.client.chat.completions.create(
            model="gpt-3.5-turbo-1106",
            messages=regenerated_messages,
            stream=True,
            temperature=0.7,
            user=user,
        )

        async for generated_response in self._extract_chat_chunks(response):
            yield generated_response

    @staticmethod
    async def _extract_chat_chunks(
        async_stream: AsyncStream[ChatCompletionChunk],
    ) -> AsyncGenerator[GenerateMessageForGuestUserResult, None]:
        ai_response_id = ""
        async for chunk in async_stream:
            chunk_message: str = (
                chunk.choices[0].delta.content
                if chunk.choices[0].delta.content is not None
                else ""
            )

            if ai_response_id == "":
                ai_response_id = chunk.id

            if chunk_message == "":
                continue

            chunk_body: GenerateMessageForGuestUserResult = {
                "ai_response_id": ai_response_id,
                "message": chunk_message,
            }

            yield chunk_body

今後の改善点

応答時間の改善

最初にtoolsの利用可否を判断するためにLLMにリクエストを送っている理由ですが、stream=True でtoolsを設定すると以下のように function.arguments の値が常に空文字になってしまう現象が発生しました。

[
  {
    "id": "call_xxxxxxxxxxxxxxxxxxxxxxxx",
    "function": {
      "arguments": "",
      "name": "fetch_current_weather"
    },
    "type": "function"
  },
  {
    "id": "call_yyyyyyyyyyyyyyyyyyyyyyyy",
    "function": {
      "arguments": "",
      "name": "fetch_current_weather"
    },
    "type": "function"
  },
  {
    "id": "call_zzzzzzzzzzzzzzzzzzzzzzzz",
    "function": {
      "arguments": "",
      "name": "get_current_datetime_in_iso_format"
    },
    "type": "function"
  }
]

一旦今の実装にしましたが、今の実装は関数の利用が必要ない場合でもLLMへのリクエストが2回発生しているので、その分レスポンス速度が遅くなっています。

この問題は詳しく調べてみないと解消できるかわかりませんが、解消できればLLMへのリクエストを一回減らせるのでやってみる価値はあると思っています。

tools実行回数が増加によるパフォーマンス低下に備える

実装を見るとわかるようにtoolsの実行を順次処理で行なっているので、関数の実行要求数が多いとそれだけ応答時間が長くなってしまいます。

また関数の実行時間が長くなるとサーバーにも負荷がかかってしまいます。

今後も実行可能なtoolsが増える場合は concurrent.futures.ProcessPoolExecutor 等でマルチプロセスで実行するようにするなどの対処法も必要になってくる可能性があります。

ただしエラーハンドリングも複雑化するので場合によっては一度に実行可能なtoolsの数を条件分岐やプロンプトで制御するという方向性も考えられます。

どちらにせよ、実行可能なtoolsが増える場合はこの部分を考えていく必要があると思っています。

エラーハンドリングとエラーレポート

現状でも以下のように最低限スタックトレース付きのログを出すようにしていますが、toolsの実行中に例外が発生した場合、独自エラーをThrowするように改修するなどをしてエラーの発生箇所を簡単に特定できるようにしていきたいです。

その際はLLMから返ってきたtoolsの呼び出し要求情報もログに出力するようにすると調査がしやすくなるのでそのあたりも対応する予定です。

おわりに

以上がtoolsを利用したFunction callingの並列実行の解説になります。

多様すると応答時間が長くなってしまうなどのデメリットもありますが、簡単にLLMに外部知識を与えることができるのは大きなメリットだと思います。
簡単なAIエージェントならtoolsで十分に実装可能です。

今後も利用可能なtoolsが増えることが予想されるので今後の発展が楽しみです。

以上になります、最後まで読んでいただきありがとうございました。最後になりましたが、この記事を書く際に以下の記事を参考にさせていただきましたので、ここに記載しておきます。

Discussion