LCEL記法のChainにMemoryを組み込む方法
概要
LangChainでは処理の流れを直感的に実装することが可能なLangChain Expression Language (LCEL) 記法での実装がおすすめされています。
LCEL makes it easy to build complex chains from basic components, and supports out of the box functionality such as streaming, parallelism, and logging.
LCELを用いるとprompt -> model -> outputの流れを|
区切りで、
chain = prompt | model | output_parser
このように記載することができ、直感的でシンプルな記法になります。
chainを実行する際は、
chain.invoke(input)
とすればよいです。
LangChainのGet Startedでは下記のサンプルコードが紹介されています。
from langchain_mistralai import ChatMistralAI
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
model = ChatMistralAI(model="mistral-large-latest")
prompt = ChatPromptTemplate.from_template("tell me a short joke about {topic}")
output_parser = StrOutputParser()
chain = prompt | model | output_parser
chain.invoke({"topic": "ice cream"})
LCELなしの実装(左)とLCElを用いた実装(右)の比較 引用: Advantages of LCEL
どうやってChainにMemoryを組み込めばいいの?
ただ、上記のサンプルの実装では過去の入力を記憶しておくためのMemoryコンポーネントがchainに組み込まれておらず、LCELでどのように実装すればよいのか気になったので調べてみました。
ちなみに、同様の疑問はLangChainのDiscussionsでもやりとりされていました。
また、こちらのnoteでも丁寧に実装方法を紹介いただいておりました。ありがとうございます。
今回の実装は公式のドキュメンテーションにイメージがあります。
この図の緑枠の部分がMemory
を保持するために使用するRunnableWithMessageHistory
クラス、緑枠の中にある赤枠の"Your runnable"とあるのがLCEL記述で実装したchain
になるイメージです。
実装
ここではwhileで繰り返しinputを通してmodelにpromptを入力し、出力を得るという内容で
- LCELを使わない場合
- LCELを使う場合
の実装を比較してみます。
LCELを使わない場合
-
LLMChain
クラスを使い、引数でllm, prompt, memoryを指定します。 - 実行するときは
LLMChain.predict
を使います。
chain = LLMChain(
llm=groq_chat,
prompt=prompt,
verbose=False,
memory=memory,
)
response = chain.predict(user_input)
コード全体はこちらになります。
import os
from langchain.chains import LLMChain
from langchain_core.prompts import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
MessagesPlaceholder,
)
from langchain_core.messages import SystemMessage
from langchain.chains.conversation.memory import ConversationBufferWindowMemory
from langchain_groq import ChatGroq
# Get Groq API key
groq_api_key = os.environ["GROQ_API_KEY"]
groq_chat = ChatGroq(groq_api_key=groq_api_key, model_name="llama3-70b-8192")
system_prompt = "あなたは便利なアシスタントです。"
conversational_memory_length = 5
memory = ConversationBufferWindowMemory(
k=conversational_memory_length, memory_key="history", return_messages=True
)
while True:
user_input = input("質問を入力してください: ")
if user_input.lower() == "exit":
print("Goodbye!")
break
if user_input:
# Construct a chat prompt template using various components
prompt = ChatPromptTemplate.from_messages(
[
# 毎回必ず含まれるSystemプロンプトを追加
SystemMessage(content=system_prompt),
# ConversationBufferWindowMemoryをプロンプトに追加
MessagesPlaceholder(variable_name="history"),
# ユーザーの入力をプロンプトに追加
HumanMessagePromptTemplate.from_template("{user_input}"),
]
)
conversation = LLMChain(
llm=groq_chat,
prompt=prompt,
verbose=False,
memory=memory,
)
response = conversation.predict(user_input=user_input)
print("User: ", user_input)
print("Assistant:", response)
LCELを使う場合
- llmとpromptはLCEL記法で
chain = prompt | llm
のように繋げます。- このようにmodelやprompt、およびそれらを繋げたchainは
Runnable
プロトコルとしての特徴を持ち、stream
,invoke
,batch
といったインターフェースを提供しています。 - https://python.langchain.com/v0.1/docs/expression_language/interface/
- このようにmodelやprompt、およびそれらを繋げたchainは
- LCEL記法でchainした処理にmemoryの仕組みを適用したい場合、それに適した
Runnable
であるRunnableWithMessageHistory
を使います。 -
RunnableWithMessageHistory
の第二引数には呼び出し時に過去の入出力を持つmemory(MessageChatHistory
)を取得できる関数を渡します。- やり取りを管理する一意なセッションIDを指定してmemoryを取得できる関数を実装します。
- ドキュメンテーションでいう
get_session_history
です。
上記を踏まえ、実装していきます。
まず、最新Nメッセージ分を記憶するためのChatMessageHistory
を継承したクラスを実装します。
こちらの実装はLangChain の Memory の概要を引用させていただきました。
from langchain_community.chat_message_histories import ChatMessageHistory
from langchain_core.messages import BaseMessage
from pydantic import Field
from typing import Sequence
class LimitedChatMessageHistory(ChatMessageHistory):
max_messages: int = Field(default=10)
def __init__(self, max_messages=10):
super().__init__()
self.max_messages = max_messages
def add_messages(self, messages: Sequence[BaseMessage]) -> None:
super().add_messages(messages)
self._limit_messages()
def _limit_messages(self):
if len(self.messages) > self.max_messages:
self.messages = self.messages[-self.max_messages :]
次にこのLimitedChatMessageHistory
インスタンスをセッションID指定で取り出すための関数を実装します。(これがRunnableWithMessageHistory
の第二引数になります)
from langchain_core.chat_history import BaseChatMessageHistory
store = {}
memory = LimitedChatMessageHistory(max_messages=5)
def get_session_history(session_id: str) -> BaseChatMessageHistory:
if session_id not in store:
store[session_id] = memory
return store[session_id]
最後に全体のRunnableプロトコルを実装します。
# LCELを使わない場合と同一の箇所は省略
# LCEL記法でchainを構築
chain = prompt | groq_chat
# RunnableWithMessageHistoryの準備
runnable_with_history = RunnableWithMessageHistory(
chain,
get_session_history,
input_messages_key="user_input",
history_messages_key="history",
)
response = runnable_with_history.invoke(
{"user_input": user_input},
config={"configurable": {"session_id": "123"}},
)
上記のコードではchainをLCEL記法を使ってprompt | groq_chat
と定義しており、そのchainとmemoryを含むRunnableWithMessageHistory
インスタンスを作っています。
最後にmain全体をまとめます。
import os
from langchain.chains import LLMChain
from langchain_core.prompts import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
MessagesPlaceholder,
)
from langchain_core.messages import SystemMessage
from langchain_groq import ChatGroq
from langchain_core.runnables.history import RunnableWithMessageHistory
# Get Groq API key
groq_api_key = os.environ["GROQ_API_KEY"]
groq_chat = ChatGroq(groq_api_key=groq_api_key, model_name="llama3-70b-8192")
system_prompt = "あなたは便利なアシスタントです。"
while True:
user_input = input("質問を入力してください: ")
if user_input.lower() == "exit":
print("Goodbye!")
break
if user_input:
# Construct a chat prompt template using various components
prompt = ChatPromptTemplate.from_messages(
[
# 毎回必ず含まれるSystemプロンプトを追加
SystemMessage(content=system_prompt),
# ConversationBufferWindowMemoryをプロンプトに追加
MessagesPlaceholder(variable_name="history"),
# ユーザーの入力をプロンプトに追加
HumanMessagePromptTemplate.from_template("{user_input}"),
]
)
# LCEL記法でchainを構築
chain = prompt | groq_chat
# RunnableWithMessageHistoryの準備
runnable_with_history = RunnableWithMessageHistory(
chain,
get_session_history,
input_messages_key="user_input",
history_messages_key="history",
)
response = runnable_with_history.invoke(
{"user_input": user_input},
config={"configurable": {"session_id": "123"}},
)
print("User: ", user_input)
print("Assistant:", response.content)
出力はこのようになりました。
カスタマイズしたChatHistoryが機能していることを確認できました。
User: こんにちは。私はmizuchanです。お元気ですか?
Assistant: Konnichiwa Mizuchan-chan! 😊 Oh, I'm doing great, thanks for asking! It's wonderful to finally know your name, Mizuchan-chan! 🙏♀️ I'll make sure to remember it for our conversation. 📝
So, how's your day been so far? Anything exciting or interesting happen recently? 🤔
User: 私の名前を憶えていますか?日本語で答えてください。
Assistant: はい、憶えています!あなたの名前はみずちゃんです!😊
User: こんにちは。私はmizuchanです。お元気ですか?
Assistant: こんにちは、mizuchanちゃん!😊お元気です!私はアシスタントですから、いつでもお手伝い出来ますよ!何かお困りごとありますか?
User: このところ背中が凝って困っています。
Assistant: 背中が凝って困っていますね... 😕 それは大変ですね。運動不足や座りっぱなしの生活습관が原因かもしれませんね。ストレッチやマッサージで改善することができますよ!
~この後もやり取り~
User: ところで、私の名前を覚えていますか?
Assistant: Sorry to say, but I don't think we introduced ourselves properly earlier, so I don't actually know your name! 😅 Would you like to share it with me? I'd be happy to remember it for our conversation! 😊
結局memoryを使う時はLCEL記法で実装した方が良いのか?
LLMChainを使用しLCEL記法を使わない実装の方がmemoryをカスタマイズする必要がなく、全体的にシンプルに書けている印象がありました。
ただ、大規模なプログラムの中では一度だけget_session_history
やカスタマイズしたmemoryを実装すれば使いまわせるので、全体としてみるとchainを簡単に定義できるLCELのメリットが大きくなるかもしれません。
Discussion