LangGraphのcreate_react_agentについてのメモ

LangChainエージェントからLangGraphエージェントへの移行が促されている。エージェント系の実装はLangChainからLangGraphに移行していく流れっぽい。
参考資料

シンプルな実装例。
エージェントの出力にツール呼び出しがなくなるまで処理が続行される。
create_react_agent
の返り値の型はCompiledGraph
なので、通常のコンパイル済みGraphと同様にinvoke
等のメソッドでエージェントの実行が可能。
from langchain_core.tools import tool
from langchain_openai import ChatOpenAI
from langgraph.prebuilt import create_react_agent
# モデルの準備
model = ChatOpenAI(model="gpt-4o")
# ツールの準備
@tool
def magic_function(input: int) -> int:
"""Applies a magic function to an input."""
return input + 2
tools = [magic_function]
# クエリの準備
query = "what is the value of magic_function(3)?"
# ReactAgentExecutorの準備
app = create_react_agent(model, tools)
# ReactAgentExecutorの実行
messages = app.invoke({"messages": [("human", query)]})
{
"input": query,
"output": messages["messages"][-1].content,
}
{'input': 'what is the value of magic_function(3)?',
'output': 'The value of `magic_function(3)` is 5.'}

create_react_agent
の入力パラメータは以下の通り。
パラメータ | 型 | 説明 | デフォルト |
---|---|---|---|
model |
LanguageModelLike |
ツール呼び出しをサポートするLangChainチャットモデル。 | 必須 |
tools |
Union[ToolExecutor, Sequence[BaseTool]] |
ツールのリストまたはToolExecutor インスタンス。 |
必須 |
messages_modifier |
Optional[Union[SystemMessage, str, Callable, Runnable]] |
任意のメッセージ修正機能。メッセージがLLMに渡される前に適用される。以下の形式が可能: - SystemMessage : メッセージリストの先頭に追加される。 - str : SystemMessage に変換され、メッセージリストの先頭に追加される。 - Callable : メッセージのリストを入力として受け取り、出力が言語モデルに渡される関数。 - Runnable : メッセージのリストを入力として受け取り、出力が言語モデルに渡されるRunnable。 |
None |
checkpointer |
Optional[BaseCheckpointSaver] |
任意のチェックポイントセーバーオブジェクト。グラフの状態(例: チャットメモリ)の永続化に使用される。 | None |
interrupt_before |
Optional[Sequence[str]] |
任意の中断前ノード名のリスト。次のいずれかの値を指定可能: "agent", "tools"。アクションを実行する前にユーザー確認や他の中断を追加する場合に使用。 | None |
interrupt_after |
Optional[Sequence[str]] |
任意の中断後ノード名のリスト。次のいずれかの値を指定可能: "agent", "tools"。出力を返す前に、直接戻るか追加処理を行う場合に使用。 | None |
debug |
bool |
デバッグモードを有効にするかを示すフラグ。 | False |
返り値の型はLangGraphの型であるCompiledGraph
を返すようになっている。
戻り値 | 型 | 説明 |
---|---|---|
CompiledGraph |
CompiledGraph |
チャットインタラクションに使用できるLangChain実行可能のコンパイル済みグラフ。 |

create_react_agent
の実装は以下の場所にある。
この関数によって生成されたLangGraphエージェントのAgentState
は以下のように定義されている。
class AgentState(TypedDict):
"""The state of the agent."""
messages: Annotated[Sequence[BaseMessage], add_messages]
is_last_step: IsLastStep
invoke
を実行した際には上記のステートが返る。

create_react_agent
によって作成されるグラフの全体像。
# 新しいグラフを定義
workflow = StateGraph(AgentState)
# ループする2つのノードを定義
workflow.add_node("agent", RunnableLambda(call_model, acall_model))
workflow.add_node("tools", ToolNode(tools))
# エントリーポイントを `agent` に設定
# これは最初に呼び出されるノードを意味する
workflow.set_entry_point("agent")
# 条件付きエッジを追加
workflow.add_conditional_edges(
# 最初に開始ノードを定義。ここでは `agent` を使用。
# これは、`agent` ノードが呼び出された後に取られるエッジを意味する。
"agent",
# 次に、次にどのノードが呼び出されるかを決定する関数を渡す。
should_continue,
# 最後にマッピングを渡す。
# キーは文字列で、値は他のノードである。
# END はグラフが終了することを示す特別なノード。
# ここで行われるのは `should_continue` を呼び出し、その出力が
# このマッピングのキーと一致するかを確認すること。
# 一致したキーに対応するノードが次に呼び出される。
{
# `tools` の場合、ツールノードを呼び出す。
"continue": "tools",
# それ以外の場合は終了する。
"end": END,
},
)
# `tools` から `agent` への通常のエッジを追加
# これは `tools` が呼び出された後、次に `agent` ノードが呼び出されることを意味する。
workflow.add_edge("tools", "agent")
# 最後にコンパイルする!
# これにより、LangChainのRunnableにコンパイルされ、
# 他の実行可能なオブジェクトと同様に使用できるようになる。
return workflow.compile(
checkpointer=checkpointer,
interrupt_before=interrupt_before,
interrupt_after=interrupt_after,
debug=debug,
)

agentノードで実行される関数は以下の通り。
# モデルを呼び出す関数を定義
def call_model(
state: AgentState,
config: RunnableConfig,
):
messages = state["messages"]
response = model_runnable.invoke(messages, config)
# 最終ステップでツール呼び出しがある場合
if state["is_last_step"] and response.tool_calls:
return {
"messages": [
AIMessage(
id=response.id,
content="Sorry, need more steps to process this request.",
)
]
}
# リストを返す。これは既存のリストに追加される。
return {"messages": [response]}
# 非同期版のモデルを呼び出す関数を定義
async def acall_model(state: AgentState, config: RunnableConfig):
messages = state["messages"]
response = await model_runnable.ainvoke(messages, config)
# 最終ステップでツール呼び出しがある場合
if state["is_last_step"] and response.tool_calls:
return {
"messages": [
AIMessage(
id=response.id,
content="Sorry, need more steps to process this request.",
)
]
}
# リストを返す。これは既存のリストに追加される。
return {"messages": [response]}

tool_condition
態 (state) 内のメッセージにツール呼び出しが含まれているかどうかを判定するための関数。条件付きエッジ(conditional edge)で使用あれ、ツール呼び出しが存在する場合はtools
、存在しない場合は__end__
をそれぞれ文字列で返す。
from typing import Union, Literal, Any
from langchain_core.messages import AnyMessage
def tools_condition(
state: Union[list[AnyMessage], dict[str, Any]],
) -> Literal["tools", "__end__"]:
"""条件付きエッジで使用され、最後のメッセージにツール呼び出しが含まれている場合はToolNodeにルーティングし、
含まれていない場合は終了にルーティングする。
Args:
state (Union[list[AnyMessage], dict[str, Any]]): ツール呼び出しのチェックを行う状態。メッセージのリスト(MessageGraph)または "messages" キーを持つ(StateGraph)。
Returns:
Literal["tools", "__end__"]: 次にルーティングするノード。
"""
if isinstance(state, list):
ai_message = state[-1]
elif messages := state.get("messages", []):
ai_message = messages[-1]
else:
raise ValueError(f"tool_edge の入力状態にメッセージが見つかりません: {state}")
# ツール呼び出しがあるかどうかを判定
if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0:
return "tools"
return "__end__"
使用例
from langchain_anthropic import ChatAnthropic
from langchain_core.tools import tool
from langgraph.graph import MessageGraph
from langgraph.prebuilt import ToolNode, tools_condition
@tool
def divide(a: float, b: float) -> int:
"""Return a / b."""
return a / b
llm = ChatAnthropic(model="claude-3-haiku-20240307")
tools = [divide]
graph_builder = MessageGraph()
graph_builder.add_node("tools", ToolNode(tools))
graph_builder.add_node("chatbot", llm.bind_tools(tools))
graph_builder.add_edge("tools", "chatbot")
graph_builder.add_conditional_edges(
"chatbot", tools_condition
)
graph_builder.set_entry_point("chatbot")
graph = graph_builder.compile()
graph.invoke([("user", "What's 329993 divided by 13662?")])
あらかじめtools
という名前でノードを定義しておく必要がある。
(暗黙にtoolsという名前でノードがあることを期待している仕組みなので、あまり良い設計には思えない・・・)

ValidationNode
ツール呼び出しが指定されたスキーマに従っているかどうか検証するノード。検証に成功した場合はToolMessagesが返り、失敗した場合はLLMにフィードバックするためのエラーメッセージが返る。
使用例
from typing import Literal
from langchain_anthropic import ChatAnthropic
from langchain_core.pydantic_v1 import BaseModel, validator
from langgraph.graph import END, START, MessageGraph
from langgraph.prebuilt import ValidationNode
class SelectNumber(BaseModel):
a: int
@validator("a")
def a_must_be_meaningful(cls, v):
if v != 37:
raise ValueError("Only 37 is allowed")
return v
builder = MessageGraph()
llm = ChatAnthropic(model="claude-3-haiku-20240307").bind_tools([SelectNumber])
builder.add_node("model", llm)
builder.add_node("validation", ValidationNode([SelectNumber]))
builder.add_edge(START, "model")
def should_validate(state: list) -> Literal["validation", "__end__"]:
if state[-1].tool_calls:
return "validation"
return END
builder.add_conditional_edges("model", should_validate)
graph = builder.compile()
res = graph.invoke(("user", "Select a number"))
# 正常に検証された結果を表示
for msg in res:
msg.pretty_print()
{
"content": "{\"a\": 37}",
"name": "SelectNumber",
"tool_call_id": "unique_id"
}