LangGraphの基本的な要素を日本語でわかりやすくまとめる
はじめに
LangGraphの基本要素であるgraph,state,node,edgeについてまとめます。
クイックスタート見ながら雰囲気でやっているけれどそろろそ限界、という方の助けになればと思います。
内容
LangGraph公式のLow Level Conceptual Guideをもとに、LangGraphの主要要素のgraph,state,node,edgeについてまとめます。
せっかくなので以下のクイックスタート part2のコードを参考に、少し変更もいれながら進めたいと思います。
今回使うコードの全文は以下です。
quick_start_custom.py
import getpass
import os
from typing import Annotated
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_openai import ChatOpenAI
from langgraph.graph import StateGraph
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode, tools_condition
from typing_extensions import TypedDict
from langgraph.graph import START,END
def _set_if_undefined(var: str) -> None:
# 環境変数が未設定の場合、ユーザーに入力を促す
if not os.environ.get(var):
os.environ[var] = getpass.getpass(f"Please provide your {var}")
# 必要な環境変数を設定
_set_if_undefined("OPENAI_API_KEY")
_set_if_undefined("LANGCHAIN_API_KEY")
_set_if_undefined("TAVILY_API_KEY")
# Optional, add tracing in LangSmith
os.environ["LANGCHAIN_TRACING_V2"] = "true"
os.environ["LANGCHAIN_PROJECT"] = "LangGraph Tutorial"
class State(TypedDict):
messages: Annotated[list, add_messages]
graph_builder = StateGraph(State)
tool = TavilySearchResults(max_results=2)
tools = [tool]
llm = ChatOpenAI(model="gpt-4o")
llm_with_tools = llm.bind_tools(tools)
# stateを返すノード
# def chatbot(state: State)
# state["messages"].append(llm_with_tools.invoke(state["messages"]))
# return state
# stetaの更新箇所のみを返すノード
def chatbot(state: State):
return {"messages": [llm_with_tools.invoke(state["messages"])]}
graph_builder.add_node("chatbot", chatbot)
tool_node = ToolNode(tools=[tool])
graph_builder.add_node("tools", tool_node)
graph_builder.add_conditional_edges(
"chatbot",
tools_condition,
)
graph_builder.add_edge("tools", "chatbot")
graph_builder.set_entry_point("chatbot")
graph = graph_builder.compile()
# edgeをシンプルに定義したバージョン
# graph_builder.add_edge(START, "chatbot")
# graph_builder.add_edge("chatbot", "tools")
# graph_builder.add_edge("tools", END)
# graph = graph_builder.compile()
# 画像を保存する
graph_image = graph.get_graph(xray=True).draw_mermaid_png()
with open("graph.png", "wb") as f:
f.write(graph_image)
for event in graph.stream({"messages": [("user", "こんにちは")]}, stream_mode="values"):
print(event["messages"])
graphとしてはユーザからの問い合わせに応じて適宜Web検索をするchatbotです。
toolsにはtavily apiを利用してWeb検索するツールが登録されており、chatbotはtoolを使うかどうかを判断して使う必要があればtoolsを使い、なければ終了します。エージェントの動作内容自体は詳しくは解説しませんので気になる方はQuick Startをご一読ください。
graph
まずは一番大きな概念のGraphです。これはエージェントのワークフローをモデル化したものです。クイックスタートpart2の例でいうと以下です。
Quick Start Part2のGraph
graphには大きくstate,node,edgeの3要素が含まれます。それぞれ詳しくは後述しますが、ざっくりいうと
-
state: エージェントの状態を定義したもの(データ)です。上記のgraphの図には描画されていません。このgraphにinputとして与えられ、各処理のinput/outputとなりながら更新され、最終的にoutputされるものです。
-
node: chatbotやtoolsなどgraphの中の各処理です。stateを受け取り、自分の処理によってstateを更新し、stateを返します。
-
edge: 各処理のつながりを定義するものです。最初はAノード、Aノードの次はBノード、Bノードのあとは終わり、などです。条件分岐も定義できます。
state
おそらくここが一番重要なのではないかと思います。
stateとはエージェントの状態を定義しているクラスと公式では説明されていますが、私の理解ではgraphにinputとして渡され、各nodeを経由しながら中身が更新され、最終的なアウトプットになるというエージェントフロー全体のデータそのものです。
node間で受け渡ししたい、共有したい情報が入ります。
state自体はTypedDictまたはPydantic BaseModelであることが多いです。要するにdict形式のクラスです。stateに定義されるもっとも代表的な要素はmepsagesでクイックスタートでも以下のようにlist形式のmessagesが定義されています。
class State(TypedDict):
messages: Annotated[list, add_messages]
※add_messagesの部分は口述するのでいったん無視ししてください。
たとえばクイックスタートで「こんにちは」と入れるとweb検索は必要ないので、そのままchatbotがレスポンスを作成します。
for event in graph.stream({"messages": [("user", "こんにちは")]}, stream_mode="values"):
print(event["messages"])
その場合、messagesは以下のようになります。(eventはclassをdit型にしたものが出力されます)
[
HumanMessage(content='こんにちは', id='a0211b85-2569-4a3f-b963-8bf37d29692c'),
AIMessage(content='こんにちは!今日はどんなお手伝いが必要ですか?', additional_kwargs={'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 15, 'prompt_tokens': 80, 'total_tokens': 95}, 'model_name': 'gpt-4o-2024-05-13', 'system_fingerprint': 'fp_25624ae3a5', 'finish_reason': 'stop', 'logprobs': None}, id='run-995dcd80-f44e-4a6d-94ad-e6b2fba85208-0', usage_metadata={'input_tokens': 80, 'output_tokens': 15, 'total_tokens': 95})
]
最初の入力のHumanMessageに対してchtbotノードがAIMessageを追加しています。
このstateクラスを基にしたエージェントフローを構築するという意味で、graphはStateで初期化します。
graph_builder = StateGraph(State)
ちなみに今回のようにstateがmessagesしかない場合はStateGraphではなくMessageGraphを使ってもいいようです。ただ、公式でもほとんどのアプリケーションでは状態がもっと複雑なためchatbotくらいでしか使われないと記載されています。
This class is rarely used except for chatbots, as most applications require the State to be more complex than a list of messages.
node
nodeはgraph内における1つ1つの処理の塊を示します。
実態としてはstateを引数として受け取り、更新したstateを返す関数です。関数であればよいので、単純にテキスト生成するchainでもagentでもルールベースのロジックでもよいです。
たとえば、chatbotであれば以下のように定義できます。
def chatbot(state: State):
state["messages"].append(llm_with_tools.invoke(state["messages"]))
return state
ただし、LangGraphは親切なので、stateの中の更新したいキー、バリューだけでもよいです。
def chatbot(state: State):
return {"messages": [llm_with_tools.invoke(state["messages"])]}
これはすごく便利なのですが、最初に構造を理解する際にはかえってわかりづらいかもしれません。
ちなみにクイックスタートの中だとtoolsはToolNodeというもともとlangchainで定義されているノードを使っています。
tools_by_name = {tool.name: tool for tool in tools}
def tool_node(state: dict):
result = []
for tool_call in state["messages"][-1].tool_calls:
tool = tools_by_name[tool_call["name"]]
observation = tool.invoke(tool_call["args"])
result.append(ToolMessage(content=observation, tool_call_id=tool_call["id"]))
return {"messages": result}
stateを引数として受け取り、その直近のmessagesから実行するtoolを取得し、そのtoolを実行して結果を返すという処理を行っています。
これも事前にToolNodeが用意されているのは簡単なのですが、初期の理解のためにはわかりづらいかもしれません。
nodeをgraphに追加するときは以下のようにします。
graph_builder.add_node("chatbot", chatbot)
State > reducers
ここでいったんstateに話を戻します。後述します、と言っていたadd_messagesの部分です。
class State(TypedDict):
messages: Annotated[list, add_messages]
nodeがreturnするときに
return {"messages": [llm_with_tools.invoke(state["messages"])]}
のように返すと普通はmessagesは上書きされてしまい、直近の内容しか残らなくなってしまいます。すると、のちのnodeが処理するときにそれ以前の内容を参照できなくなってしまいます。
今回の例でいうと、toolsノードが処理をした段階でmessagesの内容がToolMessageの内容(toolの実行結果)になってしまい、その次にその内容をもとに回答を作成するはずのchatbotノードは最初にユーザからなんと聞かれたかわからない状態になってしまいます。
そのため各nodeがmessagesの内容を返すときは上書きではなく、追加するようにしたいです。
このようにstateの各要素をどのように更新するかを定義するのがreducersです。
クイックスタートで定義されているadd_messagesはlangchainに用意されているもので、messagesを追加するものです。
Examples:
```pycon
>>> from langchain_core.messages import AIMessage, HumanMessage
>>> msgs1 = [HumanMessage(content="Hello", id="1")]
>>> msgs2 = [AIMessage(content="Hi there!", id="2")]
>>> add_messages(msgs1, msgs2)
[HumanMessage(content='Hello', id='1'), AIMessage(content='Hi there!', id='2')]
edge
最後にedgeです。これはnode同士のつながりを示すものです。
たとえばクイックスタートの例を少しいじって以下のように定義します。
graph_builder.add_edge(START, "chatbot")
graph_builder.add_edge("chatbot", "tools")
graph_builder.add_edge("tools", END)
graph = graph_builder.compile()
すると以下のようなgraphになります。
edgeを簡単に定義した時のgraph
これは.add_edge
で簡単にSTARTのあとはchatbot、chatbotのあとはtools、toolsのあとはENDという順番で処理を行うように定義しています。これはとても分かりやすいです。
もちろん、条件分岐もできます。条件分岐をしているのがクイックスタートのadd_conditional_edge
の部分です。ここではchatbotの返り値からtools_condition関数を使って次のnodeを判断しています。
graph_builder.add_conditional_edges(
"chatbot",
tools_condition,
)
graph_builder.add_edge("tools", "chatbot")
graph_builder.set_entry_point("chatbot")
graph = graph_builder.compile()
※STARTの部分が.set_entry_point
メソッドで指定されていますがSTART
と同じようなものです。
またこのtools_conditionもlangchainで事前用意されいるのですが中身は以下です。
def tools_condition(
state: Union[list[AnyMessage], dict[str, Any], BaseModel],
) -> Literal["tools", "__end__"]:
:
if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0:
return "tools"
return "__end__"
chatbotの出力したstateから直近の出力(chatbotの出力、AIMessage)を参照してtool_calls
が指定されていたらtools
を返し、それ以外は__end__
を返します。
tools
ならtoolsノードの転送され、__end__
なら__end__に転送され終了します。
おわりに
いかがでしょうか。
この要素以外にもcheckpointerなど重要な要素が他にもありますが、私はこれらの要素を理解するだけでとてもクリアになりました。
LangGraphはいろいろな処理がラッパーされており、コードを書くのは楽な反面、処理の理解は逆にしづらい部分もあります。
ですが、英語ではあるものの、Low Level Conceptual Guideに丁寧に書いてあるのでそこを読むだけでもかなり理解が進むと思います。
ご興味ある方は読んでみてください。
Discussion