Langchain で Chat + RAG
概要
Langchain で RAG をしたうえで更にチャットもしたかった.
chainで繋げればいいのかと思っていたが, チャットしながら文書を検索できなかったのでそれっぽく実装してみた.
検索範囲を個別に指定したかったため逐次dbを構築する形にしている.
aws の opensearch をしようとしたけど,
結構お金がかかるようなので chromadb を選択した
コード全体
データを入れなくてもとりあえずチャットだけできる状態にしてある.
次の書で細かく見ていく
import chromadb
from langchain.prompts import *
from langchain_core.messages import BaseMessage
from langchain_openai import ChatOpenAI
from openai import OpenAI
class EmbeddedChat:
model = 'text-embedding-3-small' # OpenAIのembeddingモデル
MAX_INPUT_LENGTH = 2000 # 入力の最大長
N_RESULTS = 5 # 検索結果の数
# チャットのプロンプトテンプレート
chat_template = '''
# シチュエーション
あなたは優秀なAIです.
ユーザーからの質問に対して, DBに保存された情報を利用して適切な回答を返すことができます.
# DBから取得した情報
{db_content}
# 今までの会話内容
{history}
人間: {input}
AI:
'''
def __init__(self, collection_name='talk_collection'):
self.ephemeralClient = chromadb.EphemeralClient() # ChromaDBクライアントの作成
self.collection_name = collection_name # コレクション名
self.collection = self.ephemeralClient.create_collection(name=self.collection_name) # コレクションの作成
self.llm = ChatOpenAI(model_name="gpt-3.5-turbo") # OpenAIのチャットモデル
# プロンプトテンプレートの作成
self.chat_prompt = PromptTemplate(
input_variables=['db_content', 'history', 'input'],
template=self.chat_template
)
self.client = OpenAI() # OpenAIクライアントの作成
self.history = [] # 会話履歴
def __del__(self):
self.ephemeralClient.delete_collection(self.collection_name) # コレクションの削除
def add_history(self, message: tuple[str, str]):
self.history.append(message) # 会話履歴に追加
def set_history(self, history: list[tuple[str, str]]):
self.history = history # 会話履歴の設定
def add(self, index: str, text: str, embedded_vector: list[float]):
self.collection.add(
embeddings=[embedded_vector],
documents=[text],
ids=[index]
) # コレクションにデータを追加
def generate_embedded_vector(self, text: str) -> list:
"""
テキストをembeddingする
"""
response = self.client.embeddings.create(model=self.model, input=text) # embeddingの生成
embedded_vector = response.data[0].embedding
return embedded_vector
def query(self, embedded_vector: len, n_results: int):
result = self.collection.query(
query_embeddings=[embedded_vector],
n_results=n_results
) # embeddingを使ってコレクションを検索
documents = []
for r in result['documents']:
documents.append(r)
return documents
def get_input_history(self) -> list[tuple[str, str]]:
"""
history の内容から 4000 文字まで取得
最新のものからカウントする
"""
history = []
total_length = 0
for i in range(len(self.history) - 1, -1, -1):
input_text, output_text = self.history[i]
total_length += len(input_text) + len(output_text)
if total_length > self.MAX_INPUT_LENGTH:
break
history.append((input_text, output_text))
return history
def chat(self, input_text: str) -> str:
# 入力文章をembedding
embedded_vector = self.generate_embedded_vector(input_text)
db_content = self.query(embedded_vector, self.N_RESULTS)
input_history = self.get_input_history() # 入力履歴の取得
# プロンプトの作成
prompt_text = self.chat_prompt.format(
db_content=db_content,
history=input_history,
input=input_text
)
prediction: BaseMessage = self.llm.invoke(prompt_text) # プロンプトを使ってチャットモデルを実行
self.add_history((input_text, prediction.content)) # 会話履歴に追加
return prediction.content
def main():
chat = EmbeddedChat()
while True:
message = input('人間:')
response = chat.chat(message)
print('AI:', response)
if __name__ == '__main__':
main()
コンストラクタ & デストラクタ
def __init__(self, collection_name='talk_collection'):
self.ephemeralClient = chromadb.EphemeralClient() # ChromaDBクライアントの作成
self.collection_name = collection_name # コレクション名
self.collection = self.ephemeralClient.create_collection(name=self.collection_name) # コレクションの作成
self.llm = ChatOpenAI(model_name="gpt-3.5-turbo") # OpenAIのチャットモデル
# プロンプトテンプレートの作成
self.chat_prompt = PromptTemplate(
input_variables=['db_content', 'history', 'input'],
template=self.chat_template
)
self.client = OpenAI() # OpenAIクライアントの作成
self.history = [] # 会話履歴
def __del__(self):
self.ephemeralClient.delete_collection(self.collection_name) # コレクションの削除
コンストラクタでインメモリの DB を作成して, デストラクタで削除している.
テンプレートにはdbから取得したデータ, チャット履歴, 新しい入力の設定が必要.
今回 GPT3.5 を使用する関係上, トークン上限にならないようにdbからの取得数や, 履歴の仕様を制限している.
検索用データと履歴
def add_history(self, message: tuple[str, str]):
self.history.append(message) # 会話履歴に追加
def set_history(self, history: list[tuple[str, str]]):
self.history = history # 会話履歴の設定
def add(self, index: str, text: str, embedded_vector: list[float]):
self.collection.add(
embeddings=[embedded_vector],
documents=[text],
ids=[index]
) # コレクションにデータを追加
add 関数で検索対象となるデータを格納する.
index には, あとから取得したデータが分かるような任意のid.自分は uuid を使用した.
text はGPTに食わせるテキスト. そして embedded_vector は埋め込みベクトルで今回は openAI の "text-embedding-3-small" である.
検索
def generate_embedded_vector(self, text: str) -> list:
"""
テキストをembeddingする
"""
response = self.client.embeddings.create(model=self.model, input=text) # embeddingの生成
embedded_vector = response.data[0].embedding
return embedded_vector
def query(self, embedded_vector: len, n_results: int):
result = self.collection.query(
query_embeddings=[embedded_vector],
n_results=n_results
) # embeddingを使ってコレクションを検索
documents = []
for r in result['documents']:
documents.append(r)
return documents
generate_embedded_vector 入力したテキストをベクトル化して, query 関数でそのベクトルに近いテキストを検索する.
ここで, n_results は db から取得するテキストの塊の個数である.
チャット履歴の取得
def get_input_history(self) -> list[tuple[str, str]]:
"""
history の内容から 4000 文字まで取得
最新のものからカウントする
"""
history = []
total_length = 0
for i in range(len(self.history) - 1, -1, -1):
input_text, output_text = self.history[i]
total_length += len(input_text) + len(output_text)
if total_length > self.MAX_INPUT_LENGTH:
break
history.append((input_text, output_text))
return history
おの関数でチャット履歴を取得している.
これは新しい方から指定したサイズの文字数だけ取得している.
トークン上限に引っかからないように指定する必要がある.
チャット実行
def chat(self, input_text: str) -> str:
# 1. 入力文章をembedding
embedded_vector = self.generate_embedded_vector(input_text)
# 2. テキストを検索
db_content = self.query(embedded_vector, self.N_RESULTS)
# 3. チャット履歴を取得
input_history = self.get_input_history() # 入力履歴の取得
# 4. プロンプトの作成
prompt_text = self.chat_prompt.format(
db_content=db_content,
history=input_history,
input=input_text
)
# 5. 推論を実行
prediction: BaseMessage = self.llm.invoke(prompt_text) # プロンプトを使ってチャットモデルを実行
self.add_history((input_text, prediction.content)) # 会話履歴に追加
return prediction.content
処理の流れとしては,
- 入力文章をembedding
- テキストを検索
- チャット履歴を取得
- プロンプトの作成
- 推論を実行
である.
最後に
langchain を使ってもっと良い方法があれば知りたい.
ユーザごとに個別のDBを作成したかったのでこのような形にした.
ユーザーがすべて同じDBにアクセスする場合はもっときれいな方法があるかもしれない。
Discussion