😊

簡単!AWS BedrockエージェントでWebサーチする方法

に公開

https://aws.amazon.com/jp/bedrock/agents/

1. はじめに

アクセンチュア株式会社の桐山です。
今回は、AWS BedrockエージェントでWeb検索する方法を紹介したいと思います。

Bedrockエージェントの「アクショングループ」にWeb検索を行うLambdaを登録することで、簡単にWeb検索を実装することができます。
今回はWeb検索に Tavily を使用します。

https://www.tavily.com/

なお、本記事はみのるんさんの以下を参考にさせていただきました。
https://qiita.com/minorun365/items/85cb57f19fe16a87acff#7-フロントエンド開発

2. ゴール

Webサーチの対象として、弊社のHPを指定してみました!
LLMの学習に含まれていない情報でも、Web検索を行い回答できています。

Bedrockエージェント内の「テスト」で確認

3. Bedrockエージェント、Lambdaの設定

  1. Bedrockエージェントを作成していきます。名前はいったんデフォルトのまま進めていきます。

  2. 「モデルを選択」を開き、使用するLLMを指定します。ここではClaude3.7 Sonnetを指定します。

  3. 「エージェント向けの指示」に、一例ですが以下のプロンプトを指定します。

  4. 「アクショングループ」にWeb検索を実行するLambdaを登録します。

  5. 「アクショングループ関数1」に以下のパラメータを指定します。

4. 実装(アクショングループのLambda)

  1. Lambdaレイヤーに tavily-python をアップロードします。
  2. Lambdaに以下のコードを指定します。
import os
import json
from tavily import TavilyClient

def lambda_handler(event, context):
    # 環境変数からAPIキーを取得
    tavily_api_key = os.environ.get('TAVILY_API_KEY')
    
    # eventからクエリパラメータを取得
    parameters = event.get('parameters', [])
    for param in parameters:
        if param.get('name') == 'query':
            query = param.get('value')
            break
    
    # Tavilyクライアントを初期化して検索を実行
    client = TavilyClient(api_key=tavily_api_key)
    search_result = client.get_search_context(
        query=query,
        search_depth="advanced",
        max_results=10
    )
    
    # 成功レスポンスを返す
    return {
        'messageVersion': event['messageVersion'],
        'response': {
            'actionGroup': event['actionGroup'],
            'function': event['function'],
            'functionResponse': {
                'responseBody': {
                    'TEXT': {
                        'body': json.dumps(search_result, ensure_ascii=False)
                    }
                }
            }
        }
    }
  1. Tavilyのサイトで取得したAPIキーを、Lambda環境変数にTAVILY_API_KEYとして設定します。

5. 実装(Bedrock Agent APIを呼び出すイメージ)

import json
import uuid
import boto3
import streamlit as st
from botocore.exceptions import ClientError
from botocore.eventstream import EventStreamError

agent_id = "XXXXXXXXXX" # エージェントID
agent_alias_id = "XXXXXXXXXX" # エイリアスID

def initialize_session():
    """セッションの初期設定を行う"""
    if "client" not in st.session_state:
        st.session_state.client = boto3.client("bedrock-agent-runtime")
    
    if "session_id" not in st.session_state:
        st.session_state.session_id = str(uuid.uuid4())
    
    if "messages" not in st.session_state:
        st.session_state.messages = []
    
    if "last_prompt" not in st.session_state:
        st.session_state.last_prompt = None
    
    return st.session_state.client, st.session_state.session_id, st.session_state.messages

def display_chat_history(messages):
    """チャット履歴を表示する"""
    st.title("チャットボット")
    st.text("Web検索もできるよ!")
    
    for message in messages:
        with st.chat_message(message['role']):
            st.markdown(message['text'])

def handle_trace_event(event):
    """トレースイベントの処理を行う"""
    if "orchestrationTrace" not in event["trace"]["trace"]:
        return
    
    trace = event["trace"]["trace"]["orchestrationTrace"]
    
    # 「モデル入力」トレースの表示
    if "modelInvocationInput" in trace:
        with st.expander("🤔 思考中…", expanded=False):
            input_trace = trace["modelInvocationInput"]["text"]
            try:
                st.json(json.loads(input_trace))
            except:
                st.write(input_trace)
    
    # 「モデル出力」トレースの表示
    if "modelInvocationOutput" in trace:
        output_trace = trace["modelInvocationOutput"]["rawResponse"]["content"]
        with st.expander("💡 思考がまとまりました", expanded=False):
            try:
                thinking = json.loads(output_trace)["content"][0]["text"]
                if thinking:
                    st.write(thinking)
                else:
                    st.write(json.loads(output_trace)["content"][0])
            except:
                st.write(output_trace)
    
    # 「根拠」トレースの表示
    if "rationale" in trace:
        with st.expander("✅ 次のアクションを決定しました", expanded=True):
            st.write(trace["rationale"]["text"])
    
    # 「ツール呼び出し」トレースの表示
    if "invocationInput" in trace:
        invocation_type = trace["invocationInput"]["invocationType"]
                
        if invocation_type == "ACTION_GROUP":
            with st.expander("💻 Lambdaを実行中…", expanded=False):
                st.write(trace['invocationInput']['actionGroupInvocationInput'])
    
    # 「観察」トレースの表示
    if "observation" in trace:
        obs_type = trace["observation"]["type"]
        
        if obs_type == "ACTION_GROUP":
            with st.expander(f"💻 Lambdaの実行結果を取得しました", expanded=False):
                st.write(trace["observation"]["actionGroupInvocationOutput"]["text"])
                
def invoke_bedrock_agent(client, session_id, prompt):
    """Bedrockエージェントを呼び出す"""
    return client.invoke_agent(
        agentId=agent_id,
        agentAliasId=agent_alias_id,
        sessionId=session_id,
        enableTrace=True,
        inputText=prompt,
    )

def handle_agent_response(response, messages):
    """エージェントのレスポンスを処理する"""
    with st.chat_message("assistant"):
        for event in response.get("completion"):
            if "trace" in event:
                handle_trace_event(event)
            
            if "chunk" in event:
                answer = event["chunk"]["bytes"].decode()
                st.write(answer)
                messages.append({"role": "assistant", "text": answer})

def show_error_popup(exeption):
    """エラーポップアップを表示する"""
    if exeption == "throttlingException":
        error_message = "【エラー】Bedrockのモデル負荷が高いようです。1分ほど待ってから、ブラウザをリロードして再度お試しください🙏(改善しない場合は、モデルを変更するか[サービスクォータの引き上げ申請](https://aws.amazon.com/jp/blogs/news/generative-ai-amazon-bedrock-handling-quota-problems/)を実施ください)"
    st.error(error_message)

def main():
    """メインのアプリケーション処理"""
    client, session_id, messages = initialize_session()
    display_chat_history(messages)
    
    if prompt := st.chat_input(""):
        messages.append({"role": "human", "text": prompt})
        with st.chat_message("user"):
            st.markdown(prompt)
        
        try:
            response = invoke_bedrock_agent(client, session_id, prompt)
            handle_agent_response(response, messages)
            
        except (EventStreamError, ClientError) as e:
            if "throttlingException" in str(e):
                show_error_popup("throttlingException")
            else:
                raise e

if __name__ == "__main__":
    main()

6. さいごに

いかがでしたでしょうか。
今回は、BedrockエージェントでLambdaによるWebサーチを試してみました。
Bedrockエージェントを使用することで、自律的なWebサーチを簡単に実装することができました。
今回は「アクショングループ」のLambdaを試してみましたが、Bedrockエージェントはナレッジベースの検索も簡単に行うことができます。

次回は、Bedrockエージェントのマルチエージェントを試してみたいと思っています。

Accenture Japan (有志)

Discussion