LangChainのLCEL (LangChain Expression Language)を使って、RAGエージェントを作ってみる。
LangChain Expression Language (LCEL) 公式ドキュメント
このスクラップを書く動機
LCELでRAGを実装してみた系の記事は沢山あるが、以前のAgentのようにLLMがRAGを使用するかどうかを判断するAgent型のRAGを作ってる記事がなかった。
LCELでRAGを実装してみた系の記事
公式Cookbook
# Prompt template
template = """Answer the question based only on the following context, which can include text and tables:
{context}
Question: {question}
"""
prompt = ChatPromptTemplate.from_template(template)
# RAG pipeline
chain = (
{"context": retriever, "question": RunnablePassthrough()}
| prompt
| print_content
| llm
| StrOutputParser()
)
chain.invoke(
"富士山の高さはいくら?"
)
[Document(page_content='~~~')]
'提供された文書には、富士山の高さに関する情報は含まれていません。'
公式でもAgentのような実装方法ではないので、Retrieverと関係ない情報の場合はRetieverを使用しないという選択ができない。
作りたいRAG Agentのフロー
回答に使用したmetadataをクライアントへ返却したい。例えば、metadataの中に商品urlを仕込んでおいて、参照したドキュメントの商品のイメージurlを返却するということがやりたい。また、Streamingを使用して、クライアントに素早く返却したい、
使用したドキュメントの詳細を取得することは以前のAgentで行うことができない(多分)。LLMのinputにmetadataを入れて、Agentのコールバックで無理やり取得するという荒技も可能だが、LLMに無駄なinputをしてしまう。なのでLCELを使用する。
LangChain Expression Language (LCEL)ってなんなの?
公式ドキュメント
記述方法が特徴的で、| セパレーターでチェーンを繋いでいく
prompt = ChatPromptTemplate.from_template(
"""
{japanese}を英語で言うと?
English:
"""
)
llm = ChatOpenAI(
model = "gpt-4o"
)
chain = prompt | llm | StrOutputParser() # 👈
chain.invoke({"japanese":"リンゴ"})
非同期サポート、並列実行、再試行とフォールバックをメインの機能として挙げている。今までのLangchainの記述方法では、中身がブラックボックスのものが多く、さらにカスタマイズするには向いていなかった(今回の例など)。
並列実行もサポートしているので、ベクトル検索とキーワード検索を並列で実行して実行時間を短縮するといったこともやりやすくなりそう。
本題
LCELは一本道の処理は得意。
なので、どうやって「LLMがドキュメントが必要な回答かどうかを判断」した後に、分岐させるか?が大変になりそう
RunnableBranchという処理によって分岐させられる機能があるらしい
from langchain_core.runnables import RunnableBranch
branch = RunnableBranch(
(lambda x: "anthropic" in x["topic"].lower(), anthropic_chain),
(lambda x: "langchain" in x["topic"].lower(), langchain_chain),
general_chain,
)
full_chain = {"topic": chain, "question": lambda x: x["question"]} | branch
full_chain.invoke({"question": "how do I use Anthropic?"})
公式の例では、question内にAnthropicという文字列があるかどうかでanthropic_chainかlangchain_chainを使うかどうか分岐させている。
カスタム関数を使っても良い
def route(info):
if "anthropic" in info["topic"].lower():
return anthropic_chain
elif "langchain" in info["topic"].lower():
return langchain_chain
else:
return general_chain
from langchain_core.runnables import RunnableLambda
full_chain = {"topic": chain, "question": lambda x: x["question"]} | RunnableLambda(
route
)
RAGを作るための下準備
databricks-qa-jaというデータセットで事前にベクトルデータベースは作成済み。
import os
from dotenv import load_dotenv
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_core.prompts import ChatPromptTemplate
from langchain_community.vectorstores import FAISS
from langchain.tools.retriever import create_retriever_tool
load_dotenv()
# LLMの作成
llm = ChatOpenAI(api_key=os.getenv("OPENAI_API_KEY"), model_name=os.getenv("OPENAI_MODEL_NAME", "gpt-4o"), temperature=0.9, streaming=True)
# ベクトルストアの読み込み
top_k = 5
embeddings = OpenAIEmbeddings(api_key=os.getenv("OPENAI_API_KEY"), model="text-embedding-3-large")
vector_store = FAISS.load_local("./test/vectorstore/databricks-qa", embeddings=embeddings, allow_dangerous_deserialization=True)
retriever = vector_store.as_retriever(search_kwargs={'k': top_k})
# Toolの作成
description = "This tool retrieves information on AI and data science"
retrieval_tool = create_retriever_tool(vector_store.as_retriever(search_kwargs={'k': top_k}), name="information-retrieval-tool", description=description)
tools = [retrieval_tool]
llm_with_tools = llm.bind_tools(tools)
まずは、RunnableBranchを試してみる
from langchain_core.runnables import RunnableBranch
from langchain_core.output_parsers import StrOutputParser
prompt1 = ChatPromptTemplate.from_template("Explain: {topic}")
prompt2 = ChatPromptTemplate.from_template("What kind of {topic}")
prompt3 = ChatPromptTemplate.from_template("What does this sentence explain?\n{explanation}")
llm = ChatOpenAI(model='gpt-3.5-turbo')
chain = RunnableBranch(
(
lambda x: " " not in x['topic'],
prompt1,
),
(
lambda x: " " in x['topic'],
{'explanation': lambda x: x['topic']}
|prompt3
),
prompt2 # Default
) | llm | StrOutputParser()
chain.invoke({"topic": "python is most popular programming language in the world"})
topicに空白が含まれていなかったら(単語だったら)prompt1を使用して、文章だったらprompt3を使用するデモ。上記の二つのどちらにも含まれなかったら、prompt2が使用される。
本題
from langchain_core.documents import Document
from langchain.output_parsers.openai_tools import JsonOutputKeyToolsParser
from langchain_core.messages import BaseMessage, SystemMessage, AIMessage, HumanMessage, AIMessageChunk
from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, MessagesPlaceholder, PromptTemplate
from langchain_core.runnables import RunnableBranch, RunnableLambda, RunnablePassthrough, RunnableParallel, RunnableGenerator
def is_tool_call_args(input):
if input.tool_calls:
print(input.tool_calls[0]["name"])
return input
def is_content(input):
if input.content:
print("not using tool")
return input
def document_format(docs: list[Document]) -> str:
return "\n".join([f"{i+1}. {doc.page_content}" for i, doc in enumerate(docs)])
def print_content(x):
print(x)
return x
prompt1: ChatPromptTemplate = ChatPromptTemplate.from_messages([
# SystemMessage(os.getenv("PROMPT_SYSTEM_MESSAGE", "You are a helpful assistant.")),
SystemMessage("You are a helpful assistant. Answer the following questions with 1 or 2 sentences."),
MessagesPlaceholder("input"),
])
prompt2 = ChatPromptTemplate.from_messages([
SystemMessage("You are a helpful assistant. Answer the following questions with 1 or 2 sentences."),
MessagesPlaceholder("document"),
MessagesPlaceholder("input"),
])
query = ''
def save_query(x):
global query
query = x["input"]
print(f"Saved query: {query}")
return x
branch = RunnableBranch(
(
is_tool_call_args,
JsonOutputKeyToolsParser(key_name="information-retrieval-tool")
| (lambda x: x[0]['query'])
| retriever
| print_content
| document_format
| print_content
| (lambda x: {"document": [x], "input": query})
| prompt2
| llm
| StrOutputParser()
),
is_content | StrOutputParser()
)
chain = (save_query | prompt1 | llm_with_tools | branch )
chain.invoke({"input": [{"role": "user", "content": "セマンティック検索とは何?"}]}):
is_tool_call_args
でLLMの返答がtoolを使用するかどうか判断する。
JsonOutputKeyToolsParser
はLLMのレスポンスを扱いやすい形に変換してくれる→[{'query': 'セマンティック検索とは何'}]
print_content
は、途中経過を確認しするために使用している。
document_format
はLLMにインプットするためのコンテンツのみを抜き出している。
prompt2
では、クエリとRAGしたドキュンメントをプロンプトに渡している。
llm
, StrOutputParser
で最終出力、といった流れ。
query
はprompt2
で使用するため変数として格納することにした。全てchain内で完結させることもできるが面倒なのでなし。
全く関係ない質問にはRAGを使用しないことが確認できる。
chain.invoke({"input": [{"role": "user", "content": "富士山の高さは?"}]}):
not using tool
'富士山の高さは3,776メートルです。'
この時点でRetrieverを使用する場合は、streamingが実装できている。
for chunk in chain.stream({"input": [{"role": "user", "content": "セマンティック検索とは何?"}]}):
print(chunk, end="|", flush=True)
Saved query: [{'role': 'user', 'content': 'セマンティック検索とは何?'}]
information-retrieval-tool
[Document(page_content='マイクロサービスは検索フレーズのようなリクエストを受け取り、レスポンスを返す軽量アプリケーションです。モデルとマイクロサービス内で検索するエンベディングをパッケージすることで、提供する検索機能を多くのアプリケーションからアクセスできるようにするだけではなく、多くのマイクロサービスインフラストラクチャソリューションは弾力性のあるスケーラビリティを提供しているので、需要の増減に追従できるようにサービスにリソースを割り当てることができます。', metadata={'context': '', 'source': 'https://qiita.com/taka_yayoi/items/d52518874318343aab04', 'instruction': 'マイクロサービスとはなんでしょうか?', 'category': 'closed_qa', 'chunk_no': 658}), Document(page_content='製品カタログの検索にLLMを使えば、商品説明や文章、音声の記録などに目を通し、ユーザーの検索に応えて、その内容に関連するものを提案するよう、モデルに課すことができます。ユーザーは、探しているものを見つけるために正確な用語を必要とせず、LLMがニーズに合わせて方向付けることができる一般的な説明だけでよいのです。その結果、ユーザーがサイトを利用する際に、まるでパーソナライズされた専門家のガイダンスを受けたかのような感覚に陥る、パワフルな新しい体験が得られました。', metadata={'context': '', 'source': 'https://www.databricks.com/jp/blog/enhancing-product-search-large-language-models-llms.html', 'instruction': '製品カタログの検索にLLMを活用したソリューションはどんなもの?', 'category': 'closed_qa', 'chunk_no': 722}), Document(page_content='特定の文書群を効果的に検索するためには、その文書に特化して学習させる必要すらありません。', metadata={'context': '', 'source': 'https://www.databricks.com/jp/blog/enhancing-product-search-large-language-models-llms.html', 'instruction': 'LLMを活用して特定の文書群を効果的に検索するために何か必要ですか?', 'category': 'closed_qa', 'chunk_no': 723}), Document(page_content='SparkElasticsearchとは、ドキュメント指向および半構造化データを格納、取得、管理するNoSQL分散データベースです。GitHubオープンソースであるElasticsearchは、ApacheLuceneをベースに構築され、Apacheライセンスの条件下でリリースされたRESTfulな検索エンジンでもあります。', metadata={'context': '', 'source': 'https://www.databricks.com/jp/glossary', 'instruction': 'Spark Elasticsearch とは?', 'category': 'closed_qa', 'chunk_no': 877}), Document(page_content='モデルを検索する際には、少なくとも読み取り権限を持っているモデルのみが返却されます。', metadata={'context': '', 'source': 'https://qiita.com/taka_yayoi/items/e1688ff127a22fcd76ff', 'instruction': 'モデルを検索する際にどんな権限が必要ですか?', 'category': 'closed_qa', 'chunk_no': 580})]
1. マイクロサービスは検索フレーズのようなリクエストを受け取り、レスポンスを返す軽量アプリケーションです。モデルとマイクロサービス内で検索するエンベディングをパッケージすることで、提供する検索機能を多くのアプリケーションからアクセスできるようにするだけではなく、多くのマイクロサービスインフラストラクチャソリューションは弾力性のあるスケーラビリティを提供しているので、需要の増減に追従できるようにサービスにリソースを割り当てることができます。
2. 製品カタログの検索にLLMを使えば、商品説明や文章、音声の記録などに目を通し、ユーザーの検索に応えて、その内容に関連するものを提案するよう、モデルに課すことができます。ユーザーは、探しているものを見つけるために正確な用語を必要とせず、LLMがニーズに合わせて方向付けることができる一般的な説明だけでよいのです。その結果、ユーザーがサイトを利用する際に、まるでパーソナライズされた専門家のガイダンスを受けたかのような感覚に陥る、パワフルな新しい体験が得られました。
3. 特定の文書群を効果的に検索するためには、その文書に特化して学習させる必要すらありません。
4. SparkElasticsearchとは、ドキュメント指向および半構造化データを格納、取得、管理するNoSQL分散データベースです。GitHubオープンソースであるElasticsearchは、ApacheLuceneをベースに構築され、Apacheライセンスの条件下でリリースされたRESTfulな検索エンジンでもあります。
5. モデルを検索する際には、少なくとも読み取り権限を持っているモデルのみが返却されます。
|セ|マン|ティ|ック|検索|とは|、|ユー|ザー|の|意|図|や|検索|の|文|脈|を|理解|し|、|関連|性|の|高|い|結果|を|提供|する|検索|技|術|です|。|単|純|な|キ|ーワード|マ|ッチ|ング|では|なく|、|意味|や|関|係|性|を|考|慮|に|入|れて|情報|を|検索|します|。||
しかし、Retrieverを使用しない場合、streamingで結果を返さない。
これは、RunnableBranchがそもそもstreaminigに対応していないのでは...?
def iterator(x):
str_list = ["あいうえお", "かきくけこ", "さしすせそ"]
for i in str_list:
yield i
def streaming(input_stream):
print("streaming")
for input in input_stream:
yield input
branch = RunnableBranch(
(
streaming,
(lambda x: x)
),
(lambda x: x)
)
chain = (iterator | branch )
for chunk in chain.stream({}):
print(chunk, end="|", flush=True)
streaming
あいうえおかきくけこさしすせそ|
Strem対応していれば、ちゃんと出力される
chain = (iterator | RunnableGenerator(streaming) | StrOutputParser() )
for chunk in chain.stream({}):
print(chunk, end="|", flush=True)
streaming|あいうえお|streaming|かきくけこ|streaming|さしすせそ|
なんとか実装してみた。下に解説が続きます。
# デバッグ用のprint関数
def print_content(x):
print(x)
return x
# LLMに返却して欲しいフォーマットを定義
class Response(BaseModel):
answer: str = Field(description="Natural language answer")
predictions: list[str] = Field(description="The next three questions, each of no more than 20 characters")
docs: Optional[list[int]] = Field(description="Return a list of the document numbers that were referenced. If no documents were referenced, return nothing. It is possible to specify multiple numbers.")
# フォーマットを使用してパーサーを定義
output_parser = JsonOutputParser(pydantic_object=Response, diff=True)
prompt1: ChatPromptTemplate = ChatPromptTemplate.from_messages([
SystemMessage(f"You are a helpful assistant. Answer the following questions with 1 or 2 sentences. {output_parser.get_format_instructions()}"),
MessagesPlaceholder("input"),
])
prompt2 = ChatPromptTemplate.from_messages([
SystemMessage(f"You are a helpful assistant. Answer the following questions with 1 or 2 sentences. {output_parser.get_format_instructions()}"),
HumanMessagePromptTemplate.from_template("# Document: {document}"),
MessagesPlaceholder("input"),
])
class QueryManager:
def __init__(self):
self.query = ''
def save_query(self, x: Dict[str, Any]) -> Dict[str, Any]:
self.query = x["input"]
print(f"Saved query: {self.query}")
return x
def get_query(self):
return self.query
class DocumentManager:
def __init__(self):
self.documents = []
self.formatted_docs = ""
def save_documents(self, docs: List[Document]) -> List[Document]:
self.documents = docs
self.formatted_docs = self.format_documents(docs)
print(f"Saved {len(docs)} documents")
return docs
def format_documents(self, docs: List[Document]) -> List[str]:
return [f"{i+1}. {doc.page_content}" for i, doc in enumerate(docs)]
def get_documents(self) -> List[Document]:
return self.documents
def get_formatted_docs(self) -> str:
return self.formatted_docs
# LLMのレスポンスがtool callか判断する
def is_tool_calls(input: AIMessageChunk | AIMessage) -> bool:
if isinstance(input, AIMessageChunk):
return bool(input.tool_call_chunks)
elif isinstance(input, AIMessage):
return bool(input.tool_calls)
elif input.response_metadata and input.response_metadata.get("finish_reason") == "tool_calls":
return True
# tool callの場合、レスポンスからtool callに必要なクエリを取得。JsonOutputToolsParserでもいいかも
def extract_args(input: AIMessageChunk | AIMessage) -> str:
if isinstance(input, AIMessageChunk) and input.tool_call_chunks:
return input.tool_call_chunks[0]["args"]
elif isinstance(input, AIMessage) and input.tool_calls:
return json.dumps(input.tool_calls[0]["args"], ensure_ascii=False)
return None
query_manager = QueryManager()
document_manager = DocumentManager()
# 分岐後のチェーン
inner_branch = (
retriever
| RunnableLambda(lambda x: document_manager.save_documents(x))
| print_content
| (lambda x: {"document": [document_manager.get_formatted_docs()], "input": query_manager.get_query()})
| prompt2
| llm
)
# LLMの返答がtool callか判断するチェーン
def branch(input_stream: Iterable[AIMessageChunk | AIMessage]):
args_list = []
for input in input_stream:
# print(input)
if is_tool_calls(input):
args_str = extract_args(input)
if args_str:
args_list.append(args_str)
else:
yield input
if args_list:
combined_args = "".join(args_list)
print(f"Combined args: {combined_args}")
try:
parsed_args = json.loads(combined_args)
for result in inner_branch.stream(parsed_args['query']):
yield result
except json.JSONDecodeError:
print("Error: Invalid JSON in combined args")
except KeyError:
print("Error: 'query' key not found in parsed args")
chain = (query_manager.save_query | prompt1 | llm_with_tools | branch | output_parser)
少し解説
yieldで返すジェネレータ関数を作成することで、streamingに対応できる。RunnableBranchの代わりに自作の関数を使うことで、streming対応しながら、LLMのレスポンスがtool callなのか、通常の文字列なのかを判断する。streamingだけでなく、invokeの場合もきちんと動作する。
def branch(input_stream: Iterable[AIMessageChunk | AIMessage]):
for input in input_stream:
``
これは公式でも紹介されているテクニックなので覚えて損はないと思う。
https://python.langchain.com/v0.1/docs/expression_language/primitives/functions/
JsonOutputParserを使用することで、LLMが使用したdocsの番号をLLM自身に指定させる。このドキュメントの番号を後工程で使用して、urlをmtadataから取得してくる。
class Response(BaseModel):
answer: str = Field(description="Natural language answer")
predictions: list[str] = Field(description="The next three questions, each of no more than 20 characters")
docs: Optional[list[int]] = Field(description="Return a list of the document numbers that were referenced. If no documents were referenced, return nothing. It is possible to specify multiple numbers.")
output_parser = JsonOutputParser(pydantic_object=Response, diff=True)
LCELは手続き的な一直線の処理が得意だが、前工程で使用したデータを後で使用するといった複雑なデータフローを扱うことは難しい。そのため、QueryManageとDocumentManagerをChainの外でインスタンス化し、必要なデータを外部で保持・管理する。
class QueryManager:
def __init__(self):
self.query = ''
def save_query(self, x: Dict[str, Any]) -> Dict[str, Any]:
self.query = x["input"]
print(f"Saved query: {self.query}")
return x
def get_query(self):
return self.query
class DocumentManager:
def __init__(self):
self.documents = []
self.formatted_docs = ""
def save_documents(self, docs: List[Document]) -> List[Document]:
self.documents = docs
self.formatted_docs = self.format_documents(docs)
print(f"Saved {len(docs)} documents")
return docs
def format_documents(self, docs: List[Document]) -> List[str]:
return [f"{i+1}. {doc.page_content}" for i, doc in enumerate(docs)]
def get_documents(self) -> List[Document]:
return self.documents
def get_formatted_docs(self) -> str:
return self.formatted_docs
わかったこと
- LCELは一本道の手続的プログラミングが得意。手続的とは、Aを実行したらBを実行して…というように順番通りに実行していくプログラミングである。
- RunnableBranchはストリーミング未対応。今回のように分岐をStreamingでやりたい場合、自分で実装するしかない。
- LCELはRunnableBranchなどで適切に表現できるもの以外の処理は向いていない。今回のような複雑なフローの時には辛いことになる。
- ただ、chain.stream()とchain.invoke()処理を切り変えられるのはめっちゃ便利
「もっとこうした方がいい!」というアドバイスあったらお願いします!
参考文献
3つのRunnable〇〇を理解する
scikit-learn とのアナロジーから見る LangChain Expression Language (LCEL)
実行例で理解する Runnable
の継承者たち in langchain
LangChainのOutput Parserを試す