🎉

Self-RAGについて

2024/02/17に公開

langgraph_self_rag.ipynb

Self-RAGとは、自己反映を利用して情報検索や生成の質を向上させるための手法。
質問に関する情報を検索し、それが質問に関連しているか、生成された回答が正確か、そして回答が質問に役立つかどうかを評価するプロセスを経ることで、より良い結果を得ることを目指す。
各ステップでトークンを生成し、質問の文脈で情報の断片(チャンク)が役立つかどうかを評価します。
評価は、関連性、支持度、有用性に基づいて行われ、この判断によって情報の断片が取り込まれるか、破棄されるかが決定される。
これはグラフとして表現され、意思決定プロセスが視覚化されます。

この画像に描かれている脳は、Self-RAGのプロセスの一部である「評価フェーズ」を象徴的に表しています。
ここでの「脳」は、生成された回答が情報の断片に基づいているか(Support)、そしてその回答が質問に対して有用であるか(Use)を評価する、モデルの意思決定プロセスの中心的な部分を表していると考えられます。つまり、モデルが情報の断片から生成した回答の質と関連性を「考える」部分です。

この図は「Self RAG LangGraph」というプロセスを表しています。このプロセスは、質問に基づいてドキュメントを取得(retrieve)し、そのドキュメントの質を評価(grade_documents)します。質の高いドキュメントがあれば、それを基にして回答を生成(generate)し、生成された回答がドキュメントを支持しているか(grade_generation_v_documents)を評価します。その後、生成された回答が質問に対して有用かどうかを評価(grade_generation_v_question)し、有用でなければ質問を変換(transform_query)してプロセスを繰り返します。各ステップは条件付きで進み、この流れによって最終的な回答が決定されます。

ここからコード

retrieve
与えられた質問に関連するドキュメントを取得する責任があります。システムの現在の状態(質問を含む)を取り、関連するドキュメントをフェッチするためのツールを使用します。これらのドキュメントはその後、状態に追加されます。

def retrieve(state):
    """
    Retrieve documents

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): New key added to state, documents, that contains retrieved documents
    """
    print("---RETRIEVE---")
    state_dict = state["keys"]
    question = state_dict["question"]
    documents = retriever.get_relevant_documents(question)
    return {"keys": {"documents": documents, "question": question}}

generate
ドキュメントが取得された後、この関数は質問に対する答えをGPT-3.5のような言語モデルを使って生成します。LangChainのハブからプロンプトを使用し、それを言語モデルに送信し、出力から生成された答えを抽出して解析します。

def generate(state):a
    """
    Generate answer

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): New key added to state, generation, that contains LLM generation
    """
    print("---GENERATE---")
    state_dict = state["keys"]
    question = state_dict["question"]
    documents = state_dict["documents"]

    # Prompt
    prompt = hub.pull("rlm/rag-prompt")

    # LLM
    llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0)

    # Post-processing
    def format_docs(docs):
        return "\n\n".join(doc.page_content for doc in docs)

    # Chain
    rag_chain = prompt | llm | StrOutputParser()

    # Run
    generation = rag_chain.invoke({"context": documents, "question": question})
    return {
        "keys": {"documents": documents, "question": question, "generation": generation}
    }

ここでのkeysは例としてこんなパラメータが入る

state = {
    "keys": {
        "question": "日本の首都はどこですか?",
        "documents": [
            {"title": "東京について", "page_content": "東京は日本の首都であり、最大の都市です。"},
            {"title": "日本の都市", "page_content": "日本には多くの都市がありますが、首都は東京です。"}
        ],
        "generation": "東京は、日本の首都であり、豊かな歴史と文化を持つ大都市です。"
    }
}

grade_documents
取得したドキュメントが質問に関連があるかどうかを評価する関数です。言語モデルとバイナリ評価システムを使用して、ドキュメントを関連があるかないかとしてマークします。

def grade_documents(state):
    """
    Determines whether the retrieved documents are relevant to the question.

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): Updates documents key with relevant documents
    """

    print("---CHECK RELEVANCE---")
    state_dict = state["keys"]
    question = state_dict["question"]
    documents = state_dict["documents"]

    # Data model
    class grade(BaseModel):
        """Binary score for relevance check."""

        binary_score: str = Field(description="Relevance score 'yes' or 'no'")

    # LLM
    model = ChatOpenAI(temperature=0, model="gpt-4-0125-preview", streaming=True)

    # Tool
    grade_tool_oai = convert_to_openai_tool(grade)

    # LLM with tool and enforce invocation
    llm_with_tool = model.bind(
        tools=[convert_to_openai_tool(grade_tool_oai)],
        tool_choice={"type": "function", "function": {"name": "grade"}},
    )

    # Parser
    parser_tool = PydanticToolsParser(tools=[grade])

    # Prompt
    prompt = PromptTemplate(
      """
        あなたは、検索された文書とユーザーの質問との関連性を評価する採点者です。\n 
        これが検索された文書です: \コンテキスト \
        これがユーザの質問です: {question}です。\n
        文書がユーザの質問に関連するキーワードや意味的な意味を含む場合、関連性があると評定する。\n
        文書が質問に関連しているかどうかを示すために、'yes'か'no'のバイナリスコアを与えます。
      """
        input_variables=["context", "question"],
    )

    # Chain
    chain = prompt | llm_with_tool | parser_tool

    # Score
    filtered_docs = []
    for d in documents:
        score = chain.invoke({"question": question, "context": d.page_content})
        grade = score[0].binary_score
        if grade == "yes":
            print("---GRADE: DOCUMENT RELEVANT---")
            filtered_docs.append(d)
        else:
            print("---GRADE: DOCUMENT NOT RELEVANT---")
            continue

    return {"keys": {"documents": filtered_docs, "question": question}}

transform_query
取得したドキュメントが関連していない場合や、生成された答えが満足できない場合、この関数はより良い質問にクエリを変換するために使用され、次の反復での検索または生成の結果を改善することを目指します。

def transform_query(state):
    """
    Transform the query to produce a better question.

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): Updates question key with a re-phrased question
    """

    print("---TRANSFORM QUERY---")
    state_dict = state["keys"]
    question = state_dict["question"]
    documents = state_dict["documents"]

    # Create a prompt template with format instructions and the query
    prompt = PromptTemplate(
        template="""
あなたは検索に最適化された質問を生成しています。\n 
        入力を見て、根底にある意味的な意図/意味を推論しようとする。\n 
        これが最初の質問です:
        \質問
        質問 
        \疑問文
        改善された質問を作成します:
""",
        input_variables=["question"],
    )

    # Grader
    model = ChatOpenAI(temperature=0, model="gpt-4-0125-preview", streaming=True)

    # Prompt
    chain = prompt | model | StrOutputParser()
    better_question = chain.invoke({"question": question})

    return {"keys": {"documents": documents, "question": better_question}}

prepare_for_final_grade
最終評価ステップの準備のためのパススルー(入力をそのまま出力として返す関数)関数です。
また、実行の流れを決定する「Edges」というものがあります

def prepare_for_final_grade(state):
    """
    Passthrough state for final grade.

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): The current graph state
    """

    print("---FINAL GRADE---")
    state_dict = state["keys"]
    question = state_dict["question"]
    documents = state_dict["documents"]
    generation = state_dict["generation"]

    return {
        "keys": {"documents": documents, "question": question, "generation": generation}
    }

decide_to_generate
ドキュメントの関連性に基づいて、答えを生成するか、クエリを変換するかを決定します。

def decide_to_generate(state):
    """
    Determines whether to generate an answer, or re-generate a question.

    Args:
        state (dict): The current state of the agent, including all keys.

    Returns:
        str: Next node to call
    """

    print("---DECIDE TO GENERATE---")
    state_dict = state["keys"]
    question = state_dict["question"]
    filtered_documents = state_dict["documents"]

    if not filtered_documents:
        # All documents have been filtered check_relevance
        # We will re-generate a new query
        print("---DECISION: TRANSFORM QUERY---")
        return "transform_query"
    else:
        # We have relevant documents, so generate answer
        print("---DECISION: GENERATE---")
        return "generate"

grade_generation_v_documents
生成された答えが取得したドキュメントに基づいているかを評価します。

def grade_generation_v_documents(state):
    """
    Determines whether the generation is grounded in the document.

    Args:
        state (dict): The current state of the agent, including all keys.

    Returns:
        str: Binary decision
    """

    print("---GRADE GENERATION vs DOCUMENTS---")
    state_dict = state["keys"]
    question = state_dict["question"]
    documents = state_dict["documents"]
    generation = state_dict["generation"]

    # Data model
    class grade(BaseModel):
        """Binary score for relevance check."""

        binary_score: str = Field(description="Supported score 'yes' or 'no'")

    # LLM
    model = ChatOpenAI(temperature=0, model="gpt-4-0125-preview", streaming=True)

    # Tool
    grade_tool_oai = convert_to_openai_tool(grade)

    # LLM with tool and enforce invocation
    llm_with_tool = model.bind(
        tools=[convert_to_openai_tool(grade_tool_oai)],
        tool_choice={"type": "function", "function": {"name": "grade"}},
    )

    # Parser
    parser_tool = PydanticToolsParser(tools=[grade])

    # Prompt
    prompt = PromptTemplate(
        template="""
あなたは採点者であり、答えが一連の事実に基づいているかどうかを評価する。\n 
        これが事実です:
        \n ------- \n
        {documents} 
        \n ------- \n
        これが答えです: 世代
        答えが一連の事実に根拠があるか/裏付けがあるかを示すために、「はい」または「いいえ」の二値スコアを与えてください。
       """,
        input_variables=["generation", "documents"],
    )

    # Chain
    chain = prompt | llm_with_tool | parser_tool

    score = chain.invoke({"generation": generation, "documents": documents})
    grade = score[0].binary_score

    if grade == "yes":
        print("---DECISION: SUPPORTED, MOVE TO FINAL GRADE---")
        return "supported"
    else:
        print("---DECISION: NOT SUPPORTED, GENERATE AGAIN---")
        return "not supported"

grade_generation_v_question
生成された答えが元の質問を解決するのに役立っているかどうかを決定します。

def grade_generation_v_question(state):
    """
    その世代が質問に対応しているかどうかを判断する。

    Args:
        state (dict): The current state of the agent, including all keys.

    Returns:
        str: Binary decision
    """

    print("---GRADE GENERATION vs QUESTION---")
    state_dict = state["keys"]
    question = state_dict["question"]
    documents = state_dict["documents"]
    generation = state_dict["generation"]

    # Data model
    class grade(BaseModel):
        """Binary score for relevance check."""

        binary_score: str = Field(description="Useful score 'yes' or 'no'")

    # LLM
    model = ChatOpenAI(temperature=0, model="gpt-4-0125-preview", streaming=True)

    # Tool
    grade_tool_oai = convert_to_openai_tool(grade)

    # LLM with tool and enforce invocation
    llm_with_tool = model.bind(
        tools=[convert_to_openai_tool(grade_tool_oai)],
        tool_choice={"type": "function", "function": {"name": "grade"}},
    )

    # Parser
    parser_tool = PydanticToolsParser(tools=[grade])

    # Prompt
    prompt = PromptTemplate(
        template="""You are a grader assessing whether an answer is useful to resolve a question. \n 
        Here is the answer:
        \n ------- \n
        {generation} 
        \n ------- \n
        Here is the question: {question}
        Give a binary score 'yes' or 'no' to indicate whether the answer is useful to resolve a question.""",
        input_variables=["generation", "question"],
    )

    # Prompt
    chain = prompt | llm_with_tool | parser_tool

    score = chain.invoke({"generation": generation, "question": question})
    grade = score[0].binary_score

    if grade == "yes":
        print("---DECISION: USEFUL---")
        return "useful"
    else:
        print("---DECISION: NOT USEFUL---")
        return "not useful"

グラフを作ってみる

このコードは、LangGraphを使用して特定のワークフローを構築し、実行するためのものです。ここでは、情報検索、文書の評価、質問への回答生成、質問の変換、最終評価といった一連のプロセスを定義し、これらのプロセスを経て、入力された質問に対する適切な回答を生成することを目指しています。

ワークフローの構築: StateGraph を使って、各種プロセス(ノード)を定義し、これらのノード間でのデータフロー(エッジ)を設定します。各ノードは特定の機能(例: retrieve, grade_documents など)を実行します。

ノードの追加: add_node メソッドを使って、各プロセスをワークフローに追加します。これらのプロセスは前述のように、文書の取得、文書の評価、回答の生成、質問の変換、最終評価などです。

エッジの設定: ノード間の接続(データフロー)を add_edge と add_conditional_edges メソッドを使って設定します。条件付きエッジは、特定の条件に基づいて次に実行するノードを決定するために使用されます。

ワークフローの実行: compile メソッドでワークフローをコンパイルした後、stream メソッドを使って入力データに対してワークフローを実行します。実行中、各ノードで得られた結果が次のノードへと渡され、最終的に質問に対する回答が生成されます。

import pprint

from langgraph.graph import END, StateGraph

workflow = StateGraph(GraphState)

# Define the nodes
workflow.add_node("retrieve", retrieve)  # retrieve
workflow.add_node("grade_documents", grade_documents)  # grade documents
workflow.add_node("generate", generate)  # generatae
workflow.add_node("transform_query", transform_query)  # transform_query
workflow.add_node("prepare_for_final_grade", prepare_for_final_grade)  # passthrough

# Build graph
workflow.set_entry_point("retrieve")
workflow.add_edge("retrieve", "grade_documents")
workflow.add_conditional_edges(
    "grade_documents",
    decide_to_generate,
    {
        "transform_query": "transform_query",
        "generate": "generate",
    },
)
workflow.add_edge("transform_query", "retrieve")
workflow.add_conditional_edges(
    "generate",
    grade_generation_v_documents,
    {
        "supported": "prepare_for_final_grade",
        "not supported": "generate",
    },
)
workflow.add_conditional_edges(
    "prepare_for_final_grade",
    grade_generation_v_question,
    {
        "useful": END,
        "not useful": "transform_query",
    },
)

# Compile
app = workflow.compile()
# Run
inputs = {"keys": {"question": "Explain how the different types of agent memory work?"}}
for output in app.stream(inputs):
    for key, value in output.items():
        # Node
        pprint.pprint(f"Node '{key}':")
        # Optional: print full state at each node
        # pprint.pprint(value["keys"], indent=2, width=80, depth=None)
    pprint.pprint("\n---\n")

# Final generation
pprint.pprint(value["keys"]["generation"])

Discussion