🎓

LangGraphにおけるStateの分割方法

2024/10/02に公開

はじめに

こんにちは。PharmaXでエンジニアをしている諸岡(@hakoten)です。

今回は、LLMアプリケーション開発をサポートするグラフ管理ツールであるLangGraphのStateの管理方法について、いくつかの事例を交えながらご紹介します。

なお、LangGraphそのものの使い方に興味のある方は、こちらの記事もぜひご参照ください。

https://zenn.dev/pharmax/articles/8796b892eed183

環境

この記事執筆時点では、以下のバージョンで実施しています。
とくにLangChain周りは非常に開発速度が早いため、現在の最新バージョンを合わせてご確認ください

  • langgraph: 0.2.28
  • Python: 3.12.4

LangGraphのStateの基本的な使い方

まずは、LangGraphにおけるStateの基本的な使い方について簡単に説明します。

LangGraphでは、StateGraph というクラスを使ってグラフを初期化します。基本的なStateの宣言は次のようになります。

(StateGraphの宣言例)

from typing_extensions import TypedDict
from langgraph.graph import StateGraph

# Stateを宣言
class State(TypedDict):
    value: str

# Stateを引数としてGraphを初期化
graph = StateGraph(State)

ここで定義したStateは、StateGraph内の各Nodeで利用されます。

from typing import Annotated
from typing_extensions import TypedDict
from langchain_core.runnables import RunnableConfig
from langgraph.graph import StateGraph

...

# Nodeを宣言
# Node関数の第一引数にStateが渡されます
def node(state: State, config: RunnableConfig):
    return {"value": "1"}
    
graph = StateGraph(State)
# ノードをグラフに追加
graph.add_node("node", node)

そして、グラフを実行する際には、初期値としてStateをInputに渡します。

from typing import Annotated
from typing_extensions import TypedDict
from langchain_core.runnables import RunnableConfig
from langgraph.graph import StateGraph

...
    
graph = StateGraph(State)
graph.add_node("node", node)
graph_builder.set_entry_point("node")

# Graphの実行(引数にはStateの初期値を渡す)
graph.invoke({"value": ""})
全てのコード
from langgraph.graph import StateGraph
from langgraph.graph.graph import RunnableConfig
from typing_extensions import TypedDict

class State(TypedDict):
    value: str


# Nodeを宣言
def node(state: State, config: RunnableConfig):
    return {'value': '1'}


# Graphの作成
graph_builder = StateGraph(State)

# Nodeの追加
graph_builder.add_node('node', node)

# Graphの始点を宣言
graph_builder.set_entry_point('node')

# Graphをコンパイル
graph = graph_builder.compile()

# Graphの実行(引数にはStateの初期値を渡す)
graph.invoke({'value': ''})

このように、最初にStateGraphで定義したStateを各ノードで利用し、更新していくのが基本的な使い方です。しかし、この方法ではグラフの規模が大きくなると、Stateに含まれるプロパティが増えすぎてしまい、管理が煩雑になるという問題が発生します。

そこで、以降では、グラフで参照するStateをいくつかの方法で分割し、効率的に管理する手法をご紹介します。なお、これらの方法はLangGraphのHow-to Guidesでも解説されていますので、ぜひそちらも併せてご参照ください。

InputとOutputを分割する

StateGraphに渡すStateは、そのままの状態だと、Inputで使うプロパティとOutputで使うプロパティが混在してしまうため、グラフが大きくなると管理が煩雑になるという問題があります。

この問題を解決するために、StateGraphのコンストラクタにはinputoutputの2つの引数が用意されています。これらを使うことで、Inputで使用するState(invokeの引数として渡されるState)と、Outputで使用するStateを分離して指定することが可能です。

# inputのState
class InputState(TypedDict):
    input_value: str

# outputのState
class OutputState(TypedDict):
    output_value: str

# 全てのプロパティを持つ型を作りstate_schemaに渡す
class OverallState(InputState, OutputState):
    pass

# Graphの作成
graph_builder = StateGraph(
    state_schema=OverallState,
    input=InputState,
    output=OutputState,
)
...
# nodeには、InputStateが渡される
def node(state: InputState, config: RunnableConfig):
    # OutPutStateに書き込む
    return {'output_value': '2'}
...
# invokeの引数には、InputStateを渡す
print(graph.invoke({'input_value': '1'}))
{'output_value': '2'}

inputとoutputのスキーマを指定することで、invokeの実行結果はOutputStateのプロパティだけにフィルタリングされるため、結果が明確になります。また、invokeの引数に何を渡すべきかが明確になるため、Stateの管理がより簡単になります。

全てのコード
from langgraph.graph import StateGraph
from langgraph.graph.graph import RunnableConfig
from typing_extensions import TypedDict

class InputState(TypedDict):
    input_value: str

class OutputState(TypedDict):
    output_value: str

class OverallState(InputState, OutputState):
    pass

# Graphの作成
graph_builder = StateGraph(
    state_schema=OverallState,
    input=InputState,
    output=OutputState,
)

# Nodeを宣言
def node(state: InputState, config: RunnableConfig):
    return {'output_value': '2'}

# Nodeの追加
graph_builder.add_node('node', node)

# Graphの始点を宣言
graph_builder.set_entry_point('node')

# Graphをコンパイル
graph = graph_builder.compile()

# Graphの実行(引数にはStateの初期値を渡す)
print(graph.invoke({'input_value': '1'}))

ノード間でしか扱わないStateを分割する

Stateの中には特定のノード間でのみ使用される値が存在する場合があります。そういったプロパティは、InputやOutputには含めずに、そのノード内だけで使用する「プライベートなState」として扱うことが可能です。

class InputState(TypedDict):
    input_value: str

class OutputState(TypedDict):
    output_value: str

class OverallState(InputState, OutputState):
    pass

# Node間で使う中間の状態
class PrivateState(TypedDict):
    private_value: str

graph_builder = StateGraph(
    state_schema=OverallState,
    input=InputState,
    output=OutputState,
)

# PrivateStateは、nodeと node2の間でのみ値を受け渡す
def node(state: InputState, config: RunnableConfig):
    print(f'node: {state}')
    # PrivateStateに書き込み
    return {'private_value': '2'}

# Nodeの引数としてPrivateStateを受け取る
def node2(state: PrivateState, config: RunnableConfig):
    print(f'node2: {state}')
    return {'output_value': '3'}
...
print(graph.invoke({'input_value': '1'}))

実行結果は以下のとおりです。node2には、PrivateStateの値だけがフィルタリングされて渡されていることが確認できます。

node: {'input_value': '1'}
node2: {'private_value': '2'}
{'output_value': '3'}

このように、LangGraphでは、StateGraphに指定したState以外でも、各ノードの関数の引数に特定のStateを指定することができます。

特定のノード間でのみ使用するStateを別のクラスに分けることで、全体のStateの構造がより見やすくなり、管理しやすくなります。

全てのコード
from langgraph.graph import StateGraph
from langgraph.graph.graph import RunnableConfig
from typing_extensions import TypedDict

class InputState(TypedDict):
    input_value: str

class OutputState(TypedDict):
    output_value: str

class OverallState(InputState, OutputState):
    pass

class PrivateState(TypedDict):
    private_value: str

graph_builder = StateGraph(
    state_schema=OverallState,
    input=InputState,
    output=OutputState,
)

# PrivateStateは、nodeと node2の間でのみ値を受け渡す
def node(state: InputState, config: RunnableConfig):
    print(f'node: {state}')
    # PrivateStateに書き込み
    return {'private_value': '2'}

# PrivateStateは、nodeと node2の間でのみ値を受け渡す
def node2(state: PrivateState, config: RunnableConfig):
    print(f'node2: {state}')
    return {'output_value': '3'}

# Nodeの追加
graph_builder.add_node('node', node)
graph_builder.add_node('node2', node2)

# edgeの定義
graph_builder.set_entry_point('node')
graph_builder.add_edge('node', 'node2')

# Graphをコンパイル
graph = graph_builder.compile()

# Graphの実行(引数にはStateの初期値を渡す)
print(graph.invoke({'input_value': '1'}))

ノードへのStateの受け渡しの仕組みをコードで確認する

Nodeに渡されるStateは、PrivateStateのように、指定された型のプロパティでフィルタリングされた状態で渡されます。このフィルタリングがどのように行われているのか、少し調べてみました。

各スキーマの保持方法

StateGraphの初期化時に渡される state_schemainputoutput は、それぞれ内部で個別のスキーマ情報として保持されます。

(実装コードはこのあたりです)
https://github.com/langchain-ai/langgraph/blob/716d23e2699957ddf32117ec321ad716c6337c27/libs/langgraph/langgraph/graph/state.py#L174-L204

この情報は、StateGraphのコンストラクタ時に呼び出される _add_schema メソッドによって保存されます。

ノード追加(add_node)の時にもスキーマは保存される

ノードを追加する際の add_node メソッドでは、関数のシグネチャをチェックし、その情報をスキーマに登録しています。

(実装コードはこのあたりです)
https://github.com/langchain-ai/langgraph/blob/716d23e2699957ddf32117ec321ad716c6337c27/libs/langgraph/langgraph/graph/state.py#L341-L355

この仕組みにより、StateGraphの初期値に含まれていないプライベートなStateについても、適切にフィルタリングされているようです。

Node実行時は必要なプロパティのみが取得されて引数に渡される

ノードが実行される際、保存されているスキーマ情報から、そのノードに必要なプロパティだけが抽出され、Input(State)として渡されます。

(実装コードはこのあたりです)
https://github.com/langchain-ai/langgraph/blob/716d23e2699957ddf32117ec321ad716c6337c27/libs/langgraph/langgraph/pregel/algo.py#L585-L629

処理が深い部分にあるため、すべてを詳細に読み解いてはいませんが、最終的に必要なプロパティだけが抽出されて渡されているようです。

おわりに

以上、LangGraphにおけるStateの扱い方について紹介しました。グラフ構造の中で一つのStateをそのまま使い続けると、どのノードでどのプロパティが利用されているかが分かりづらくなりがちです。今回ご紹介した方法を活用して、Stateの管理を効率化を考えてみると良いと思います。

PharmaXでは、AIやLLMに関連する技術の活用を積極的に進めています。もし、この記事が興味を引いた方や、LangGraphの活用に関心がある方は、ぜひ私のXアカウント(@hakoten)やコメントで気軽にお声がけください。PharmaXのエンジニアチームで一緒に働けることを楽しみにしています。

まずはカジュアルにお話できることを楽しみにしています!

PharmaXテックブログ

Discussion