streamlitでLangGraphによる自己修正RAGを実装してみよう!
この記事は以下記事のRAGシステムをstreamlitで実行したものになります。
ワークフローの詳細はこちらで説明してますのでぜひこちらもチェックして下さい!
Streamlitとは?
Streamlitは、PythonでWebアプリケーションを素早く作成するためのオープンソースライブラリです。データ分析、可視化、機械学習モデルのデモなどに特に適しており、コーディングの専門知識がなくても使いやすいツールです。
前回記事のおさらい
ワークフローは以下のようになります。
Adaptive RAGは以下の機能を有しています。
- クエリ分析
- 取得したドキュメントの分析
- 回答のハルシネーションチェック
- 回答の有用性チェック
- クエリの調整
これをStreamlitで実装していきます!
Streamlitでの実行の肝
st.status
LangGraphの進捗管理は以下のコンポーネントを使用します。
st.statusは、主に長時間実行されるタスクやプロセスのステータスをユーザーにリアルタイムで伝えるために使用されます。with st.status(label="**GO!!**", expanded=True,state="running") as st.session_state.status:
st.session_state.placeholder = st.empty()
このコードで以下写真のようなステータスコンテナを表示できます。
expandedでステータスコンテナを折り畳み有無を指定できます。
stateには"running", "complete", "error"が選べてアイコンが変わります。
これを使用してLangGraphの進捗をリアルタイム表示します!
ステータスコンテナーを更新する際には以下のコードで更新できます。
st.session_state.status.update(label=f"**---DECISION: GENERATE---**", state="running", expanded=False)
後述しますが非同期関数の中でステータスコンテナをアップデートさせるので、st.session_stateにステータスコンテナを格納しています。
st.empty
ステータスコンテナにメッセージを表示するのはst.emptyを使います!
下で言う赤枠のことですね
st.session_state.placeholder = st.empty()
st.emptyは動的なデータや情報をリアルタイムで更新する必要がある場合に使用されます。
ステータスコンテナにはst.markdownなどでメッセージを表示すると、メッセージが残ったままになって見にくくなっちゃうんですね。
メッセージを更新する際には以下のようにします。
st.session_state.placeholder.markdown("---ROUTE QUESTION TO WEB SEARCH---")
こちらもst.session_stateに格納しています。
試行回数の管理
LangGraphではノード間でループ出来ますが、無限にループしてしまうと使用トークン量も天井突破してしまいます。
(*デフォルトでは20回までの制限となっています。変更可能です。)
しかし現実的には現在時点でのドキュメントのベクトル化の限界という点からあまり試行回数増やしたからと言ったといて正確無比な回答になるとは限らないので、このStreamlitアプリでは回答のやり直しをさせるのは1回までとの制限を設けています。あと先述のLangGraph設定の上限値に引っかかるとエラーになるので。
st.session_state.number_trialという変数を定義して各ノードで数値を調整してその値によって条件分岐させるということにしています。
非同期処理
LangGraphには非同期処理を導入しています。Streamlitで非同期処理を扱う場合にはasyncを使用します。
非同期処理を使用しない場合、Streamlitでは実行されるタスクや外部データの取得が完了するまでアプリがブロックされてしまいます。そのためRAGが回答生成するまでアプリが固まりっぱなしになるなんでことになってしまいます。
非同期処理を導入することによりユーザーインターフェースの応答性が向上し、長時間実行されるタスクでもアプリケーションがフリーズしません!次に、複数の処理を並行して実行でき、リアルタイムでの進捗状況の更新が可能になるのでユーザーエクスペリエンスが向上します。
中間ステップの表示
前回の出力結果ですが、以下のようになっていました。
出力
---ROUTE QUESTION---
---ROUTE QUESTION TO RAG---
---RETRIEVE---
"Node 'retrieve':"
'\n---\n'
---CHECK DOCUMENT RELEVANCE TO QUESTION---
---GRADE: DOCUMENT RELEVANT---
---GRADE: DOCUMENT RELEVANT---
---GRADE: DOCUMENT RELEVANT---
---GRADE: DOCUMENT NOT RELEVANT---
---ASSESS GRADED DOCUMENTS---
---DECISION: GENERATE---
"Node 'grade_documents':"
'\n---\n'
---GENERATE---
---CHECK HALLUCINATIONS---
---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---
---GRADE GENERATION vs QUESTION---
---DECISION: GENERATION DOES NOT ADDRESS QUESTION---
"Node 'generate':"
'\n---\n'
---TRANSFORM QUERY---
"Node 'transform_query':"
'\n---\n'
---RETRIEVE---
"Node 'retrieve':"
'\n---\n'
---CHECK DOCUMENT RELEVANCE TO QUESTION---
---GRADE: DOCUMENT RELEVANT---
---GRADE: DOCUMENT NOT RELEVANT---
---GRADE: DOCUMENT NOT RELEVANT---
---GRADE: DOCUMENT RELEVANT---
---ASSESS GRADED DOCUMENTS---
---DECISION: GENERATE---
"Node 'grade_documents':"
'\n---\n'
---GENERATE---
---CHECK HALLUCINATIONS---
---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---
---GRADE GENERATION vs QUESTION---
---DECISION: GENERATION ADDRESSES QUESTION---
"Node 'generate':"
'\n---\n'
('エージェントメモリには、以下のような種類が存在します。\n'
'\n'
'1. **感覚記憶 (Sensory Memory)**:\n'
' - '
'これは、視覚や聴覚などの感覚情報の印象を保持する最初の段階の記憶です。例えば、目の前で何かが一瞬見えた後、そのイメージを数秒間保持することができます。\n'
'\n'
'2. **短期記憶 (Short-Term Memory) または 作業記憶 (Working Memory)**:\n'
' - '
'現在意識している情報を保持し、学習や推論などの複雑な認知タスクを実行するために必要な記憶です。例えば、電話番号を一時的に覚えておくことが短期記憶の一例です。\n'
'\n'
'3. **長期記憶 (Long-Term Memory)**:\n'
' - 情報を長期間保存することができ、数日から数十年にわたって保持される記憶です。長期記憶には以下の2つのサブタイプがあります:\n'
' - **明示的記憶 (Explicit / Declarative Memory)**: '
'事実や出来事を意識的に思い出すことができる記憶。例えば、特定の歴史的出来事や自分の誕生日を思い出すことが含まれます。\n'
' - **暗黙的記憶 (Implicit / Procedural Memory)**: '
'無意識的に行われるスキルやルーチンに関する記憶。例えば、自転車に乗ることやタイピングのスキルがこれに該当します。\n'
'\n'
'これらのメモリの種類は、エージェントが過去の経験を基に行動を決定し、他のエージェントと相互作用するために重要な役割を果たします。')
回答が提示された後もこのようなログが見れると便利ですよね、しかしこれがそのまま表示されると見づらくもなるので以下コンポーネントに格納しておきます。
状態が遷移するたびにログを追加していきます。
with st.popover("ログ"):
st.markdown(st.session_state.log)
これでカーソルを合わせた時にログ確認できるようになるのでユーザー目線からも回答のデバッグが容易になります。
Streamlitコード
必要なライブラリのインストール
pip install -U langchain_community tiktoken langchain-openai langchainhub chromadb langchain langgraph tavily-python streamlit
コード
全体のコードは以下になります。少し長くなりますが
基本的には前回記事で紹介したコードに前述の肝を盛り込んだものになります。
import streamlit as st
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.pydantic_v1 import BaseModel, Field
from langgraph.graph import StateGraph
from typing import List, Annotated, Literal, Sequence, TypedDict
from langgraph.graph import END, StateGraph, START
import asyncio
import os
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.vectorstores import Chroma
from langchain_openai import OpenAIEmbeddings
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain.schema import Document
class RouteQuery(BaseModel):
"""ユーザーのクエリを最も関連性の高いデータソースにルーティングします。"""
datasource: Literal["vectorstore", "web_search"] = Field(
...,
description="ユーザーの質問に応じて、ウェブ検索またはベクターストアにルーティングします。",
)
class GradeDocuments(BaseModel):
"""取得された文書の関連性チェックのためのバイナリスコア。"""
binary_score: str = Field(
description="文書が質問に関連しているかどうか、「yes」または「no」"
)
class GradeHallucinations(BaseModel):
"""生成された回答における幻覚の有無を示すバイナリスコア。"""
binary_score: str = Field(
description="回答が事実に基づいているかどうか、「yes」または「no」"
)
class GradeAnswer(BaseModel):
"""回答が質問に対処しているかどうかを評価するバイナリスコア。"""
binary_score: str = Field(
description="回答が質問に対処しているかどうか、「yes」または「no」"
)
class GraphState(TypedDict):
"""
グラフの状態を表します。
属性:
question: 質問
generation: LLM生成
documents: 文書のリスト
"""
question: str
generation: str
documents: List[str]
async def route_question(state):
st.session_state.status.update(label=f"**---ROUTE QUESTION---**", state="running", expanded=True)
st.session_state.log += "---ROUTE QUESTION---" + "\n\n"
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
structured_llm_router = llm.with_structured_output(RouteQuery)
system = """あなたはユーザーの質問をベクターストアまたはウェブ検索にルーティングする専門家です。
ベクターストアにはエージェント、プロンプトエンジニアリング、アドバーサリアルアタックに関連する文書が含まれています。
これらのトピックに関する質問にはベクターストアを使用し、それ以外の場合はウェブ検索を使用してください。"""
route_prompt = ChatPromptTemplate.from_messages(
[
("system", system),
("human", "{question}"),
]
)
question_router = route_prompt | structured_llm_router
question = state["question"]
source = question_router.invoke({"question": question})
if source.datasource == "web_search":
st.session_state.log += "---ROUTE QUESTION TO WEB SEARCH---" + "\n\n"
st.session_state.placeholder.markdown("---ROUTE QUESTION TO WEB SEARCH---")
return "web_search"
elif source.datasource == "vectorstore":
st.session_state.placeholder.markdown("ROUTE QUESTION TO RAG")
st.session_state.log += "ROUTE QUESTION TO RAG" + "\n\n"
return "vectorstore"
async def retrieve(state):
st.session_state.status.update(label=f"**---RETRIEVE---**", state="running", expanded=True)
st.session_state.placeholder.markdown(f"RETRIEVING…\n\nKEY WORD:{state['question']}")
st.session_state.log += f"RETRIEVING…\n\nKEY WORD:{state['question']}" + "\n\n"
embd = OpenAIEmbeddings()
urls = [
"https://lilianweng.github.io/posts/2023-06-23-agent/",
"https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/",
"https://lilianweng.github.io/posts/2023-10-25-adv-attack-llm/",
]
docs = [WebBaseLoader(url).load() for url in urls]
docs_list = [item for sublist in docs for item in sublist]
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
chunk_size=500, chunk_overlap=0
)
doc_splits = text_splitter.split_documents(docs_list)
vectorstore = Chroma.from_documents(
documents=doc_splits,
collection_name="rag-chroma",
embedding=embd,
)
retriever = vectorstore.as_retriever()
question = state["question"]
documents = retriever.invoke(question)
st.session_state.placeholder.markdown("RETRIEVE SUCCESS!!")
return {"documents": documents, "question": question}
async def web_search(state):
st.session_state.status.update(label=f"**---WEB SEARCH---**", state="running", expanded=True)
st.session_state.placeholder.markdown(f"WEB SEARCH…\n\nKEY WORD:{state['question']}")
st.session_state.log += f"WEB SEARCH…\n\nKEY WORD:{state['question']}" + "\n\n"
question = state["question"]
web_search_tool = TavilySearchResults(k=3)
docs = web_search_tool.invoke({"query": question})
web_results = "\n".join([d["content"] for d in docs])
web_results = Document(page_content=web_results)
return {"documents": web_results, "question": question}
async def grade_documents(state):
st.session_state.number_trial += 1
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
structured_llm_grader = llm.with_structured_output(GradeDocuments)
system = """あなたは、ユーザーの質問に対して取得されたドキュメントの関連性を評価する採点者です。
ドキュメントにユーザーの質問に関連するキーワードや意味が含まれている場合、それを関連性があると評価してください。
目的は明らかに誤った取得を排除することです。厳密なテストである必要はありません。
ドキュメントが質問に関連しているかどうかを示すために、バイナリスコア「yes」または「no」を与えてください。"""
grade_prompt = ChatPromptTemplate.from_messages(
[
("system", system),
("human", "Retrieved document: \n\n {document} \n\n User question: {question}"),
]
)
retrieval_grader = grade_prompt | structured_llm_grader
st.session_state.status.update(label=f"**---CHECK DOCUMENT RELEVANCE TO QUESTION---**", state="running", expanded=False)
st.session_state.log += "**---CHECK DOCUMENT RELEVANCE TO QUESTION---**" + "\n\n"
question = state["question"]
documents = state["documents"]
filtered_docs = []
i = 0
for d in documents:
if st.session_state.number_trial <= 2:
file_name = d.metadata["source"]
file_name = os.path.basename(file_name.replace("\\","/"))
i += 1
score = retrieval_grader.invoke(
{"question": question, "document": d.page_content}
)
grade = score.binary_score
if grade == "yes":
st.session_state.status.update(label=f"**---GRADE: DOCUMENT RELEVANT---**", state="running", expanded=True)
st.session_state.placeholder.markdown(f"DOC {i}/{len(documents)} : **RELEVANT**\n\n")
st.session_state.log += "---GRADE: DOCUMENT RELEVANT---" + "\n\n"
st.session_state.log += f"doc {i}/{len(documents)} : RELEVANT\n\n"
filtered_docs.append(d)
else:
st.session_state.status.update(label=f"**---GRADE: DOCUMENT NOT RELEVANT---**", state="error", expanded=True)
st.session_state.placeholder.markdown(f"DOC {i}/{len(documents)} : **NOT RELEVANT**\n\n")
st.session_state.log += "---GRADE: DOCUMENT NOT RELEVANT---" + "\n\n"
st.session_state.log += f"DOC {i}/{len(documents)} : NOT RELEVANT\n\n"
else:
filtered_docs.append(d)
if not st.session_state.number_trial <= 2:
st.session_state.status.update(label=f"**---NO NEED TO CHECK---**", state="running", expanded=True)
st.session_state.placeholder.markdown("QUERY TRANSFORMATION HAS BEEN COMPLETED")
st.session_state.log += "QUERY TRANSFORMATION HAS BEEN COMPLETED" + "\n\n"
return {"documents": filtered_docs, "question": question}
async def generate(state):
st.session_state.status.update(label=f"**---GENERATE---**", state="running", expanded=False)
st.session_state.log += "---GENERATE---" + "\n\n"
prompt = ChatPromptTemplate.from_messages(
[
("system", """ユーザーから与えられたコンテキストを参考に質問に対し答えて下さい。"""),
("human", """Question: {question}
Context: {context}"""),
]
)
llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
rag_chain = prompt | llm | StrOutputParser()
question = state["question"]
documents = state["documents"]
generation = rag_chain.invoke({"context": documents, "question": question})
return {"documents": documents, "question": question, "generation": generation}
async def transform_query(state):
st.session_state.status.update(label=f"**---TRANSFORM QUERY---**", state="running", expanded=True)
st.session_state.placeholder.empty()
st.session_state.log += "---TRANSFORM QUERY---" + "\n\n"
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
system = """あなたは、入力された質問をベクトルストア検索に最適化されたより良いバージョンに変換する質問リライターです。
質問を見て、質問者の意図/意味について推論してより良いベクトル検索の為の質問を作成してください。"""
re_write_prompt = ChatPromptTemplate.from_messages(
[
("system", system),
(
"human",
"Here is the initial question: \n\n {question} \n Formulate an improved question.",
),
]
)
question_rewriter = re_write_prompt | llm | StrOutputParser()
question = state["question"]
documents = state["documents"]
better_question = question_rewriter.invoke({"question": question})
st.session_state.log += f"better_question : {better_question}\n\n"
st.session_state.placeholder.markdown(f"better_question : {better_question}")
return {"documents": documents, "question": better_question}
async def decide_to_generate(state):
filtered_documents = state["documents"]
if not filtered_documents:
st.session_state.status.update(label=f"**---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, TRANSFORM QUERY---**", state="error", expanded=False)
st.session_state.log += "---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, TRANSFORM QUERY---" + "\n\n"
return "transform_query"
else:
st.session_state.status.update(label=f"**---DECISION: GENERATE---**", state="running", expanded=False)
st.session_state.log += "---DECISION: GENERATE---" + "\n\n"
return "generate"
async def grade_generation_v_documents_and_question(state):
st.session_state.number_trial += 1
st.session_state.status.update(label=f"**---CHECK HALLUCINATIONS---**", state="running", expanded=False)
st.session_state.log += "---CHECK HALLUCINATIONS---" + "\n\n"
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
structured_llm_grader = llm.with_structured_output(GradeHallucinations)
system = """あなたは、LLMの生成が取得された事実のセットに基づいているか/サポートされているかを評価する採点者です。
バイナリスコア「yes」または「no」を与えてください。「yes」は、回答が事実のセットに基づいている/サポートされていることを意味します。"""
hallucination_prompt = ChatPromptTemplate.from_messages(
[
("system", system),
("human", "Set of facts: \n\n {documents} \n\n LLM generation: {generation}"),
]
)
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
structured_llm_grader = llm.with_structured_output(GradeAnswer)
system = """あなたは、回答が質問に対処しているか/解決しているかを評価する採点者です。
バイナリスコア「yes」または「no」を与えてください。「yes」は、回答が質問を解決していることを意味します。"""
answer_prompt = ChatPromptTemplate.from_messages(
[
("system", system),
("human", "User question: \n\n {question} \n\n LLM generation: {generation}"),
]
)
answer_grader = answer_prompt | structured_llm_grader
hallucination_grader = hallucination_prompt | structured_llm_grader
question = state["question"]
documents = state["documents"]
generation = state["generation"]
score = hallucination_grader.invoke(
{"documents": documents, "generation": generation}
)
grade = score.binary_score
if st.session_state.number_trial <= 3:
if grade == "yes":
st.session_state.placeholder.markdown("DECISION: ANSWER IS BASED ON A SET OF FACTS")
st.session_state.log += "---DECISION: ANSWER IS BASED ON A SET OF FACTS---" + "\n\n"
st.session_state.log += "---GRADE GENERATION vs QUESTION---" + "\n\n"
score = answer_grader.invoke({"question": question, "generation": generation})
grade = score.binary_score
st.session_state.status.update(label=f"**---GRADE GENERATION vs QUESTION---**", state="running", expanded=True)
if grade == "yes":
st.session_state.status.update(label=f"**---DECISION: GENERATION ADDRESSES QUESTION---**", state="running", expanded=True)
with st.session_state.placeholder:
st.markdown("**USEFUL!!**")
st.markdown(f"question : {question}")
st.markdown(f"generation : {generation}")
st.session_state.log += "---DECISION: GENERATION ADDRESSES QUESTION---" + "\n\n"
st.session_state.log += f"USEFUL!!\n\n"
st.session_state.log += f"question:{question}\n\n"
st.session_state.log += f"generation:{generation}\n\n"
return "useful"
else:
st.session_state.number_trial -= 1
st.session_state.status.update(label=f"**---DECISION: GENERATION DOES NOT ADDRESS QUESTION---**", state="error", expanded=True)
with st.session_state.placeholder:
st.markdown("**NOT USEFUL**")
st.markdown(f"question:{question}")
st.markdown(f"generation:{generation}")
st.session_state.log += "---DECISION: GENERATION DOES NOT ADDRESS QUESTION---" + "\n\n"
st.session_state.log += f"NOT USEFUL\n\n"
st.session_state.log += f"question:{question}\n\n"
st.session_state.log += f"generation:{generation}\n\n"
return "not useful"
else:
st.session_state.status.update(label=f"**---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---**", state="error", expanded=True)
with st.session_state.placeholder:
st.markdown("not grounded")
st.markdown(f"question:{question}")
st.markdown(f"generation:{generation}")
st.session_state.log += "---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---" + "\n\n"
st.session_state.log += f"not grounded\n\n"
st.session_state.log += f"question:{question}\n\n"
st.session_state.log += f"generation:{generation}\n\n"
return "not supported"
else:
st.session_state.status.update(label=f"**---NO NEED TO CHECK---**", state="running", expanded=True)
st.session_state.placeholder.markdown("TRIAL LIMIT EXCEEDED")
st.session_state.log += "---NO NEED TO CHECK---" + "\n\n"
st.session_state.log += "TRIAL LIMIT EXCEEDED" + "\n\n"
return "useful"
async def run_workflow(inputs):
st.session_state.number_trial = 0
with st.status(label="**GO!!**", expanded=True,state="running") as st.session_state.status:
st.session_state.placeholder = st.empty()
value = await st.session_state.workflow.ainvoke(inputs)
st.session_state.placeholder.empty()
st.session_state.message_placeholder = st.empty()
st.session_state.status.update(label="**FINISH!!**", state="complete", expanded=False)
st.session_state.message_placeholder.markdown(value["generation"])
with st.popover("ログ"):
st.markdown(st.session_state.log)
def st_rag_langgraph():
if 'log' not in st.session_state:
st.session_state.log = ""
if 'status_container' not in st.session_state:
st.session_state.status_container = st.empty()
if not hasattr(st.session_state, "workflow"):
workflow = StateGraph(GraphState)
workflow.add_node("web_search", web_search)
workflow.add_node("retrieve", retrieve)
workflow.add_node("grade_documents", grade_documents)
workflow.add_node("generate", generate)
workflow.add_node("transform_query", transform_query)
workflow.add_conditional_edges(
START,
route_question,
{
"vectorstore": "retrieve",
"web_search": "web_search",
},
)
workflow.add_edge("web_search", "generate")
workflow.add_edge("retrieve", "grade_documents")
workflow.add_conditional_edges(
"grade_documents",
decide_to_generate,
{
"transform_query": "transform_query",
"generate": "generate",
},
)
workflow.add_edge("transform_query", "retrieve")
workflow.add_conditional_edges(
"generate",
grade_generation_v_documents_and_question,
{
"not supported": "generate",
"useful": END,
"not useful": "transform_query",
},
)
app = workflow.compile()
app = app.with_config(recursion_limit=10,run_name="Agent",tags=["Agent"])
app.name = "Agent"
st.session_state.workflow = app
st.title("Adaptive RAG by LangGraph")
if prompt := st.chat_input("質問を入力してください"):
st.session_state.log = ""
with st.chat_message("user", avatar="😊"):
st.markdown(prompt)
inputs = {"question": prompt}
asyncio.run(run_workflow(inputs))
if __name__ == "__main__":
st_rag_langgraph()
実行する際は以下をターミナルで実行して下さい!
streamlit run XXXXX.py
必要な環境変数については前回記事と同様のものになります。
最後に
いかがでしたでしょうか?LangGraphによるRAGをStreamlitで実行する方法を紹介しました!
Adaptive RAGは実用的なテクニックですが実行に時間がかかるのも特徴です。そんな時に動的に実行h状態を通知する仕組みを構築しました。これによりユーザーエクスペリエンスの向上につながります!
この記事が参考になれば幸いです。それでは、次回の記事でお会いしましょう!
Discussion