🦔

【LangGraph】ツールの認識と実行を理解する

に公開

LangGraphの基本的な書き方

詳細は割愛します。
LangGraph公式チュートリアルの、Get Startedを参考に書きました

クリックして展開
"""
LangGraphのチャットボット
"""

from typing import Annotated

from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, SystemMessage
from langchain_core.tools import tool
from langchain_core.runnables import RunnableConfig

from pydantic import BaseModel, Field

from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import StateGraph, END
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode, tools_condition

from langchain_openai import ChatOpenAI

# LangSmithをセット
import os

os.environ["LANGCHAIN_TRACING_V2"] = "true"
os.environ["LANGCHAIN_PROJECT"] = os.path.basename(__file__) # ファイル名がlangsmithのプロジェクト名になります
os.environ["LANGCHAIN_ENDPOINT"] = "https://api.smith.langchain.com"


chat_openai_llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)


system_template = """
あなたはチャットボットです。
ユーザーと楽しくお話ししてください
"""

system_message = SystemMessage(content=system_template)

class State(BaseModel):
    messages: Annotated[list, add_messages] = Field(default=[system_message])


def chatbot(state: State):
    response = chat_openai_llm.invoke(state.messages)
    return {"messages": [response]}

# グラフの作成
graph_builder = StateGraph(State)

# ノードを登録
graph_builder.add_node("chatbot", chatbot)

# エッジを登録
graph_builder.set_entry_point("chatbot")
graph_builder.add_edge("chatbot", END)

# メモリの登録と、そのメモリをチェックポインターとしてグラフをコンパイルする
memory = MemorySaver()
graph = graph_builder.compile(checkpointer=memory)

# この設定は、この会話のキーとしてとして使用するスレッドを選択する
config = RunnableConfig({"configurable": {"thread_id": "1"}})

def stream_graph_updates(user_input: str):
    for event in graph.stream({"messages": [HumanMessage(content=user_input)]}, config, stream_mode="values"):
        for value in event.values():
            value[-1].pretty_print()

# グラフを実行
if __name__ == "__main__":
    for user_input in ('こんにちは','出身地は大阪','誕生日は6月12日','趣味はスノボ'):
        try:
            stream_graph_updates(user_input)
        except Exception as e:
            print(e)
            break

関数のツールを導入する

クリックして展開
from typing import Annotated

from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, SystemMessage
from langchain_core.tools import tool
from langchain_core.runnables import RunnableConfig

from pydantic import BaseModel, Field

from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import StateGraph, END
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode, tools_condition

from langchain_openai import ChatOpenAI

# LangSmithをセット
import os

os.environ["LANGCHAIN_TRACING_V2"] = "true"
os.environ["LANGCHAIN_PROJECT"] = os.path.basename(__file__) # ファイル名がlangsmithのプロジェクト名になります
os.environ["LANGCHAIN_ENDPOINT"] = "https://api.smith.langchain.com"


# ツールを定義
@tool
def record_hobby(hobby: str):
    """趣味を記録する"""
    print(f"趣味: {hobby}")
    return f"返り値: {hobby}"


chat_openai_llm = ChatOpenAI(model="gpt-4o-mini", temperature=0.5)

llm_with_tools = chat_openai_llm.bind_tools([record_hobby])

system_template = """
あなたはチャットボットです。
ユーザーと楽しくお話ししてください
"""

system_message = SystemMessage(content=system_template)

class State(BaseModel):
    messages: Annotated[list, add_messages] = Field(default=[system_message])


def chatbot(state: State):
    response = llm_with_tools.invoke(state.messages)
    return {"messages": [response]}

graph_builder = StateGraph(State)

# ノードを登録
graph_builder.add_node("chatbot", chatbot)
graph_builder.add_node("tools", ToolNode([record_hobby]))

# エッジを登録
graph_builder.set_entry_point("chatbot")
graph_builder.add_edge("chatbot", END) # これはダメ。解説参照

# メモリの登録と、そのメモリをチェックポインターとしてグラフをコンパイルする
memory = MemorySaver()
graph = graph_builder.compile(checkpointer=memory)

# この設定は、この会話のキーとしてとして使用するスレッドを選択する
config = RunnableConfig({"configurable": {"thread_id": "1"}})

def stream_graph_updates(user_input: str):
    for event in graph.stream({"messages": [HumanMessage(content=user_input)]}, config, stream_mode="values"):
        for value in event.values():
            value[-1].pretty_print()

# グラフを実行
if __name__ == "__main__":
    for user_input in ('こんにちは','出身地は大阪','誕生日は6月12日','趣味はスノボ'):
        try:
            stream_graph_updates(user_input)
        except Exception as e:
            print(e)
            break

関数のツールを用意

  • デコレーターを使って用意します。
  • 引数はAIが良しなに決めてくれます。
  • 返り値はあってもなくても構いません。(例えばDBにデータ保存する時とかはなくても良いです)
from langchain_core.tools import tool

@tool
def record_hobby(hobby: str):
    """趣味を記録する"""
    print(f"趣味: {hobby}")
    return f"返り値: {hobby}"

LLMにツールを認識させる

  • 然るべきタイミングでツールを呼び起こせるよう、LLMにバインドします

chat_openai_llm = ChatOpenAI(model="gpt-4o-mini", temperature=0.5)

llm_with_tools = chat_openai_llm.bind_tools([record_hobby])

さて、この状態で実行すると以下のような会話になるはずです。

================================ Human Message =================================

こんにちは
================================== Ai Message ==================================

こんにちは!今日はどんなことをお手伝いできますか?
================================ Human Message =================================

出身地は大阪
================================== Ai Message ==================================

大阪出身なんですね!大阪は美味しい食べ物や楽しい文化がたくさんありますよね。何か特別な思い出や好きな場所がありますか?
================================ Human Message =================================

誕生日は6月12日
================================== Ai Message ==================================

6月12日ですね!誕生日が近づくとワクワクしますよね。何か特別な計画や希望があれば教えてください!
================================ Human Message =================================

趣味はスノボ
================================== Ai Message ==================================
Tool Calls:
  record_hobby (call_lolQlFN4DnvUYpus5SzEdMzd)
 Call ID: call_lolQlFN4DnvUYpus5SzEdMzd
  Args:
    hobby: スノボ

気をつけなければならないのは、AI MessageとしてTool Calls:が記録されていますがツールは実行されてないということです。
(その証拠に、本来record_hobbyが実行されるとprint(f"趣味: {hobby}")が表示されるはずですが、表示されていません)

今の状態はツールをAIに認識させただけですので、実行させる必要があります

ツールを実行させるノードを作成

from langgraph.prebuilt import ToolNode, tools_condition # 追加

# ノードを登録
graph_builder.add_node("chatbot", chatbot)
graph_builder.add_node("tools", ToolNode([record_hobby])) # 追加

# エッジを登録
# graph_builder.add_edge("chatbot", END) # 削除
graph_builder.add_edge("tools", "chatbot") # 追加
graph_builder.add_conditional_edges("chatbot", tools_condition) # 追加

実行結果

================================ Human Message =================================

こんにちは
================================== Ai Message ==================================

こんにちは!今日はどんなことをお手伝いできますか?
================================ Human Message =================================

出身地は大阪
================================== Ai Message ==================================

大阪出身なんですね!大阪は美味しい食べ物や楽しい文化がたくさんありますよね。何か特別なことや思い出があれば教えてください!
================================ Human Message =================================

誕生日は6月12日
================================== Ai Message ==================================

6月12日なんですね!誕生日は特別な日ですし、何か毎年楽しみにしていることや、特別な思い出がありますか?
================================ Human Message =================================

趣味はスノボ
================================== Ai Message ==================================
Tool Calls:
  record_hobby (call_ia1wqtVXZ8mVbXMfn0jZM4V9)
 Call ID: call_ia1wqtVXZ8mVbXMfn0jZM4V9
  Args:
    hobby: スノボ
趣味: スノボ
================================= Tool Message =================================
Name: record_hobby

返り値: スノボ
================================== Ai Message ==================================

スノボが趣味なんですね!冬のスポーツは楽しいですよね。どのくらいの頻度でスノボに行くのですか?また、お気に入りのスキー場などがあれば教えてください!

Tool Messageが追加されており、返り値が帰ってきているのがわかります。
また、その直前のAI Messageにおいても趣味: スノボの通り、ツールが実行されているのがわかります

クラスのツールを導入する

ツールには上記で述べた関数の他、クラスを使うこともできます。

サンプルコード
from typing import Annotated

from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, SystemMessage, ToolMessage
from langchain_core.tools import tool
from langchain_core.runnables import RunnableConfig

from pydantic import BaseModel, Field

from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import StateGraph, END
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode, tools_condition

from langchain_openai import ChatOpenAI

# LangSmithをセット
import os

os.environ["LANGCHAIN_TRACING_V2"] = "true"
os.environ["LANGCHAIN_PROJECT"] = os.path.basename(__file__) # ファイル名がlangsmithのプロジェクト名になります
os.environ["LANGCHAIN_ENDPOINT"] = "https://api.smith.langchain.com"

class Profile(BaseModel):
    birthday: str
    birthplace: str


chat_openai_llm = ChatOpenAI(model="gpt-4o-mini", temperature=0.5)

llm_with_tools = chat_openai_llm.bind_tools([Profile])

system_template = """
あなたはチャットボットです。
ユーザーと楽しくお話ししてください
"""

system_message = SystemMessage(content=system_template)

class State(BaseModel):
    messages: Annotated[list, add_messages] = Field(default=[system_message])


def chatbot(state: State):
    response = llm_with_tools.invoke(state.messages)    
    return {"messages": [response]}


graph_builder = StateGraph(State)

# ノードを登録
graph_builder.add_node("chatbot", chatbot)
graph_builder.add_node("tools", ToolNode([Profile]))

# エッジを登録
graph_builder.set_entry_point("chatbot")
# graph_builder.add_edge("chatbot", END)
graph_builder.add_edge("tools", "chatbot")
graph_builder.add_conditional_edges("chatbot", tools_condition)

# メモリの登録と、そのメモリをチェックポインターとしてグラフをコンパイルする
memory = MemorySaver()
graph = graph_builder.compile(checkpointer=memory)

# この設定は、この会話のキーとしてとして使用するスレッドを選択する
config = RunnableConfig({"configurable": {"thread_id": "1"}})

def stream_graph_updates(user_input: str):
    global is_break
    for event in graph.stream({"messages": [HumanMessage(content=user_input)]}, config, stream_mode="values"):
        for value in event.values():
            value[-1].pretty_print()


# グラフを実行
if __name__ == "__main__":
    for user_input in ('こんにちは','出身地は大阪','誕生日は6月12日','趣味はスノボ'):
        try:
            stream_graph_updates(user_input)

        except Exception as e:
            print(e)
            break

以下のようにBaseModelを継承したProfileクラスを用意します。(TypedDict継承でも良いですが、私は分かりやすさの観点からBaseModelを使うことが多いです)

class Profile(BaseModel):
    birthday: str
    birthplace: str


chat_openai_llm = ChatOpenAI(model="gpt-4o-mini", temperature=0.5)

llm_with_tools = chat_openai_llm.bind_tools([Profile])

ダメな例ですが、toolsノードを用意せずにchatbotノードから直ENDに結びつけて実行してみましょう。

# エッジを登録
graph_builder.set_entry_point("chatbot")
graph_builder.add_edge("chatbot", END)

以下のようにエラーが出ます。

================================ Human Message =================================

こんにちは
================================== Ai Message ==================================

こんにちは!今日はどのようにお手伝いできますか?
================================ Human Message =================================

出身地は大阪
================================== Ai Message ==================================

ありがとうございます!出身地が大阪ですね。何か特別なことについてお話ししたいことがありますか?例えば、大阪の文化や観光スポット、食べ物などについてですか?それとも他のことについて知りたいですか?
================================ Human Message =================================

誕生日は6月12日
================================== Ai Message ==================================
Tool Calls:
  Profile (call_wlBAXoMbbzcBAsivPwepvs6e)
 Call ID: call_wlBAXoMbbzcBAsivPwepvs6e
  Args:
    birthday: 6月12日
    birthplace: 大阪
================================ Human Message =================================

趣味はスノボ
Error code: 400 - {'error': {'message': "An assistant message with 'tool_calls' must be followed by tool messages responding to each 'tool_call_id'. The following tool_call_ids did not have response messages: call_wlBAXoMbbzcBAsivPwepvs6e", 'type': 'invalid_request_error', 'param': 'messages.[6].role', 'code': None}}

Errorの意味としては、ツールが呼び出された後に、そのツールの実行結果をメッセージとして追加する必要があることを示しています。

先ほどと同じようにtoolsノードを定義して実行しましょう。

# エッジを登録
graph_builder.set_entry_point("chatbot")
graph_builder.add_edge("tools", "chatbot")
graph_builder.add_conditional_edges("chatbot", tools_condition)

実行結果。Tool Messageが追加されます。

================================ Human Message =================================

こんにちは
================================== Ai Message ==================================

こんにちは!今日はどのようにお手伝いできますか?
================================ Human Message =================================

出身地は大阪
================================== Ai Message ==================================

出身地は大阪ですね。何か特別なことについてお話ししたいことがありますか?それとも、他にお手伝いできることがありますか?
================================ Human Message =================================

誕生日は6月12日
================================== Ai Message ==================================
Tool Calls:
  Profile (call_9wDMqZ21CcYR13wD7tRxTt0k)
 Call ID: call_9wDMqZ21CcYR13wD7tRxTt0k
  Args:
    birthday: 6月12日
    birthplace: 大阪
================================= Tool Message =================================
Name: Profile

birthday='6月12日' birthplace='大阪'
================================== Ai Message ==================================

あなたのプロフィールは以下の通りです:

- 誕生日: 6月12日
- 出身地: 大阪

何か他に知りたいことやお手伝いできることはありますか?
================================ Human Message =================================

趣味はスノボ
================================== Ai Message ==================================

スノーボードが趣味なんですね!冬のスポーツとしてとても楽しいですよね。どのようなスノーボードのスタイルが好きですか?また、行きつけのスキー場などはありますか?

出身地を言ったときにはツールが呼び出されておらず、誕生日を言ったときに初めて呼び出されています。

データが揃ったら、会話を終了させる

これを利用して、例えば出身地と誕生日さえ聞き出したら会話をストップさせることもできます

is_break = False # 追加

def stream_graph_updates(user_input: str):
    global is_break
    for event in graph.stream({"messages": [HumanMessage(content=user_input)]}, config, stream_mode="values"):
        for value in event.values():
            value[-1].pretty_print()
            # AIがProfileを呼び出したら、is_breakをTrueにする
            if isinstance(value[-1], AIMessage) and value[-1].additional_kwargs.get('tool_calls'):
                for tool_call in value[-1].additional_kwargs['tool_calls']:
                    if tool_call['function']['name'] == "Profile":
                        is_break = True

# グラフを実行
if __name__ == "__main__":
    for user_input in ('こんにちは','出身地は大阪','誕生日は6月12日','趣味はスノボ'):
        try:
            stream_graph_updates(user_input)
            if is_break: # 追加
                break    # 追加

        except Exception as e:
            print(e)
            break

大事な部分はツール呼び出しを取得する以下の部分です。
ツールを呼び出すのはAIで、呼び出されたツールのリストはadditional_kwargs['tool_calls']に格納されているので、ツール一つ一つの名前をチェックしてProfileがあればbreakさせています

isinstance(value[-1], AIMessage) and value[-1].additional_kwargs.get('tool_calls')

実行結果

================================ Human Message =================================

こんにちは
================================== Ai Message ==================================

こんにちは!今日はどのようなことをお手伝いできますか?
================================ Human Message =================================

出身地は大阪
================================== Ai Message ==================================

出身地が大阪なんですね!大阪について何か特別なことや、知りたいことがありますか?それとも、他にお手伝いできることがあれば教えてください。
================================ Human Message =================================

誕生日は6月12日
================================== Ai Message ==================================
Tool Calls:
  Profile (call_VWPrcxAicjrzcdkED5LVJaAL)
 Call ID: call_VWPrcxAicjrzcdkED5LVJaAL
  Args:
    birthday: 6月12日
    birthplace: 大阪
================================= Tool Message =================================
Name: Profile

birthday='6月12日' birthplace='大阪'
================================== Ai Message ==================================

あなたの誕生日は6月12日で、出身地は大阪ですね!何か特別なことを知りたいですか?例えば、誕生日に関連するイベントや大阪の文化についてなど、どんなことでもお答えします!

まとめ

  • LangGraphのtoolsには、関数とクラスを登録することができる。
  • toolsはllmにバインドして認識させるのと、toolsノードで実行させる2つの手順が必要
  • 関数は何かしらのアクションをしたい時に使う
  • クラスは構造化データに使う。クラスが呼び出されるのは、全ての項目が埋まった時。

Discussion