🐰

LangChain で実装する RAG の仕組み

2024/05/06に公開

LangChain で RAG を試してみて内部の仕組みが分かったので備忘用にメモ

登場人物

名前 説明
RAG Chain RAG(Retrieval-Augmented Generation)を実現するための入出力を行う
QA Chain LLMへのQA(Question-Answering)の入出力を行う
Retriever Vectorstore に対して入力に関連した文書を取り出す役割
Vectorstore 事前に作成したドキュメントDB (文書を分割+Embeddingしたもの)

チャット履歴がない場合の動作

  • ユーザー入力をそのまま Vectorstore に対する問い合わせに使用する
  • コンテキストは Vectorstore から取得したドキュメントをの内容を単純に繋げた文字列
  • 事前に定義したプロンプトテンプレートを用いて以下のように組み合わせて問い合わせる
ロール
system アシスタントへの命令 + コンテキスト情報(ドキュメント)
human ユーザー入力

チャット履歴がある場合の動作

  • 履歴がある場合は Vectorstore への問い合わせを LLM を用いて生成する
    • ユーザー入力とチャット履歴を組み合わせて文脈を考慮した問い合わせ文にする
    • コンテキスト生成のための Vectorstore への問い合わせにのみ使用されるため最終的な LLM への問い合わせに使用されない
  • 事前に定義したプロンプトテンプレートを用いて以下のように組み合わせて問い合わせる
ロール
system アシスタントへの命令 + コンテキスト情報(ドキュメント)
human 質問:チャット履歴
ai 応答:チャット履歴
... (チャット履歴が続く)
human ユーザー入力

サンプルコード

import streamlit as st
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_community.vectorstores import FAISS

openai_api_key = st.secrets["openai"]["api_key"]


llm = llm = ChatOpenAI(
    temperature=0,
    model_name="gpt-3.5-turbo-1106",
    api_key=openai_api_key,
    max_retries=5,
    timeout=60,
)
vector = FAISS.load_local(
    "vectorstore/documents_faiss",
    embeddings=OpenAIEmbeddings(api_key=openai_api_key),
    allow_dangerous_deserialization=True,
)
retriever = vector.as_retriever()


# chat_history が与えられたときにLLMを用いてRetrieverに問いかけるクエリを生成する際に使用するプロンプト
contextualize_q_system_prompt = """Given a chat history and the latest user question \
which might reference context in the chat history, formulate a standalone question \
which can be understood without the chat history. Do NOT answer the question, \
just reformulate it if needed and otherwise return it as is."""
retriever_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", contextualize_q_system_prompt),
        MessagesPlaceholder("chat_history"),
        ("human", "{input}"),
    ]
)
history_aware_retriever = create_history_aware_retriever(
    llm, retriever, retriever_prompt
)


# LLMを用いてQAを行う際に使用するプロンプト
qa_system_prompt = """You are an assistant for question-answering tasks. \
Use the following pieces of retrieved context to answer the question. \
If you don't know the answer, just say that you don't know. \
Use three sentences maximum and keep the answer concise.\

{context}"""
qa_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", qa_system_prompt),
        MessagesPlaceholder("chat_history"),
        ("human", "{input}"),
    ]
)
question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)

rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)


def response_generator(question: str, chat_history=[]):
    response = rag_chain.invoke({"input": question, "chat_history": chat_history})
    chat_history.extend(
        [HumanMessage(content=question), AIMessage(content=response["answer"])]
    )
    return response, chat_history

参考

https://python.langchain.com/docs/use_cases/question_answering/

Discussion