Open8

LangGraphのcreate_react_agentについてのメモ

mah_labmah_lab

シンプルな実装例。
エージェントの出力にツール呼び出しがなくなるまで処理が続行される。
create_react_agentの返り値の型はCompiledGraphなので、通常のコンパイル済みGraphと同様にinvoke等のメソッドでエージェントの実行が可能。

from langchain_core.tools import tool
from langchain_openai import ChatOpenAI
from langgraph.prebuilt import create_react_agent

# モデルの準備
model = ChatOpenAI(model="gpt-4o")

# ツールの準備
@tool
def magic_function(input: int) -> int:
    """Applies a magic function to an input."""
    return input + 2

tools = [magic_function]

# クエリの準備
query = "what is the value of magic_function(3)?"

# ReactAgentExecutorの準備
app = create_react_agent(model, tools)

# ReactAgentExecutorの実行
messages = app.invoke({"messages": [("human", query)]})

{
    "input": query,
    "output": messages["messages"][-1].content,
}
{'input': 'what is the value of magic_function(3)?',
 'output': 'The value of `magic_function(3)` is 5.'}
mah_labmah_lab

create_react_agentの入力パラメータは以下の通り。

パラメータ 説明 デフォルト
model LanguageModelLike ツール呼び出しをサポートするLangChainチャットモデル。 必須
tools Union[ToolExecutor, Sequence[BaseTool]] ツールのリストまたはToolExecutorインスタンス。 必須
messages_modifier Optional[Union[SystemMessage, str, Callable, Runnable]] 任意のメッセージ修正機能。メッセージがLLMに渡される前に適用される。以下の形式が可能: - SystemMessage: メッセージリストの先頭に追加される。 - str: SystemMessageに変換され、メッセージリストの先頭に追加される。 - Callable: メッセージのリストを入力として受け取り、出力が言語モデルに渡される関数。 - Runnable: メッセージのリストを入力として受け取り、出力が言語モデルに渡されるRunnable。 None
checkpointer Optional[BaseCheckpointSaver] 任意のチェックポイントセーバーオブジェクト。グラフの状態(例: チャットメモリ)の永続化に使用される。 None
interrupt_before Optional[Sequence[str]] 任意の中断前ノード名のリスト。次のいずれかの値を指定可能: "agent", "tools"。アクションを実行する前にユーザー確認や他の中断を追加する場合に使用。 None
interrupt_after Optional[Sequence[str]] 任意の中断後ノード名のリスト。次のいずれかの値を指定可能: "agent", "tools"。出力を返す前に、直接戻るか追加処理を行う場合に使用。 None
debug bool デバッグモードを有効にするかを示すフラグ。 False

返り値の型はLangGraphの型であるCompiledGraphを返すようになっている。

戻り値 説明
CompiledGraph CompiledGraph チャットインタラクションに使用できるLangChain実行可能のコンパイル済みグラフ。
mah_labmah_lab

create_react_agentの実装は以下の場所にある。

https://github.com/langchain-ai/langgraph/blob/main/langgraph/prebuilt/chat_agent_executor.py#L168

この関数によって生成されたLangGraphエージェントのAgentStateは以下のように定義されている。

class AgentState(TypedDict):
    """The state of the agent."""
    messages: Annotated[Sequence[BaseMessage], add_messages]
    is_last_step: IsLastStep

invokeを実行した際には上記のステートが返る。

mah_labmah_lab

create_react_agentによって作成されるグラフの全体像。

# 新しいグラフを定義
workflow = StateGraph(AgentState)

# ループする2つのノードを定義
workflow.add_node("agent", RunnableLambda(call_model, acall_model))
workflow.add_node("tools", ToolNode(tools))

# エントリーポイントを `agent` に設定
# これは最初に呼び出されるノードを意味する
workflow.set_entry_point("agent")

# 条件付きエッジを追加
workflow.add_conditional_edges(
    # 最初に開始ノードを定義。ここでは `agent` を使用。
    # これは、`agent` ノードが呼び出された後に取られるエッジを意味する。
    "agent",
    # 次に、次にどのノードが呼び出されるかを決定する関数を渡す。
    should_continue,
    # 最後にマッピングを渡す。
    # キーは文字列で、値は他のノードである。
    # END はグラフが終了することを示す特別なノード。
    # ここで行われるのは `should_continue` を呼び出し、その出力が
    # このマッピングのキーと一致するかを確認すること。
    # 一致したキーに対応するノードが次に呼び出される。
    {
        # `tools` の場合、ツールノードを呼び出す。
        "continue": "tools",
        # それ以外の場合は終了する。
        "end": END,
    },
)

# `tools` から `agent` への通常のエッジを追加
# これは `tools` が呼び出された後、次に `agent` ノードが呼び出されることを意味する。
workflow.add_edge("tools", "agent")

# 最後にコンパイルする!
# これにより、LangChainのRunnableにコンパイルされ、
# 他の実行可能なオブジェクトと同様に使用できるようになる。
return workflow.compile(
    checkpointer=checkpointer,
    interrupt_before=interrupt_before,
    interrupt_after=interrupt_after,
    debug=debug,
)
mah_labmah_lab

agentノードで実行される関数は以下の通り。

# モデルを呼び出す関数を定義
def call_model(
    state: AgentState,
    config: RunnableConfig,
):
    messages = state["messages"]
    response = model_runnable.invoke(messages, config)
    # 最終ステップでツール呼び出しがある場合
    if state["is_last_step"] and response.tool_calls:
        return {
            "messages": [
                AIMessage(
                    id=response.id,
                    content="Sorry, need more steps to process this request.",
                )
            ]
        }
    # リストを返す。これは既存のリストに追加される。
    return {"messages": [response]}

# 非同期版のモデルを呼び出す関数を定義
async def acall_model(state: AgentState, config: RunnableConfig):
    messages = state["messages"]
    response = await model_runnable.ainvoke(messages, config)
    # 最終ステップでツール呼び出しがある場合
    if state["is_last_step"] and response.tool_calls:
        return {
            "messages": [
                AIMessage(
                    id=response.id,
                    content="Sorry, need more steps to process this request.",
                )
            ]
        }
    # リストを返す。これは既存のリストに追加される。
    return {"messages": [response]}
mah_labmah_lab

tool_condition

態 (state) 内のメッセージにツール呼び出しが含まれているかどうかを判定するための関数。条件付きエッジ(conditional edge)で使用あれ、ツール呼び出しが存在する場合はtools、存在しない場合は__end__をそれぞれ文字列で返す。

from typing import Union, Literal, Any

from langchain_core.messages import AnyMessage


def tools_condition(
    state: Union[list[AnyMessage], dict[str, Any]],
) -> Literal["tools", "__end__"]:
    """条件付きエッジで使用され、最後のメッセージにツール呼び出しが含まれている場合はToolNodeにルーティングし、
    含まれていない場合は終了にルーティングする。

    Args:
        state (Union[list[AnyMessage], dict[str, Any]]): ツール呼び出しのチェックを行う状態。メッセージのリスト(MessageGraph)または "messages" キーを持つ(StateGraph)。

    Returns:
        Literal["tools", "__end__"]: 次にルーティングするノード。
    """
    if isinstance(state, list):
        ai_message = state[-1]
    elif messages := state.get("messages", []):
        ai_message = messages[-1]
    else:
        raise ValueError(f"tool_edge の入力状態にメッセージが見つかりません: {state}")

    # ツール呼び出しがあるかどうかを判定
    if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0:
        return "tools"
    return "__end__"

使用例

from langchain_anthropic import ChatAnthropic
from langchain_core.tools import tool
from langgraph.graph import MessageGraph
from langgraph.prebuilt import ToolNode, tools_condition

@tool
def divide(a: float, b: float) -> int:
    """Return a / b."""
    return a / b

llm = ChatAnthropic(model="claude-3-haiku-20240307")
tools = [divide]

graph_builder = MessageGraph()
graph_builder.add_node("tools", ToolNode(tools))
graph_builder.add_node("chatbot", llm.bind_tools(tools))
graph_builder.add_edge("tools", "chatbot")
graph_builder.add_conditional_edges(
    "chatbot", tools_condition
)
graph_builder.set_entry_point("chatbot")
graph = graph_builder.compile()
graph.invoke([("user", "What's 329993 divided by 13662?")])

あらかじめtoolsという名前でノードを定義しておく必要がある。
(暗黙にtoolsという名前でノードがあることを期待している仕組みなので、あまり良い設計には思えない・・・)

mah_labmah_lab

ValidationNode

ツール呼び出しが指定されたスキーマに従っているかどうか検証するノード。検証に成功した場合はToolMessagesが返り、失敗した場合はLLMにフィードバックするためのエラーメッセージが返る。

使用例

from typing import Literal
from langchain_anthropic import ChatAnthropic
from langchain_core.pydantic_v1 import BaseModel, validator
from langgraph.graph import END, START, MessageGraph
from langgraph.prebuilt import ValidationNode

class SelectNumber(BaseModel):
    a: int

    @validator("a")
    def a_must_be_meaningful(cls, v):
        if v != 37:
            raise ValueError("Only 37 is allowed")
        return v

builder = MessageGraph()
llm = ChatAnthropic(model="claude-3-haiku-20240307").bind_tools([SelectNumber])
builder.add_node("model", llm)
builder.add_node("validation", ValidationNode([SelectNumber]))
builder.add_edge(START, "model")

def should_validate(state: list) -> Literal["validation", "__end__"]:
    if state[-1].tool_calls:
        return "validation"
    return END

builder.add_conditional_edges("model", should_validate)

graph = builder.compile()
res = graph.invoke(("user", "Select a number"))

# 正常に検証された結果を表示
for msg in res:
    msg.pretty_print()
{
    "content": "{\"a\": 37}",
    "name": "SelectNumber",
    "tool_call_id": "unique_id"
}