🍇

【LangGraph】LangGraphの勉強ことはじめ

2024/07/15に公開

Graphの概念

公式よりもこちらの解説のほうがムダなく説明してくれて非常にわかりやすいです。
https://www.youtube.com/watch?v=LSCgHdSEbqI

toolが2つあった時の書き方

from langchain_core.tools import tool
from langchain_core.messages import ToolMessage, AIMessage
from langgraph.graph import END

@tool
def fake_database_api(query: str) -> str:
    """パーソナル情報を格納したデータベースを検索するAPI"""
    return "にゃんたは毎日8時間寝ます"

# ツールを追加
@tool
def fake_database_api2(query: str) -> str:
    """今日の番組情報を格納したデータベースを検索するAPI"""
    return "12時からOsakan Hot 100、15時からChillin Sundayです"

class State(TypedDict):
    messages: Annotated[list, add_messages]


llm = ChatOpenAI(temperature=0, model="gpt-3.5-turbo")
llm_with_tools = llm.bind_tools([fake_database_api, fake_database_api2]) # 変更

def llm_agent(state):
    state["messages"].append(llm_with_tools.invoke(state["messages"]))
    return state

# 関数toolの書き方が少し変わります
def tool(state):
    tool_by_name = {
        "fake_database_api": fake_database_api,
        "fake_database_api2": fake_database_api2
    }
    last_message = state["messages"][-1]
    for tool_call in last_message.tool_calls:
        tool_function = tool_by_name[tool_call["name"]]
        tool_output = tool_function.invoke(tool_call["args"])
        state["messages"].append(ToolMessage(content=tool_output, tool_call_id=tool_call["id"]))
    return state

# routerも書き方が少し変わります
def router(state):
    last_message = state["messages"][-1]
    if isinstance(last_message, AIMessage) and last_message.tool_calls:
        return "tool"
    else:
        return "__end__"

graph = StateGraph(State)

graph.add_node("llm_agent", llm_agent)
graph.add_node("tool", tool)

graph.set_entry_point("llm_agent")
graph.add_conditional_edges("llm_agent",
                            router,
                            {"tool":"tool", "__end__": END})

graph.add_edge("tool", "llm_agent")

runner = graph.compile()

messages = [HumanMessage(content="にゃんたの睡眠時間と今日の番組情報を教えてください")]
result = runner.invoke({"messages": messages})

永続性を持ってグラフを実行

SqliteSaverを使ってメッセージのやり取りを保存します

実行するには、コンパイル時にSqliteSaverをインスタンス化したmemoryを与えます。
また、メッセージを投げる時にスレッドのキーを含んだconfigを与えます

from langgraph.checkpoint.sqlite import SqliteSaver
import time

memory = SqliteSaver.from_conn_string(":memory:")

# memory付きでgraphをコンパイル
runner2 = graph.compile(checkpointer=memory)

# 実行時はスレッドのキーを含んだconfigを引数に入れる
config = {"configurable": {
      "thread_id": "1",
    }}

messages = [HumanMessage(content="にゃんたの睡眠時間と今日の番組情報を教えてください")]

results = runner2.stream(
    {'messages': messages},
    config,
    stream_mode="values"
)
for result in results:
  result['messages'][-1].pretty_print()

#->
================================ Human Message =================================

にゃんたの睡眠時間と今日の番組情報を教えてください
================================== Ai Message ==================================
Tool Calls:
  fake_database_api (call_cX7eyyMuqGkVG1LK6fJAoVST)
 Call ID: call_cX7eyyMuqGkVG1LK6fJAoVST
  Args:
    query: にゃんたの睡眠時間を教えて
  fake_database_api2 (call_mbM27gbGJ4qSwbRVQLeApKBZ)
 Call ID: call_mbM27gbGJ4qSwbRVQLeApKBZ
  Args:
    query: 今日の番組情報を教えて
================================= Tool Message =================================

12時からOsakan Hot 10015時からChillin Sundayです
================================== Ai Message ==================================

にゃんたは毎日8時間睡眠します。今日の番組情報は、12時からOsakan Hot 10015時からChillin Sundayです。

メモリに保存されているか確認するため、もうひとつメッセージを送ってみましょう。
この時に与えるconfigは先ほど与えたものと同じであることに注意してください

messages2 = [HumanMessage(content="さっき何と言いましたか")]

results2 = runner2.stream(
    {'messages': messages2},
    config,
    stream_mode="values"
)
for result in results2:
  result['messages'][-1].pretty_print()

#->
================================ Human Message =================================

さっき何と言いましたか
================================== Ai Message ==================================

にゃんたは毎日8時間睡眠します。今日の番組情報は、12時からOsakan Hot 10015時からChillin Sundayです。

人間参加型

後日執筆予定

Discussion