😺

[LangGraph] Sendを使った動的なNodeの並列実行

2024/12/02に公開

はじめに

こんにちは。PharmaXでエンジニアをしている諸岡(@hakoten)です。

この記事では、LangGraphを活用してアプリケーションを実装する際に、動的にノードを並列実行する方法である Send APIについて解説します。

※厳密には「並行 (concurrent)」と「並列 (parallel)」は異なる動作を指しますが、この記事ではLangGraphのノードを非同期で同時に処理するという意味で「並列」という言葉を使用しています。ご了承ください。

LangGraphにおけるNodeの並列実行の基本

ここでは、LangGraphでのNodeの並列実行について簡単に解説します。
LangGraphでは、同じステップに定義されたNodeが並列に実行されます。

例えば、以下のグラフでは、node_anode_bnode_cは、start_nodeの実行直後に並列で処理され、それぞれの結果がマージされた状態(State)としてend_nodeに渡されます。

この並列実行の仕組みは、以下のようにstart_nodeから同じステップ内のNodeとしてnode_anode_bnode_cを定義することで実現できます。

...
graph_builder.add_node('start_node', start_node)
graph_builder.add_node('node_a', node_a)
graph_builder.add_node('node_b', node_b)
graph_builder.add_node('node_c', node_c)
graph_builder.add_node('end_node', end_node)

graph_builder.set_entry_point('start_node')
graph_builder.add_edge('start_node', 'node_a')
graph_builder.add_edge('start_node', 'node_b')
graph_builder.add_edge('start_node', 'node_c')
graph_builder.add_edge(['node_a', 'node_b', 'node_c'], 'end_node')
graph_builder.set_finish_point('end_node')
...
全てのコード
from operator import add
from typing import Annotated

from langgraph.graph import StateGraph
from typing_extensions import TypedDict

class State(TypedDict):
    path: Annotated[list[str], add]

graph_builder = StateGraph(State)

def start_node(state: State) -> State:
    return {'path': ['start_node']}

def node_a(state: State) -> State:
    print('log ----> a start')
    return {'path': ['node_a']}

def node_b(state: State) -> State:
    print('log ----> b start')
    return {'path': ['node_b']}

def node_c(state: State) -> State:
    print('log ----> c start')
    return {'path': ['node_c']}

def end_node(state: State) -> State:
    return {'path': ['end_node']}

graph_builder.add_node('start_node', start_node)
graph_builder.add_node('node_a', node_a)
graph_builder.add_node('node_b', node_b)
graph_builder.add_node('node_c', node_c)
graph_builder.add_node('end_node', end_node)

graph_builder.set_entry_point('start_node')
graph_builder.add_edge('start_node', 'node_a')
graph_builder.add_edge('start_node', 'node_b')
graph_builder.add_edge('start_node', 'node_c')
graph_builder.add_edge(['node_a', 'node_b', 'node_c'], 'end_node')
graph_builder.set_finish_point('end_node')

graph = graph_builder.compile()
graph.invoke({'path': []})

より詳しい内容は、以下の記事も参考にしてください。

https://zenn.dev/pharmax/articles/78f2e6a51a459e

Sendによる動的な並列Node処理

前述の add_edge を用いた並列処理の定義方法は、処理するNodeの数が事前に決まっている場合には有効です。

例えば、外部APIからの結果の数に応じて並列で実行するノードの数が動的に変化するようなユースケースでは、add_edge を使って実行Nodeを事前に定義することができません。

こうした場合には、Send APIと add_conditional_edges を組み合わせることで、動的なNodeの実行を実現できます。

Sendの使い方の概要

ここでは、以下のようなグラフを例に、Send を活用した使い方を解説します。このグラフでは、parallel_node が並列で実行されるNodeとして設定されています。

上記のグラフを実現するコード例は次の通りです。

実際のコード
# グラフ全体のState
class OverallState(TypedDict):
    paths: Annotated[list[str], operator.add]


# 並列ノード間でのみ値を受け渡すState
class ParallelState(TypedDict):
    node_path: str
    paths: Annotated[list[str], operator.add]


graph_builder = StateGraph(OverallState)


def start_node(state: OverallState, config: RunnableConfig):
    print(f'start_node: {state}')
    return {'paths': ['start_node']}


def parallel_node(state: ParallelState, config: RunnableConfig):
    print(f'parallel_node: {state}')
    return {'paths': [state['node_path']]}


def end_node(state: OverallState, config: RunnableConfig):
    print(f'end_node: {state}')
    return {'paths': ['end_node']}


def routing_parallel_node(state: OverallState):
    # 並列で実行したいNodeの名前を第一引数に指定する
    # 第二引数には、並列で実行したいNodeに渡したいStateを指定する(今回は、グラフ全体のStateに並列ノードの名前を追加して渡す)
    return [Send('parallel_node', state | {'node_path': f'parallel_node_{i + 1}'}) for i in range(3)]


graph_builder.add_node('start_node', start_node)
# 動的に実行するノードをadd_nodeで追加しておく
graph_builder.add_node('parallel_node', parallel_node)
graph_builder.add_node('end_node', end_node)

graph_builder.set_entry_point('start_node')
# 条件付きエッジで、Sendを指定する関数を定義
graph_builder.add_conditional_edges('start_node', routing_parallel_node, ['parallel_node'])
graph_builder.add_edge('parallel_node', 'end_node')
graph_builder.set_finish_point('end_node')

graph = graph_builder.compile()

print(graph.invoke({'paths': []}))

このコードでは、start_node の実行後に routing_parallel_node 関数を使って動的に並列ノードが生成されます。それぞれの parallel_node は独立して実行され、最終的に結果が end_node に渡されます。

1.動的に実行するNodeを add_node で登録する

まずは、通常のノードを実行する場合と同様に、Graphに対して add_node を用いてNodeを登録する必要があります。

# グラフ全体のState
class OverallState(TypedDict):
    paths: Annotated[list[str], operator.add]

# 並列ノード間でのみ値を受け渡すState
class ParallelState(TypedDict):
    node_path: str
    paths: Annotated[list[str], operator.add]

graph_builder = StateGraph(OverallState)

...

# 並列で実行する対象のNode
def parallel_node(state: ParallelState, config: RunnableConfig):
    print(f'parallel_node: {state}')
    return {'paths': [state['node_path']]}

...
# 動的に実行するノードをadd_nodeで追加しておく
graph_builder.add_node('parallel_node', parallel_node)
...

2.add_conditional_edgesで前のステップのNode終了後に動的なNodeの実行を定義する

Send を使用するのは、add_conditional_edges で指定する関数の中です。この関数では、次に実行するNodeを条件付きで定義することができます。具体的には、この関数内で Send クラスのインスタンスをリストで返すことで、特定のNodeを動的に並列実行することができます。

Send クラスのコンストラクタ(イニシャライザ)は、以下の通りです。

  • 第一引数: 実行するNodeの名前を指定します。
  • 第二引数: 実行するNodeに渡す引数(State)を指定します。
...
graph_builder = StateGraph(OverallState)

def parallel_node(state: ParallelState, config: RunnableConfig):
    print(f'parallel_node: {state}')
    return {'paths': [state['node_path']]}

def routing_parallel_node(state: OverallState):
    # 並列で実行したいNodeの名前を第一引数に指定する
    # 第二引数には、並列で実行したいNodeに渡したいStateを指定する(今回は、グラフ全体のStateに並列ノードの名前を追加して渡す)
    return [Send('parallel_node', state | {'node_path': f'parallel_node_{i}'}) for i in range(3)]

...
# 動的に実行するノードをadd_nodeで追加しておく
graph_builder.add_node('parallel_node', parallel_node)
...
# 条件付きエッジで、Sendを指定する関数を定義
graph_builder.add_conditional_edges('start_node', routing_parallel_node, ['parallel_node'])

実行結果

以下は、このグラフを実行した際の結果です。動的に生成された3つの Send インスタンスにより、parallel_node が3回実行されていることが確認できます。

start_node: {'paths': []}
parallel_node: {'paths': ['start_node'], 'node_path': 'parallel_node_1'}
parallel_node: {'paths': ['start_node'], 'node_path': 'parallel_node_2'}
parallel_node: {'paths': ['start_node'], 'node_path': 'parallel_node_3'}
end_node: {'paths': ['start_node', 'parallel_node_1', 'parallel_node_2', 'parallel_node_3']}
{'paths': ['start_node', 'parallel_node_1', 'parallel_node_2', 'parallel_node_3', 'end_node']}

実行の流れは以下の通りです。

  1. 前ステップのNode(start_node)が実行される
  2. parallel_node_1 ~ parallel_node_3までが並列に実行される
  3. parallel_node_1 ~ parallel_node_3の実行結果がグラフ全体のStateにマージされる
  4. end_nodeが実行され、結果がマージされたStateが渡される

Sendを使う時のポイント

Send を利用する際に注意すべきポイントを紹介します。

実行グラフのStateは、Sendで実行されるNodeには自動的に渡されない

通常、LangGraphのNodeでは、Graph全体のStateが前のステップから次のステップへと自動的に引き継がれます。 しかし、Send で実行されるNodeにはGraph全体のStateが渡されません

そのため、Send を使用する際には、必要なStateを明示的に指定してNodeに渡す必要があります。

前述の例では、paths という共通のStateを更新するために、グラフ全体のStateを意図的にNodeに渡しています。

def routing_parallel_node(state: OverallState):
    # グラフ全体のStateをSendで指定し、Nodeに渡す
    return [Send('parallel_node', state | {'node_path': f'parallel_node_{i}'}) for i in range(3)]

上記のように、グラフのStateを利用して、Nodeを実行する場合は、Sendに明示的にパラメタを渡す必要があります。

これは、別の言い方をすると Send で実行されるNodeは、グラフ全体のStateとは無関係な独立したStateを持てるということでもあります。

以下のコードは、並列のNodeのStateとグラフ全体のStateを別で定義しています。

 # グラフ全体のState
class OverallState(TypedDict):
    paths: Annotated[list[str], operator.add]

# 並列Nodeでのみ参照するState
class ParallelState(TypedDict):
    parallel_paths: Annotated[list[str], operator.add]

graph_builder = StateGraph(OverallState)

def start_node(state: OverallState, config: RunnableConfig):
    return {'paths': ['start_node']}

# Sendで指定する並列Nodeは全体のStateとは独立してState更新ができる
def parallel_node(state: ParallelState, config: RunnableConfig):
    return {'parallel_paths': state['node_path']}

このようにSendで実行するNodeは、グラフ全体のStateに依存せず、自身のStateのみを参照して独立して動作することが可能です。

Sendで実行されるNodeの戻り値のプロパティは、GraphのStateにも定義する必要がある

Send によって実行されるNodeのStateは独立していますが、その結果を次のNodeで利用する場合には、グラフ全体のStateにもその結果のプロパティが定義されている必要があります。

以下の例では、parallel_paths というプロパティを OverallState(グラフ全体のState)と ParallelState(並列Node専用のState)の両方に定義しています。

class OverallState(TypedDict):
    paths: Annotated[list[str], operator.add]
    # 並列ノードが処理した結果を受け取るために共通のプロパティ定義が必要
    parallel_paths: Annotated[list[str], operator.add]

class ParallelState(TypedDict):
    parallel_paths: Annotated[list[str], operator.add]

graph_builder = StateGraph(OverallState)

def start_node(state: OverallState, config: RunnableConfig):
    return {'paths': ['start_node']}

def parallel_node(state: ParallelState, config: RunnableConfig):
    return {'parallel_paths': [state['node_path']]}

LangGraphでは、ステップ終了時に同じプロパティを持つStateの値が自動的にマージされます。この挙動は、グラフを階層的に呼び出すSubGraphでも共通です。

したがって、Send を使ったNodeの結果を次のNodeで利用する場合は、グラフ全体のStateと並列NodeのStateの両方でプロパティが定義されているか?を確認してください。

Stateの更新にはReducerが必要

Send に限った話ではありませんが、並列実行されるNodeを使用する際には、Stateにreducerを指定する必要があります。

class ParallelState(TypedDict):
    # Annotatedでreducerを指定する
    parallel_paths: Annotated[list[str], operator.add]

公式ページ

LangGraphの公式ページでは、Sendを使ったMapReduceの実装サンプルが紹介されています。
より実践的な使い方を知りたい場合は、ぜひこちらも参照してください。

https://langchain-ai.github.io/langgraph/how-tos/map-reduce/

その他の並列処理の選択肢

LangGraphにおいて、Sendを使わずに動的に並列処理を行うには、Pythonのconcurrent.futuresを始めとする非同期APIを利用するか、LangChainの並列処理の仕組みである RunnableParallelを使うといった選択肢が考えられます。

Send を使用する場合は、Nodeの事前宣言やStateの管理が必要なため、ユースケースによっては上記の方法の方が適している場合があります。そのため、アプリケーションの要件に応じて柔軟に選択することをおすすめします。

一方で、Nodeとして実装したい機能が明確で、LangGraph内での並列処理や結果のマージを簡潔に実現したい場合は、Send を活用するのが良いかと思います。

終わりに

この記事では、LangGraphで動的にNodeの並列処理を行う方法(Send API)について、解説しました。この記事が、少しでも皆さんの参考になれば幸いです。

PharmaXでは、AIやLLMに関連する技術の活用を積極的に進めています。もし、この記事に興味を持たれた方や、LangGraphの活用に関心がある方は、ぜひ私のXアカウント(@hakoten)やコメントで気軽にお声がけください。PharmaXのエンジニアチームで一緒に働けることを楽しみにしています。

まずはカジュアルにお話できることを楽しみにしています!

PharmaXテックブログ

Discussion