📑

【LangGraph】状態更新と条件分岐

に公開

前回の記事に続き、LangGraphの忘れてしまいがちなとこをメモしました
https://zenn.dev/yuta_enginner/articles/eebd539a877f69

今回はこのように、ユーザーの年齢を聞いて、ユーザーが30歳未満であればタメ口で話し、30歳以上なら敬語で話すbotを作成します。

サンプルコード

コード全体
from typing import Annotated, Dict, Any, Literal, TypedDict
from pydantic import BaseModel, Field

from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, SystemMessage, ToolMessage
from langchain_core.runnables import RunnableConfig
from langchain_core.tools import InjectedToolCallId, tool

from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import StateGraph, END, START
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode, tools_condition
from langgraph.types import Command

from langchain_openai import ChatOpenAI

# LangSmithをセット
import os

os.environ["LANGCHAIN_TRACING_V2"] = "true"
os.environ["LANGCHAIN_PROJECT"] = os.path.basename(__file__) # ファイル名がlangsmithのプロジェクト名になります
os.environ["LANGCHAIN_ENDPOINT"] = "https://api.smith.langchain.com"


llm = ChatOpenAI(model="gpt-4o-mini", temperature=0.5)

class Age(TypedDict):
    age: int

@tool
def set_age(birthday: str, tool_call_id: Annotated[str, InjectedToolCallId]) -> Command:
    """ユーザーが生年月日が入力した時に呼び出すツール"""
    response = llm.with_structured_output(Age).invoke(f"生年月日から年齢を計算してください。生年月日:{birthday}")
    return Command(update={
        "messages": [ToolMessage(content=f"年齢は{response['age']}歳です。", tool_call_id=tool_call_id)],
        "age": response["age"]
    })


tools = [set_age]


chat_openai_llm = ChatOpenAI(model="gpt-4o-mini", temperature=0.5)

llm_with_tools = chat_openai_llm.bind_tools(tools)


class State(BaseModel):
    messages: Annotated[list, add_messages] = Field(default=[])
    age: int = Field(default=0)  # 0は年齢不明、それ以外は具体的な年齢
 

def router(state: State) -> str:
    """年齢に基づいて適切なノードに振り分けるルーター"""
    if state.age == 0:
        return "initial_bot"
    else:
        return "younger_bot" if state.age < 30 else "older_bot"    

def initial_bot(state: State):
    """年齢を確認する初期会話ノード"""
    template = """
        あなたはチャットボットです。
        ユーザーと楽しくお話ししてください。
        ユーザーの年齢を確認するために、自然な会話の中で年齢を聞いてください。
    """
    if not any(isinstance(msg, SystemMessage) for msg in state.messages):
        state.messages.insert(0, SystemMessage(content=template))
    response = llm_with_tools.invoke(state.messages)    
    return {"messages": [response]}

def younger_bot(state: State):
    """30歳未満向けの会話ノード"""
    template = """
        あなたはチャットボットです。
        ユーザーは30歳未満なので、フレンドリーなくだけた口調で話してください。関西弁にしてください。
    """
    state.messages.insert(0, SystemMessage(content=template))
    response = llm_with_tools.invoke(state.messages)    
    return {"messages": [response]}

def older_bot(state: State):
    """30歳以上向けの会話ノード"""
    template = """
        あなたはチャットボットです。
        ユーザーは30歳以上なので、丁寧語で敬意を持って話してください。
    """
    state.messages.insert(0, SystemMessage(content=template))
    response = llm_with_tools.invoke(state.messages)
    return {"messages": [response]}


graph_builder = StateGraph(State)

# ノードを登録
graph_builder.add_node("initial_bot", initial_bot)
graph_builder.add_node("younger_bot", younger_bot)
graph_builder.add_node("older_bot", older_bot)
graph_builder.add_node("tools", ToolNode(tools))

# エッジを登録
graph_builder.add_conditional_edges(
    START,
    router,
    {"initial_bot": "initial_bot", "younger_bot": "younger_bot", "older_bot": "older_bot"}
)
graph_builder.add_edge("tools", "initial_bot")
graph_builder.add_conditional_edges("initial_bot", tools_condition)

# メモリの登録と、そのメモリをチェックポインターとしてグラフをコンパイルする
memory = MemorySaver()
graph = graph_builder.compile(checkpointer=memory)

# この設定は、この会話のキーとしてとして使用するスレッドを選択する
config = RunnableConfig({"configurable": {"thread_id": "1"}})

def stream_graph_updates(user_input: str):
    for event in graph.stream({"messages": [HumanMessage(content=user_input)]}, config, stream_mode="values"):
        for value in event.values():
            try:
                value[-1].pretty_print()
            except:
                pass


def draw_graph():
    """グラフを描画する"""
    graph.get_graph().draw_mermaid_png(output_file_path="graph.png")

解説

トップのState

状態管理する変数はトップのStateに入れる必要があります。今回はageを追加しました

class State(BaseModel):
    messages: Annotated[list, add_messages] = Field(default=[])
    age: int = Field(default=0)  # 0は年齢不明、それ以外は具体的な年齢

ノードとエッジの登録

エントリーポイントとして、状態を更新しないダミーノードを置いても良いのですが、STARTノードを使うとダミーノードを省略できます。

def dummy_node(state:State):
  """エントリーポイントのためのダミーノード"""
  return {}

state.ageの状態により、"initial_bot"、 "younger_bot"、"older_bot"のいずれかに振り分けるrouterをconditional_edgeとして設けます。

ユーザーが生年月日を入力したら、年齢を登録する関数(set_age)をツールとして用意します。
set_ageは年齢を聞いていない時のbotであるinitial_botからしか呼び出さないはずなので、initial_botと結合させます

# ノードを登録
graph_builder.add_node("initial_bot", initial_bot)
graph_builder.add_node("younger_bot", younger_bot)
graph_builder.add_node("older_bot", older_bot)
graph_builder.add_node("tools", ToolNode(tools))

# エッジを登録
graph_builder.add_conditional_edges(
    START,
    router,
    {"initial_bot": "initial_bot", "younger_bot": "younger_bot", "older_bot": "older_bot"}
)
graph_builder.add_edge("tools", "initial_bot")
graph_builder.add_conditional_edges("initial_bot", tools_condition)

トップのStateを更新する関数ツール

生年月日が入力されたら、年齢を計算するツールを作ります。
claude3.7やgemini2.5に聞いたら、正規表現とかで年齢を出そうとしてきますが、LLM使ったほうが簡単です。
with_structured_outputで数値を出すようにしておけば良いでしょう。

この関数の引数にtool_call_idを渡す必要があります。
Annotated[str, InjectedToolCallId]を入れてやります。

この関数はCommandを返す必要があります。
Commandのupdate引数にageを入れてやると、トップのStateのageが更新されます。

from langgraph.prebuilt import ToolNode, tools_condition
from langgraph.types import Command

class Age(TypedDict):
    age: int

@tool
def set_age(birthday: str, tool_call_id: Annotated[str, InjectedToolCallId]) -> Command:
    """ユーザーが生年月日が入力した時に呼び出すツール"""
    response = llm.with_structured_output(Age).invoke(f"生年月日から年齢を計算してください。生年月日:{birthday}")
    return Command(update={
        "messages": [ToolMessage(content=f"年齢は{response['age']}歳です。", tool_call_id=tool_call_id)],
        "age": response["age"]
    })

実行結果

年下と分かった途端、急に馴れ馴れしくなるウザいAI botができました

# グラフを実行
if __name__ == "__main__":
    draw_graph()
    print("==== テスト1 1990年生まれ ====")
    for user_input in ['こんにちは', '私は1990年6月1日生まれです', '出身は大阪です','趣味は料理です']:
        stream_graph_updates(user_input)
    
    print("\n\n")
    print("==== テスト2 2000年生まれ ====")
    config = RunnableConfig({"configurable": {"thread_id": "2"}})
    for user_input in ['こんにちは', '私は2000年6月1日生まれです', '出身は大阪です','趣味はスノボです']:
        stream_graph_updates(user_input)
クリックで展開
==== テスト1 1990年生まれ ====
================================ Human Message =================================

こんにちは
================================== Ai Message ==================================

こんにちは!今日はどんなことを話しましょうか?
================================ Human Message =================================

私は1990年6月1日生まれです
================================== Ai Message ==================================
Tool Calls:
  set_age (call_3GqZRPP1gPvPPZmc0Ez1PrWh)
 Call ID: call_3GqZRPP1gPvPPZmc0Ez1PrWh
  Args:
    birthday: 1990-06-01
================================= Tool Message =================================
Name: set_age

年齢は33歳です。
================================== Ai Message ==================================

ありがとうございます!あなたは33歳なんですね。何か特別な趣味や好きなことはありますか?
================================ Human Message =================================

出身は大阪です
================================== Ai Message ==================================

大阪出身なのですね!大阪は美味しい食べ物や楽しい文化がたくさんありますね。特に好きな料理や観光スポットはありますか?
================================ Human Message =================================

趣味は料理です
================================== Ai Message ==================================

料理が趣味なのですね!素晴らしいです。どのような料理を作るのが好きですか?また、特に得意な料理があれば教えていただけますか?



==== テスト2 2000年生まれ ====
================================ Human Message =================================

こんにちは
================================== Ai Message ==================================

こんにちは!今日はどんなことを話しましょうか?何か興味のあることや質問がありますか?
================================ Human Message =================================

私は2000年6月1日生まれです
================================== Ai Message ==================================
Tool Calls:
  set_age (call_5Xz9mPLycfQnhyusUtdCkHiT)
 Call ID: call_5Xz9mPLycfQnhyusUtdCkHiT
  Args:
    birthday: 2000-06-01
================================= Tool Message =================================
Name: set_age

年齢は23歳です。
================================== Ai Message ==================================

ありがとうございます!あなたは23歳なんですね。素敵な年齢ですね!最近の趣味や好きなことはありますか?
================================ Human Message =================================

出身は大阪です
================================== Ai Message ==================================

おお、大阪なんや!やっぱり大阪はええとこやなぁ。たこ焼きとかお好み焼き、美味しいもんいっぱいあるし、なんかおすすめのスポットとかある?
================================ Human Message =================================

趣味はスノボです
================================== Ai Message ==================================

スノボ好きなんや!めっちゃ楽しそうやなぁ。雪山で滑るのは最高やろ?どこか行ったことあるスキー場とかあるん?

補足

graph.get_state(config)で、現在の状態を取得することができます

    print("==== テスト1 1990年生まれ ====")
    for user_input in ['こんにちは', '私は1990年6月1日生まれです', '出身は大阪です','趣味は料理です']:
        stream_graph_updates(user_input)

    # 状態を取得する
    state = graph.get_state(config=config)
    for key, value in state.values.items():
        if key == "age":
            print(key, value)

Discussion