🍵

RAGのTutorialやってみた part5

2024/12/21に公開

前回したこと

  • StateやNode、Control FLowなどについて理解
  • graphの作成
  • 実際にRAGをgraphを使用して体験

前回は作成したgraphを使用して実際にクエリを入力してsyncとstreamで出力を得ました。
今回は入力テンプレートの変更と入力クエリをわかりやすい形に変換したら、ドキュメントのメタデータに項目をつけるなどしてみたいと思います。

テンプレートのカスタマイズ

入力テンプレートのカスタマイズは簡単です。前回定義したStateを元にテンプレートをカスタマイズしていきます。
Stateの定義

# Define state for application
class State(TypedDict):
    question: str
    context: List[Document]
    answer: str

このように定義されていましたね。テンプレートにquestion, contextを当てはめていきます。

from langchain_core.prompts import PromptTemplate

template = """Use the following pieces of context to answer the question at the end.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
Use three sentences maximum and keep the answer as concise as possible.
Always say "thanks for asking!" at the end of the answer.

{context}

Question: {question}

Helpful Answer:"""
custom_rag_prompt = PromptTemplate.from_template(template)

この例では自作のtemplateの雛形を作成して、{context}と{quetion}を当てはめるようなテンプレートとなっています。このcustom_rag_promptも前回使用したpromptインスタンスと同様の使用方法でcustom_rag_prompt.invoke(...)のように使用できます。

クエリ分析処理

クエリ分析とはretrieveする前、retrieveに条件を持たせてドキュメントのフィルタリングの時などに使用します。実際にドキュメントにsectionという項目をmetadataに持たせて今回はこれをフィルタリング対象にしましょう。

ドキュメントにsection項目を作成する

何もしていないとall_splitsの各ドキュメントはこのようになっています。

all_splits[0].metadata

出力: 'source': 'https://lilianweng.github.io/posts/2023-06-23-agent/'}
ここにsectionという項目を入れるために以下を実行します。

total_documents = len(all_splits)
third = total_documents // 3

for i, document in enumerate(all_splits):
    if i < third:
        document.metadata["section"] = "beginning"
    elif i < 2 * third:
        document.metadata["section"] = "middle"
    else:
        document.metadata["section"] = "end"

このようにすることで、documentのmetadataプロパティのsectionという項目が追加されました。
出力: {'source': 'https://lilianweng.github.io/posts/2023-06-23-agent/', 'section': 'beginning'}
最後にvectore_storeを更新します。

from langchain_core.vectorstores import InMemoryVectorStore

vector_store = InMemoryVectorStore(embeddings)
_ = vector_store.add_documents(all_splits)

sectionを利用したクエリ分析

流れ

  1. ユーザーがクエリを入力する
  2. llmを用いて自動的にクエリにsectionを割り当てる
  3. sectionが割り当てられたクエリをStateという形でretrieveに渡す
  4. いつも通り、retrieve->generateの流れ

このような流れとなっています。
そこでStateを再定義しましょう

from typing import Literal

from typing_extensions import Annotated


class Search(TypedDict):
    """Search query."""

    query: Annotated[str, ..., "Search query to run."]
    section: Annotated[
        Literal["beginning", "middle", "end"],
        ...,
        "Section to query.",
    ]
class State(TypedDict):
    question: str
    query: Search
    context: List[Document]
    answer: str

このようにStateを定義し直しました。変化項目はqueryという項目が増えたことです。このqueryというはSearch型です。さらにSearch型はqueryとsectionプロパティを持ちます。しかし、ユーザーは自分のクエリがどのsectionに属しているかやそもそもどのようなsectionが存在しているかなどわからないので、ここはllmで自動的に補完する必要があります。実際にそのコードを見てみましょう。

def analyze_query(state: State):
   structured_llm = llm.with_structured_output(Search)
   query = structured_llm.invoke(state["question"])
   return {"query": query}

このようになっています。最初のstructured_llm = llm.with_structured_output(Search)は出力がSearch型のを出力するようにしています。
次にquery = structured_llm.invoke(state["question"])ですが、入力としてquestionを受け取って、Search型の定義のAnnotatedに書かれているようにstructured_llmが出力を返却します。queryの部分は要約されたquestionが入り、section部分は"beginning", "middle", "end"のいづれかが入ります。(structured_llmが判断)これによってstructured_llm.invokeの返却値はSearch型のインスタンスを返すことができます。最後に、いつも通りの決まった形で{"query": query}を出力しgraphが自動的にSateのqueryの部分を更新して次のステップへ進みます。

sectionを用いたretrieve

次にretrieveは以下のようになります

def retrieve(state: State):
    query = state["query"]
    retrieved_docs = vector_store.similarity_search(
        query["query"],
        filter=lambda doc: doc.metadata.get("section") == query["section"],
    )
    return {"context": retrieved_docs}

基本的にやっていることが前回と同じです。今回はsimilarity_searchメソッドに渡すのが、questionではなく一度structured_llmで要約されたqueryを用いているのと、filterという引数に関数を渡していることが異なります。
流れとして、queryに似たドキュメントをいくつか取得してその中でsectionが同じものを選んでいるだけだと思います。

graphインスタンスを作成

ここまで来れば前回と同様です

def analyze_query(state: State):
    structured_llm = llm.with_structured_output(Search)
    query = structured_llm.invoke(state["question"])
    return {"query": query}


def retrieve(state: State):
    query = state["query"]
    retrieved_docs = vector_store.similarity_search(
        query["query"],
        filter=lambda doc: doc.metadata.get("section") == query["section"],
    )
    return {"context": retrieved_docs}


def generate(state: State):
    docs_content = "\n\n".join(doc.page_content for doc in state["context"])
    messages = prompt.invoke({"question": state["question"], "context": docs_content})
    response = llm.invoke(messages)
    return {"answer": response.content}


graph_builder = StateGraph(State).add_sequence([analyze_query, retrieve, generate])
graph_builder.add_edge(START, "analyze_query")
graph = graph_builder.compile()

スタート地点がanalyze_queryからになっただけですね

RAGでクエリを送る

最後に実査にクエリを送ってみましょう。

for step in graph.stream(
    {"question": "What does the beginning of the post say about Task Decomposition?"},
    stream_mode="updates",
):
    print(f"{step}\n\n----------------\n")

出力

出力の中身を見てみるとanalyze_queryのノードのqueryの部分がTask Decompositionと要約され、自動的にsectionがbeginningになっています。文章から内容を察してbeginningにしたのでしょう。また、ちゃんとbeginningのsectionから類似した文章を取得できているのもretrieveのノードから理解できます。

基本的なチュートリアルはここまでです。次回からはもう少し発展的な使用方法について説明します。
読んでいて気になる点があれば、公式ドキュメントを読んでみるのがいいかと思います。
https://python.langchain.com/docs/tutorials/rag/

Discussion