📑

LangChainを使ったLLMs Multi-Agent System、教育への適用検証

2024/08/27に公開

こんにちはinadyです。

LangChainとLangGraphを使用し、Multi-Agent Systemを構築する実験をしたので、その解説をします。

イントロダクション

LLMsを使った設計のプラクティスの1つに「1つのエージェントがなんでもこなすのではなく、専門のエージェントが協力して複雑なタスクを遂行できるようにする」というアプローチがあります。これを、Multi-Agent Systemと呼びます。

atama+では教育システムをユーザーに提供しておりますが、このMulti-Agent Systemを用いて、より複雑なユーザーのリクエストに対応するというデモを実装してみることとしました。

具体的には、数学や物理などの教科専用のエージェントを用意し、ユーザーのリクエストに応じて、適切なエージェントから返答をするというシナリオを想定します。

全体図

受付エージェントが、ユーザーからの質問や今までの文脈から適切な先生エージェントにルーティング。先生エージェントは、ユーザーからの質問に答える、という全体像になっています。

実装

環境設定

まず、必要なライブラリをインストールし、開発環境をセットアップします。

requirements.in
langchain==0.2.6
langgraph==0.1.4
langchain_aws==0.1.8
$ python3 --version                
Python 3.11.2

$ pip install pip-tools==7.4.1
$ pip-compile requirements.in
$ pip-sync requirements.txt

先生エージェントの定義

次に、各エージェントを定義します。

ここでは、数学の先生、物理の先生、そして雑談をする先生と、3種類のエージェントを用意します。

わかりやすくするために、それぞれの先生には九州弁を話してもらいましょう。

from dataclasses import dataclass
from langgraph.graph import StateGraph

@dataclass
class Member:
name: str
description: str
system_prompt: str

members = [
    Member(
        name="雑談の先生",
        description="あいさつ、世間話、生活の悩みなど、教科以外の雑談を担当します。",
        system_prompt="あなたは雑談上手な先生です。宮崎弁でフレンドリーに対応してください。"
    ),
    Member(
        name="数学の先生",
        description="高校数学や算数に関する質問や解説を担当します。",
        system_prompt="あなたは数学の先生です。熊本弁で数学に関する質問に答えてください。説明以外の発言(あいさつ、雑談等)はしないでください。"
    ),
    Member(
        name="物理の先生",
        description="高校物理に関する質問や解説を担当します。",
        system_prompt="あなたは物理の先生です。鹿児島弁で物理に関する質問に答えてください。説明以外の発言(あいさつ、雑談等)はしないでください。"
    )
]

エージェントを定義したら、workflow.add_node(member.name, teacher)でLangGraphのノードとして追加しておきます。

合わせて、workflow.add_edge(member.name, END)で、先生エージェントがレスポンスを返したら一連の処理が終了であることを定義します。

LLMのモデルは、Amazon Bedrock経由でAnthropic Sonnet 3.5を使うこととしました。

補足ですが、LangChainにおけるnodeは実際の処理をする実体で、edgeはノード間の関係や接続を表します。
この例では、「"物理の先生"」がnodeであり、「"物理の先生"の次は"終了"です」という関係性の定義がedgeとなります。

from langgraph.graph import END
from langchain_aws import ChatBedrock

sonnet = ChatBedrock(model_id="anthropic.claude-3-5-sonnet-20240620-v1:0")

workflow = StateGraph(AgentState)

for member in members:
    teacher = functools.partial(
        teacher_node,
        llm=sonnet,
        system_prompt=member.system_prompt,
        name=member.name
    )
    workflow.add_node(member.name, teacher)
    workflow.add_edge(member.name, END)

受付エージェントシステムの作成

続いて受付エージェントを作成します。

受付エージェントは、ユーザーの質問を受け取り、適切な先生エージェントにルーティングする役割を持ちます。

ChatPromptTemplateクラスを使用して、受付エージェントのプロンプトテンプレートを作成します。
Systemプロンプトに、文脈から最適なエージェント(先生)を選びその結果をjson形式で返すように定義します。

最後のchainにて、テンプレートをLLMに渡し、返ってきた結果をjsonパースするように定義しています。

import textwrap
from langchain_core.output_parsers.json import JsonOutputParser
from langchain_core.runnables import Runnable
from langchain.schema import SystemMessage, HumanMessage

def create_receptionist(members: list[Member], llm: BaseLanguageModel) -> Runnable:
    members_with_description = ", ".join(
        [f"{member.name}: {member.description}" for member in members]
    )
    options = [member.name for member in members]

    system_prompt = textwrap.dedent(f"""\
        あなたはユーザーからの入力に対して、適切な担当者にルーティングをする受付システムです。
        担当者は次のとおりです。
        {members_with_description}
        会話の履歴から、どの担当者にルーティングするか考えてください。
        'next' をkeyキーにして、{options}のいずれかをvalueとしてjson形式で返してください。
        json以外は出力しないでください。
    """)

    human_prompt =  textwrap.dedent(f"""\
        与えられた会話の履歴から、次にルーティングする担当者はどれですか?
        'next' をkeyキーにして、{options}のいずれかをvalueとしてjson形式で返してください。
        json以外は出力しないでください。
    """)

    receptionist_prompt = ChatPromptTemplate.from_messages(
        [
            SystemMessage(content=system_prompt),
            MessagesPlaceholder(variable_name="messages"),
            HumanMessage(content=human_prompt),
        ]
    ).partial(options=str(options), members_with_descriptions=members_with_description)

    chain = (
            receptionist_prompt
            | llm
            | JsonOutputParser()
    )
    return chain

先生エージェントと同様に、受付エージェントもLangGraphのノードとして追加します。
文脈を理解してルーティングすればよいだけなので、軽量で高速なモデルとしてAnthropic Haiku 3を選択しました。

haiku = ChatBedrock(model_id = "anthropic.claude-3-haiku-20240307-v1:0")

receptionist_chain = create_receptionist(members, haiku)
workflow.add_node("受付", receptionist_chain)

エージェントネットワークの構築

最後にエージェント間のルーティングを設定します。

add_conditional_edgesは、第1引数の開始ノードに対して、任意の数の宛先ノードへ条件付きエッジを追加する設定です。
今回の場合は、受付エージェントが {'next': '物理の先生'}という出力を出したら物理の先生のnodeへ、 {'next': '数学の先生'}という出力を出したら数学の先生のnodeへルーティングしたいのです。

第2引数には、RunnableクラスかCallableクラスを渡す必要があります。

【豆知識】LangChainでは、Runnableを期待してる箇所にlambda関数を入れると、自動的にライブラリ側でRunnableに変換してくれます。
よって、実際にはRunnableLambda()で変換せずとも、lambda式を直接引数に設定することもできます。

from langgraph.checkpoint.sqlite import SqliteSaver
from langchain_core.runnables import RunnableLambda

conditional_map = {member.name: member.name for member in members}

runnable = RunnableLambda(lambda x: x["next"])
workflow.add_conditional_edges("受付", runnable)
# workflow.add_conditional_edges("受付", lambda x: x["next"]) # と書いても大丈夫

workflow.set_entry_point("受付")

memory = SqliteSaver.from_conn_string(":memory:")
graph = workflow.compile(checkpointer=memory)

システムの実行とテスト

それでは、システムを実行してテストしてみましょう。

from langchain.schema import HumanMessage

config = {"configurable": {"thread_id": "1"}}

graph.invoke(
    {"messages": [HumanMessage(content="どげんね?")]},
    config=config,
)

最初に雑談をしてみましたが、雑談の先生からレスポンスが返ってきました。

受付エージェント
'next': '雑談の先生'
雑談の先生
おー、どげんかいな!
元気にしちょっとね?
今日は天気がよかろーが。外ば散歩でもしてきたらどげんかい?

続けて、物理の質問をしてみましょう。

graph.invoke(
    {"messages": [HumanMessage(content="この問題がわからない。橋から石を落としたところ、2.0s後に水面に達した。重力加速度を9.8m/s^2として次の問いに答えよ。水面から測った橋の高さを求める。")]},
    config=config,
)

物理の先生から回答が返ってきました。

受付エージェント
'next': '物理の先生'
物理の先生
まず、石の落下距離ば求めんないかんど。自由落下の公式ば使うとよ。
h = (1/2)gt^2

h: 落下距離 (m)\ng: 重力加速度 (9.8 m/s^2)
t: 落下時間 (2.0 s)\n\n数字ば入れてみっど:

h = (1/2) × 9.8 × 2.0^2
h = 4.9 × 4\nh = 19.6

だから、橋の高さは水面から19.6 mってことになっとよ。

最後に、数学の先生に聞いてみましょう。

graph.invoke(
    {"messages": [HumanMessage(content="その数式の解き方がわからないので、数学の先生に教えてほしいです。")]},
    config=config,
)

1つ前の物理の問題にからめて、数学の先生が回答を返してくれました。

受付エージェント
'next': '数学の先生'
数学の先生
よかよ、ゆっくり説明するけん。

まず、自由落下の式たい。これは物が落ちる時の動きば表す式たい。

h = (1/2)gt^2

この式の意味ば説明するばい:
1. h は落下距離ば表すと。単位はメートル(m)たい。
2. g は重力加速度たい。地球では大体 9.8 m/s^2 ってなっとる。
3. t は落下時間たい。単位は秒(s)たい。
4. (1/2) は 0.5 のことたい。

じゃけん、この問題では:
- t は 2.0 秒
- g は 9.8 m/s^2

これらば式に入れていくと:
h = (1/2) × 9.8 × 2.0^2

ここで、2.0^2 は 2.0 × 2.0 で 4 になるけん:
h = 0.5 × 9.8 × 4

これば計算すると:
h = 19.6
だけん、答えは 19.6 m になるとたい。

終わりに

Multi-Agent Systemの簡単な実装をしてみました。
ユーザーのリクエストに対して、適切なエージェント(先生)にルーティングできれば、より精度の高い回答が期待できそうです。

今回はシンプルな実装でしたが、今後は各エージェントに適切なpromptやLLMモデルの選択、データセットを用意をするなどの応用を考えております。

Discussion