RAG Fusionを試してみる
RAG(Retrieval-Augmented Generation)の拡張手法がいろいろと提案されていますが、お手軽に試すことができて、LLMの処理時間やコストと性能のバランスが良さそうな RAG Fusion を試してみました。ユーザからの入力クエリを元に類似したクエリを生成AIに複数生成させ、それぞれのクエリでベクトル検索を実施、得られたチャンクをリランキングして、上位のチャンクのみを最終的なコンテキストとする手法です。
RAG Fusion とは?
RAG Fusion は、入力クエリに関連するクエリを複数生成し、それぞれでベクトル検索を実施、得られたチャンクをリランキングして、上位のチャンクをコンテキストとしてLLMに渡す手法です。
LangChainのブログに掲載されている以下の図が、RAG Fusion の手法を端的に示しています。
(出典)Query Transformations | LangChain
具体的には、以下のような流れになります。
- 入力クエリに対して、類似する複数のクエリをLLMに生成させる(図中の"Generate Similar Queries")
- 生成されたクエリのそれぞれに対してベクトル検索を実施してチャンクを取得(図中の"Vector Search Query")
- 各クエリに対して得られたチャンクをリランキングし、スコアが上位となるチャンクをコンテキスト情報とする(図中の"Reciprocal Rank Fusion")
- 元の入力クエリと、3.で得られたコンテキスト情報をLLMに渡して、最終的な回答を生成させる(図中の"Re-ranked Results" → "Generative Output")
手法としてはとてもシンプルです。ポイントとなるのは、最初にLLMに類似クエリを生成させるところ(上記1.)と、各クエリに対してベクトル検索で得られたチャンクをリランキングするところ(上記3.)です。
1.で類似クエリを複数生成することで、ベクトル検索を実施するときに、関連するチャンクを幅広く拾うことができます。また、3.でリランキングを行うことで、多くのクエリで上位となるチャンクを優先的に選択することができます。
通常のベクトル検索のみのRAGと比較すると、類似クエリを生成するステップでLLMへの問い合わせが発生するぶんだけ処理時間とコストがかかります。ただ、類似クエリの生成は、入力トークンも出力トークンも少ないため、処理時間とコストの増加はわずかです。
実装の概要
今回は、上記の流れに沿って、Langchainを利用して、RAG Fusionを実装してみました。コードは、以下のリポジトリを参考に、少し手を加えています。
動作を確認した環境は以下のとおりです。
- Windows 10
- Python 3.11.6
- Langchain 0.1.13
- Chroma 0.4.24
Vector StoreはChromaを利用しています。
以下でポイントとなる部分をかいつまんで説明していきますが、コード全体は以下のリポジトリに置いてあります。
リポジトリに置いてあるrag_fusion.py
を動作させるには、必要なPythonのパッケージをインストールしておいてください。また、LLMとEmbeddingsにOpenAI APIを利用しています。.env
ファイルを作成して、OPENAI_API_KEY
を環境変数として設定してください。
OPENAI_API_KEY="<key>"
実装の詳細
それでは、上で説明した流れに沿って、ポイントだけ紹介していきます。
類似クエリの生成
まずは、類似クエリを生成する関数query_generator
です。
def query_generator(original_query: dict) -> list[str]:
"""Generate queries from original query
Args:
query (dict): original query
Returns:
list[str]: list of generated queries
"""
# original query
query = original_query.get("query")
# prompt for query generator
prompt = ChatPromptTemplate.from_messages([
("system", "You are a helpful assistant that generates multiple search queries based on a single input query."),
("user", "Generate multiple search queries related to: {original_query}. When creating queries, please refine or add closely related contextual information in Japanese, without significantly altering the original query's meaning"),
("user", "OUTPUT (3 queries):")
])
# LLM model
model = ChatOpenAI(
temperature=0,
model_name=LLM_MODEL_OPENAI
)
# query generator chain
query_generator_chain = (
prompt | model | StrOutputParser() | (lambda x: x.split("\n"))
)
# gererate queries
queries = query_generator_chain.invoke({"original_query": query})
# add original query
queries.insert(0, "0. " + query)
# for TEST
print('Generated queries:\n', '\n'.join(queries))
return queries
元のクエリoriginal_query
をプロンプトに入れて、類似クエリの生成をLLMに依頼しているだけです。元のコードのプロンプトはGenerate multiple search queries related to: {original_query}.
だけだったのですが、かなり幅広い内容のクエリを生成してしまうので、もう少し範囲を狭めるために、When creating queries, please refine or add closely related contextual information in Japanese, without significantly altering the original query's meaning
を追加しています。
以下のようにLECL表記でChainを生成し、実行しています。
query_generator_chain = (
prompt | model | StrOutputParser() | (lambda x: x.split("\n"))
)
元のコードでは、元のクエリから類似クエリを4つ生成して返していたのですが、今回は、類似クエリ3つ+元のクエリの合計4つを返すようにしています。そのため、以下の部分で、元のクエリも追加しています。
# add original query
queries.insert(0, "0. " + query)
類似クエリでベクトル検索を実行
次に、複数の類似クエリからベクトル検索を実行する関数rrf_retriever
です。
def rrf_retriever(query: str) -> list[Document]:
"""RRF retriever
Args:
query (str): Query string
Returns:
list[Document]: retrieved documents
"""
# Retriever
retriever = create_retriever(search_type="similarity", kwargs={"k": TOP_K})
# RRF chain
chain = (
{"query": itemgetter("query")}
| RunnableLambda(query_generator)
| retriever.map()
| reciprocal_rank_fusion
)
# invoke
result = chain.invoke({"query": query})
return result
retriever = create_retriever(...)
は、ドキュメントを分割して、Chromaを利用したVector Storeを作成し、Retrieverを返しています。(詳細はリポジトリのコードをご覧ください)
以下のChainでベクトル検索を実施しています。
chain = (
{"query": itemgetter("query")}
| RunnableLambda(query_generator)
| retriever.map()
| reciprocal_rank_fusion
)
RunnableLambda(query_generator)
のところは、前述の類似クエリの生成をしています。
次のretriever.map()
は、query_generator
で生成した元クエリを含む類似クエリ4つに対して、それぞれベクトル検索を実施します。map()によって、4つのクエリに対して、それぞれ5つのチャンクを検索して取得しています。
最後のreciprocal_rank_fusion
でリランキングをしますが、これは次で説明します。
Reciprocal Rank Fusion (RRF) によるリランキング
次に、類似クエリごとにベクトル検索して得られたチャンクのリランキングを実施します。リランキングには、単純に類似度の順位だけを用いたReciprocal Rank Fusion (RRF) を利用します。
Reciprocal Rank Fusion (RRF) における文書dのスコア
ハイパーパラメータ
Reciprocal Rank Fusion (RRF) を計算する関数reciprocal_rank_fusion
は以下のとおりです。
def reciprocal_rank_fusion(results: list[list], k=60):
"""Rerank docs (Reciprocal Rank Fusion)
Args:
results (list[list]): retrieved documents
k (int, optional): parameter k for RRF. Defaults to 60.
Returns:
ranked_results: list of documents reranked by RRF
"""
fused_scores = {}
for docs in results:
for rank, doc in enumerate(docs):
doc_str = dumps(doc)
if doc_str not in fused_scores:
fused_scores[doc_str] = 0
fused_scores[doc_str] += 1 / (rank + k)
reranked_results = [
(loads(doc), score)
for doc, score in sorted(fused_scores.items(), key=lambda x: x[1], reverse=True)
]
# for TEST (print reranked documentsand scores)
print("Reranked documents: ", len(reranked_results))
for doc in reranked_results:
print('---')
print('Docs: ', ' '.join(doc[0].page_content[:100].split()))
print('RRF score: ', doc[1])
# return only documents
return [x[0] for x in reranked_results[:MAX_DOCS_FOR_CONTEXT]]
前述のretriever
は、関連度の大きい順に検索されたチャンクのリストを返しますので、それを順番に読みだしてRRFのスコアを計算しています。
# return only documents
return [x[0] for x in reranked_results[:MAX_DOCS_FOR_CONTEXT]]
テスト用にRRFスコアも表示するようにしていますが、LLMにコンテキストとして渡すのに必要なのはそのチャンクの内容のみであるため、チャンクのリストのみを返しています。
4つの類似クエリに対して、それぞれ5個ずつチャンクを取得していますので、最大でチャンク数は20となります。そのままだとコンテキストとしては多すぎるため、ここでは上位MAX_DOCS_FOR_CONTEXT
個ぶんだけ渡すようにしています。MAX_DOCS_FOR_CONTEXT
は8に設定してあります。
RAG Fusion 全体のChain
RAG Fusion 全体のChainはquery
関数で以下のように定義しています。
def query(query: str, retriever: BaseRetriever):
"""
Query with vectordb
"""
# model
model = ChatOpenAI(
temperature=0,
model_name=LLM_MODEL_OPENAI)
# prompt
prompt = PromptTemplate(
template=my_template_jp,
input_variables=["context", "question"],
)
# Query chain
chain = (
{
"context": itemgetter("question") | retriever,
"question": itemgetter("question")
}
| RunnablePassthrough.assign(
context=itemgetter("context")
)
| {
"response": prompt | model | StrOutputParser(),
"context": itemgetter("context"),
}
)
# execute chain
result = chain.invoke({"question": query})
return result
通常のRAGとまったく同じですが、引数のretriever
に、前述のrrf_retriever
を与えればRAG Fusionに、通常のベクトル検索のRetrieverを与えれば、通常のベクトル検索となります。
chain
の定義は複雑にみえますが、LLMの回答だけでなく、与えたコンテキストも出力させるようにしているためです。基本的には、retriever
→ prompt
→ model
とつなげているだけです。
動作確認・評価
それでは実際に動作させてみます。ドキュメントとしては、Wikipediaの「北陸新幹線」のページを読み込ませました。
TokenTextSplitter
でチャンクに分割していますが、chunk_size=2048
で、66個のチャンクに分割されました。
とりあえず、適当に質問を投げてみます。まずは、RAG Fusionでのクエリと回答です。
PS D:\Documents\work\rag-fusion> python .\rag_fusion.py -q 北陸新幹線の雪対策は?
Original document: 66 docs
Generated queries:
0. 北陸新幹線の雪対策は?
1. 北陸新幹線の雪対策は何が行われているか?
2. 北陸新幹線の雪対策にはどのような技術が使われているか?
3. 北陸新幹線の雪対策は冬季にどのように影響を及ぼしているか?
Reranked documents: 8
---
Docs: ��田SP - 糸魚川駅 - 新糸魚川SP間が50 Hz、新糸魚川SP - 黒部宇奈月温泉駅 - 金沢駅 - 敦賀駅間が60 Hzとなっている[32]。 また、新幹線の保安装置であるATC(自動列車制
RRF score: 0.06666666666666667
---
Docs: 対策のためホーム全体が屋根で覆われている。 JR東日本管内のうち比較的積雪量が少ない長野までの区間では高架橋の軌道下の路盤コンクリートを高くし、線路の両脇に雪を貯める貯雪方式を採用している。降雪量の多
RRF score: 0.06557377049180328
---
Docs: 活用した北陸新幹線着雪量推定モデル開発」『AI・データサイエンス論文集』第2巻第J2号、土木学会、2021年2月、687–990頁。doi:10.11532/jsceiii.2.J2_687。 井野俊
RRF score: 0.06451612903225806
---
Docs: 7009。 堀内義朗「整備新幹線と内需拡大」『土木学会論文集』第1987巻第385号、土木学会、1987年、5–19頁。doi:10.2208/jscej.1987.385_5。 御船直人、由川透、吉
RRF score: 0.047371031746031744
---
Docs: �川毅(中越パルプ工業創業者・当時の砺波商工会議所会頭)は、政府に対して東京を起点とし松本、立山連峰を貫通して富山、金沢を経由して大阪に至る「北陸新幹線」の建設を求めた[54]。この提案に、鉄道官僚出
RRF score: 0.03125
---
Docs: 陸新幹線トンネル掘削現場で崩落 地上のグラウンド陥没、直径15m」『福井新聞ONLINE』福井新聞社、2017年9月8日。2017年9月8日時点のオリジナルよりアーカイブ。2018年2月2日閲覧。
RRF score: 0.015873015873015872
---
Docs: �上1位更新”. NHK (2019年10月15日). 2019年10月18日閲覧。 ^ “大雨特別警報 一時13都県に発表”. NHK. 2019年10月22日閲 覧。 ^ “7都県に大雨
RRF score: 0.015873015873015872
---
Docs: 後 50 年間の平均値である。 ^ “平成18年度事業評価監視委員会 北陸新幹線(長野・金沢間)事業に関する対応方針”. 鉄道 ・運輸機構. p. 21. 2023年2月2日閲覧。 ^ “平成2
RRF score: 0.015625
---
Answer:
北陸新幹線の雪対策は、冬季においても安定輸送を維持するために、散水消雪方式や貯雪方式など様々な対策が施されている。具体的には、散水消雪方式や貯雪方式を採用し、新たな対策方法として消雪パネルの開発や温水パイプの設置などが行われている。また、トンネル緩衝口端部での散水や保守用斜路への散水消雪設備の導入など、周辺環境に合わせた対策が行われている。
「北陸新幹線の雪対策は?」というクエリに対して、類似クエリとして「北陸新幹線の雪対策は何が行われているか?」「北陸新幹線の雪対策にはどのような技術が使われているか?」「北陸新幹線の雪対策は冬季にどのように影響を及ぼしているか?」が生成されました。回答も、Wikipediaに記されている対策をある程度拾えているようです。
ちなみに、元クエリに対して単純にベクトル検索だけを実施した回答は以下のとおりです。
PS D:\Documents\work\rag-fusion> python .\rag_fusion.py -v 北陸新幹線の雪対策は?
Original document: 66 docs
---
Answer:
北陸新幹線の雪対策は、冬季においても安定輸送を維持するための対策が施されており、散水消雪方式や貯雪方式、消雪パネルの開発など様々な技術が導入されています。また、新幹線の高架橋内には雪覆いを設けるなど周辺環境に合わせた対策が行われています。
RAG Fusion の回答とあまり差がなさそうではありますね……。
もう一つ試してみます。今度は「北陸新幹線の建設主体と運営主体は?」というシンプルではありますが、2つのことを聞いている質問です。以下がRAG Fusionでの回答です。
PS D:\Documents\work\rag-fusion> python .\rag_fusion.py -q 北陸新幹線の建設主体と運営主体は?
Original document: 66 docs
Generated queries:
0. 北陸新幹線の建設主体と運営主体は?
1. 北陸新幹線の建設主体は誰ですか?
2. 北陸新幹線の運営主体は誰ですか?
3. 北陸新幹線の建設と運営を担当している組織は何ですか?
Answer:
建設主体:日本鉄道建設公団、独立行政法人鉄道建設・運輸施設整備支援機構
運営主体:JR東日本、JR西日本
次に単純なベクトル検索のみでの回答です。
PS D:\Documents\work\rag-fusion> python .\rag_fusion.py -v 北陸新幹線の建設主体と運営主体は?
Original document: 66 docs
---
Answer:
建設主体は日本鉄道建設公団、運営主体はJR東日本とJR西日本。
この例では、RAG Fusionのほうが漏れなく回答しています。類似クエリとして「北陸新幹線の建設主体は誰ですか?」「北陸新幹線の運営主体は誰ですか?」と、二つのクエリに分解して、それぞれベクトル検索をした効果が出ているようです。
まとめ
リランキングにReciprocal Rank Fusion (RRF)を利用した RAG Fusion を試してみました。
この記事で紹介した Wikipedia 以外にも、自分のブログの記事を全部読み込ませて試してみましたが、あくまで主観的な評価としては、以下のような感じです。
- 元クエリが単語のみ、「〇〇とは?」のようなシンプルな場合には、類似クエリを生成することで周辺情報を拾うことができ、回答の幅が広くなる傾向がある
- 元クエリが具体的な場合には、通常のベクトル検索とあまり変わらない
- 実質的に複数の内容を問う複合的な質問に関しては、類似クエリで質問を分解してくれることで、漏れなく回答できる確率が上がる
- 類似クエリを生成する際には、その幅広さをプロンプトで調整する必要がある
類似クエリを生成するという仕組み上、シンプルな質問に対する最終的な回答をリッチにする効果が大きいと感じます。
一方で、類似クエリの幅広さをどの程度にするかは、最終的な回答に影響しますので、類似クエリの生成を依頼するプロンプトで調整する必要があります。単に「類似クエリを〇個作って」だけだと、かなり幅広い類似クエリを生成してきます。回答に周辺情報を多く含めたい場合には良いですが、元ドキュメントにその情報がなければ役に立ちません。
今回は、ほぼ同じ内容で文言の異なる類似クエリを生成することで、似たようなチャンクのヒット率を上げる方向にしてみましたが、このあたりは、RAG Fusion をどのような目的で利用するかによって調整の余地がありそうです。
関連記事
Wordpressのブログ記事の構造を元にチャンクに分割してQAボットをつくったときの記事です。
Multivector Retriever を利用してQAボットをつくった時の記事です。
Discussion