🧚‍♂️

Langchain で Chat + RAG

2024/04/24に公開

概要

Langchain で RAG をしたうえで更にチャットもしたかった.
chainで繋げればいいのかと思っていたが, チャットしながら文書を検索できなかったのでそれっぽく実装してみた.
検索範囲を個別に指定したかったため逐次dbを構築する形にしている.
aws の opensearch をしようとしたけど,
結構お金がかかるようなので chromadb を選択した

コード全体

データを入れなくてもとりあえずチャットだけできる状態にしてある.
次の書で細かく見ていく

import chromadb
from langchain.prompts import *
from langchain_core.messages import BaseMessage
from langchain_openai import ChatOpenAI
from openai import OpenAI


class EmbeddedChat:
   model = 'text-embedding-3-small'  # OpenAIのembeddingモデル

   MAX_INPUT_LENGTH = 2000  # 入力の最大長
   N_RESULTS = 5  # 検索結果の数

   # チャットのプロンプトテンプレート
   chat_template = '''
   # シチュエーション
   あなたは優秀なAIです. 
   ユーザーからの質問に対して, DBに保存された情報を利用して適切な回答を返すことができます.

   # DBから取得した情報
   {db_content}

   # 今までの会話内容
   {history}
   人間: {input}
   AI: 
   '''

   def __init__(self, collection_name='talk_collection'):
       self.ephemeralClient = chromadb.EphemeralClient()  # ChromaDBクライアントの作成
       self.collection_name = collection_name  # コレクション名
       self.collection = self.ephemeralClient.create_collection(name=self.collection_name)  # コレクションの作成
       self.llm = ChatOpenAI(model_name="gpt-3.5-turbo")  # OpenAIのチャットモデル

       # プロンプトテンプレートの作成
       self.chat_prompt = PromptTemplate(
           input_variables=['db_content', 'history', 'input'],
           template=self.chat_template
       )
       self.client = OpenAI()  # OpenAIクライアントの作成

       self.history = []  # 会話履歴

   def __del__(self):
       self.ephemeralClient.delete_collection(self.collection_name)  # コレクションの削除

   def add_history(self, message: tuple[str, str]):
       self.history.append(message)  # 会話履歴に追加

   def set_history(self, history: list[tuple[str, str]]):
       self.history = history  # 会話履歴の設定

   def add(self, index: str, text: str, embedded_vector: list[float]):
       self.collection.add(
           embeddings=[embedded_vector],
           documents=[text],
           ids=[index]
       )  # コレクションにデータを追加

   def generate_embedded_vector(self, text: str) -> list:
       """
       テキストをembeddingする
       """
       response = self.client.embeddings.create(model=self.model, input=text)  # embeddingの生成
       embedded_vector = response.data[0].embedding
       return embedded_vector

   def query(self, embedded_vector: len, n_results: int):
       result = self.collection.query(
           query_embeddings=[embedded_vector],
           n_results=n_results
       )  # embeddingを使ってコレクションを検索

       documents = []
       for r in result['documents']:
           documents.append(r)
       return documents

   def get_input_history(self) -> list[tuple[str, str]]:
       """
       history の内容から 4000 文字まで取得
       最新のものからカウントする
       """
       history = []
       total_length = 0
       for i in range(len(self.history) - 1, -1, -1):
           input_text, output_text = self.history[i]
           total_length += len(input_text) + len(output_text)
           if total_length > self.MAX_INPUT_LENGTH:
               break
           history.append((input_text, output_text))

       return history

   def chat(self, input_text: str) -> str:
       # 入力文章をembedding
       embedded_vector = self.generate_embedded_vector(input_text)

       db_content = self.query(embedded_vector, self.N_RESULTS)

       input_history = self.get_input_history()  # 入力履歴の取得

       # プロンプトの作成
       prompt_text = self.chat_prompt.format(
           db_content=db_content,
           history=input_history,
           input=input_text
       )
       prediction: BaseMessage = self.llm.invoke(prompt_text)  # プロンプトを使ってチャットモデルを実行

       self.add_history((input_text, prediction.content))  # 会話履歴に追加
       return prediction.content


def main():
   chat = EmbeddedChat()
   while True:
       message = input('人間:')
       response = chat.chat(message)
       print('AI:', response)


if __name__ == '__main__':
   main()

コンストラクタ & デストラクタ

   def __init__(self, collection_name='talk_collection'):
       self.ephemeralClient = chromadb.EphemeralClient()  # ChromaDBクライアントの作成
       self.collection_name = collection_name  # コレクション名
       self.collection = self.ephemeralClient.create_collection(name=self.collection_name)  # コレクションの作成
       self.llm = ChatOpenAI(model_name="gpt-3.5-turbo")  # OpenAIのチャットモデル

       # プロンプトテンプレートの作成
       self.chat_prompt = PromptTemplate(
           input_variables=['db_content', 'history', 'input'],
           template=self.chat_template
       )
       self.client = OpenAI()  # OpenAIクライアントの作成

       self.history = []  # 会話履歴
   def __del__(self):
       self.ephemeralClient.delete_collection(self.collection_name)  # コレクションの削除

コンストラクタでインメモリの DB を作成して, デストラクタで削除している.
テンプレートにはdbから取得したデータ, チャット履歴, 新しい入力の設定が必要.
今回 GPT3.5 を使用する関係上, トークン上限にならないようにdbからの取得数や, 履歴の仕様を制限している.

検索用データと履歴

   def add_history(self, message: tuple[str, str]):
       self.history.append(message)  # 会話履歴に追加

   def set_history(self, history: list[tuple[str, str]]):
       self.history = history  # 会話履歴の設定

   def add(self, index: str, text: str, embedded_vector: list[float]):
       self.collection.add(
           embeddings=[embedded_vector],
           documents=[text],
           ids=[index]
       )  # コレクションにデータを追加

add 関数で検索対象となるデータを格納する.
index には, あとから取得したデータが分かるような任意のid.自分は uuid を使用した.
text はGPTに食わせるテキスト. そして embedded_vector は埋め込みベクトルで今回は openAI の "text-embedding-3-small" である.

検索

   def generate_embedded_vector(self, text: str) -> list:
       """
       テキストをembeddingする
       """
       response = self.client.embeddings.create(model=self.model, input=text)  # embeddingの生成
       embedded_vector = response.data[0].embedding
       return embedded_vector

   def query(self, embedded_vector: len, n_results: int):
       result = self.collection.query(
           query_embeddings=[embedded_vector],
           n_results=n_results
       )  # embeddingを使ってコレクションを検索

       documents = []
       for r in result['documents']:
           documents.append(r)
       return documents

generate_embedded_vector 入力したテキストをベクトル化して, query 関数でそのベクトルに近いテキストを検索する.
ここで, n_results は db から取得するテキストの塊の個数である.

チャット履歴の取得

   def get_input_history(self) -> list[tuple[str, str]]:
       """
       history の内容から 4000 文字まで取得
       最新のものからカウントする
       """
       history = []
       total_length = 0
       for i in range(len(self.history) - 1, -1, -1):
           input_text, output_text = self.history[i]
           total_length += len(input_text) + len(output_text)
           if total_length > self.MAX_INPUT_LENGTH:
               break
           history.append((input_text, output_text))

       return history

おの関数でチャット履歴を取得している.
これは新しい方から指定したサイズの文字数だけ取得している.
トークン上限に引っかからないように指定する必要がある.

チャット実行

   def chat(self, input_text: str) -> str:
       # 1. 入力文章をembedding
       embedded_vector = self.generate_embedded_vector(input_text)

       # 2. テキストを検索
       db_content = self.query(embedded_vector, self.N_RESULTS)

       # 3. チャット履歴を取得
       input_history = self.get_input_history()  # 入力履歴の取得

       # 4. プロンプトの作成
       prompt_text = self.chat_prompt.format(
           db_content=db_content,
           history=input_history,
           input=input_text
       )
       # 5. 推論を実行
       prediction: BaseMessage = self.llm.invoke(prompt_text)  # プロンプトを使ってチャットモデルを実行

       self.add_history((input_text, prediction.content))  # 会話履歴に追加
       return prediction.content

処理の流れとしては,

  1. 入力文章をembedding
  2. テキストを検索
  3. チャット履歴を取得
  4. プロンプトの作成
  5. 推論を実行
    である.

最後に

langchain を使ってもっと良い方法があれば知りたい.
ユーザごとに個別のDBを作成したかったのでこのような形にした.
ユーザーがすべて同じDBにアクセスする場合はもっときれいな方法があるかもしれない。

Discussion