Building an Advanced Fusion Retriever from Scratch 解説
Building an Advanced Fusion Retriever from Scratch 解説
- 元記事の解説
- Query Fusion が気になる方向け
- Jupyter Notebook 上で実行
- Windows 環境
概要
このチュートリアルでは、スクラッチから高度な「QueryFusionRetriever」を構築する方法を説明します。このプロセスは、RAG-fusionのリポジトリに大きく触発されています。
セットアップ
まず、文書を読み込み、シンプルなベクトルインデックスを構築します。
!pip install rank-bm25 pymupdf
import nest_asyncio
nest_asyncio.apply()
ドキュメントの読み込み
!mkdir data
!wget --user-agent "Mozilla" "https://arxiv.org/pdf/2307.09288.pdf" -O "data/llama2.pdf"
または
import os
import requests
# データを保存するディレクトリを作成
os.makedirs('data', exist_ok=True)
# ダウンロードするファイルのURL
url = "https://arxiv.org/pdf/2307.09288.pdf"
# ダウンロードとファイルの保存
response = requests.get(url)
with open('data/llama2.pdf', 'wb') as f:
f.write(response.content)
必要があればインストールする。
!pip install llama-index
!pip install llama_hub
import openai
openai.api_key = "sk-..."
ここで、PDF文書をダウンロードし、それを読み込みます。
from pathlib import Path
from llama_hub.file.pymu_pdf.base import PyMuPDFReader
loader = PyMuPDFReader()
documents = loader.load(file_path="./data/llama2.pdf")
ベクトルストアへの読み込み
from llama_index import VectorStoreIndex, ServiceContext
service_context = ServiceContext.from_defaults(chunk_size=1024)
index = VectorStoreIndex.from_documents(documents, service_context=service_context)
読み込んだ文書をベクトルストアにロードします。
LLM(大規模言語モデル)の定義
from llama_index.llms import OpenAI
llm = OpenAI(model="gpt-3.5-turbo")
ここで、OpenAIのGPT-3.5モデルを使用するLLMを定義します。
高度なレトリバーの定義
高度なレトリバーは、以下のステップで機能します。
-
クエリ生成/リライティング: 元のユーザークエリに基づいて複数のクエリを生成します。
-
各クエリに対する検索の実行: 複数のレトリバーを用いて各クエリに対して検索を行います。
-
リランキング/融合: 全てのクエリからの結果を融合し、トップの関連結果にリランキングステップを適用します。
このチュートリアルのステップは、複数のクエリに対するベクトル検索を実行するプロセスを説明しています。各ステップを詳しく見ていきましょう。
ステップ1: クエリ生成/リライティング
このステップでは、元のクエリから複数の関連クエリを生成し、検索結果の精度と再現率を向上させることを目指します。これは、ChatGPTを使ったプロンプト作成によって行います。
from llama_index import PromptTemplate
query_str = "How do the models developed in this work compare to open-source chat models based on the benchmarks tested?"
query_gen_prompt_str = (
"You are a helpful assistant that generates multiple search queries based on a "
"single input query. Generate {num_queries} search queries, one on each line, "
"related to the following input query:\n"
"Query: {query}\n"
"Queries:\n"
)
query_gen_prompt = PromptTemplate(query_gen_prompt_str)
def generate_queries(llm, query_str: str, num_queries: int = 4):
fmt_prompt = query_gen_prompt.format(
num_queries=num_queries - 1, query=query_str
)
response = llm.complete(fmt_prompt)
queries = response.text.split("\n")
return queries
queries = generate_queries(llm, query_str, num_queries=4)
print(queries)
['1. What are the benchmarks used to evaluate open-source chat models?', '2. Can you provide a comparison between the models developed in this work and existing open-source chat models?', '3. Are there any notable differences in performance between the models developed in this work and open-source chat models based on the benchmarks tested?']
このコードは、指定された入力クエリに基づいて複数の検索クエリを生成することを目的としています。生成されたクエリは、元のクエリに関連しており、異なる側面や詳細を探求することができます。
ステップ2: 各クエリに対するベクトル検索の実行
このステップでは、生成された各クエリに対して検索を実行します。これは、各ベクトルストアから最も関連性の高いトップkの結果を取得することを意味します。
from tqdm.asyncio import tqdm
import asyncio
async def run_queries(queries, retrievers):
tasks = []
for query in queries:
for i, retriever in enumerate(retrievers):
tasks.append(retriever.aretrieve(query))
task_results = await tqdm.gather(*tasks)
results_dict = {}
for i, (query, query_result) in enumerate(zip(queries, task_results)):
results_dict[(query, i)] = query_result
return results_dict
このコードは、非同期プログラミングとasyncio
ライブラリを使用して、複数のクエリを複数のレトリバーで同時に実行し、結果を集めるプロセスを実装しています。それぞれの部分について詳しく説明します。
非同期処理分かんないマン向け
- async def
async def
は、非同期関数を定義するためのキーワードです。これは、関数内で非同期処理(await
を使用した処理など)を行うことを示します。非同期関数は、関数の実行を待機せずに、他のタスクを同時に実行することができます。
async def run_queries()
関数は、複数のクエリを同時に実行します。非同期プログラミングを使用することで、各クエリの実行を個別のタスクとして扱い、これらのタスクを並行して実行することが可能になります。
async def run_queries(queries, retrievers):
この行で、run_queries
という非同期関数を定義しています。
- tqdm.asyncio
tqdm
はプログレスバーを表示するためのライブラリです。tqdm.asyncio
はtqdm
の非同期処理に特化したバージョンで、非同期タスクの進行状況を表示します。
- await tqdm.gather(*tasks)
await
は、非同期関数(asyncで定義された関数)やコルーチン(協調的ルーチンの略)の実行が完了するまで待機するために使用されるキーワードです。これにより、その関数が終わるまでプログラムの実行を一時停止し、他のタスクにCPUの処理を切り替えることができます。
asyncio.gather()
は、複数の非同期タスク(コルーチン)を同時に開始し、すべてのタスクが完了するのを待つ関数です。これにより、複数のタスクを効率的に並行して実行できます。
*tasks
は、リストやタプルのようなイテラブルのすべての要素を個別の引数として展開するために使用されるアスタリスク(スプラット)演算子です。この場合、*tasks
はtasks
リスト内のすべての非同期タスクをasyncio.gather()
に個別の引数として渡します。
ここでawait
は、asyncio.gather
の呼び出しが完了するまでプログラムの実行を一時停止します。つまり、リスト内のすべての非同期タスクが完了するまで待機します。
task_results = await tqdm.gather(*tasks)
この行では、tasks
リスト内のすべての非同期タスクをasyncio.gather()
を使用して同時に実行し、それらが完了するのを待機しています。その進行状況はtqdm
のプログレスバーで表示されます。
結果の集約
関数の残りの部分では、各タスクの結果(クエリの実行結果)を集め、クエリごとに結果を格納するための辞書(results_dict
)を作成しています。これにより、後で結果を処理しやすくなります。
# get retrievers
from llama_index.retrievers import BM25Retriever
# vector retriever
vector_retriever = index.as_retriever(similarity_top_k=2)
# bm25 retriever
bm25_retriever = BM25Retriever.from_defaults(
docstore=index.docstore, similarity_top_k=2
)
results_dict = await run_queries(queries, [vector_retriever, bm25_retriever])
ここでのポイントは、複数のレトリバーを使用して各クエリに対して検索を実行し、それぞれのレトリバーからの結果を集約することです。このプロセスでは、非同期処理を使用して効率的にタスクを処理しています。最終的に、すべてのクエリの結果がresults_dict
に保存されます。
このチュートリアルでは、実際に2種類の異なる検索手法を使用しています。一つはベクトル検索、もう一つはTF-IDFベースの検索です。これらの手法は、テキストデータを検索するための異なるアプローチを提供します。
ベクトル検索(Vector Retrieval)
-
原理:
- ベクトル検索では、ドキュメントやクエリを多次元のベクトル空間内の点として表現します。この表現は、通常、機械学習モデル(特に深層学習モデル)によって生成されます。
- ベクトルは、テキストの意味的な特徴を捉え、それらを数値の形で表現します。
-
検索プロセス:
- クエリのベクトルとドキュメントのベクトル間の類似度を計算します(例えば、コサイン類似度など)。
- 類似度スコアに基づいて、最も関連性の高いドキュメント(トップk)を選択します。
-
利点:
- 意味的な関連性に基づいて検索が可能。
- 大量のデータセットに対して効率的。
TF-IDFベースの検索(BM25 Retrieval)
-
原理:
- TF-IDF(Term Frequency-Inverse Document Frequency)は、単語の重要性を評価するための古典的な手法です。BM25はTF-IDFの改良版と考えることができます。
- TF-IDF/BM25では、特定の単語がドキュメント内でどの程度頻繁に登場し、その単語がどの程度希少か(他のドキュメントにはあまり登場しないか)に基づいて重み付けします。
-
検索プロセス:
- クエリ内の各単語に対してTF-IDF/BM25スコアを計算し、これらを用いてドキュメントとの関連度を評価します。
- スコアに基づいて、最も関連性の高いドキュメントを選択します。
-
利点:
- 単純で解釈しやすい。
- 計算効率が高く、小規模から中規模のデータセットに適しています。
組み合わせによる利点
ベクトル検索とTF-IDF/BM25検索を組み合わせることで、検索結果の品質を向上させることができます。ベクトル検索は意味的な関連性に優れている一方で、TF-IDF/BM25検索は特定のキーワードやフレーズに基づく検索に強いです。この二つを組み合わせることで、より包括的で正確な検索結果を得ることが可能になります。
ステップ3: 結果の融合(Fusion)
このステップでは、複数のレトリバーからの結果を組み合わせて再ランキングする作業を行います。異なるレトリバーから同じノード(情報の単位)が複数回取得される可能性があるため、重複を取り除き、複数の取得結果に基づいてノードを再ランキングする方法が必要です。
「相互ランク融合」の実行方法
- 各ノードに対して、それが取得されたすべてのリストでの逆順位(reciprocal rank)を加算します。
- スコアが最も高いノードから最も低いノードへと順位を並べ替えます。
コードの詳細解説
def fuse_results(results_dict, similarity_top_k: int = 2):
k = 60.0 # 外れ値のランキングの影響を制御するパラメータ
fused_scores = {} # 統合されたスコアを格納する辞書
text_to_node = {} # テキスト内容とノードの対応を格納する辞書
# 逆順位スコアの計算
for nodes_with_scores in results_dict.values():
for rank, node_with_score in enumerate(
sorted(
nodes_with_scores, key=lambda x: x.score or 0.0, reverse=True
)
):
text = node_with_score.node.get_content()
text_to_node[text] = node_with_score
if text not in fused_scores:
fused_scores[text] = 0.0
fused_scores[text] += 1.0 / (rank + k)
# スコアに基づいて結果を降順で並べ替える
reranked_results = dict(
sorted(fused_scores.items(), key=lambda x: x[1], reverse=True)
)
# ノードスコアの調整
reranked_nodes: List[NodeWithScore] = []
for text, score in reranked_results.items():
reranked_nodes.append(text_to_node[text])
reranked_nodes[-1].score = score
return reranked_nodes[:similarity_top_k]
この部分では、各ノードに対して「逆順位スコア」を計算しています。つまり、ノードが複数のリストに登場する場合、それぞれのリストでの順位の逆数を加算してスコアを求めます。これにより、複数のリストで高いランクを得たノードが高いスコアを得ることになります。
最終結果の表示
final_results = fuse_results(results_dict)
from llama_index.response.notebook_utils import display_source_node
for n in final_results:
display_source_node(n, source_length=500)
ここでは、融合された最終結果を取得し、それらのノードを表示しています。これにより、複数の検索手法を組み合わせた結果を効果的に活用できます。
Node ID: d92e53b7-1f27-4129-8d5d-dd06638b1f2d
Similarity: 0.04972677595628415
Text: Figure 12: Human evaluation results for Llama 2-Chat models compared to open- and closed-source models across ~4,000 helpfulness prompts with three raters per prompt. The largest Llama 2-Chat model is competitive with ChatGPT. Llama 2-Chat 70B model has a win rate of 36% and a tie rate of 31.5% relative to ChatGPT. Llama 2-Chat 70B model outperforms PaLM-bison chat model by a large percentage on our prompt set. More results and analysis is available in Section A.3.7. Inter-Rater Reliability (…
Node ID: 20d32df8-e16e-45fb-957a-e08175e188e8
Similarity: 0.016666666666666666
Text: Figure 1: Helpfulness human evaluation results for Llama 2-Chat compared to other open-source and closed-source models. Human raters compared model generations on ~4k prompts consisting of both single and multi-turn prompts. The 95% confidence intervals for this evaluation are between 1% and 2%. More details in Section 3.4.2. While reviewing these results, it is important to note that human evaluations can be noisy due to limitations of the prompt set, subjectivity of the review guidelines, s…
- Node ID: これは各検索結果(ノード)を一意に識別するためのIDです。このIDは、通常、ノードの内容(テキスト)のハッシュ値などに基づいて生成されます。ユニークなIDにより、同じ内容を持つノードが複数回取得された場合でも、それらを区別し、重複を避けることができます。
-
Similarity: これは、検索クエリと各ノードの関連性を数値化したスコアです。このスコアは、
fuse_results
関数によって計算された「逆順位スコア」に基づいており、複数のレトリバーからの結果を統合した後の関連性の尺度を示しています。
-
スコアの集計:
- 各ノードに対して、複数のリストでのランクに基づいてスコアを計算します。
k
は外れ値のランキングの影響を制御するためのパラメータです。
- 各ノードに対して、複数のリストでのランクに基づいてスコアを計算します。
-
結果の並べ替え:
- 統合したスコアに基づいて、結果を降順で並べ替えます。
-
ノードスコアの調整:
- 並べ替えた結果に基づいて、各ノードのスコアを調整します。
結論
このステップでは、異なるレトリバーからの結果を効果的に統合することで、より関連性の高い結果を取得することを目指しています。この方法は、情報の取得精度を向上させ、特定のクエリに対してより有用な結果を提供するために重要です。相互ランク融合により、各ノードの重要性を総合的に評価し、最終的なランキングを形成しています。
Plug into RetrieverQueryEngine
このチュートリアルのセクションでは、カスタムレトリバー(FusionRetriever)を定義し、これをRetrieverQueryEngine
に組み込むプロセスを説明しています。RetrieverQueryEngine
は、情報の検索と回答の生成(合成)を行うためのエンジンです。
FusionRetrieverクラスの定義
FusionRetriever
は、複数のレトリバーからの結果を統合するアンサンブルレトリバーです。このクラスはBaseRetriever
から派生しており、カスタムの検索ロジックを実装します。
from llama_index import QueryBundle
from llama_index.retrievers import BaseRetriever
from typing import Any, List
from llama_index.schema import NodeWithScore
class FusionRetriever(BaseRetriever):
def __init__(self, llm, retrievers: List[BaseRetriever], similarity_top_k: int = 2):
self._retrievers = retrievers
self._similarity_top_k = similarity_top_k
super().__init__()
def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
queries = generate_queries(llm, query_str, num_queries=4)
results = run_queries(queries, [vector_retriever, bm25_retriever])
final_results = fuse_results(results_dict, similarity_top_k=self._similarity_top_k)
return final_results
このクラスでは、以下のステップが行われます:
- クエリの生成: 元のクエリから複数の関連クエリを生成します。
- クエリの実行: 生成されたクエリに対して、指定された複数のレトリバーで検索を実行します。
- 結果の融合: 検索結果を統合し、最終結果を生成します。
QueryBundleについて
-
QueryBundle
は、検索クエリに関連するデータを一つにまとめたデータクラスです。 -
query_str
: ユーザーが指定した元のクエリ文字列。これは、埋め込み(embedding)ベースでないすべてのクエリで使用されます。 -
custom_embedding_strs
: クエリを埋め込むために使用される文字列のリスト。これは、埋め込みベースのクエリで使用されます。 -
embedding
: クエリの埋め込みを保存するための浮動小数点数のリスト。 -
embedding_image
とembedding_strs
: クエリを画像検索やカスタムの埋め込み文字列を使用する検索に適用するためのプロパティ。
QueryBundle
は、異なる種類の検索(テキストベース、画像ベース、カスタム埋め込みベース)をサポートするための柔軟性を提供し、検索エンジンにクエリ情報を効率的に渡すために使用されます。
RetrieverQueryEngineの使用
FusionRetriever
をRetrieverQueryEngine
に組み込み、クエリに対する応答を生成します。
from llama_index.query_engine import RetrieverQueryEngine
fusion_retriever = FusionRetriever(llm, [vector_retriever, bm25_retriever], similarity_top_k=2)
query_engine = RetrieverQueryEngine(fusion_retriever)
response = query_engine.query(query_str)
print(str(response))
このコードでは、query_engine.query(query_str)
を使用してクエリを実行し、print(str(response))
で応答を表示します。
実行結果の解析
The models developed in this work, specifically the Llama 2-Chat models, are competitive with open-source chat models based on the benchmarks tested. The largest Llama 2-Chat model has a win rate of 36% and a tie rate of 31.5% relative to ChatGPT, which indicates that it performs well in comparison. Additionally, the Llama 2-Chat 70B model outperforms the PaLM-bison chat model by a large percentage on the prompt set used for evaluation. While it is important to note the limitations of the benchmarks and the subjective nature of human evaluations, the results suggest that the Llama 2-Chat models are on par with or even outperform open-source chat models in certain aspects.
最終的に得られる応答は、「Llama 2-Chat」モデルがオープンソースのチャットモデルと比較して競争力があること、特に評価に使用されたプロンプトセットにおいて「Llama 2-Chat 70B」モデルが「PaLM-bison」チャットモデルを大きく上回っていることを示しています。
注意点
-
run_queries
は非同期関数であり、await
キーワードを使って呼び出す必要があります。これが行われていない場合、RuntimeWarning
が発生します。 - このコードは実際の応答生成において複数の検索手法を組み合わせ、より包括的で精度の高い情報を提供することを目指しています。
Discussion