📘

Amazon Bedrock+LangChainを使う際のClaudeの最大入力トークン数への対処

2024/08/14に公開

はじめに

AnthropicのClaudeモデルへの入力トークン数は、1リクエストあたり最大20万トークンまでとなっています。これを超える文書(コンテキスト)を与えて質問をしたいという時があり、その時に困ったので本記事にまとめておきます。

本記事で取り上げる対処法

このような場合の対処法として以下の3つが挙げられるようです(参考)。

  1. より大きな最大入力トークン数を持つLLMに変更する
  2. 文書を分割して、分割した文書から要約等を取得する
  3. RAGの場合は、文書を分割して、その中から質問と関連性があると思われるものを抽出する

本記事で取り上げるのは、2の対処法を取りたい状況のときになります。
Amazon Bedrockだけだと実現できなさそうだと思ったので、LangChainにそのような機能がないか調べてみました。

2の対処法の中でも、Map-ReduceRefineというものがあり、どちらもLangChainで実現できそうでした。

Map-Reduceは、元の文書を分割した後、分割した文書をそれぞれ要約して、各要約文を統合して最終的な要約文を作るイメージです。

Refineは、元の文書を例えば文書1、文書2、文書3に分割した後、文書1を要約し、次に文書2を要約する際に文書1の要約を与え、文書3を要約する際に文書2の要約を与えて、最終的な要約文を作るイメージです。

今回はMap-Reduceについてやってみたので、それをまとめておきます。

環境

上記のMap-Reduceのドキュメントを参考にやっていきました。
LangGraphによる実装が推奨されていたため、それに従って進めていきます。

  • Python: 3.12.3
  • LangGraph: 0.2.3

ノートブックの実行

sample.ipynb
import operator
from typing import Annotated, List, TypedDict
import boto3
from langchain_aws import ChatBedrock
from langgraph.constants import Send
from langchain_core.documents import Document
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate
from langgraph.graph import END, START, StateGraph
from langchain_text_splitters import RecursiveCharacterTextSplitter
from IPython.display import Image

# 要約したい文章
document = """
りんごは赤いです。
ブルーベリーは青いです。
バナナは黄色いです。
"""

# 要約したい文章を(本例は短いが長いと仮定して)分割する
text_splitter = RecursiveCharacterTextSplitter(
    separators=["\n"],
    chunk_size=15,
    chunk_overlap=0,
)
documents: List[Document] = text_splitter.create_documents([document])
print(documents)
実行結果
[Document(page_content='りんごは赤いです。'),
Document(page_content='ブルーベリーは青いです。'),
Document(page_content='バナナは黄色いです。')]
sample.ipynb
# 使用するLLMの設定
boto3_session = boto3.Session(
    profile_name="PROFILE_NAME",
)
bedrock_client = boto3_session.client(
    service_name="bedrock-runtime",
    region_name="REGION_NAME",
)
llm = ChatBedrock(
    client=bedrock_client,
    model_id="anthropic.claude-3-haiku-20240307-v1:0",
    model_kwargs={
        "max_tokens": 4096,
        "temperature": 0,
        "top_p": 1,
    },
)

# 分割後の文書のうちのある1つを要約するためのプロンプトとチェーンの設定
map_template = """
次のコンテキストを読んで、簡潔で自然な要約文を作成してください。
必ず要約文のみを出力してください。
あなたの考えは不要です。

コンテキスト:
------------
{context}
------------
"""
map_prompt = ChatPromptTemplate([("human", map_template)])
map_llm_chain = map_prompt | llm | StrOutputParser()

# 一連の要約文を統合して最終的な要約文を作るためのプロンプトとチェーンの設定
reduce_template = """
以下の一連の要約文を統合して、最終的な要約文を作成してください。
以下の全ての要約文は同じ日に手に入れました。
必ず日本語で回答してください。
必ず要約文のみを出力してください。
あなたの考えは不要です。

一連の要約文:
{docs}
"""
reduce_prompt = ChatPromptTemplate([("human", reduce_template)])
reduce_llm_chain = reduce_prompt | llm | StrOutputParser()

class OverallState(TypedDict):
    """
    グラフ全体の状態を表す型
    - contents: 入力ドキュメント(分割後)のリスト
    - summaries: contentsに対応する要約のリスト
    - final_summary: 最終的な要約
    """
    contents: List[str]

    # Annotated[list, operator.add]:
    # 各ドキュメントから生成した全ての要約を結合して1つのリストに戻す。
    summaries: Annotated[list, operator.add]
    final_summary: str

class SummaryState(TypedDict):
    """
    分割後の文書のうちのある1つの文書の型
    """
    content: str

async def generate_summary(state: SummaryState) -> dict:
    """
    generate_summaryノード:
    分割後の文書のうちのある1つを要約して返す。
    """
    response = await map_llm_chain.ainvoke(state["content"])
    return {"summaries": [response]}

def map_summaries(state: OverallState) -> List[Send]:
    """
    STARTからgenerate_summaryノードへの分岐に使用する関数:
    グラフ全体のState(OverallState)にあるcontentsの要素を
    それぞれgenerate_summaryノードに送信する。
    """
    return [
        Send(
            node="generate_summary",    # node: グラフ内のノード名
            arg={                       # arg: 指定したノード名のノードに送信する内容
                "content": content
            }
        ) for content in state["contents"]
    ]

async def generate_final_summary(state: OverallState) -> dict:
    """
    generate_final_summaryノード:
    一連の要約文を統合して最終的な要約文を返す。
    """
    response = await reduce_llm_chain.ainvoke(state["summaries"])
    return {"final_summary": response}

# グラフの構成
graph = StateGraph(OverallState)
graph.add_node("generate_summary", generate_summary)
graph.add_node("generate_final_summary", generate_final_summary)
graph.add_conditional_edges(START, map_summaries, ["generate_summary"])
graph.add_edge("generate_summary", "generate_final_summary")
graph.add_edge("generate_final_summary", END)
app = graph.compile()

# グラフの実行
async for step in app.astream({"contents": [doc.page_content for doc in documents]}):
    print(step)
実行結果
{'generate_summary': {'summaries': ['りんごは赤い。']}}
{'generate_summary': {'summaries': ['バナナは黄色い果物です。']}}
{'generate_summary': {'summaries': ['ブルーベリーは青い果実である。']}}
{'generate_final_summary': {'final_summary': 'りんごは赤い果物、ブルーベリーは青い果実、バナナは黄色い果物である。'}}

app.astreamでストリームモードで実行すると、ステップごとの出力結果を表示できます。

分割された文書が並列でそれぞれ要約されていることがわかります。
さらに、それらを元に要約文が生成されていることがわかります。

sample.ipynb
Image(app.get_graph().draw_mermaid_png())

実行結果

構成したグラフの図です。
STARTからmap_summaries()によって各文書がそれぞれgenerate_summaryノードに渡されて要約された後、generate_final_summaryノードに移行して最終的な要約文が生成されていることがわかります。

おわりに

本記事では、要約タスクを行いましたが、分割された各文書に対するプロンプトと、最終的にどう統合するかのプロンプトを変えられるため、他のタスクであっても長い文書に対応できそうです。

次は、Refineもやってみたいと思います。
ここまでご覧いただき、ありがとうございました。

NCDCエンジニアブログ

Discussion