🤖

Amazon Bedrock の Code Interpreter を使ってみる

2024/08/13に公開

はじめに

最近、Amazon Bedrock 生成AIアプリ開発入門 [AWS深掘りガイド] という本を読む機会があり、OpenAIがAdvanced Data Analysis(旧Code Interpreter)という機能を提供していることを知りました。

Advanced Data Analysisは、ユーザーの質問を受け取って回答を生成する際に、プログラミングによる処理が必要であると判断した場合、自動的にコードを生成して、その実行結果を出力するというものになっています。

Amazon Bedrockにおいてもそのような機能が使えないか調べていると、2024年7月10日に、Agents for Amazon Bedrockの新機能(プレビュー版)として、Advanced Data Analysisと同様の機能を持つCode Interpreterが使えるようになったという紹介があったことを知りました。

Agents for Amazon Bedrock

エージェントとは、ある目的を達成するために自律的に行動するプログラムのことを指します。
そのため、AIエージェントは、LLMが持つ強力な思考能力を利用して、AI(LLM)自身が行動を選択して実行するエージェントということになります。

もう少し具体的に言うと、AIエージェントは、事前に用意された関数(ツール)の中から回答の生成に最も適切な関数を自ら判断して実行することができます。

Langchain等が提供するモジュールを組み合わせて、このAIエージェントの機能を実装することができますが、Agents for Amazon Bedrockを使用すると、AIエージェントの機能自体を実装する必要がなく、事前に用意する関数(ツール)やプロンプトの準備に注力できるようになります。

Code Interpreter

Code Interpreterは、Agents for Amazon Bedrockで開発したAIエージェントが使用できるツールの1つという認識をしています。

Code Interpreterの機能は、Agents for Amazon Bedrock上のエージェントに設定でCode InterpreterをEnabledにするだけで使い始めることができます(参考)。

Code Interpreterの機能を使えるようにしておくと、AIエージェントは、サンドボックス環境でコードを動的に生成して実行できるようになります。データ分析、視覚化、テキスト処理、方程式解決、最適化問題などの複雑なタスクも対処できるユースケースとして想定されているようです。

さらに、CSVやPDF等のファイルを読み込んでそれを処理することもできます。入力するファイルの最大サイズは10MBと記載されていまして、おそらく1つのファイルサイズだと思うのですが、サンドボックスに持ち込める合計のファイルサイズの最大値については私はよくわかっていません。出力としてグラフ等の図も出力できます。

プレビュー中の現在は、米国東部 (バージニア北部)、米国西部 (オレゴン)、および欧州 (フランクフルト) リージョンで使用できます。

実行環境

  • Python: 3.12.3
  • boto3: 1.34.156
  • matplotlib: 3.9.1.post1

ノートブックの実行

sample.ipynb
import boto3
import matplotlib.pyplot as plt
import io
import uuid

class MyBedrockAgentClient:
    def __init__(
        self, 
        region_name: str,
        agent_id: str,
        agent_alias_id: str
    ) -> None:
        self.agent_id = agent_id
        self.agent_alias_id = agent_alias_id

        # ランタイムクライアントをセットアップする
        self.boto3_session = boto3.Session(profile_name="PROFILE_NAME")
        self.bedrock_agent_runtime = self.boto3_session.client(
            service_name = "bedrock-agent-runtime",
            region_name = region_name,
        )
    
    def invoke(
        self,
        session_id: str,
        input_text: str,
        csv_file: bytes | None = None,
    ) -> None:
        
        session_state = {}
        if csv_file is not None:
            session_state = {
                "files": [
                    {
                        "name": "csv",
                        "source": {
                            "sourceType": "BYTE_CONTENT",
                            "byteContent": {
                                "mediaType": "text/csv", 
                                "data": csv_file,
                            }
                        },
                        "useCase": "CODE_INTERPRETER"
                    }
                ]
            }


        # エージェントを実行する
        response = self.bedrock_agent_runtime.invoke_agent(
            sessionId=session_id,
            agentId=self.agent_id,
            agentAliasId=self.agent_alias_id,
            inputText=input_text,
            sessionState=session_state,
            enableTrace=True,
        )

        # エージェントの実行結果を取得する
        event_stream = response["completion"]

        # イベントストリームは複数のイベントを含むため、
        # 各イベントをforループで処理する
        for event in event_stream:

            # トレースの取得
            if "trace" in event:
                trace = event["trace"]["trace"]

                # OrchestrationTrace: オーケストレーションステップの入力と出力をトレースする
                # オーケストレーションステップ: ユーザーの入力を解釈して、
                # Code interpreterの実行、アクショングループの呼び出し、ナレッジベースへのクエリ等を行う
                # 次に、エージェントはオーケストレーションを続行するか、ユーザーに応答するために出力を返す
                # 参考: https://docs.aws.amazon.com/ja_jp/bedrock/latest/userguide/trace-events.html
                if "orchestrationTrace" in trace:
                    orchestration_trace = trace["orchestrationTrace"]

                    if "modelInvocationInput" in orchestration_trace:
                        print("-"*10, "オーケストレーションステップへの入力", "-"*10)
                        print(json.dumps(orchestration_trace, indent=2, ensure_ascii=False))
                        print("\n")
                    
                    if "rationale" in orchestration_trace:
                        print("-"*10, "必要な行動と理由の推論", "-"*10)
                        print(json.dumps(orchestration_trace, indent=2, ensure_ascii=False))
                        print("\n")
                    
                    if "invocationInput" in orchestration_trace:
                        print("-"*10, "コードインタープリター等に関する情報", "-"*10)
                        print(json.dumps(orchestration_trace, indent=2, ensure_ascii=False))
                        print("\n")
                    
                    if "observation" in orchestration_trace:
                        print("-"*10, "行動結果の観察", "-"*10)
                        print(json.dumps(orchestration_trace, indent=2, ensure_ascii=False))
                        print("\n")

            # イベントにchunkキーが含まれている場合、
            # そのチャンクのバイト列をデコードしてテキストに変換する
            if "chunk" in event:
                chunk = event["chunk"]
                if "bytes" in chunk:
                    text = chunk["bytes"].decode("utf-8")
                    print("-"*10, "最終出力", "-"*10)
                    print(f"Chunk: {text}")
                    print("\n")
                else:
                    print("-"*10, "最終出力", "-"*10)
                    print("Chunkには「bytes」が含まれていません。")
                    print("\n")
    
            # イベントにfilesキーが含まれている場合、
            # png形式なら表示する
            if "files" in event:
                files = event["files"]["files"]
                for file in files:
                    name = file["name"]
                    type = file["type"]
                    bytes_data = file["bytes"]
                    if type == "image/png":
                        print("-"*10)
                        img = plt.imread(io.BytesIO(bytes_data))
                        plt.figure(figsize=(10, 10))
                        plt.imshow(img)
                        plt.axis("off")
                        plt.title(name)
                        plt.show()
                        plt.close()
                        print("\n")

my_bedrock_agent_client = MyBedrockAgentClient(
    region_name="REGION_NAME",
    agent_id="AGENT_ID",
    agent_alias_id="TSTALIASID"
)

session_id = str(uuid.uuid1())
my_bedrock_agent_client.invoke(
    session_id=session_id,
    input_text="100までの素数を出力してください。"
)
出力結果
---------- オーケストレーションステップへの入力 ----------
{
  "modelInvocationInput": {
    "inferenceConfiguration": {
      "maximumLength": 2048,
      "stopSequences": [
        "</invoke>",
        "</answer>",
        "</error>"
      ],
      "temperature": 0.0,
      "topK": 250,
      "topP": 1.0
    },
    "text": "{\"system\":\"        あなたは、コード実行、チャート生成、... 常に努力してください。        You have been provided ... or file paths.                </guidelines>                                        \",\"messages\":[{\"content\":\"100までの素数を出力してください。\",\"role\":\"user\"}]}",
    "traceId": "98be2112-1878-4c98-a7f2-8c1157b6401d-0",
    "type": "ORCHESTRATION"
  }
}


---------- 必要な行動と理由の推論 ----------
{
  "rationale": {
    "text": "この問題を解決するには、以下のステップを踏む必要があります:\n\n1. 2から100までの数字のリストを作成する\n2. 各数字について、素数かどうかを判定する\n3. 素数であれば、リストに追加する\n4. 最終的な素数のリストを出力する\n\nこれらのステップを実行するためのPythonコードを書いて実行します。",
    "traceId": "98be2112-1878-4c98-a7f2-8c1157b6401d-0"
  }
}


---------- コードインタープリター等に関する情報 ----------
{
  "invocationInput": {
    "codeInterpreterInvocationInput": {
      "code": "\nprimes = []\nfor num in range(2, 101):\n    is_prime = True\n    for i in range(2, int(num**0.5) + 1):\n        if num % i == 0:\n            is_prime = False\n            break\n    if is_prime:\n        primes.append(num)\n\nprint(primes)\n"
    },
    "invocationType": "ACTION_GROUP_CODE_INTERPRETER",
    "traceId": "98be2112-1878-4c98-a7f2-8c1157b6401d-0"
  }
}


---------- 行動結果の観察 ----------
{
  "observation": {
    "codeInterpreterInvocationOutput": {
      "executionOutput": "[2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97]"
    },
    "traceId": "98be2112-1878-4c98-a7f2-8c1157b6401d-0",
    "type": "ACTION_GROUP_CODE_INTERPRETER"
  }
}


---------- オーケストレーションステップへの入力 ----------
{
  "modelInvocationInput": {
    "inferenceConfiguration": {
      "maximumLength": 2048,
      "stopSequences": [
        "</invoke>",
        "</answer>",
        "</error>"
      ],
      "temperature": 0.0,
      "topK": 250,
      "topP": 1.0
    },
    "text": "{\"system\":\"        あなたは、コード実行、チャート生成、... 常に努力してください。        You have been provided ... or file paths.                </guidelines>                                        \",\"messages\":[{\"content\":\"100までの素数を出力してください。\",\"role\":\"user\"},{\"content\":\"<thinking>この問題を解決するには、以下のステップを踏む必要があります:1. 2から100までの数字のリストを作成する2. 各数字について、素数かどうかを判定する3. 素数であれば、リストに追加する4. 最終的な素数のリストを出力するこれらのステップを実行するためのPythonコードを書いて実行します。</thinking><function_calls><invoke><tool_name>get::codeinterpreteraction::execute</tool_name><parameters><code>primes = []for num in range(2, 101):    is_prime = True    for i in range(2, int(num**0.5) + 1):        if num % i == 0:            is_prime = False            break    if is_prime:        primes.append(num)print(primes)</code></parameters></invoke></function_calls>\",\"role\":\"assistant\"},{\"content\":\"<function_results><result><tool_name>get::codeinterpreteraction::execute</tool_name><stdout>'code_execution_output': [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97]\\,\\'is_error': False,\\'is_timeout': False,\\'generated_files': []</stdout></result></function_results>\",\"role\":\"user\"}]}",
    "traceId": "98be2112-1878-4c98-a7f2-8c1157b6401d-1",
    "type": "ORCHESTRATION"
  }
}


---------- 行動結果の観察 ----------
{
  "observation": {
    "finalResponse": {
      "text": "上記のコードは、2から100までの範囲で素数を見つけ、それらをリストに追加しています。出力は、2から100までの素数のリスト [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97] となります。"
    },
    "traceId": "98be2112-1878-4c98-a7f2-8c1157b6401d-1",
    "type": "FINISH"
  }
}


---------- 最終出力 ----------
Chunk: 上記のコードは、2から100までの範囲で素数を見つけ、それらをリストに追加しています。出力は、2から100までの素数のリスト [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97] となります。

次に、CSVファイルを読み込んで処理してみます。CSVファイルは、東京都オープンデータカタログサイトの主な産業別就業者数のデータを使います。

このデータの上から3つの行が男女計の年ごとの平均になっているため、その部分のみを取り出したsample.csvを作成しました(下記のイメージ)。

sample.ipynb
with open("sample.csv", mode="rb") as f:
    csv_file: bytes = f.read()

session_id = str(uuid.uuid1())
my_bedrock_agent_client.invoke(
    session_id=session_id,
    input_text="""
    添付ファイルは、産業別就業者数のデータです。
    建設業の就業者数(男女計)の推移を示す折れ線グラフを作成してください。
    横軸には、西暦(年)を使ってください。
    さらに、そのグラフに製造業の折れ線グラフを追加してください。
    なお、グラフの縦軸と横軸と凡例は英語で記載して、グラフの画像を出力してください。
    """,
    csv_file=csv_file,
)
出力結果
---------- オーケストレーションステップへの入力 ----------
{
  "modelInvocationInput": {
    "inferenceConfiguration": {
      "maximumLength": 2048,
      "stopSequences": [
        "</invoke>",
        "</answer>",
        "</error>"
      ],
      "temperature": 0.0,
      "topK": 250,
      "topP": 1.0
    },
    "text": "{\"system\":\"        あなたは、コード実行、チャート生成、... 常に努力してください。        You have been provided or file paths.                </guidelines>                You have access to the following files:            <file path='$BASE_PATH$/csv' type='text/csv'>            <metadata>            {\\\"columns\\\":\\\"[年・期,西暦(年),四半期,建設業(千人),製造業(千人),情報通信業(千人),運輸業,郵便業(千人),卸売業,小売業(千人),金融業,保険業(千人),不動産業,物品賃貸業(千人),学術研究,専門・技術サービス業(千人),宿泊業,飲食サービス業(千人),生活関連サービス業,娯楽業(千人),教育,学習支援業(千人),医療,福祉(千人),サービス業(他に分類されないもの)(千人)]\\\"}            </metadata>            <head>            年・期,西暦(年),四半期,建設業(千人),製造業(千人),情報通信業(千人),運輸業,郵便業(千人),卸売業,小売業(千人),金融業,保険業(千人),不動産業,物品賃貸業(千人),学術研究,専門・技術サービス業(千人),宿泊業,飲食サービス業(千人),生活関連サービス業,娯楽業(千人),教育,学習支援業(千人),医療,福祉(千人),サービス業(他に分類されないもの)(千人)実数/男女計            </head>            </file>                        \",\"messages\":[{\"content\":\"添付ファイルは、産業別就業者数のデータです。    建設業の就業者数(男女計)の推移を示す折れ線グラフを作成してください。    横軸には、西暦(年)を使ってください。    さらに、そのグラフに製造業の折れ線グラフを追加してください。    なお、グラフの縦軸と横軸と凡例は英語で記載して、グラフの画像を出力してください。\",\"role\":\"user\"}]}",
    "traceId": "4ac6cf8d-add6-4405-aada-be9e69a9e4ad-0",
    "type": "ORCHESTRATION"
  }
}


---------- 必要な行動と理由の推論 ----------
{
  "rationale": {
    "text": "この質問に答えるには、以下の手順が必要です:\n\n1. CSVファイルからデータを読み込む\n2. 建設業と製造業の就業者数のデータを抽出する\n3. 西暦(年)を横軸に、就業者数を縦軸にしてグラフを作成する\n4. 建設業と製造業の折れ線グラフを1つのグラフに重ねる\n5. グラフの軸ラベルと凡例を英語で設定する\n6. グラフを画像ファイルとして出力する\n\nこれらの手順を実行するためのコードを書いて実行します。",
    "traceId": "4ac6cf8d-add6-4405-aada-be9e69a9e4ad-0"
  }
}


---------- コードインタープリター等に関する情報 ----------
{
  "invocationInput": {
    "codeInterpreterInvocationInput": {
      "code": "\nimport pandas as pd\nimport matplotlib.pyplot as plt\n\n# CSVファイルからデータを読み込む\ndf = pd.read_csv('$BASE_PATH$/csv')\n\n# 建設業と製造業の就業者数のデータを抽出する\nconstruction = df['建設業(千人)']\nmanufacturing = df['製造業(千人)']\nyears = df['西暦(年)']\n\n# グラフを作成する\nfig, ax = plt.subplots(figsize=(10, 6))\nax.plot(years, construction, label='Construction')\nax.plot(years, manufacturing, label='Manufacturing')\n\n# 軸ラベルと凡例を設定する\nax.set_xlabel('Year', fontsize=14)\nax.set_ylabel('Number of Employees (Thousands)', fontsize=14)\nax.legend()\n\n# グラフを画像ファイルとして出力する\nplt.savefig('$BASE_PATH$/employment_trends.png', bbox_inches='tight')\nprint('Graph saved as employment_trends.png')\n"
    },
    "invocationType": "ACTION_GROUP_CODE_INTERPRETER",
    "traceId": "4ac6cf8d-add6-4405-aada-be9e69a9e4ad-0"
  }
}


---------- 行動結果の観察 ----------
{
  "observation": {
    "codeInterpreterInvocationOutput": {
      "executionOutput": "Graph saved as employment_trends.png",
      "files": [
        "employment_trends.png"
      ]
    },
    "traceId": "4ac6cf8d-add6-4405-aada-be9e69a9e4ad-0",
    "type": "ACTION_GROUP_CODE_INTERPRETER"
  }
}


----------

建設業と製造業の列を使って折れ線グラフを作ってと指示しましたが合ってますね。

おわりに

とりあえず動くコードになるまで思考してくれるし、出力がJSON形式のためキーを指定すれば、生成してくれたコードや結果等、欲しい情報のみを得られるので、アプリケーションに組み込みやすいと思いました。

code interpreterを使ってデータ分析業務の補助をやってもらいたいと考えているので、次はもっと難しいことをやらせてみたいです。

code interpreterが実行される最大ファイルサイズが決まっていますが、データ分析の場合はデータの量が多く、それを超えることもあると思いますので、その場合はどうすればいいんだろう?という感じです。

ここまで、ご覧いただきありがとうございました。

参考

NCDCエンジニアブログ

Discussion