🌟

【RAG】LangChainでつくるRAGチャットボット ~会話履歴を考慮する~

2024/07/18に公開

RAG (Retrieval-Augmented Generation) に関しては、ドキュメントのチャンク分割方法やデータベースからの検索方法などがよく論じられます。一方で、チャットボットとして提供するのであれば、会話履歴付きチャット機能をどのように実現するかも重要だと感じています。そこで、RAGベースのチャットボットにおいて、会話履歴をどのように管理すべきか、また、ナレッジの検索を毎回するのかなど、複数回の会話のやり取りを前提としたRAGチャットボットの実現方法について、いろいろと試してみました。

なお、前回の記事で、LCEL記法での会話履歴付きチャットボットの作り方をまとめました。今回は、このチャットボットにRAG機能を付け加えたものをベースとして検討していきます。

https://zenn.dev/khisa/articles/7f56f4e66cae43

RAGチャットボットにおける会話履歴

RAGでは、ユーザーからの入力(質問)に関連する文書をナレッジから取得し、コンテキストとしてLLMに与えることで、ナレッジを活用した回答をLLMに生成させることが基本となります。

ユーザーの質問→LLMの回答という1ターンだけで完結するのであればこれだけでよいのですが、チャットボットとして実現するとなると、LLMの回答に対するさらなる質問に対しても答えられないといけません。このように複数回の会話のやり取りに的確に回答するには、ナレッジから取得したコンテキストに加えて、これまでの会話履歴もLLMに渡して、会話の流れを理解したうえで回答を生成する必要があります。

このようなチャットボットにおいて、RAGを実現するうえで考慮しなくてはならない事項として、以下があげられます。

  • 会話履歴をどのように保持するか?
  • ナレッジの検索はどのタイミングで行うのか?

以降では、実際にLangChainでRAGチャットボットを実装して、その動作を確認していきます。この記事では、会話履歴なしの基本的なRAGと、単純に会話履歴を付け加えたRAGチャットボットの動作を見ていきます。

動作環境・準備

以下の環境で動作確認をしています。

  • Windows10
  • Python 3.11.6
  • LangChain 0.2.8
  • ChromaDB 0.5.4

今回は、ベクトルDBのEmbeddingにOpenAIの text-embedding-3-smallを、LLMに Google Gemini 1.5 Flash を利用しています。これらを利用するために、APIキーを環境変数に設定しておいてください。

OPENAI_API_KEY=<OpenAI API Key>
GOOGLE_API_KEY=<Google Gemini API Key>

ベクトルDBには、Chromaを利用します。今回は、比較的多くのドキュメントを格納した状態で動作確認をしたかったので、自前のWordpressのブログの全記事をナレッジとして入れています。

この記事で掲載するコードを動作させるには、手元に適当なドキュメントを格納したChromaDBを用意して、以下の例のように環境変数にChromaDBのディレクトリとコレクション名を設定してください。

CHROMA_PERSIST_DIRECTORY="./chroma-db"
CHROMA_COLLECTION_NAME="wpchatbot"

あるいは、コードのChromaDBの部分を適当なベクトルDBなどに置き換えていただいても構いません。最終的に、LangChainのretrieverが用意できれば動作するはずです。

なお、コードは以下のリポジトリに置いてあります。必要に応じて参照してください。

https://github.com/kzhisa/rag-chatbot

会話履歴なしの単純なRAGチャットボット

まずは、ベースとなる単純なRAGを実装して、動作を確認していきます。チャットボットのように複数回の会話ができますが、単にループさせているだけで、会話の履歴機能はありません。

具体的には、以下の図のように動作します。

これを実装していきます。

実装

コードの全体は以下のとおりです。前述のリポジトリにrag_chatbot1.pyとして置いてあります。

import chromadb
from langchain.chains import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_community.vectorstores import Chroma
from langchain_core.prompts import ChatPromptTemplate
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_openai import OpenAIEmbeddings

# OpenAI embedding model
EMBEDDING_MODEL = "text-embedding-3-small"

# ChromaDB
CHROMA_PERSIST_DIRECTORY = os.environ.get("CHROMA_PERSIST_DIRECTORY")
CHROMA_COLLECTION_NAME = os.environ.get("CHROMA_COLLECTION_NAME")

# Retriever settings
TOP_K_VECTOR = 8

# 既存のChromaDBを読み込みVector Retrieverを作成
def vector_retriever(top_k: int = TOP_K_VECTOR):
    """Create base vector retriever from ChromaDB

    Returns:
        Vector Retriever
    """

    # chroma db
    embeddings = OpenAIEmbeddings(model=EMBEDDING_MODEL)
    client = chromadb.PersistentClient(path=CHROMA_PERSIST_DIRECTORY)
    vectordb = Chroma(
        collection_name=CHROMA_COLLECTION_NAME,
        embedding_function=embeddings,
        client=client,
    )

    # base retriever (vector retriever)
    vector_retriever = vectordb.as_retriever(
        search_kwargs={"k": top_k},
    )

    return vector_retriever


# プロンプトテンプレート
system_prompt = (
    "You are an assistant for question-answering tasks. "
    "Use the following pieces of retrieved context to answer "
    "the question. If you don't know the answer, say that you "
    "don't know. Use three sentences maximum and keep the "
    "answer concise."
    "\n\n"
    "{context}"
)
prompt_template = ChatPromptTemplate.from_messages(
    [
        ("system", system_prompt),
        ("human", "{input}"),
    ]
)

# 実際の応答生成の例
def chat_with_bot(session_id: str):

    # LLM
    chat_model = ChatGoogleGenerativeAI(model="gemini-1.5-flash", temperature=0.0)

    # Vector Retriever
    retriever = vector_retriever()

    # RAG Chain
    basic_qa_chain = create_stuff_documents_chain(
        llm = chat_model,
        prompt = prompt_template,
    )
    rag_chain = create_retrieval_chain(retriever, basic_qa_chain)

    count = 0
    while True:
        print("---")
        input_message = input(f"[{count}]あなた: ")
        if input_message.lower() == "終了":
            break

        # プロンプトテンプレートに基づいて応答を生成
        response = rag_chain.invoke(
            {"input": input_message},
            config={"configurable": {"session_id": session_id}}
        )
        
        print(f"AI: {response['answer']}")
        count += 1


if __name__ == "__main__":

    # チャットセッションの開始
    session_id = "example_session"
    chat_with_bot(session_id)

RAGのChainを定義しているのは以下の箇所です。

# RAG Chain
basic_qa_chain = create_stuff_documents_chain(
    llm = chat_model,
    prompt = prompt_template,
)
rag_chain = create_retrieval_chain(retriever, basic_qa_chain)

create_stuff_documents_chainは、複数のドキュメントDocumentのリストを一つのプロンプトにまとめてLLMに渡すChainを生成する関数です。Documentは引数document_variable_nameで指定された変数名で渡されますが、デフォルトではcontextとなっていますので、ここではデフォルト値を利用しています。

https://api.python.langchain.com/en/latest/chains/langchain.chains.combine_documents.stuff.create_stuff_documents_chain.html

create_stuff_documents_chainを利用して生成したChain basic_qa_chain の前に、ChromaDBからドキュメントを取得するretrievercreate_retrieval_chain関数で連結して、最終的なRAGを実現するChainとしてrag_chainを作成しています。

実際にRAGを実行しているのは以下の部分です。

# プロンプトテンプレートに基づいて応答を生成
response = rag_chain.invoke(
    {"input": input_message},
    config={"configurable": {"session_id": session_id}}
)

inputにユーザーの入力を渡して、先ほど定義したRAGのChainであるrag_chainをinvokeします。session_idは会話履歴を保持するために会話のセッションを識別するIDですが、このコードでは利用していません。

動作確認

それでは動作確認をしてみます。ナレッジとしては、私が運営する鉄道関連(青春18きっぷ関連)のブログで公開している全記事です。記事のHTMLを<h2>と<h3>でチャンク分割していて、チャック当たりの文字数は平均500文字程度です。

https://www.kzlifelog.com/seishun18/

青春18きっぷについて聞いてみます。

$ python .\rag_chatbot1.py
---
[0]あなた: 2024年春の青春18きっぷはいつ利用できますか?
AI: 2024年春の青春18きっぷは、2024年3月1日(金)から4月10日(水)まで利用できます。 

---
[1]あなた: 夏は?
AI: 夏は「青春18きっぷ」が発売されます。夏は7月20日から9月10日まで利用できます。「青春18きっぷ」は、JR全線の普通列車・快速列車に乗車できるきっぷです。

---
[2]あなた: 価格は?
AI: 普通列車グリーン車のグリーン料金は、101キロまでは1,550円、101キロ以上は2,000円です。 

---
[3]あなた: 青春18きっぷの価格は?
AI: 青春18きっぷの価格は12,050円(税込)で、5回分(5日分)セットでの販売です。1回分(1日分)あたり2,410円になります。

[0]は正確に答えられていて、RAGとして正常に動作していることがわかります。

[1]で「夏は?」と聞いていますが、これはもちろん「夏の青春18きっぷはいつ利用できますか?」という意図での質問です。『夏は「青春18きっぷ」が発売されます。』は回答として少し変ではありますが、夏に利用できる期間については正確に答えています。

[2]の「価格は?」の質問は、もちろん「青春18きっぷの価格は?」という意図ですが、ここではグリーン車の料金をAIが回答しています。おそらく「価格は?」というクエリに対して、グリーン車の料金の記事がナレッジから返されたのでしょう。

改めて[3]で「青春18きっぷの価格は?」と聞いてみると、今度は正確に12,050円と回答できています。

このように、会話履歴がないRAGチャットボットでは、過去の質問の文脈を考慮した回答ができず、単独の質問として成立する場合にしか、正確な回答ができません。

会話履歴付きのRAGチャットボット

会話履歴のないRAGでは、これまでの質疑の文脈を理解できません。そのためLLMの回答に対してさらに質問をする場合、ナレッジを検索するためのキーワードを質問文に含めないと、適切な文書を検索することができませんでした。

そこで、2回目以降の質問に対して、会話履歴をプロンプトに含めてLLMに渡すことで、LLMにこれまでの文脈を理解させる「会話履歴付きRAGチャットボット」に拡張してみます。

具体的には、以下の図のように、会話履歴(図中の Chat history)をプロンプトに含めるようにします。図には書かれていませんが、LLMが回答を返すたびに、ユーザーの質問とLLMの回答を Chat history DB に追加していきます。

それでは、会話履歴付きのRAGチャットボットを実装していきます。

実装

チャットボットの会話履歴については、以下の記事で紹介したRunnableWithMessageHistoryを利用します。RunnableWithMessageHistoryの使い方については、以下の記事をご覧ください。

https://zenn.dev/khisa/articles/7f56f4e66cae43

ここでは、前述の会話履歴なしのRAGチャットボットからの修正点を中心に、ポイントだけを見ていきます。コードの全体については、リポジトリにあるrag_chatbot2.pyをご覧ください。

https://github.com/kzhisa/rag-chatbot

会話履歴を保存するクラスとしては、LangChainにあるChatMessageHistoryを利用しますが、ここでは会話履歴の保持数の上限を制限するLimitedChatMessageHistoryに拡張して利用します。

# 会話履歴数をmax_lengthに制限するLimitedChatMessageHistoryクラス
DEFAULT_MAX_MESSAGES = 4
class LimitedChatMessageHistory(ChatMessageHistory):

    # 会話履歴の保持数
    max_messages: int = DEFAULT_MAX_MESSAGES

    def __init__(self, max_messages=DEFAULT_MAX_MESSAGES):
        super().__init__()
        self.max_messages = max_messages

    def add_message(self, message):
        super().add_message(message)
        # 会話履歴数を制限
        if len(self.messages) > self.max_messages:
            self.messages = self.messages[-self.max_messages:]

    def get_messages(self):
        return self.messages

会話履歴の保存数はDEFAULT_MAX_MESSAGES = 4としています。ユーザーの質問、LLMの回答はそれぞれ1つと数えますので、過去2ターン分の会話履歴を保存することになります。

次に、このLimitedChatMessageHistoryのインスタンスを、会話のセッションごとに保存しておくstoreと、その会話履歴を取得するための関数get_session_historyを準備します。

# 会話履歴のストア
store = {}

# セッションIDごとの会話履歴の取得
def get_session_history(session_id: str) -> BaseChatMessageHistory:
    if session_id not in store:
        store[session_id] = LimitedChatMessageHistory()
    return store[session_id]

これらを準備したうえで、RAGを実現するChainrag_chainを、RunnableWithMessageHistoryでラップします。

# RAG Chain
basic_qa_chain = create_stuff_documents_chain(
    llm = chat_model,
    prompt = prompt_template,
)
rag_chain = create_retrieval_chain(retriever, basic_qa_chain)

# Runnable chain を RunnableWithMessageHistory でラップ
runnable_with_history = RunnableWithMessageHistory(
    runnable=rag_chain,
    get_session_history=get_session_history,
    input_messages_key="input",
    history_messages_key="chat_history",
    output_messages_key="answer",
)

rag_chainの定義は前述の会話履歴なしのチャットボットと同じです。ここには示していませんが、プロンプトに会話履歴を入れるプレースフォルダだけを追加しています。

RunnableWithMessageHistoryは、チャットボットのChainをラップする形で、プロンプトに会話履歴を追加したり、ユーザーの質問やLLMの回答を会話履歴に追加したりといったことを実現できます。

RunnableWithMessageHistoryの引数などについては、以下の記事で紹介していますので、必要に応じて参照してください。

https://zenn.dev/khisa/articles/7f56f4e66cae43

なお、RAGですので、ユーザーの入力(質問)に加えて、ナレッジから取得したコンテキスト(文書)もLLMに送られますが、会話履歴としては、ユーザーの入力とLLMの回答のみを保存しています。一般にコンテキストの分量は多く、これを会話履歴として毎回保存し、LLMに送信していると、あっという間にトークン数の上限に達してしまうためです。

動作確認

それでは実行してみます。先ほどの会話履歴なしのチャットボットと同じ質問をしてみます。

$ python .\rag_chatbot2.py
---
[0]あなた: 2024年春の青春18きっぷはいつ利用できますか? 
AI: 2024年春の青春18きっぷは、2024年3月1日(金)~4月10日(水)の期間に利用できます。 

---
[1]あなた: 夏は?
AI: 2024年夏の青春18きっぷは、7月20日(土)~9月10日(火)の期間に利用できます。 

---
[2]あなた: 価格は?
AI: 2024-25年シーズンの青春18きっぷの価格は、12,050円(おとな・こども同額)です。 

[1]の「夏は?」に対して、夏の青春18きっぷの利用期間を、[2]の「価格は?」に対して、青春18きっぷの価格を正確に答えています。

最初の[0]の質問「2024年春の青春18きっぷはいつ利用できますか?」や、その回答「2024年春の青春18きっぷは、2024年3月1日(金)~4月10日(水)の期間に利用できます。」が会話履歴としてLLMに渡されていますので、LLMはその文脈を理解して、[1]では「夏の青春18きっぷはいつ利用できるか?」、[2]では「青春18きっぷの価格は?」という質問であると正しく解釈しています。

もう一つやってみましょう。

$ python .\rag_chatbot2.py
---
[0]あなた: 青春18きっぷで本州から北海道へ移動する方法は?
AI: 青春18きっぷで本州から北海道へ移動するには、北海道新幹線オプション券を利用して、青森から函館まで移動する方法があります。ただし、乗り継ぎに時間がかかる場合があるので、フェリーなどの他の手段も検討しましょう。  フェリーは、夜行便を利用すれば宿泊費も節約できます。

---
[1]あなた: 一番安い方法は?
AI: 青春18きっぷで本州から北海道へ行く場合、一番安い方法はフェリーです。フェリーは、夜行便を利用すれば宿泊費も節約できます。

---
[2]あなた: おすすめのフェリーは?
AI: 青春18きっぷを使って本州から北海道へ行く場合、おすすめのフェリーは「青函フェリー」か「津軽海峡フェリー」です。どちらも青森港と函館港を結んでおり、2等運賃が2,000円前後とリーズナブルです。

---
[3]あなた: 鉄道で移動したい場合は?
AI: 鉄道で移動したい場合は、青春18きっぷと北海道新幹線オプション券を組み合わせるのがおすすめです。ただし、乗り継ぎが複雑で時間がかかるため、時間に余裕がある場合にのみ検討しましょう。

青春18きっぷの旅で本州~北海道を移動する方法を最初に質問しています。その後に「一番安い方法は?」「おすすめのフェリーは?」「鉄道で移動したい場合は?」と簡潔に質問していますが、いずれも「青春18きっぷの旅で本州と北海道を移動する場合」を前提に解答してくれています。

単純に会話履歴を渡すだけで、LLMが文脈を理解して回答してくれることが確認できました。

まとめ

RAGチャットボットでの会話履歴の効果を、実際にLangChainを用いて実装することで確認してみました。

ナレッジから取得したコンテキストを除外して、ユーザーの入力とLLMの回答を会話履歴として保存するだけでも、それなりに効果があることが確認できました。

この記事で紹介した例以外にもいろいろ試してみましたが、明らかに文脈を理解できていない回答はほとんどありませんでした。ナレッジの精度の関係で、回答が間違っていることはありますが、それは単独で質問しても間違えるので、会話履歴のせいではなさそうです。

一方で、LangChainのWebサイトでは、さらにRAGの性能を上げる方法が紹介されていますので、次回はそれを試してみたいと思います。

関連記事

会話履歴付きのチャットボットをLangChainのLCEL記法で実装する方法をまとめた記事です。この記事でも利用したRunnableWithMessageHistoryの使い方を中心に紹介しています。

https://zenn.dev/khisa/articles/7f56f4e66cae43

Discussion