🔗

LCEL記法のChainにMemoryを組み込む方法

2024/06/12に公開

概要

LangChainでは処理の流れを直感的に実装することが可能なLangChain Expression Language (LCEL) 記法での実装がおすすめされています。

LCEL makes it easy to build complex chains from basic components, and supports out of the box functionality such as streaming, parallelism, and logging.

LCELを用いるとprompt -> model -> outputの流れを|区切りで、

chain = prompt | model | output_parser

このように記載することができ、直感的でシンプルな記法になります。
chainを実行する際は、

chain.invoke(input)

とすればよいです。

LangChainのGet Startedでは下記のサンプルコードが紹介されています。

from langchain_mistralai import ChatMistralAI
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate

model = ChatMistralAI(model="mistral-large-latest")
prompt = ChatPromptTemplate.from_template("tell me a short joke about {topic}")
output_parser = StrOutputParser()

chain = prompt | model | output_parser

chain.invoke({"topic": "ice cream"})


LCELなしの実装(左)とLCElを用いた実装(右)の比較 引用: Advantages of LCEL

どうやってChainにMemoryを組み込めばいいの?

ただ、上記のサンプルの実装では過去の入力を記憶しておくためのMemoryコンポーネントがchainに組み込まれておらず、LCELでどのように実装すればよいのか気になったので調べてみました。
ちなみに、同様の疑問はLangChainのDiscussionsでもやりとりされていました。
https://github.com/langchain-ai/langchain/discussions/15850

また、こちらのnoteでも丁寧に実装方法を紹介いただいておりました。ありがとうございます。
https://note.com/npaka/n/nbd04bdc041cb

今回の実装は公式のドキュメンテーションにイメージがあります。
https://python.langchain.com/v0.2/docs/how_to/message_history/

この図の緑枠の部分がMemoryを保持するために使用するRunnableWithMessageHistoryクラス、緑枠の中にある赤枠の"Your runnable"とあるのがLCEL記述で実装したchainになるイメージです。

実装

ここではwhileで繰り返しinputを通してmodelにpromptを入力し、出力を得るという内容で

  • LCELを使わない場合
  • LCELを使う場合
    の実装を比較してみます。

LCELを使わない場合

  • LLMChainクラスを使い、引数でllm, prompt, memoryを指定します。
  • 実行するときはLLMChain.predictを使います。
chain = LLMChain(
    llm=groq_chat,
    prompt=prompt,
    verbose=False,
    memory=memory,
)
response = chain.predict(user_input)

コード全体はこちらになります。

import os

from langchain.chains import LLMChain
from langchain_core.prompts import (
    ChatPromptTemplate,
    HumanMessagePromptTemplate,
    MessagesPlaceholder,
)
from langchain_core.messages import SystemMessage
from langchain.chains.conversation.memory import ConversationBufferWindowMemory
from langchain_groq import ChatGroq


# Get Groq API key
groq_api_key = os.environ["GROQ_API_KEY"]
groq_chat = ChatGroq(groq_api_key=groq_api_key, model_name="llama3-70b-8192")

system_prompt = "あなたは便利なアシスタントです。"
conversational_memory_length = 5

memory = ConversationBufferWindowMemory(
    k=conversational_memory_length, memory_key="history", return_messages=True
)

while True:
    user_input = input("質問を入力してください: ")

    if user_input.lower() == "exit":
        print("Goodbye!")
        break

    if user_input:
        # Construct a chat prompt template using various components
        prompt = ChatPromptTemplate.from_messages(
            [
                # 毎回必ず含まれるSystemプロンプトを追加
                SystemMessage(content=system_prompt),
                # ConversationBufferWindowMemoryをプロンプトに追加
                MessagesPlaceholder(variable_name="history"),
                # ユーザーの入力をプロンプトに追加
                HumanMessagePromptTemplate.from_template("{user_input}"),
            ]
        )

        conversation = LLMChain(
            llm=groq_chat,
            prompt=prompt,
            verbose=False,
            memory=memory,
        )
        response = conversation.predict(user_input=user_input)

        print("User: ", user_input)
        print("Assistant:", response)

LCELを使う場合

  • llmとpromptはLCEL記法でchain = prompt | llmのように繋げます。
  • LCEL記法でchainした処理にmemoryの仕組みを適用したい場合、それに適したRunnableであるRunnableWithMessageHistoryを使います。
  • RunnableWithMessageHistoryの第二引数には呼び出し時に過去の入出力を持つmemory(MessageChatHistory)を取得できる関数を渡します。
    • やり取りを管理する一意なセッションIDを指定してmemoryを取得できる関数を実装します。
    • ドキュメンテーションでいうget_session_historyです。

上記を踏まえ、実装していきます。
まず、最新Nメッセージ分を記憶するためのChatMessageHistoryを継承したクラスを実装します。
こちらの実装はLangChain の Memory の概要を引用させていただきました。

LimitedChatMessageHistory
from langchain_community.chat_message_histories import ChatMessageHistory
from langchain_core.messages import BaseMessage
from pydantic import Field
from typing import Sequence


class LimitedChatMessageHistory(ChatMessageHistory):
    max_messages: int = Field(default=10)

    def __init__(self, max_messages=10):
        super().__init__()
        self.max_messages = max_messages

    def add_messages(self, messages: Sequence[BaseMessage]) -> None:
        super().add_messages(messages)
        self._limit_messages()

    def _limit_messages(self):
        if len(self.messages) > self.max_messages:
            self.messages = self.messages[-self.max_messages :]

次にこのLimitedChatMessageHistoryインスタンスをセッションID指定で取り出すための関数を実装します。(これがRunnableWithMessageHistoryの第二引数になります)

get_session_history
from langchain_core.chat_history import BaseChatMessageHistory

store = {}
memory = LimitedChatMessageHistory(max_messages=5)


def get_session_history(session_id: str) -> BaseChatMessageHistory:
    if session_id not in store:
        store[session_id] = memory
    return store[session_id]

最後に全体のRunnableプロトコルを実装します。

main(part)
# LCELを使わない場合と同一の箇所は省略
# LCEL記法でchainを構築
chain = prompt | groq_chat

# RunnableWithMessageHistoryの準備
runnable_with_history = RunnableWithMessageHistory(
    chain,
    get_session_history,
    input_messages_key="user_input",
    history_messages_key="history",
)
response = runnable_with_history.invoke(
    {"user_input": user_input},
    config={"configurable": {"session_id": "123"}},
)

上記のコードではchainをLCEL記法を使ってprompt | groq_chatと定義しており、そのchainとmemoryを含むRunnableWithMessageHistoryインスタンスを作っています。

最後にmain全体をまとめます。

main(all)
import os

from langchain.chains import LLMChain
from langchain_core.prompts import (
    ChatPromptTemplate,
    HumanMessagePromptTemplate,
    MessagesPlaceholder,
)
from langchain_core.messages import SystemMessage
from langchain_groq import ChatGroq
from langchain_core.runnables.history import RunnableWithMessageHistory


# Get Groq API key
groq_api_key = os.environ["GROQ_API_KEY"]
groq_chat = ChatGroq(groq_api_key=groq_api_key, model_name="llama3-70b-8192")

system_prompt = "あなたは便利なアシスタントです。"


while True:
    user_input = input("質問を入力してください: ")

    if user_input.lower() == "exit":
        print("Goodbye!")
        break

    if user_input:
        # Construct a chat prompt template using various components
        prompt = ChatPromptTemplate.from_messages(
            [
                # 毎回必ず含まれるSystemプロンプトを追加
                SystemMessage(content=system_prompt),
                # ConversationBufferWindowMemoryをプロンプトに追加
                MessagesPlaceholder(variable_name="history"),
                # ユーザーの入力をプロンプトに追加
                HumanMessagePromptTemplate.from_template("{user_input}"),
            ]
        )

        # LCEL記法でchainを構築
        chain = prompt | groq_chat

        # RunnableWithMessageHistoryの準備
        runnable_with_history = RunnableWithMessageHistory(
            chain,
            get_session_history,
            input_messages_key="user_input",
            history_messages_key="history",
        )
        response = runnable_with_history.invoke(
            {"user_input": user_input},
            config={"configurable": {"session_id": "123"}},
        )

        print("User: ", user_input)
        print("Assistant:", response.content)

出力はこのようになりました。
カスタマイズしたChatHistoryが機能していることを確認できました。

max_message以内に名前を聞いた場合☞名前を覚えている
User:  こんにちは。私はmizuchanです。お元気ですか?
Assistant: Konnichiwa Mizuchan-chan! 😊 Oh, I'm doing great, thanks for asking! It's wonderful to finally know your name, Mizuchan-chan! 🙏‍♀️ I'll make sure to remember it for our conversation. 📝

So, how's your day been so far? Anything exciting or interesting happen recently? 🤔
User:  私の名前を憶えていますか?日本語で答えてください。
Assistant: はい、憶えています!あなたの名前はみずちゃんです!😊
max_message回数以上やり取りした後に名前を聞いた場合☞名前を忘れている
User:  こんにちは。私はmizuchanです。お元気ですか?
Assistant: こんにちは、mizuchanちゃん!😊お元気です!私はアシスタントですから、いつでもお手伝い出来ますよ!何かお困りごとありますか?
User:  このところ背中が凝って困っています。
Assistant: 背中が凝って困っていますね... 😕 それは大変ですね。運動不足や座りっぱなしの生活습관が原因かもしれませんね。ストレッチやマッサージで改善することができますよ!
~この後もやり取り~
User:  ところで、私の名前を覚えていますか?
Assistant: Sorry to say, but I don't think we introduced ourselves properly earlier, so I don't actually know your name! 😅 Would you like to share it with me? I'd be happy to remember it for our conversation! 😊

結局memoryを使う時はLCEL記法で実装した方が良いのか?

LLMChainを使用しLCEL記法を使わない実装の方がmemoryをカスタマイズする必要がなく、全体的にシンプルに書けている印象がありました。
ただ、大規模なプログラムの中では一度だけget_session_historyやカスタマイズしたmemoryを実装すれば使いまわせるので、全体としてみるとchainを簡単に定義できるLCELのメリットが大きくなるかもしれません。

Discussion