Ollamaによるチャンクサイズとモデル精度の関係を検証
はじめに
@SNAMGNさんより「チャンクサイズ」が回答精度に影響する可能性があるとのご指南をいただき、貴重なアドバイスをもらいました。この場を借りてお礼申し上げます!
また、前回の記事の公開後、「1か月以内に新しい記事を出す」と宣言していたので、何とかまとめることができてよかったです。
前回の記事:
実は、今回は自分のPCのGPUをフル活用して処理を高速化する予定でしたが、残念ながらうまくいかず、結局CPUのみでの処理となっています。
GPU利用に失敗した記事:
私のPCのGPUのバージョンが古いためだと思われるのですが、手順などは分かりやすく載せているつもりです。興味がありましたら、ご覧ください。
評価対象
今回は、前回の記事で採用した5つのembedding
モデルについて、それぞれの「チャンクサイズ」を比較してみることにしました。特に、チャンクサイズの変更が「処理速度」と「回答性能」にどのような影響を及ぼすかを検証していきます。
llm
モデル
今回は、llm
モデルをllama3.2
に統一し、同一条件で評価を行うことにしました。
embedding
モデル
モデル名 | 256 | 512 | 1025 | 2048 | 4096 | 8192 |
---|---|---|---|---|---|---|
nomic-embed-text | 〇 | 〇 | 〇 | |||
mxbai-embed-large | 〇 | 〇 | ||||
all-minilm | 〇 | 〇 | ||||
unclemusclez/jina-embeddings-v2-base-code | 〇 | 〇 | 〇 | |||
bge-m3 | 〇 | 〇 | 〇 |
「チャンクサイズ」については、全モデルに対して256
と512
で検証を行います。また、より大きな「チャンクサイズ」に対応できるモデルに対しては、最大の「チャンクサイズ」を1つ選んで検証する予定です。
さらに、各embedding
モデルと「チャンクサイズ」の組み合わせごとに、4つの質問を投げて評価を行います。
No. | 質問内容 |
---|---|
1 | 「組合せ処置およびその方法」の概要を教えて |
2 | 「フューリンインヒビター」の概要を教えて |
3 | 用語「アルキル」の意味は? |
4 | 用語「アゴニスト」とは? |
回答精度の定義は以下としました。
得点 | 説明 |
---|---|
3 | 正確に要約されている |
2 | 観点は異なるが、納得できる要約 |
1 | キーワードのみ含まれている要約 |
0 | キーワードすら含まれていない |
「質問内容」と「回答精度」は、以前の投稿記事と同じです。
プログラム
今回使用したプログラムは、前回の記事のコードを一部修正したものです。
プログラム(ベクトルデータベース生成)
各処理に長時間かかるため、embedding
モデルと「チャンクサイズ」をリスト化し、それを順番に実行して結果をテキストファイルに出力するようにしました。これで朝までぐっすり眠ることができました。
import glob
import os
import xml.etree.ElementTree as ET
from dotenv import load_dotenv
from langchain.text_splitter import CharacterTextSplitter
from langchain_chroma import Chroma
import ollama
from datetime import datetime
load_dotenv()
docs = []
# 取り出したい名前空間-タグ名
name_spaces_tag_names = [
"{http://www.wipo.int/standards/XMLSchema/ST96/Common}PublicationNumber",
"{http://www.wipo.int/standards/XMLSchema/ST96/Common}PublicationDate",
"{http://www.wipo.int/standards/XMLSchema/ST96/Common}RegistrationDate",
"{http://www.wipo.int/standards/XMLSchema/ST96/Common}ApplicationNumberText",
"{http://www.wipo.int/standards/XMLSchema/ST96/Common}PartyIdentifier",
"{http://www.wipo.int/standards/XMLSchema/ST96/Common}EntityName",
"{http://www.wipo.int/standards/XMLSchema/ST96/Common}PostalAddressText",
"{http://www.wipo.int/standards/XMLSchema/ST96/Common}PatentCitationText",
"{http://www.wipo.int/standards/XMLSchema/ST96/Common}PersonFullName",
"{http://www.wipo.int/standards/XMLSchema/ST96/Common}P",
"{http://www.wipo.int/standards/XMLSchema/ST96/Common}FigureReference",
"{http://www.wipo.int/standards/XMLSchema/ST96/Patent}PlainLanguageDesignationText",
"{http://www.wipo.int/standards/XMLSchema/ST96/Patent}FilingDate",
"{http://www.wipo.int/standards/XMLSchema/ST96/Patent}InventionTitle",
"{http://www.wipo.int/standards/XMLSchema/ST96/Patent}MainClassification",
"{http://www.wipo.int/standards/XMLSchema/ST96/Patent}FurtherClassification",
"{http://www.wipo.int/standards/XMLSchema/ST96/Patent}PatentClassificationText",
"{http://www.wipo.int/standards/XMLSchema/ST96/Patent}SearchFieldText",
"{http://www.wipo.int/standards/XMLSchema/ST96/Patent}ClaimText",
]
model_chunksize = [
("nomic-embed-text",256),
("nomic-embed-text",512),
("nomic-embed-text",2048),
("mxbai-embed-large",256),
("mxbai-embed-large",512),
("all-minilm",256),
("all-minilm",512),
("unclemusclez/jina-embeddings-v2-base-code",256),
("unclemusclez/jina-embeddings-v2-base-code",512),
("unclemusclez/jina-embeddings-v2-base-code",8192),
("bge-m3",256),
("bge-m3",512),
("bge-m3",8192),
]
# 埋め込み関数のラッパーを作成
class OllamaEmbeddingFunction:
def __init__(self, model):
self.model = model
def embed_documents(self, texts):
embeddings = []
for text in texts:
response = ollama.embeddings(model=self.model, prompt=text)
embeddings.append(response['embedding'])
return embeddings # ここで計算した埋め込みを返します
def set_element(level, trees, el):
trees.append({"tag" : el.tag, "attrib" : el.attrib, "content_page" :el.text})
def set_child(level, trees, el):
set_element(level, trees, el)
for child in el:
set_child(level+1, trees, child)
def parse_and_get_element(input_file):
tmp_elements = []
new_elements = []
tree = ET.parse(input_file)
root = tree.getroot()
set_child(1, tmp_elements, root)
for name_space_tag_name in name_spaces_tag_names:
for tmp_element in tmp_elements:
if tmp_element["tag"] == name_space_tag_name:
new_elements.append(tmp_element)
return new_elements
def execute():
title = ""
entryName = ""
patentCitationText = ""
for model_name, embedding_length in model_chunksize:
files = glob.glob(os.path.join("C:/Users/ogiki/JPB_2024999", "**/*.*"), recursive=True)
formatted_time = datetime.now().strftime("%H:%M:%S")
print("開始時刻:", formatted_time)
for file in files:
base, ext = os.path.splitext(file)
if ext == '.xml':
topic_name = os.path.splitext(os.path.basename(file))[0]
print(file)
text_splitter = CharacterTextSplitter(chunk_size=embedding_length, chunk_overlap=0)
new_elements = parse_and_get_element(file)
for new_element in new_elements:
try:
text = new_element["content_page"]
tag = new_element["tag"]
title = text if tag == "{http://www.wipo.int/standards/XMLSchema/ST96/Patent}InventionTitle" else ""
entryName = text if tag == "{http://www.wipo.int/standards/XMLSchema/ST96/Common}EntityName" else ""
patentCitationText = text if tag == "{http://www.wipo.int/standards/XMLSchema/ST96/Common}PatentCitationText" else ""
documents = text_splitter.create_documents(texts=[text], metadatas=[{
"name": topic_name,
"source": file,
"tag": tag,
"title": title,
"entry_name": entryName,
"patent_citation_text" : patentCitationText}]
)
docs.extend(documents)
except Exception as e:
continue
# OllamaEmbeddingFunctionのインスタンスを作成
embedding_function = OllamaEmbeddingFunction(model=model_name)
db = Chroma(persist_directory=f"C:/Users/ogiki/vectorDB/{model_name}-{embedding_length}", embedding_function=embedding_function)
intv = 500
ln = len(docs)
max_loop = int(ln / intv) + 1
for i in range(max_loop):
splitted_documents = text_splitter.split_documents(docs[intv * i : intv * (i+1)])
db.add_documents(splitted_documents)
formatted_time = datetime.now().strftime("%H:%M:%S")
print("修了時刻:", formatted_time)
if __name__ == "__main__":
execute()
embedding
モデル名と「チャンクサイズ」をリストにまとめ、for
ループでそれぞれの組み合わせを順に処理して、ベクトルデータベースを作成する方式を採用しました。
model_chunksize = [
("nomic-embed-text",256),
("nomic-embed-text",512),
("nomic-embed-text",2048),
("mxbai-embed-large",256),
("mxbai-embed-large",512),
("all-minilm",256),
("all-minilm",512),
("unclemusclez/jina-embeddings-v2-base-code",256),
("unclemusclez/jina-embeddings-v2-base-code",512),
("unclemusclez/jina-embeddings-v2-base-code",8192),
("bge-m3",256),
("bge-m3",512),
("bge-m3",8192),
]
この処理が完了すると、embedding
モデル数 × 「チャンクサイズ」の種類数 に応じたベクトルデータベースが生成されます。
プログラム(回答生成)
さらに、生成AIへの問い合わせも毎回streamlit
から手動で行うのは非効率だと感じたため、こちらもfor
ループで一括処理することにしました。
from langchain_community.chat_models.ollama import ChatOllama
from langchain.prompts import PromptTemplate
from langchain.schema import HumanMessage
from langchain.vectorstores import Chroma
import ollama
from datetime import datetime
FILE_PATH = "C:/Users/ogiki/Desktop/result.txt"
model_chunksize = [
("nomic-embed-text",256),
("nomic-embed-text",512),
("nomic-embed-text",2048),
("mxbai-embed-large",256),
("mxbai-embed-large",512),
("all-minilm",256),
("all-minilm",512),
("unclemusclez/jina-embeddings-v2-base-code",256),
("unclemusclez/jina-embeddings-v2-base-code",512),
("unclemusclez/jina-embeddings-v2-base-code",8192),
("bge-m3",256),
("bge-m3",512),
("bge-m3",8192),
]
input_messages = [
("0007350118","「組合せ処置およびその方法」の概要を教えて"),
("0007350061","「フューリンインヒビター」の概要を教えて"),
("0007350061","用語「アルキル」の意味は?"),
("0007350118","用語「アゴニスト」はどういうもの?"),
]
# 埋め込み関数のラッパーを作成
class OllamaEmbeddingFunction:
def __init__(self, model):
self.model = model
def embed_documents(self, texts):
embeddings = []
for text in texts:
response = ollama.embeddings(model=self.model, prompt=text)
embeddings.append(response['embedding'])
return embeddings # ここで計算した埋め込みを返します
def embed_query(self, query):
response = ollama.embeddings(model=self.model, prompt=query)
return response['embedding'] # クエリの埋め込みを返す
file = open(FILE_PATH, "w", encoding="utf-8")
file.write("")
file.close()
for model_name, embedding_length in model_chunksize:
file = open(FILE_PATH, "a", encoding="utf-8")
embedding_function = OllamaEmbeddingFunction(model=model_name)
chat = ChatOllama(model="llama3.2", temperature=0)
database = Chroma(
persist_directory=f"C:/Users/ogiki/vectorDB/{model_name}-{embedding_length}",
embedding_function=embedding_function
)
prompt = PromptTemplate(template="""文章を元に質問に答えてください。
文章:
{document}
質問: {query}
""", input_variables=["document", "query"])
# =====================================================
for regist_no, input_message in input_messages:
file.write(f"モデル:{model_name}\n")
file.write(f"エンベッディング:{embedding_length}\n")
file.write(f"チャンク数:{str(embedding_length)}")
file.write(f"登録番号:{regist_no}\n")
file.write(f"質問:{input_message}\n")
formatted_time = datetime.now().strftime("%H:%M:%S")
file.write(f"開始時刻:{formatted_time}\n")
documents = database.similarity_search_with_score(input_message, k=3, filter={"name":regist_no})
documents_string = ""
for document in documents:
print("---------------document.metadata---------------")
print(document[0].metadata)
print(document[1])
documents_string += f"""
---------------------------
{document[0].page_content}
"""
print("---------------documents_string---------------")
print(input_message)
print(documents_string)
# ----- プロンプトを基に回答をもらう (ローカルLLMを利用) -----
result = chat([
HumanMessage(content=prompt.format(document=documents_string,
query=input_message))
])
file.write(f"回答:{result.content}\n")
formatted_time = datetime.now().strftime("%H:%M:%S")
file.write(f"修了時刻:{formatted_time}\n")
file.write("\n")
file.write("\n")
file.write("\n")
file.close()
では、プログラムの詳細を見ていきます。
まず、標準出力をテキストファイルに保存する設定を行いました。
FILE_PATH = "C:/Users/ogiki/Desktop/result.txt"
次に、「ベクトルデータベース生成」のプログラム同様、embedding
モデルと「チャンクサイズ」の組み合わせリストを設定します。
model_chunksize = [
("nomic-embed-text", 256),
("nomic-embed-text", 512),
("nomic-embed-text", 2048),
("mxbai-embed-large", 256),
("mxbai-embed-large", 512),
# ...
]
生成AIには4つの質問を投げかけるため、こちらもリスト化しました。
input_messages = [
("0007350118", "「組合せ処置およびその方法」の概要を教えて"),
("0007350061", "「フューリンインヒビター」の概要を教えて"),
("0007350061", "用語「アルキル」の意味は?"),
("0007350118", "用語「アゴニスト」はどういうもの?"),
]
リスト内の最初の数字は、質問対象の絞り込み用として使用する「特許登録番号」です。
llm
モデルはllama3.2
で固定しています。
chat = ChatOllama(model="llama3.2", temperature=0)
実行
処理速度(ベクトルデータベース作成)
モデル名 | 256 | 512 | 1025 | 2048 | 4096 | 8192 | 平均 |
---|---|---|---|---|---|---|---|
nomic-embed-text | 1,014 | 1,014 | 1,014 | 1,014 | |||
mxbai-embed-large | 2,060 | 2,148 | 2,104 | ||||
all-minilm | 70 | 70 | 70 | ||||
unclemusclez/jina-embeddings-v2-base-code | 1,236 | 1,258 | 1,258 | 1,251 | |||
bge-m3 | 2,500 | 2,559 | 2,491 | 2,517 |
以上のことから、all-miniln
の処理速度が非常に速いことが明らかになりました。また、「チャンクサイズ」が処理速度に与える影響はほとんど見られないことが、表から読み取ることができました。
処理速度(回答生成)
モデル名 | 256 | 512 | 1025 | 2048 | 4096 | 8192 | 平均 |
---|---|---|---|---|---|---|---|
nomic-embed-text | 85 | 77 | 71 | 78 | |||
mxbai-embed-large | 158 | 158 | 158 | ||||
all-minilm | 49 | 51 | 50 | ||||
unclemusclez/jina-embeddings-v2-base-code | 5,413 | 5,320 | 5,248 | 5,327 | |||
bge-m3 | 152 | 145 | 145 | 147 |
やはり、all-miniln
の処理速度は非常に優れています。ただし、nomic-embed-text
もそれほど大きな差は見受けられませんでした。また、「チャンクサイズ」による処理速度の顕著な違いは表から確認できませんでした。
回答精度
4回の質問の平均です。(小数点あり)
モデル名 | 256 | 512 | 1025 | 2048 | 4096 | 8192 | 平均 |
---|---|---|---|---|---|---|---|
nomic-embed-text | 2.3 | 1.3 | 2.3 | 2.3 | |||
mxbai-embed-large | 2.5 | 2.5 | 2.5 | ||||
all-minilm | 2.8 | 2.0 | 2.4 | ||||
unclemusclez/jina-embeddings-v2-base-code | 1.5 | 1.5 | 1.5 | 1.5 | |||
bge-m3 | 2.3 | 2.1 | 1.9 | 2.1 |
最後に「回答精度」の比較を行います。ここでの平均はあまり意味を持たないことに注意が必要です。なぜなら、サンプル数(母数)が異なるからです。そのため、全体的な傾向を確認するだけになりますが、やはり「チャンクサイズ」の違いによる「回答精度」の顕著な差は見受けられませんでした。
おわりに
私の計測方法が不十分だったのか、期待した結果には至りませんでした。うーん…
次に比較すべきは'chunk_overlap'の違いでしょうか。反省点としては、「回答精度」においてRAGが良い結果を出した場合と、RAGにはないがllmに問い合わせた時に良い回答が得られた場合では、意味が異なるということです。したがって、類似性のスコアを確認する方法も考慮すべきだったかもしれません。体力があれば、今度その点にも挑戦してみたいと思います。
Discussion