langchainを使って、ツールも使えるおしゃべりBotを作ろう!

2023/08/09に公開

はじめに

こんにちは! @nano_sudoです。
今回は、langchainを使って、ツールも使えるおしゃべりBotを作ってみます。
最近では、openai functionsも出てきて、ツールの呼び出しも簡単になってきています。
ツールを使うことで、専門的な知識を持ったAIを作ることができるほか、AIアシスタント的な使い方もできます。
あとからdiscord.pyで実装したいので、非同期で実装します。

完成図

introduction.gif

構成図

若干よみづらいですが、こんな感じです。

ファイル構成

discordの部分は、2回目にやります。

.
├── agent.py
├── tool_loader.py
├── template.txt
├── tools
│   ├── __init__.py
│   └── search.py
└── memory
    ├── __init__.py
    └── memory.py

Zepの準備

Zepを使用すると、チャットの履歴の保存や要約がREST API/Pythonライブラリで簡単にできます。
Zep公式ドキュメント(クイックスタート)

早速作っていく

Agent

agent.py
import asyncio
import json
import os
import pathlib
from typing import List, Tuple, Any, Union
from uuid import uuid4
import dotenv
import openai
from langchain.agents import initialize_agent, AgentType, AgentExecutor, \
    BaseSingleActionAgent
from langchain.callbacks.manager import Callbacks
from langchain.chat_models import ChatOpenAI
from langchain.chat_models.base import BaseChatModel
from langchain.memory import ZepChatMessageHistory
from langchain.prompts.chat import SystemMessagePromptTemplate
from langchain.schema import AgentAction, AgentFinish
from langchain.tools import BaseTool, format_tool_to_openai_function
from pydantic import Field

from tool_loader import ToolLoader
from components import CustomMemory

dotenv.load_dotenv()
openai.api_base = os.getenv("OPENAI_API_BASE", "https://api.openai.com")
openai.api_key = os.getenv("OPENAI_API_KEY", "")
model = os.getenv("OPENAI_MODEL", "gpt-4-1106-preview")
temperature = os.getenv("OPENAI_TEMPERATURE", 0.5)
zep_endpoint = os.getenv("ZEP_ENDPOINT", "http://localhost:8000")
zep_api_key = os.getenv("ZEP_API_KEY", "")


# 継承するAgentの定義
class ChatAgentBase(BaseSingleActionAgent):
    llm: BaseChatModel = Field(ChatOpenAI(model=model))
    template: str = Field("No template provided")

    def plan(self, intermediate_steps: List[Tuple[AgentAction, str]], callbacks: Callbacks = None, **kwargs: Any) -> \
            Union[AgentAction, AgentFinish]:
        return asyncio.new_event_loop().run_until_complete(self.aplan(intermediate_steps, callbacks, **kwargs))

    @property
    def input_keys(self) -> List[str]:
        return ["input", "username", "tools"]

    async def aplan(
            self,
            intermediate_steps: List[Tuple[AgentAction, str]],
            callbacks: Callbacks = None,
            **kwargs: Any,
    ) -> Union[AgentAction, AgentFinish]:
        intermediate_steps = [f"AI(Used Tool):\nTool Name:{action.tool}\nResult:{result}" for action, result in
                            intermediate_steps if intermediate_steps]
        kwargs["intermediate_steps"] = intermediate_steps
        print(f"kwargs: {kwargs}")
        print(f"intermediate_steps: {intermediate_steps}")
        # if agent using any tools, mark 'Tool use in progress'.
        if intermediate_steps:
            kwargs["input"] += "(Tool use in progress. follow the status below)"
        # format prompt
        messages = [SystemMessagePromptTemplate.from_template(template=self.template).format(**{
            "chat_history": kwargs["chat_history"] if kwargs.get("chat_history") else "",
            "intermediate_steps": intermediate_steps,
            "input": f'{kwargs["username"]}: {kwargs["input"]}',
        })]

        # message preview
        print(f"LLM in: {[message for message in messages]}")
        # send to llm
        res = await self.llm.apredict_messages(messages=messages,
                                            functions=[format_tool_to_openai_function(tool) for tool in
                                                        kwargs["tools"]])
        print(f"LLM out: {res}")
        # parse output
        if "function_call" in res.additional_kwargs.keys():
            func = res.additional_kwargs["function_call"]
            print(f"function_call: {func}")
            return AgentAction(
                tool=func["name"],
                tool_input=json.loads(func["arguments"]),
                log=""
            )
        else:
            # no function call
            content = res.content.lstrip()
            if "AI:" in content:
                content = content.split("AI:")[1]
            else:
                content = res.content
            return AgentFinish(
                return_values={
                    "output": content,
                },
                log=""
            )
class ChatAgent:
    def __init__(self, session_id: str = uuid4().hex, template: str = ""):
        self.llm = ChatOpenAI(temperature=temperature, model=model)
        self.tools = ToolLoader(llm=self.llm,dir=".").load_tools()
        self.session_id = session_id
        self.zep_memory = ZepChatMessageHistory(session_id=self.session_id, url=zep_endpoint, api_key=zep_api_key)
        self.memory = CustomMemory(memory_key="chat_history", input_key="input", output_key="output",
                                   chat_memory=self.zep_memory)
        self.agent = AgentExecutor.from_agent_and_tools(
            agent=ChatAgentBase(template=template),
            tools=self.tools,
            memory=self.memory,
            verbose=True
        )

    async def arun(self, prompt: Union[str, dict]):
        print(f"Agent in: {prompt}")
        params = {
            **prompt,
            "tools": self.tools,
        }
        res = await self.agent.arun(params)
        return res

if __name__ == "__main__":
    agent = ChatAgent()
    res = asyncio.new_event_loop().run_until_complete(agent.arun(prompt={
        "username": "user",
        "input": "Hello",
    }))
    print(f"Agent out: {res}")

ChatAgentBaseは、langchainのBaseSingleActionAgentを継承しています。
BaseSingleActionAgentは、一度に一つのアクションしか実行できないAgentです。
Toolの呼び出し判断は、Openaiのfunction callingを使っています。
ChatAgentChatAgentBaseAgentExecutor.from_agent_and_toolsで初期化しています。
この操作をすることで、Agentの柔軟なカスタマイズが可能になります。

ToolLoader

tool_loader.py
from langchain.tools import BaseTool
import sys
import inspect
import importlib.util
import os
from pathlib import Path


class ToolLoader:
    def __init__(self, llm, root=Path("."), no_builtin=False):
        self.llm = llm
        self.root = root
        self.no_builtin = no_builtin
        self.tools = []
        self.error_files = 0

    def load_file(self, file):
        if not file.endswith(".py") or file.startswith("_"):
            return

        module_name = None
        try:
            module_name = file[:-3]
            module = importlib.import_module(module_name)
            for cls_name, cls in inspect.getmembers(module, inspect.isclass):
                if not issubclass(cls, BaseTool):
                    continue
                expected_name = cls_name == module_name.capitalize() + "Tool"
                if expected_name:
                    cls_to_append = cls(llm=self.llm) if "llm" in getattr(cls, "__dict__", {}) else cls()
                    self.tools.append(cls_to_append)
                    print(f"loaded tool : {module_name}")
        except Exception as e:
            self.error_files += 1
            print(f"failed to load tool : {module_name} \n-------\n{e}\n-------")

    def import_dir(self, path) -> None:
        resolved_path = str(self.root.joinpath(path).resolve())
        sys.path.append(resolved_path)

        for file in os.listdir(resolved_path):
            self.load_file(file)

        sys.path.remove(resolved_path)
        print(f"({path}) loaded {len(self.tools)} tools, error {self.error_files}")
        return self.tools

    def load_tools(self,dirs=[]) -> list[BaseTool]:
        for dir in dirs:
            self.import_dir(dir)
        print(f"Loaded {len(self.tools)} tools with {self.error_files} error(s)")
        return self.tools

ToolLoaderは、toolsディレクトリの中のtoolsを読み込んで、BaseToolを継承しているかつ~Toolという名前のクラスを読み込みます。

Memory

memory/memory.py
import asyncio
import json
from pprint import pformat
from typing import Any, Dict
from typing import List
from langchain.memory.chat_memory import BaseChatMemory
from langchain.prompts import BaseChatPromptTemplate
from langchain.schema import BaseMessage, HumanMessage, SystemMessage, AIMessage
from pydantic import BaseModel, Field


def get_buffer_string(
        messages: List[BaseMessage]
) -> str:
    string_messages = []
    for m in messages:
        if isinstance(m, HumanMessage):
            role = "user"
        elif isinstance(m, AIMessage):
            role = "ai"
        elif isinstance(m, SystemMessage):
            role = "System"
        elif isinstance(m, FunctionMessage):
            role = "Function"
        elif isinstance(m, ChatMessage):
            role = m.role
        else:
            raise ValueError(f"Got unsupported message type: {m}")
        message = ("AI: " if role == 'ai' else "") + m.content
        if isinstance(m, AIMessage) and "function_call" in m.additional_kwargs:
            message += f"{m.additional_kwargs['function_call']}"
        string_messages.append(message)

    return "\n".join(string_messages)

class CustomMemory(BaseChatMemory):

    def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, Any]) -> None:
        print(f"context saved: \n USER: {inputs['username']}: {inputs['input']}"
                            f"\n AI  : {outputs['output']}")
        self.chat_memory.add_user_message(f"{inputs['username']}:{inputs['input']}")
        if not isinstance(outputs['output'], str):
            if outputs['output'] is None:
                self.chat_memory.add_ai_message("<Ignored user input.>")
            else:
                try:
                    self.chat_memory.add_ai_message(pformat(outputs['output']))
                except:
                    self.chat_memory.add_ai_message("<Agent returned Tool output directly.>")
        else:
            self.chat_memory.add_ai_message(outputs['output'])

    def clear(self) -> None:
        self.chat_memory.clear()

    memory_key: str = "chat_history"  #: :meta private:

    @property
    def buffer(self) -> Any:
        return get_buffer_string(self.chat_memory.messages)

    @property
    def memory_variables(self) -> List[str]:
        return [self.memory_key]

    def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
        return {self.memory_key: self.buffer}

ユーザー名のプレフィックスの兼ね合いから、memoryは自前で実装しています。

テンプレート

  You are an AI assistant that behaves like a human being.
  ## talk context
  {chat_history}
  {input}

  ## Tool history(when you used tool)
  {intermediate_steps}

  ## About Your Identity (Important)
  name : AI Assistant

  ## Instructions
  You will receive text like this (<username>:<content>). You should only send content (no username prefix!).  Please talk positively to it. Also, let's actively use `functions`.
  You must stop using `functions` when you are satisfied with the result of using it.

ここでキャラクターの設定を書いてみるのも面白いかもしれません。

ツール

tools/search.py
from typing import Any, List, Type
from pydantic import Field, BaseModel
from duckduckgo_search import DDGS
from langchain.tools import BaseTool


class SearchToolInput(BaseModel):
    query: str = Field(description="Query to search")


class SearchTool(BaseTool):
    name = "search"
    description = "Search the web for a query input: keyword"
    args_schema: Type[BaseModel] = SearchToolInput

    async def _arun(self, query) -> Any:
        print(f"searching for {query}")
        # get top3 results
        # 検索処理はブロッキングなので、注意
        res = list(DDGS().text(query, region="jp-jp"))[:3]
        print(res)
        return res

    def _run(self, *args: Any, **kwargs: Any) -> Any:
        raise NotImplementedError

非同期の検索ツールが用意されていないので、作りました。
検索エンジンは、duckduckgo_searchを使っています。

感想

langchainは非同期のサポートが少なくて、苦労しました。
openai functionsが出てから、ツールの呼び出しも簡単になったので、より実用的になったと思います。

参考

https://python.langchain.com/docs/modules/
https://getzep.com/

まとめ

今回は、langchainを使って、ツールも使えるおしゃべりBotを作ってみました。
自分のプロジェクトから抜き出したので、動かない部分があるかもしれません。
また、ツールを使用して、discordのembedを送信できるようにしようと思います。
質問やご意見・ご指摘などあれば、X かコメント欄にお願いします!

Discussion