Corrective RAG(CRAG)の概念と実装方法
こんにちは、酒井です!
株式会社 EGGHEAD(エッグヘッド)という「製造業で生成 AI を活用したシステム開発」をしている会社の代表をしております。
はじめに
本記事では、検索結果を自己評価し、必要に応じて追加情報源を活用する「Corrective RAG(CRAG)」について解説し、LangGraph を使った実装方法を紹介します。
Corrective RAG(CRAG)とは
CRAG の基本概念
Corrective RAG(CRAG)は、従来の RAG システムを拡張し、検索結果がユーザーの質問と関連しているかを評価(自己反省)することでRAGの精度を上げるための戦略です。CRAG では検索されたドキュメントを「Correct(正しい)」、「Incorrect(不正確)」、「Ambiguous(曖昧)」と三つに評価します。
評価後のアクション
Correct
検索結果が正しいと判断された場合、検索されたドキュメントを知識分解、フィルタリング、再構成によって、より正確な知識に絞り込みます。
Incorrect
検索結果が不正確だと判断された場合、検索されたドキュメントを破棄し、ウェブ検索を知識源として利用します。
Ambiguous
判断が難しい場合、CorrectとIncorrectの両方のアクションを組み合わせて、内部知識と外部知識を組み合わせます。
そして、検索された知識を再構築(decompose-then-recompose)して、さらに精度を高めます。これをした後にその知識とユーザーの質問からプロンプトを作成してLLMを使って回答を生成します。
CRAG のワークフロー
CRAG の基本的なワークフローは以下のステップで構成されます。
- ユーザーからの質問を受け取る
- ベクトルデータベースから関連文書を検索する
- 検索結果の関連性を評価する
- 関連性の高い文書が存在する場合は回答を生成する
- 関連性の高い文書が不足している場合は、クエリを最適化して Web 検索を行う
- Web 検索結果を含めて回答を生成する
※ 簡略化のため、あいまい(Ambiguous)の場合を省略しています。実際の論文中の図が以下です。
従来の RAG との違い
従来の RAG システムと比較した CRAG の主な違いは以下の通りです。
特徴 | 従来のRAG | CRAG |
---|---|---|
検索結果の使い方 | 検索結果をそのまま使う | 検索結果が質問に関係あるか確認する |
情報の整理 | 特に行わない | 文書を小さく分け、必要な情報だけを選び出す |
外部情報の利用 | 特に行わない | 検索結果だけでは足りない場合、Webなどから追加情報を集める |
検索キーワードの改善 | 特に行わない | より良い検索結果を得るために検索キーワードを工夫する |
CRAG の実装
実装準備
Google Colab での実装を前提とします。なお、ここからはLangGraphのチュートリアルを踏襲しています。
必要なライブラリのインストール
CRAG を実装するために、以下のライブラリをインストールします。
pip install langchain_community tiktoken langchain-openai langchainhub chromadb langchain langgraph tavily-python
API キーの設定
OpenAI API と Tavily Search API のキーを設定します。
import getpass
import os
def _set_env(key: str):
if key not in os.environ:
os.environ[key] = getpass.getpass(f"{key}:")
_set_env("OPENAI_API_KEY")
_set_env("TAVILY_API_KEY")
インデックスの作成
テスト用のデータとして、僕の前のブログ記事をインデックス化します。
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.vectorstores import Chroma
from langchain_openai import OpenAIEmbeddings
urls = [
"https://zenn.dev/shun_sakai/articles/22966db52b604d",
]
docs = [WebBaseLoader(url).load() for url in urls]
docs_list = [item for sublist in docs for item in sublist]
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
chunk_size=250, chunk_overlap=0
)
doc_splits = text_splitter.split_documents(docs_list)
# ベクトルDBに追加
vectorstore = Chroma.from_documents(
documents=doc_splits,
collection_name="rag-chroma",
embedding=OpenAIEmbeddings(),
)
retriever = vectorstore.as_retriever()
CRAG の主要コンポーネント
検索結果の評価機能
検索結果の関連性を評価するためのコンポーネントを実装します。
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
from pydantic import BaseModel, Field
# データモデル
class GradeDocuments(BaseModel):
"""検索された文書の関連性を評価するためのバイナリスコア"""
binary_score: str = Field(
description="文書が質問に関連しているかどうか、'yes'または'no'で評価"
)
# 関数呼び出し機能を持つLLM
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
structured_llm_grader = llm.with_structured_output(GradeDocuments)
# プロンプト
system = """あなたは検索された文書がユーザーの質問に関連しているかを評価する採点者です。
文書に質問に関連するキーワードや意味的な内容が含まれている場合、関連性があると評価してください。
文書が質問に関連しているかどうかを'yes'または'no'のバイナリスコアで示してください。"""
grade_prompt = ChatPromptTemplate.from_messages(
[
("system", system),
("human", "検索された文書: \n\n {document} \n\n ユーザーの質問: {question}"),
]
)
retrieval_grader = grade_prompt | structured_llm_grader
このコードでは、Pydantic モデルを使用して構造化された出力を定義し、LLM に文書の関連性を評価させています。
クエリ書き換え機能
検索結果が不十分な場合に、より効果的な Web 検索のためにクエリを最適化します。
# LLM
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
# プロンプト
system = """あなたは質問を書き換えて、Web検索に最適化された形式に変換する専門家です。
入力された質問を分析し、その背後にある意味的な意図を理解して、より良い検索結果が得られるように書き換えてください。"""
re_write_prompt = ChatPromptTemplate.from_messages(
[
("system", system),
(
"human",
"元の質問: \n\n {question} \n より良い質問に書き換えてください。",
),
]
)
from langchain_core.output_parsers import StrOutputParser
question_rewriter = re_write_prompt | llm | StrOutputParser()
Web 検索統合機能
Tavily Search API を使用して Web 検索を実装します:
from langchain_community.tools.tavily_search import TavilySearchResults
web_search_tool = TavilySearchResults(k=3)
回答生成機能
検索結果を基に回答を生成する機能を実装します:
from langchain import hub
from langchain_core.output_parsers import StrOutputParser
# プロンプト
prompt = hub.pull("rlm/rag-prompt")
# LLM
llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0)
# 後処理
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)
# チェーン
rag_chain = prompt | llm | StrOutputParser()
LangGraph による CRAG の実装
グラフ状態の定義
LangGraph を使用して CRAG のワークフローを実装するために、まずグラフの状態を定義します。
from typing import List
from typing_extensions import TypedDict
class GraphState(TypedDict):
"""
グラフの状態を表現します。
属性:
question: 質問
generation: LLMの生成結果
web_search: Web検索を追加するかどうか
documents: 文書のリスト
"""
question: str
generation: str
web_search: str
documents: List[str]
ノードの実装
グラフのノードとなる各機能を実装します。
from langchain.schema import Document
def retrieve(state):
"""
文書を検索します
Args:
state (dict): 現在のグラフ状態
Returns:
state (dict): 検索された文書を含む更新された状態
"""
print("---検索中---")
question = state["question"]
# 検索
documents = retriever.invoke(question)
return {"documents": documents, "question": question}
def generate(state):
"""
回答を生成します
Args:
state (dict): 現在のグラフ状態
Returns:
state (dict): 生成された回答を含む更新された状態
"""
print("---回答生成中---")
question = state["question"]
documents = state["documents"]
# RAG生成
generation = rag_chain.invoke({"context": documents, "question": question})
return {"documents": documents, "question": question, "generation": generation}
こちらはドキュメントの評価をするメソッドですが、抽出した複数のドキュメントの内一つでも関係ないものが含まれていたらWeb検索する仕様になっています。
def grade_documents(state):
"""
検索された文書が質問に関連しているかを評価します。
Args:
state (dict): 現在のグラフ状態
Returns:
state (dict): フィルタリングされた関連文書を含む更新された状態
"""
print("---文書の関連性を評価中---")
question = state["question"]
documents = state["documents"]
# 各文書を評価
filtered_docs = []
web_search = "No"
for d in documents:
score = retrieval_grader.invoke(
{"question": question, "document": d.page_content}
)
grade = score.binary_score
if grade == "yes":
print("---評価: 文書は関連性あり---")
filtered_docs.append(d)
else:
print("---評価: 文書は関連性なし---")
web_search = "Yes"
continue
return {"documents": filtered_docs, "question": question, "web_search": web_search}
def transform_query(state):
"""
より良い質問に変換します。
Args:
state (dict): 現在のグラフ状態
Returns:
state (dict): 書き換えられた質問を含む更新された状態
"""
print("---クエリ変換中---")
question = state["question"]
documents = state["documents"]
# 質問を書き換え
better_question = question_rewriter.invoke({"question": question})
return {"documents": documents, "question": better_question}
def web_search(state):
"""
書き換えられた質問に基づいてWeb検索を実行します。
Args:
state (dict): 現在のグラフ状態
Returns:
state (dict): Web検索結果を追加した更新された状態
"""
print("---Web検索中---")
question = state["question"]
documents = state["documents"]
# Web検索
docs = web_search_tool.invoke({"query": question})
web_results = "\n".join([d["content"] for d in docs])
web_results = Document(page_content=web_results)
documents.append(web_results)
return {"documents": documents, "question": question}
エッジと条件分岐の設定
ノード間の遷移を決定する条件分岐を実装します。
def decide_to_generate(state):
"""
回答を生成するか、質問を再生成するかを決定します。
Args:
state (dict): 現在のグラフ状態
Returns:
str: 次に呼び出すノードを示すバイナリ決定
"""
print("---評価された文書を分析中---")
state["question"]
web_search = state["web_search"]
state["documents"]
if web_search == "Yes":
# すべての文書がフィルタリングされた場合
# 新しいクエリを生成します
print(
"---決定: すべての文書が質問に関連していないため、クエリを変換します---"
)
return "transform_query"
else:
# 関連する文書があるため、回答を生成します
print("---決定: 回答を生成します---")
return "generate"
グラフのコンパイルと実行
最後に、定義したノードとエッジを使用してグラフを構築し、コンパイルします。
from langgraph.graph import END, StateGraph, START
workflow = StateGraph(GraphState)
# ノードを定義
workflow.add_node("retrieve", retrieve) # 検索
workflow.add_node("grade_documents", grade_documents) # 文書評価
workflow.add_node("generate", generate) # 生成
workflow.add_node("transform_query", transform_query) # クエリ変換
workflow.add_node("web_search_node", web_search) # Web検索
# グラフを構築
workflow.add_edge(START, "retrieve")
workflow.add_edge("retrieve", "grade_documents")
workflow.add_conditional_edges(
"grade_documents",
decide_to_generate,
{
"transform_query": "transform_query",
"generate": "generate",
},
)
workflow.add_edge("transform_query", "web_search_node")
workflow.add_edge("web_search_node", "generate")
workflow.add_edge("generate", END)
# コンパイル
app = workflow.compile()
実行例と評価
基本的な質問への対応
記事の内容をストレートに聞いてみます。
from pprint import pprint
# 実行
inputs = {"question": "AgenticRAGとは"}
for output in app.stream(inputs):
for key, value in output.items():
# ノード
pprint(f"ノード '{key}':")
pprint("\n---\n")
# 最終的な生成結果
pprint(value["generation"])
実行結果:
---検索中---
"ノード 'retrieve':"
'\n---\n'
---文書の関連性を評価中---
---評価: 文書は関連性あり---
---評価: 文書は関連性あり---
---評価: 文書は関連性あり---
---評価: 文書は関連性あり---
関連性のあるドキュメント数: 4/4
---ASSESS GRADED DOCUMENTS---
---決定: 回答を生成します---
"ノード 'grade_documents':"
'\n---\n'
---GENERATE---
"ノード 'generate':"
'\n---\n'
('AgenticRAGは、AIエージェントベースのRAG実装を指します。現在のRAGに関する概念と実装方法について説明されています。Agentic '
'RAGは複数の情報源からの検索や検索結果の検証を行うことができます。')
4つの情報源から情報を取得して、関連性があるので、Web検索をせずに回答を生成しています。
関連のない質問への対応
全く関連のない質問をしてみます。
# 実行
inputs = {"question": "明日の天気について教えて"}
for output in app.stream(inputs):
for key, value in output.items():
# ノード
pprint(f"ノード '{key}':")
pprint("\n---\n")
# 最終的な生成結果
pprint(value["generation"])
実行結果:
---検索中---
"ノード 'retrieve':"
'\n---\n'
---文書の関連性を評価中---
---評価: 文書は関連性なし---
---評価: 文書は関連性なし---
---評価: 文書は関連性なし---
---評価: 文書は関連性なし---
関連性のあるドキュメント数: 0/4
---ASSESS GRADED DOCUMENTS---
---決定: すべての文書が質問に関連していないため、クエリを変換します---
"ノード 'grade_documents':"
'\n---\n'
---TRANSFORM QUERY---
元の質問: 明日の天気について教えて
書き換えられた質問: 明日の天気予報を教えてください。
"ノード 'transform_query':"
'\n---\n'
---Web検索中---
ウェブ検索結果: 08日(土) 明日 08日(土) 明日 信頼度 - - - C B A B C 信頼度 - - - C B A B C 08日(土) 明日 信頼度 - - - C B A C C 08日(土) 明日 信頼度 - - - A B C A A 08日(土) 明日 信頼度 - -...
"ノード 'web_search_node':"
'\n---\n'
---GENERATE---
"ノード 'generate':"
'\n---\n'
'明日の天気は、東北では雪や雨が降り、太平洋側では大雪のおそれがあります。関東は雨が降る見込みで、東海や西日本は曇りが続くでしょう。全国の天気情報によると、明日の天気は各地で異なる状況が予想されます。'
きちんとWeb検索して正しい回答を返していますね。
まとめ
ここまで読んでいただきありがとうございました!
今回の実装では、抽出した情報がすべて関連性があると判断された時のみ、抽出した文章のみ使い、一つでも関連性がないと判断されたらWeb検索を行い、その結果をドキュメントに追加していました。この辺りのロジックは実際にRAGを行うときの要件によっても変わってきそうです。
Discussion