Closed9

Vertex AI で Gemini 1.5を使う 2: Function Calling

kun432kun432

Gemini を使用した関数呼び出しの概要

以下のnotebookを進める。

https://github.com/GoogleCloudPlatform/generative-ai/tree/main/gemini/function-calling/intro_function_calling.ipynb

事前準備

セットアップ回りは前回の記事を参考に、ということで詳細は割愛。

パッケージインストール

!pip install --upgrade --user --quiet google-cloud-aiplatform

ColaboratoryからGCPを使えるように認証を行う。

from google.colab import auth

auth.authenticate_user()

Vertex AIへの接続

import vertexai

PROJECT_ID="YOUR_PROJECT_ID"
REGION="asia-northeast1"

vertexai.init(
    project=PROJECT_ID,
    location=REGION
)

関数の定義

サンプルに従って、Google Storeで製品情報を取得、Geminiがこの情報を元に回答する、というのをやってみる。

まずFunction Callingでモデルに渡す関数の定義を行う。vertexai.generative_models.FunctionDeclarationを使う。

from vertexai.generative_models import FunctionDeclaration

                                        
get_product_info = FunctionDeclaration(
    name="get_product_info",
    description="指定された商品の在庫量とIDを取得する",
    parameters={
        "type": "object",
        "properties": {
            "product_name": {"type": "string", "description": "商品名"}
        },
    },
)

get_store_location = FunctionDeclaration(
    name="get_store_location",
    description="最寄りの店舗の場所を取得する",
    parameters={
        "type": "object",
        "properties": {"location": {"type": "string", "description": "場所"}},
    },
)

place_order = FunctionDeclaration(
    name="place_order",
    description="商品を注文する",
    parameters={
        "type": "object",
        "properties": {
            "product": {"type": "string", "description": "商品名"},
            "address": {"type": "string", "description": "送付先住所"},
        },
    },
)

これらの関数定義をツールとして定義する。vertexai.generative_models.Toolを使う。

from vertexai.generative_models import Tool

retail_tool = Tool(
    function_declarations=[
        get_product_info,
        get_store_location,
        place_order,
    ],
)

モデルの定義を行い、ここでツールを渡す。なお、今回はGemini 1.5 Proを使う。バージョンの表記については以下を参照。

https://ai.google.dev/gemini-api/docs/models/gemini?hl=ja#model-versions

from vertexai.generative_models import GenerationConfig, GenerativeModel

model = GenerativeModel(
    "gemini-1.5-pro-001",
    generation_config=GenerationConfig(temperature=0),
    tools=[retail_tool],
)

chat = model.start_chat()

Function Callingを使ったチャット

ではクエリに投げる。

import json

response = chat.send_message("Pixel 8 Proの在庫について教えて。")
print(json.dumps(response.to_dict(), ensure_ascii=False, indent=2))
{
  "candidates": [
    {
      "content": {
        "role": "model",
        "parts": [
          {
            "function_call": {
              "name": "get_product_info",
              "args": {
                "product_name": "Pixel 8 Pro"
              }
            }
          }
        ]
      },
      "finish_reason": "STOP",
      "safety_ratings": [
        {
          "category": "HARM_CATEGORY_HATE_SPEECH",
          "probability": "NEGLIGIBLE",
          "probability_score": 0.15022333,
          "severity": "HARM_SEVERITY_NEGLIGIBLE",
          "severity_score": 0.11183353
        },
        {
          "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
          "probability": "NEGLIGIBLE",
          "probability_score": 0.26940945,
          "severity": "HARM_SEVERITY_LOW",
          "severity_score": 0.22600734
        },
        {
          "category": "HARM_CATEGORY_HARASSMENT",
          "probability": "NEGLIGIBLE",
          "probability_score": 0.096052155,
          "severity": "HARM_SEVERITY_NEGLIGIBLE",
          "severity_score": 0.059279762
        },
        {
          "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
          "probability": "NEGLIGIBLE",
          "probability_score": 0.060781546,
          "severity": "HARM_SEVERITY_NEGLIGIBLE",
          "severity_score": 0.040733125
        }
      ]
    }
  ],
  "usage_metadata": {
    "prompt_token_count": 57,
    "candidates_token_count": 12,
    "total_token_count": 69
  }
}

response.candidates[0].content.parts[0]に"function_call"が含まれていれば、Function CallingによりGeminiがツールの使用を提案したということになる。

このレスポンスのnameargsを元に実際の関数なりAPIを叩いて結果を返すということになるのだが、このnotebookでは簡単のため仮の結果を使う。

api_response = {"sku": "GA04834-US", "in_stock": "yes"}

この結果をGeminiに送信する。前回のマルチモーダルのときにも出てきたが、vertexai.generative_models.Partはいろいろなフォーマットをモデルへ渡すための形式に変換してくれるっぽくて、関数の結果の場合はfrom_function_response()を使う。

from vertexai.generative_models import Part

response = chat.send_message(
    Part.from_function_response(
        name="get_product_info",
        response={
            "content": api_response,
        },
    ),
)
print(response.text)
Pixel 8 Proは在庫ありです。SKUはGA04834-USです。他に何か知りたいことはありますか?

続けて質問してみる。

response = chat.send_message("じゃあ Pixel 8 の方はどう?あと、カリフォルニア州マウンテンビューで試せる店はある?")
print(json.dumps(response.candidates[0].content.parts[0].to_dict(), ensure_ascii=False, indent=2))
{
  "function_call": {
    "name": "get_product_info",
    "args": {
      "product_name": "Pixel 8"
    }
  }
}

先ほどと同じ関数が提案されている。これについても仮の結果をGeminiに返すこととする。

api_response = {"sku": "GA08475-US", "in_stock": "yes"}

response = chat.send_message(
    Part.from_function_response(
        name="get_product_info",
        response={
            "content": api_response,
        },
    ),
)

print(json.dumps(response.candidates[0].content.parts[0].to_dict(), ensure_ascii=False, indent=2))
{
  "function_call": {
    "name": "get_store_location",
    "args": {
      "location": "Mountain View, CA"
    }
  }
}

別の関数が提案されている。回答を返すために複数の関数の結果が必要な場合はこのようになる。ではこれも仮の結果を返してみる。

api_response = {"store": "2000 N Shoreline Blvd, Mountain View, CA 94043, US"}

response = chat.send_message(
    Part.from_function_response(
        name="get_store_location",
        response={
            "content": api_response,
        },
    ),
)

print(response.text)
Pixel 8 も在庫ありです。SKUはGA08475-USです。カリフォルニア州マウンテンビューには、2000 N Shoreline Blvd, Mountain View, CA 94043, US にお店があります。

続けてみる。

response = chat.send_message("Pixel 8 Proを注文したいです。配送先は 1155 Borregas Ave, Sunnyvale, CA 94089 でお願いします。")
print(json.dumps(response.candidates[0].content.parts[0].to_dict(), ensure_ascii=False, indent=2))
{
  "function_call": {
    "name": "place_order",
    "args": {
      "product": "Pixel 8 Pro",
      "address": "1155 Borregas Ave, Sunnyvale, CA 94089"
    }
  }
}
api_response = {
    "payment_status": "支払済",
    "order_number": 12345,
    "est_arrival": "2日",
}

response = chat.send_message(
    Part.from_function_response(
        name="place_order",
        response={
            "content": api_response,
        },
    ),
)
response.text
注文を受け付けました。お届け予定日は2日後、注文番号は12345です。お支払いは確認済みです。
kun432kun432

マップAPIで住所をジオコーディングするためにFunction Callingを使用する

実際にAPIアクセスと連携するのを試してみる。住所から緯度経度を取得する。

ツールの定義。

get_location = FunctionDeclaration(
    name="get_location",
    description="Get latitude and longitude for a given location",
    parameters={
        "type": "object",
        "properties": {
            "poi": {"type": "string", "description": "Point of interest"},
            "street": {"type": "string", "description": "Street name"},
            "city": {"type": "string", "description": "City name"},
            "county": {"type": "string", "description": "County name"},
            "state": {"type": "string", "description": "State name"},
            "country": {"type": "string", "description": "Country name"},
            "postal_code": {"type": "string", "description": "Postal code"},
        },
    },
)

location_tool = Tool(
    function_declarations=[get_location],
)
model = GenerativeModel(
    "gemini-1.5-pro-001",
    generation_config=GenerationConfig(temperature=0),
    tools=[location_tool],
)

chat = model.start_chat()

prompt = """
以下の住所の緯度経度を教えて:
1600 Amphitheatre Pkwy, Mountain View, CA 94043, US
"""
response = chat.send_message(prompt)

print(json.dumps(response.candidates[0].content.parts[0].to_dict(), ensure_ascii=False, indent=2))

以下の応答が返ってくる。

{
  "function_call": {
    "name": "get_location",
    "args": {
      "country": "US",
      "city": "Mountain View",
      "postal_code": "94043",
      "street": "1600 Amphitheatre Pkwy",
      "state": "CA"
    }
  }
}

実際にAPIに問い合わせてみる。

import requests

x = response.candidates[0].content.parts[0].function_call.args

url = "https://nominatim.openstreetmap.org/search?"
for i in x:
    url += '{}="{}"&'.format(i, x[i])
url += "format=json"

headers = {"User-Agent": "none"}
x = requests.get(url, headers=headers)
content = x.json()

print(json.dumps(content, ensure_ascii=False, indent=2))
[
  {
    "place_id": 377680635,
    "licence": "Data © OpenStreetMap contributors, ODbL 1.0. http://osm.org/copyright",
    "osm_type": "node",
    "osm_id": 2192620021,
    "lat": "37.4217636",
    "lon": "-122.084614",
    "class": "office",
    "type": "it",
    "place_rank": 30,
    "importance": 0.6949356759210291,
    "addresstype": "office",
    "name": "Google Headquarters",
    "display_name": "Google Headquarters, 1600, Amphitheatre Parkway, Mountain View, Santa Clara County, California, 94043, United States",
    "boundingbox": [
      "37.4217136",
      "37.4218136",
      "-122.0846640",
      "-122.0845640"
    ]
  }
]

この結果をGeminiに返す。

response = chat.send_message(
    Part.from_function_response(
        name="get_location",
        response={
            "content": content,
        },
    ),
)

print(response.text)
緯度経度は、緯度37.4217636、経度-122.084614です。
kun432kun432

エンティティ抽出のみにFunction Callingを使用する

Function Callingをシンプルなエンティティ抽出だけ使う。JSONフォーマットでレスポンスが欲しい、とかそういう場合にも使えるやつ。

extract_log_data = FunctionDeclaration(
    name="extract_log_data",
    description="生ログデータのエラーメッセージから詳細を抽出する",
    parameters={
        "type": "object",
        "properties": {
            "locations": {
                "type": "array",
                "description": "エラー",
                "items": {
                    "description": "エラーの詳細",
                    "type": "object",
                    "properties": {
                        "error_message": {
                            "type": "string",
                            "description": "完全なエラーメッセージ",
                        },
                        "error_code": {"type": "string", "description": "エラーコード"},
                        "error_type": {"type": "string", "description": "エラー種別"},
                    },
                },
            }
        },
    },
)

extraction_tool = Tool(
    function_declarations=[extract_log_data],
)

prompt = """
[15:43:28] ERROR: Could not process image upload: Unsupported file format. (Error Code: 308)
[15:44:10] INFO: Search index updated successfully.
[15:45:02] ERROR: Service dependency unavailable (payment gateway). Retrying... (Error Code: 5522)
[15:45:33] ERROR: Application crashed due to out-of-memory exception. (Error Code: 9001)
"""

response = model.generate_content(
    prompt,
    generation_config=GenerationConfig(temperature=0),
    tools=[extraction_tool],
)

print(json.dumps(response.to_dict()["candidates"][0]["content"]["parts"][0]["function_call"], ensure_ascii=False, indent=2))
{
  "name": "extract_log_data",
  "args": {
    "locations": [
      {
        "error_message": "Could not process image upload: Unsupported file format.",
        "error_type": "ERROR",
        "error_code": "308"
      },
      {
        "error_code": "5522",
        "error_type": "ERROR",
        "error_message": "Service dependency unavailable (payment gateway). Retrying..."
      },
      {
        "error_message": "Application crashed due to out-of-memory exception.",
        "error_code": "9001",
        "error_type": "ERROR"
      }
    ]
  }
}
kun432kun432

マルチステップのFunction Calling

以下と同じ例で。

https://zenn.dev/kun432/scraps/18d2b102faea9b

# ダミーのデータベース
product_list = {
    'スマートフォン': 'E1001',
    'ノートパソコン': 'E1002',
    'タブレット': 'E1003',
    'Tシャツ': 'C1001',
    'ジーンズ':'C1002',
    'ジャケット': 'C1003',
}

product_catalog = {
    'E1001': {'price': 500, 'stock_level': 20},
    'E1002': { 'price': 1000, 'stock_level': 15},
    'E1003': {'price': 300, 'stock_level': 25},
    'C1001': {'price': 20, 'stock_level': 100},
    'C1002': {'price': 50, 'stock_level': 80},
    'C1003': {'price': 100, 'stock_level': 40},
}


def get_product_id_from_product_name(product_name: str) -> dict:
    return {"product_name": product_name, "product_id": product_list[product_name]}


def get_product_info_from_product_id(product_id: str) -> dict:
    return {"product_id": product_id, "product_info": product_catalog[product_id]}


get_product_id = FunctionDeclaration(
    name="get_product_id_from_product_name",
    description="「商品名」から「商品ID」を取得する。",
    parameters={
        "type": "object",
        "properties": {
            "product_name": {
                "type": "string",
                "description": "「商品ID」を取得するための「商品名」を指定する。「商品名」は一般名詞で指定する必要がある。ex. タブレット、ジャケット、等。"
            },
        },
    },
)

get_product_info = FunctionDeclaration(
    name="get_product_info_from_product_id",
    description="「商品ID」から「商品情報(価格、在庫)」を取得する。",
    parameters={
        "type": "object",
        "properties": {
            "product_id": {
                "type": "string",
                "description": "「商品情報(価格、在庫)」を取得するための「商品ID」を指定する。「商品ID」は [A-Z]{1}[0-9]{3} で指定すること。ex. X0002、等。"
            },
        },
    },
)

tool_to_function_map = {
    "get_product_id_from_product_name": get_product_id_from_product_name,
    "get_product_info_from_product_id": get_product_info_from_product_id,
}

retail_tool = Tool(
    function_declarations=[
        get_product_id,
        get_product_info,
    ],
)

model = GenerativeModel(
    "gemini-1.5-pro-001",
    generation_config=GenerationConfig(temperature=0),
    tools=[retail_tool],
)
chat = model.start_chat()

res = chat.send_message("タブレットの在庫を調べて。")
response = res.candidates[0].content.parts[0]

while True:
    if response.function_call:
        tool_name = response.function_call.name
        tool_params = {key: value for key, value in response.function_call.args.items()}
        function_to_call = tool_to_function_map[tool_name]
        tool_result = function_to_call(**tool_params)
        print("name: ", tool_name)
        print("params: ", tool_params)
        print("result: ", tool_result)
        print("----")
        res = chat.send_message(
            Part.from_function_response(
                name=tool_name,
                response={"content": tool_result}
            )
        )
        response = res.candidates[0].content.parts[0]
    else:
        break

print(response.text)
name:  get_product_id_from_product_name
params:  {'product_name': 'タブレット'}
result:  {'product_name': 'タブレット', 'product_id': 'E1003'}
----
name:  get_product_info_from_product_id
params:  {'product_id': 'E1003'}
result:  {'product_id': 'E1003', 'product_info': {'price': 300, 'stock_level': 25}}
----
タブレットの在庫は25個です。
kun432kun432

動きを見ているとParallelに動くってのはない感じがする。

https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/function-calling

「ニューデリーとサンフランシスコの天気の詳細情報を取得しますか?」などのプロンプトでは、モデルが複数の並列関数呼び出しを提案する場合があります。並列関数呼び出しはプレビュー版の機能です。Gemini 1.5 Pro モデルと Gemini 1.5 Flash モデルでサポートされています。詳細については、並列関数呼び出しの例をご覧ください。

あぁ、プレビュー版ならできるのか

このスクラップは2ヶ月前にクローズされました