🤖

Gemini API の Function Calling 機能で LLM Agent を実装する

2024/06/28に公開

LLM Agent 入門

データ処理パイプラインと LLM Agent の違い

Google Cloud の Gemini API には Function Calling 機能が実装されており、基盤モデルの Gemini に「外部 API を利用して回答に必要な情報を収集する」という動作が追加できます。ここでポイントになるのは、「どの API をどのように使用すれば回答に必要な情報が得られるか?」という部分を Gemini 自身に考えさせるという点です。これを利用すると、いわゆる LLM Agent が実装できます。

集めるべき情報の種類や処理の手順があらかじめ決まっている場合は、LLM によるテキスト生成を組み込んだデータ処理パイプラインを実装する方が安定的に動作する(期待する結果が確実に得られる)はずですが、特定の手順を前提としない柔軟な処理を実現する際は LLM Agent が向いているでしょう。

たとえば、都道府県名を選ぶとその地域の天候を短いテキストにまとめて教えてくれるサービスを作るのであれば、

  • 都道府県名を受け取る
     ↓
  • 「天候情報サービス API」で該当地域の天候情報を取得する
     ↓
  • 得られた情報を LLM に渡して天候状況をまとめたテキストを生成する

という一連の処理をパイプラインとして実装すればよいでしょう。

一方、「東京と京都の現在の気温差を知りたい」など、天候に関するさまざまな質問に柔軟に回答するシステムを作る場合、質問内容に応じて、API サービスから取得する情報を切り替えたり、複数の API サービスを組み合わせて利用する必要があります。これを実現するのが LLM Agent の役割です。

LLM Agent を手動で実現してみる

LLM Agent の動作を理解するために、シングルターンの LLM を利用して、「東京と京都の現在の気温差を知りたい」という質問に答える LLM Agent の動作を擬似的に再現してみます。以下は、Cloud Console の Vertex AI Studio で、gemini-1.5-pro-001 を使って実際に得られた結果です。

まず、次のプロンプトを入力して、応答を得ます。プロンプト内で利用可能な API サービスを指定して、必要な API を LLM に選択させようと目論んでいます。

プロンプト1

# 依頼事項
質問に答えてください。
ツールを使用する必要がある場合は、ツール名と入力データを答えてください。
マルチターンで回答します。ツールは1度に1つしか使えません。

[ツール]
ツール名:気象情報API
→ 機能:{location: '都市名', data_type: '気温' or '湿度'} を入力すると、指定した都市の現在の気温、または、湿度が得られます。

ツール名:交通情報API
→ 機能:{location: '都市名'} を入力すると、指定した都市の現在の交通情報が得られます。

[質問]
東京と京都の現在の気温差を教えてください。

# あなたの回答1

次の応答が得られます。「気象情報 API」と「交通情報 API」がありますが、気象情報 APIを正しく選択したようです。

東京と京都の現在の気温を知る必要がありますね。 まず、気象情報APIを使って東京の気温を調べましょう。

使うツール:気象情報API 入力データ:{location: '東京', data_type: '気温'}

ただし、LLM 自身はこの API サービスを呼び出すことはできませんので、この部分は人間が手伝います。上記のデータ {location: '東京', data_type: '気温'} を架空の気象情報 API に入力して、{location: '東京', 気温:'32度'} という結果が得られたとします。

そして、先ほどのプロンプトに LLM の応答と API サービスから得られた結果を付け加えたものを新しいプロンプトとして、LLM に入力します。

プロンプト2

# 依頼事項
質問に答えてください。
ツールを使用する必要がある場合は、ツール名と入力データを答えてください。
マルチターンで回答します。ツールは1度に1つしか使えません。

[ツール]
ツール名:気象情報API
→ 機能:{location: '都市名', data_type: '気温' or '湿度'} を入力すると、指定した都市の現在の気温、または、湿度が得られます。

ツール名:交通情報API
→ 機能:{location: '都市名'} を入力すると、指定した都市の現在の交通情報が得られます。

[質問]
東京と京都の現在の気温差を教えてください。

# あなたの回答1
東京と京都の現在の気温を知る必要がありますね。 まず、気象情報APIを使って東京の気温を調べましょう。

使うツール:気象情報API 入力データ:{location: '東京', data_type: '気温'}

# ツールの応答1
{location: '東京', 気温:'32度'}

# あなたの回答2

LLM は、このプロンプトに含まれるこれまでのやりとりを勘案して、次のステップを応答として返してくれます。

東京の気温は32度ですね。次に、京都の気温を調べましょう。

使うツール: 気象情報API 入力データ: {location: '京都', data_type: '気温'}

先ほどと同様に、API サービスを呼び出す部分は人間が手伝います。上記のデータ {location: '京都', data_type: '気温'} を架空の気象情報 API に入力して、{location: '京都', 気温:'28度'} という結果が得られたとします。先ほどのプロンプトに LLM の応答と API サービスの出力をさらに付け加えて、再度、LLM を呼び出します。

プロンプト3

# 依頼事項
質問に答えてください。
ツールを使用する必要がある場合は、ツール名と入力データを答えてください。
マルチターンで回答します。ツールは1度に1つしか使えません。

[ツール]
ツール名:気象情報API
→ 機能:{location: '都市名', data_type: '気温' or '湿度'} を入力すると、指定した都市の現在の気温、または、湿度が得られます。

ツール名:交通情報API
→ 機能:{location: '都市名'} を入力すると、指定した都市の現在の交通情報が得られます。

[質問]
東京と京都の現在の気温差を教えてください。

# あなたの回答1
東京と京都の現在の気温を知る必要がありますね。 まず、気象情報APIを使って東京の気温を調べましょう。

使うツール:気象情報API 入力データ:{location: '東京', data_type: '気温'}

# ツールの応答1
{location: '東京', 気温:'32度'}

# あなたの回答2
東京の気温は32度ですね。次に、京都の気温を調べましょう。

使うツール: 気象情報API 入力データ: {location: '京都', data_type: '気温'}

# ツールの応答2
{location: '京都', 気温:'28度'}

# あなたの回答3

LLM の応答は次の様になります。無事に質問の回答が得られました。

東京の気温は32度、京都の気温は28度なので、その差は4度です。

東京と京都の現在の気温差は 4度 です。

この例では、気象情報 API しか使いませんでしたが、天候と交通の両方に関わる質問であれば、気象情報 API と交通情報 API を順番に利用する流れになることは容易に想像できるでしょう。

また、この例では、LLM が指定した API サービスを実行する部分は人間が手動で行いましたが、LLM の応答をパースして該当の API を呼び出す処理を自動実行するコードを書くのはそれほど難しくはないでしょう。上記の一連の作業をループで回して自動化することもできるはずです。

ただし、LLM の応答を安定的にパースするには、LLM の出力形式を固定するための指示をプロンプトに作り込む必要があります。また、複数の LLM Agent を作成・メンテナンスすることを考えると、プロンプト内にツールの機能を記述する際のフォーマットも標準化する必要がありそうです。

—— そこで登場するのが Function Calling です。

Function Calling の役割

Function Calling を利用すると、主に次のことが可能になります。

  • 利用する API サービスの仕様を OpenAPI 3.0 の標準フォーマットで記述する
  • LLM の応答をパースして、API サービスの使用を要求しているかどうかを自動で判別する
  • LLM が指定した API サービスへの入力データを構造化データとして取り出す

前述の作業をループで回す部分は Function Calling だけでは実現できず、この部分は普通にコードとして実装する必要がありますが、これらの機能があれば、コードの実装はかなり楽になりそうです。

また、先ほどの例では、プロンプトに新しい内容を追加していきましたが、Gemini API の「チャットプロンプト」を利用するとこの部分も簡単になります。追加部分だけを入力すれば、自動的に過去のやり取りを記録したプロンプトに新しい内容が追加されていきます。

Function Calling を用いた LLM Agent の実装例

NYC TLC Trips データセットの検索 Agent

BigQuery のオープンデータセットに含まれる NYC TLC Trips データセットを使って、タクシーの利用客に関する質問に答える LLM Agent を作成します。NYC TLC Trips データセットの説明は次のとおりですが、ここでは、2022 年の Yellow taxi に関する情報に限定して使用します。

This dataset is collected by the NYC Taxi and Limousine Commission (TLC) and includes trip records from all trips completed in Yellow and Green taxis in NYC, and all trips in for-hire vehicles (FHV) in the last 5 years. Records include fields capturing pick-up and drop-off dates/times, pick-up and drop-off locations, trip distances, itemized fares, rate types, payment types, and driver-reported passenger counts. For detailed information about this dataset, go to TOC Trip Record Data

実際の使用例は次のようになります。

Input

question = 'チップがたくさんもらえる場所は?'
response = ask_bigquery(question)
print('回答\n', response)

Output

** クエリ **
SELECT taxi_zone_geom.zone_name, AVG(tip_amount) AS average_tip FROM bigquery-public-data.new_york_taxi_trips.tlc_yellow_trips_2022 INNER JOIN bigquery-public-data.new_york_taxi_trips.taxi_zone_geom ON tlc_yellow_trips_2022.pickup_location_id = taxi_zone_geom.zone_id GROUP BY taxi_zone_geom.zone_name ORDER BY average_tip DESC LIMIT 1

** 検索結果 **
[{"zone_name": "Newark Airport", "average_tip": "11.954689339"}]

回答
 チップをたくさんもらえる場所は、ニューアーク空港です。平均チップ額が11.95ドルで最も高くなっています。これは、空港への送迎が長距離であることが多く、チップが高額になる傾向があるためと考えられます。 

Input

question = '平均乗客数が多い地域と時間帯の組み合わせについてトップ10を教えて。'
response = ask_bigquery(question)
print('回答\n', response)

Output

** クエリ **

SELECT tzpu.zone_name AS pickup_zone,
       EXTRACT(HOUR FROM pickup_datetime) AS pickup_hour,
       AVG(passenger_count) AS average_passenger_count
FROM `bigquery-public-data`.new_york_taxi_trips.tlc_yellow_trips_2022 AS yellow
JOIN `bigquery-public-data`.new_york_taxi_trips.taxi_zone_geom AS tzpu ON yellow.pickup_location_id = tzpu.zone_id
GROUP BY pickup_zone, pickup_hour
ORDER BY average_passenger_count DESC
LIMIT 10


** 検索結果 **
[{"pickup_zone": "South Beach/Dongan Hills", "pickup_hour": "3", "average_passenger_count": "5.0"}, {"pickup_zone": "Country Club", "pickup_hour": "4", "average_passenger_count": "4.0"}, {"pickup_zone": "New Dorp/Midland Beach", "pickup_hour": "20", "average_passenger_count": "4.0"}, {"pickup_zone": "Rikers Island", "pickup_hour": "15", "average_passenger_count": "4.0"}, {"pickup_zone": "Ocean Parkway South", "pickup_hour": "2", "average_passenger_count": "3.8571428571428568"}, {"pickup_zone": "Green-Wood Cemetery", "pickup_hour": "14", "average_passenger_count": "3.5"}, {"pickup_zone": "Ocean Parkway South", "pickup_hour": "4", "average_passenger_count": "3.5"}, {"pickup_zone": "East Flushing", "pickup_hour": "18", "average_passenger_count": "3.3333333333333335"}, {"pickup_zone": "Ocean Parkway South", "pickup_hour": "1", "average_passenger_count": "3.3333333333333335"}, {"pickup_zone": "Green-Wood Cemetery", "pickup_hour": "16", "average_passenger_count": "3.0"}]

回答
2022年のイエロータクシー乗車データによると、平均乗客数が最も多い地域と時間帯の組み合わせトップ10は以下の通りです。

1. South Beach/Dongan Hills, 3時: 平均乗客数 5.0人
2. Country Club, 4時: 平均乗客数 4.0人
3. New Dorp/Midland Beach, 20時: 平均乗客数 4.0人
4. Rikers Island, 15時: 平均乗客数 4.0人
5. Ocean Parkway South, 2時: 平均乗客数 3.86人
6. Green-Wood Cemetery, 14時: 平均乗客数 3.5人
7. Ocean Parkway South, 4時: 平均乗客数 3.5人
8. East Flushing, 18時: 平均乗客数 3.33人
9. Ocean Parkway South, 1時: 平均乗客数 3.33人
10. Green-Wood Cemetery, 16時: 平均乗客数 3.0人 

なかなか実用的で便利そうですね。Function Calling では、API サービスへの入力データも LLM 自身が考えますが、ここでは、BigQuery を API サービスとして利用しており、その入力データとなる SQL のクエリを LLM 自身が生成しています。

実装手順

事前準備

それでは、具体的な実装内容を説明します。Vertex AI Workbench のノートブック上で実装するので、新しいプロジェクトを作成したら、Cloud Shell のコマンド端末を開いて、Vertex AI Workbench を使用するのに必要な API を有効化します。あわせて、BigQuery の API も有効化しておきます。

gcloud services enable \
  aiplatform.googleapis.com \
  notebooks.googleapis.com \
  bigquery.googleapis.com

次のコマンドで Workbench インスタンスを作成します。

PROJECT_ID=$(gcloud config list --format 'value(core.project)')
gcloud workbench instances create agent-development \
  --project=$PROJECT_ID \
  --location=us-central1-a \
  --machine-type=e2-standard-2

クラウドコンソールのナビゲーションメニューから「Vertex AI」→「ワークベンチ」を選択すると、作成したインスタンス agent-develpment があります。インスタンスの起動が完了するのを待って、「JUPYTERLAB を開く」をクリックしたら、「Python 3(ipykernel)」の新規ノートブックを作成します。

この後は、ノートブックのセルでコードを実行していきます。

API サービスの定義

まずは、必要なモジュールをインポートします。

import json, vertexai
from google.cloud import bigquery
from vertexai.generative_models import \
    FunctionDeclaration, GenerationConfig, GenerativeModel, Part, Tool

そして、OpanAPI のフォーマットで Agent が使用する API サービスの仕様を宣言します。

sql_query_func = FunctionDeclaration(
    name='sql_query',
    description='Get factual information from BigQuery using SQL queries',
    parameters={
        'type': 'object',
        'properties': {
            'query': {
                'type': 'string',
                'description': 'SQL query on a single line that will help give quantitative answers'
            }
        },
        'required': ['query']
    }
)

ここでは、「sql_query という名称のサービスがあって、query パラメーターに SQL のクエリを文字列で与えると BigQuery から情報が得られる」という事実を宣言しています。FunctionDeclaration というクラス名からわかるように、ここで定義するものは、API サービスに限定されるものではなく、「パラメーターを与えると応答が得られる」という挙動が期待されるもの、つまり、広い意味での関数として振る舞うものであればなんでも構いません。

LLM は、使用する API サービスを選択する際に description 部分に書かれた情報を参考にしますので、この部分にはできるだけ詳しい情報を書いておくことをお勧めします。用意した API サービスを Agent が想定通りに利用してくれない際は、この部分の記述を工夫することで改善する場合もあります。

つづいて、実際に使用する LLM モデルのオブジェクトを取得しますが、この際に、使用可能なツールの情報を tools オプションで受け渡します。

bq_tool = Tool(
    function_declarations=[sql_query_func]
)

model = GenerativeModel(
    'gemini-1.5-pro-001',
    generation_config=GenerationConfig(temperature=0.4),
    tools=[bq_tool]
)

ここでは、2 行目で、先ほど定義した sql_query_func を指定しています。複数のツールがある場合はこのリストに追加します。このモデルにプロンプトを送信すると、プロンプトで明示的に指定しなくても、必要な際はツールの使用を要求する応答を返すようになります。

プロンプトテンプレートの用意

そして、このモデルに質問を投げるためのプロンプトのテンプレートを用意します。最後の {} の部分に実際の質問を埋め込んで使います。

prompt_template = '''\
You are a data analytics expert. Work on the following tasks.
    
[task]
A. Answer the question with the reason based on the data you get from BigQuery.

[condition]
A. Use SQL queries to get information from BigQuery using the column definitions in the [table information].
A. The answer and the reason must be based on the quantitative information in tables.
A. Use concrete area names in the answer instead of zone_id or location_id.

[format instruction]
In Japanese. In plain text, no markdowns.

[table information]
columns of the table 'bigquery-public-data.new_york_taxi_trips.taxi_zone_geom'
- zone_id : Unique ID number of each taxi zone. Corresponds with the pickup_location_id and dropoff_location_id in each of the trips tables
- zone_name : Full text name of the taxi zone

columns of the table: 'bigquery-public-data.new_york_taxi_trips.tlc_yellow_trips_2022'
- pickup_datetime : The date and time when the meter was engaged
- dropoff_datetime : The date and time when the meter was disengaged
- passenger_count : The number of passengers in the vehicle. This is a driver-entered value.
- trip_distance : The elapsed trip distance in miles reported by the taximeter.
- fare_amount : The time-and-distance fare calculated by the meter
- tip_amount : Tip amount. This field is automatically populated for credit card tips. Cash tips are not included.
- tolls_amount : Total amount of all tolls paid in trip.
- total_amount : The total amount charged to passengers. Does not include cash tips.
- pickup_location_id : TLC Taxi Zone in which the taximeter was engaged
- dropoff_location_id : TLC Taxi Zone in which the taximeter was disengaged

[question]
{}
'''

[task] 部分で、BigQuery のデータを使って質問に答えるように指示しています。また、[table information] の部分に使用するテーブルの情報が記載されていますが、この部分は、実際のテーブルのスキーマ情報をそのままコピペしてあります。たとえば、taxi_zone_geom テーブルのスキーマ情報は次のように定義されています。


taxi_zone_geom テーブルのスキーマ情報

Agent 本体の実装

そしていよいよ、Agent の処理を実装したメインパートです。

def ask_bigquery(question):
    chat = model.start_chat()
    client = bigquery.Client()
    prompt = prompt_template.format(question)
    response = chat.send_message(prompt).candidates[0].content.parts[0]

    while True:
        try:
            # Throw AttributeError unless function call is required.    
            function_call = response.function_call
            params = {key: value for key, value in function_call.args.items()}

            if function_call.name == 'sql_query':
                    try:
                        query = params['query']
                        print(f'** クエリ **\n{query}\n')
                        query_job = client.query(query)
                        result = query_job.result()
                        result = [dict(row) for row in result]
                        result = [{key: str(value) for key, value in raw.items()} for raw in result]
                        api_response = json.dumps(result)
                    except Exception as e:
                        api_response = json.dumps({'error message': f'{str(e)}'})
                    print(f'** 検索結果 **\n{api_response}\n')

            response = chat.send_message(
                Part.from_function_response(
                    name=function_call.name,
                    response={'content': api_response}
                )
            ).candidates[0].content.parts[0]
        except AttributeError:
            break

    return response.text

ここでは、Gemini API のチャットプロンプトを使用しています。chat = model.start_chat() で用意したオブジェクトに chat.send_message() でプロンプトを繰り返し送信すると、LLM は「過去のプロンプトとその応答の履歴に、新しく指定したプロンプトを追加した内容」を受け取ったものとして応答を返します。

while True: のループが始まる直前の部分では、先ほど用意したテンプレートに引数 question で受け取った質問を埋め込んだものを最初のプロンプトとして送信して、その応答を受け取っています。

次に、while True: の内部では、LLM からの応答が API サービス(ツール)の実行を要求しているかを判断して処理を分岐します。

  • API サービスの実行を要求している場合は、要求通りに実行して、その結果をプロンプトとして送信する。(ループの先頭に戻る)
  • API サービスの実行を要求していない場合は、最終の回答が得られたものとして、それを返却する。(ループを抜けて終了する)

全体として、次のようなループが回ることになります。


Function Calling を使用した Agent の動作

ここでは、主要なステップの実装を詳しく説明しておきます。

まず、ループの先頭では、応答に含まれる response.function_call 要素を変数 function_call に保存していますが、この要素は、LLM が API サービスの実行を要求している場合にのみ存在します。そのため、API サービスの実行を要求していない場合は、ここで例外が発生して、except AttributeError: に処理が飛んでループが終了します。

この要素が存在する場合、function_call.name は、使用する API サービスの名称(事前に定義した内容に含まれる name 要素)を表すので、この値を見て、対応する API サービスをコードから実行します。この例では、sql_query の一択になります。API サービスに送信するべきパラメーターは、function_call.args.items から取得します。

そして、この例では、BigQuery のクライアント SDK を使ってクエリを実行した後、得られた結果を次の手順で JSON 文字列に変換しています。

  • 各レコードをディクショナリに変換したものを集めたリストを作る
     ↓
  • リスト内のディクショナリの個々のバリューを文字列型に変換する
     ↓
  • 得られた結果を JSON 文字列に変換する

これは、Function Calling のお約束で、API サービスからの応答は JSON 文字列で受け渡すことになっているからです。また、クエリの結果がエラーになった場合は、エラーメッセージを応答として返します。これは重要なポイントで、LLM が用意したクエリが間違っていた場合、LLM はエラーメッセージを参考にして、再度、新しいクエリを提案するという動作を行います。

最後に、API サービスからの応答をプロンプトとして送信する下記の部分について補足します。

            response = chat.send_message(
                Part.from_function_response(
                    name=function_call.name,
                    response={'content': api_response}
                )
            ).candidates[0].content.parts[0]

この部分は、Function Calling のお作法に従って、使用した API サービスの名前 function_call.name と実行結果 api_responsePart.from_function_response() メソッドで Part オブジェクトに変換したものを送信していますが、現在の実装では、これは必須ではありません。次の様にプレインテキストで送信しても、同等の結果が得られます。

            response = chat.send_message(
                f'[response from {function_call.name}]\n{api_response}'
            ).candidates[0].content.parts[0]

LLM Agent の実行

これで LLM Agent が完成しました。関数 ask_bigquery() に質問を投げると回答が得られます。

question = '乗客数とチップの平均額を表にして、乗客数とチップの額に関連性があるか調べて。'
response = ask_bigquery(question)
print('回答\n', response)

実行結果

** クエリ **

SELECT passenger_count, AVG(tip_amount) AS average_tip_amount FROM bigquery-public-data.new_york_taxi_trips.tlc_yellow_trips_2022 GROUP BY passenger_count


** 検索結果 **
[{"passenger_count": "5", "average_tip_amount": "2.682891441"}, {"passenger_count": "7", "average_tip_amount": "7.887932692"}, {"passenger_count": "3", "average_tip_amount": "2.732784987"}, {"passenger_count": "2", "average_tip_amount": "2.878064225"}, {"passenger_count": "6", "average_tip_amount": "2.702456693"}, {"passenger_count": "9", "average_tip_amount": "9.481282051"}, {"passenger_count": "1", "average_tip_amount": "2.64627696"}, {"passenger_count": "4", "average_tip_amount": "2.749474037"}, {"passenger_count": "None", "average_tip_amount": "3.718722887"}, {"passenger_count": "8", "average_tip_amount": "8.517593985"}, {"passenger_count": "0", "average_tip_amount": "2.437039335"}]

回答
 乗客数とチップの平均額は以下の通りです。

| 乗客数 | チップの平均額 |
|---|---|
| 0 | 2.44 |
| 1 | 2.65 |
| 2 | 2.88 |
| 3 | 2.73 |
| 4 | 2.75 |
| 5 | 2.68 |
| 6 | 2.70 |
| 7 | 7.89 |
| 8 | 8.52 |
| 9 | 9.48 |

このデータを見ると、乗客数が7人以上の場合にチップの平均額が大きく増加していることがわかります。つまり、乗客数とチップの額にはある程度の関連性があると言えるでしょう。 

まとめ

実装例を使用する様子を見ると、シンプルな実装ながら実用性の高い LLM Agent が実現できていることがわかります。実業務で使う際は、LLM が生成したクエリが間違っていないかをチェックできる程度のリテラシーは利用者に必要かもしれませんが、調べたい内容にあわせて自分でクエリを考えるよりはよほど高速に調べ物が進みそうです。

今回の実装例からわかるように、外部の API サービスを呼び出す処理は、Python のコードで普通に実行していますので、呼び出しに必要なパラメーターさえ決まれば、任意のツールを利用することができます。他の LLM Agent を外部 API サービスとして利用する多段構成も容易に実現できるでしょう。

また、一般に「エージェント」というと、外部システムから情報を収集するだけではなく、ユーザーの代わりに外部システムを操作する仕組みを想像する方もいるかも知れません。今回実装したシステムを応用すれば、これも簡単に実現できます。たとえば、部屋のエアコンを操作する API サービスを作っておき、Function Calling でこの API サービスの情報を登録した LLM を用意します。この LLM に「部屋が暑いからなんとかして」というと、エアコンの設定温度を下げるリクエストが応答として返ってきます。あとは、先ほどと同様のコードで、このリクエストを実際の API サービスに投げれば OK です。

Google Cloud Japan

Discussion