🐷

Ollamaによるチャンクサイズとモデル精度の関係を検証

2024/10/29に公開

はじめに

@SNAMGNさんより「チャンクサイズ」が回答精度に影響する可能性があるとのご指南をいただき、貴重なアドバイスをもらいました。この場を借りてお礼申し上げます!

また、前回の記事の公開後、「1か月以内に新しい記事を出す」と宣言していたので、何とかまとめることができてよかったです。

前回の記事:

https://qiita.com/ogi_kimura/items/695d0d067b62501b753f

実は、今回は自分のPCのGPUをフル活用して処理を高速化する予定でしたが、残念ながらうまくいかず、結局CPUのみでの処理となっています。

GPU利用に失敗した記事:

https://qiita.com/ogi_kimura/items/d5ea9b9cf377dcbfe218

私のPCのGPUのバージョンが古いためだと思われるのですが、手順などは分かりやすく載せているつもりです。興味がありましたら、ご覧ください。

評価対象

今回は、前回の記事で採用した5つのembeddingモデルについて、それぞれの「チャンクサイズ」を比較してみることにしました。特に、チャンクサイズの変更が「処理速度」と「回答性能」にどのような影響を及ぼすかを検証していきます。

llmモデル

今回は、llmモデルをllama3.2に統一し、同一条件で評価を行うことにしました。

embeddingモデル

モデル名 256 512 1025 2048 4096 8192
nomic-embed-text
mxbai-embed-large
all-minilm
unclemusclez/jina-embeddings-v2-base-code
bge-m3

「チャンクサイズ」については、全モデルに対して256512で検証を行います。また、より大きな「チャンクサイズ」に対応できるモデルに対しては、最大の「チャンクサイズ」を1つ選んで検証する予定です。

さらに、各embeddingモデルと「チャンクサイズ」の組み合わせごとに、4つの質問を投げて評価を行います。

No. 質問内容
1 「組合せ処置およびその方法」の概要を教えて
2 「フューリンインヒビター」の概要を教えて
3 用語「アルキル」の意味は?
4 用語「アゴニスト」とは?

回答精度の定義は以下としました。

得点 説明
3 正確に要約されている
2 観点は異なるが、納得できる要約
1 キーワードのみ含まれている要約
0 キーワードすら含まれていない

「質問内容」と「回答精度」は、以前の投稿記事と同じです。

https://qiita.com/ogi_kimura/items/695d0d067b62501b753f

プログラム

今回使用したプログラムは、前回の記事のコードを一部修正したものです。

プログラム(ベクトルデータベース生成)

各処理に長時間かかるため、embeddingモデルと「チャンクサイズ」をリスト化し、それを順番に実行して結果をテキストファイルに出力するようにしました。これで朝までぐっすり眠ることができました。

retrieve.py
import glob
import os
import xml.etree.ElementTree as ET
from dotenv import load_dotenv
from langchain.text_splitter import CharacterTextSplitter
from langchain_chroma import Chroma
import ollama
from datetime import datetime

load_dotenv()

docs = []

# 取り出したい名前空間-タグ名
name_spaces_tag_names = [
    "{http://www.wipo.int/standards/XMLSchema/ST96/Common}PublicationNumber",
    "{http://www.wipo.int/standards/XMLSchema/ST96/Common}PublicationDate",
    "{http://www.wipo.int/standards/XMLSchema/ST96/Common}RegistrationDate",
    "{http://www.wipo.int/standards/XMLSchema/ST96/Common}ApplicationNumberText",
    "{http://www.wipo.int/standards/XMLSchema/ST96/Common}PartyIdentifier",
    "{http://www.wipo.int/standards/XMLSchema/ST96/Common}EntityName",
    "{http://www.wipo.int/standards/XMLSchema/ST96/Common}PostalAddressText",
    "{http://www.wipo.int/standards/XMLSchema/ST96/Common}PatentCitationText",
    "{http://www.wipo.int/standards/XMLSchema/ST96/Common}PersonFullName",
    "{http://www.wipo.int/standards/XMLSchema/ST96/Common}P",
    "{http://www.wipo.int/standards/XMLSchema/ST96/Common}FigureReference",
    "{http://www.wipo.int/standards/XMLSchema/ST96/Patent}PlainLanguageDesignationText",
    "{http://www.wipo.int/standards/XMLSchema/ST96/Patent}FilingDate",
    "{http://www.wipo.int/standards/XMLSchema/ST96/Patent}InventionTitle",
    "{http://www.wipo.int/standards/XMLSchema/ST96/Patent}MainClassification",
    "{http://www.wipo.int/standards/XMLSchema/ST96/Patent}FurtherClassification",
    "{http://www.wipo.int/standards/XMLSchema/ST96/Patent}PatentClassificationText",
    "{http://www.wipo.int/standards/XMLSchema/ST96/Patent}SearchFieldText",
    "{http://www.wipo.int/standards/XMLSchema/ST96/Patent}ClaimText",
]

model_chunksize = [
    ("nomic-embed-text",256),
    ("nomic-embed-text",512),
    ("nomic-embed-text",2048),
    ("mxbai-embed-large",256),
    ("mxbai-embed-large",512),
    ("all-minilm",256),
    ("all-minilm",512),
    ("unclemusclez/jina-embeddings-v2-base-code",256),
    ("unclemusclez/jina-embeddings-v2-base-code",512),
    ("unclemusclez/jina-embeddings-v2-base-code",8192),
    ("bge-m3",256),
    ("bge-m3",512),
    ("bge-m3",8192),
]

# 埋め込み関数のラッパーを作成
class OllamaEmbeddingFunction:
    def __init__(self, model):
        self.model = model
    def embed_documents(self, texts):
        embeddings = []
        for text in texts:
            response = ollama.embeddings(model=self.model, prompt=text)
            embeddings.append(response['embedding'])
        return embeddings  # ここで計算した埋め込みを返します

def set_element(level, trees, el):
    trees.append({"tag" : el.tag, "attrib" : el.attrib, "content_page" :el.text})

def set_child(level, trees, el):
    set_element(level, trees, el)
    for child in el:
        set_child(level+1, trees, child)

def parse_and_get_element(input_file):
    tmp_elements = []
    new_elements = []
    tree = ET.parse(input_file)
    root = tree.getroot()
    set_child(1, tmp_elements, root)
    for name_space_tag_name in name_spaces_tag_names:
        for tmp_element in tmp_elements:
            if tmp_element["tag"] == name_space_tag_name:
                new_elements.append(tmp_element)
    return new_elements

def execute():
    title = ""
    entryName = ""
    patentCitationText = ""

    for model_name, embedding_length in model_chunksize:
        files = glob.glob(os.path.join("C:/Users/ogiki/JPB_2024999", "**/*.*"), recursive=True)
        
        formatted_time = datetime.now().strftime("%H:%M:%S")
        print("開始時刻:", formatted_time)
    
        for file in files:
            base, ext = os.path.splitext(file)
            if ext == '.xml':
                topic_name = os.path.splitext(os.path.basename(file))[0]
                print(file)

                text_splitter = CharacterTextSplitter(chunk_size=embedding_length, chunk_overlap=0)
                new_elements = parse_and_get_element(file)
                for new_element in new_elements:
                    try:
                        text = new_element["content_page"]
                        tag = new_element["tag"]
                        title = text if tag == "{http://www.wipo.int/standards/XMLSchema/ST96/Patent}InventionTitle" else ""
                        entryName = text if tag == "{http://www.wipo.int/standards/XMLSchema/ST96/Common}EntityName" else ""
                        patentCitationText = text if tag == "{http://www.wipo.int/standards/XMLSchema/ST96/Common}PatentCitationText" else ""

                        documents = text_splitter.create_documents(texts=[text], metadatas=[{
                            "name": topic_name, 
                            "source": file, 
                            "tag": tag, 
                            "title": title,
                            "entry_name": entryName, 
                            "patent_citation_text" : patentCitationText}]
                        )
                        docs.extend(documents)
                    except Exception as e:
                        continue

        # OllamaEmbeddingFunctionのインスタンスを作成
        embedding_function = OllamaEmbeddingFunction(model=model_name)
        db = Chroma(persist_directory=f"C:/Users/ogiki/vectorDB/{model_name}-{embedding_length}", embedding_function=embedding_function)

        intv = 500
        ln = len(docs)
        max_loop = int(ln / intv) + 1
        for i in range(max_loop):
            splitted_documents = text_splitter.split_documents(docs[intv * i : intv * (i+1)])
            db.add_documents(splitted_documents)
            
        formatted_time = datetime.now().strftime("%H:%M:%S")
        print("修了時刻:", formatted_time)

if __name__ == "__main__":
    execute()

embeddingモデル名と「チャンクサイズ」をリストにまとめ、forループでそれぞれの組み合わせを順に処理して、ベクトルデータベースを作成する方式を採用しました。

model_chunksize = [
    ("nomic-embed-text",256),
    ("nomic-embed-text",512),
    ("nomic-embed-text",2048),
    ("mxbai-embed-large",256),
    ("mxbai-embed-large",512),
    ("all-minilm",256),
    ("all-minilm",512),
    ("unclemusclez/jina-embeddings-v2-base-code",256),
    ("unclemusclez/jina-embeddings-v2-base-code",512),
    ("unclemusclez/jina-embeddings-v2-base-code",8192),
    ("bge-m3",256),
    ("bge-m3",512),
    ("bge-m3",8192),
]

この処理が完了すると、embeddingモデル数 × 「チャンクサイズ」の種類数 に応じたベクトルデータベースが生成されます。

プログラム(回答生成)

さらに、生成AIへの問い合わせも毎回streamlitから手動で行うのは非効率だと感じたため、こちらもforループで一括処理することにしました。

from langchain_community.chat_models.ollama import ChatOllama
from langchain.prompts import PromptTemplate
from langchain.schema import HumanMessage
from langchain.vectorstores import Chroma
import ollama
from datetime import datetime

FILE_PATH = "C:/Users/ogiki/Desktop/result.txt"

model_chunksize = [
    ("nomic-embed-text",256),
    ("nomic-embed-text",512),
    ("nomic-embed-text",2048),
    ("mxbai-embed-large",256),
    ("mxbai-embed-large",512),
    ("all-minilm",256),
    ("all-minilm",512),
    ("unclemusclez/jina-embeddings-v2-base-code",256),
    ("unclemusclez/jina-embeddings-v2-base-code",512),
    ("unclemusclez/jina-embeddings-v2-base-code",8192),
    ("bge-m3",256),
    ("bge-m3",512),
    ("bge-m3",8192),
]

input_messages = [
    ("0007350118","「組合せ処置およびその方法」の概要を教えて"),
    ("0007350061","「フューリンインヒビター」の概要を教えて"),
    ("0007350061","用語「アルキル」の意味は?"),
    ("0007350118","用語「アゴニスト」はどういうもの?"),
]


# 埋め込み関数のラッパーを作成
class OllamaEmbeddingFunction:
    def __init__(self, model):
        self.model = model
    def embed_documents(self, texts):
        embeddings = []
        for text in texts:
            response = ollama.embeddings(model=self.model, prompt=text)
            embeddings.append(response['embedding'])
        return embeddings  # ここで計算した埋め込みを返します
    def embed_query(self, query):
        response = ollama.embeddings(model=self.model, prompt=query)
        return response['embedding']  # クエリの埋め込みを返す


file = open(FILE_PATH, "w", encoding="utf-8")
file.write("")
file.close()

for model_name, embedding_length in model_chunksize:
    
    file = open(FILE_PATH, "a", encoding="utf-8")
    
    embedding_function = OllamaEmbeddingFunction(model=model_name)
    chat = ChatOllama(model="llama3.2", temperature=0)
    database = Chroma(
        persist_directory=f"C:/Users/ogiki/vectorDB/{model_name}-{embedding_length}", 
        embedding_function=embedding_function
    )
    
    prompt = PromptTemplate(template="""文章を元に質問に答えてください。 

    文章: 
    {document}

    質問: {query}
    """, input_variables=["document", "query"])


    # =====================================================
    for regist_no, input_message in input_messages:
        
        file.write(f"モデル:{model_name}\n")
        file.write(f"エンベッディング:{embedding_length}\n")
        file.write(f"チャンク数:{str(embedding_length)}")
        file.write(f"登録番号:{regist_no}\n")
        file.write(f"質問:{input_message}\n")
        formatted_time = datetime.now().strftime("%H:%M:%S")
        file.write(f"開始時刻:{formatted_time}\n")
    
        documents = database.similarity_search_with_score(input_message, k=3, filter={"name":regist_no})
        documents_string = ""
        for document in documents:
            print("---------------document.metadata---------------")
            print(document[0].metadata)
            print(document[1])
            documents_string += f"""
                ---------------------------
                {document[0].page_content}
                """
        print("---------------documents_string---------------")
        print(input_message)
        print(documents_string)
        # ----- プロンプトを基に回答をもらう (ローカルLLMを利用) -----
        result = chat([
            HumanMessage(content=prompt.format(document=documents_string,
                                            query=input_message))
        ])
        file.write(f"回答:{result.content}\n")
        formatted_time = datetime.now().strftime("%H:%M:%S")
        file.write(f"修了時刻:{formatted_time}\n")
        file.write("\n")
        file.write("\n")
        file.write("\n")

    file.close()

では、プログラムの詳細を見ていきます。

まず、標準出力をテキストファイルに保存する設定を行いました。

FILE_PATH = "C:/Users/ogiki/Desktop/result.txt"

次に、「ベクトルデータベース生成」のプログラム同様、embeddingモデルと「チャンクサイズ」の組み合わせリストを設定します。

model_chunksize = [
    ("nomic-embed-text", 256),
    ("nomic-embed-text", 512),
    ("nomic-embed-text", 2048),
    ("mxbai-embed-large", 256),
    ("mxbai-embed-large", 512),
    # ...
]

生成AIには4つの質問を投げかけるため、こちらもリスト化しました。

input_messages = [
    ("0007350118", "「組合せ処置およびその方法」の概要を教えて"),
    ("0007350061", "「フューリンインヒビター」の概要を教えて"),
    ("0007350061", "用語「アルキル」の意味は?"),
    ("0007350118", "用語「アゴニスト」はどういうもの?"),
]

リスト内の最初の数字は、質問対象の絞り込み用として使用する「特許登録番号」です。

llmモデルはllama3.2で固定しています。

chat = ChatOllama(model="llama3.2", temperature=0)

実行

処理速度(ベクトルデータベース作成)

モデル名 256 512 1025 2048 4096 8192 平均
nomic-embed-text 1,014 1,014 1,014 1,014
mxbai-embed-large 2,060 2,148 2,104
all-minilm 70 70 70
unclemusclez/jina-embeddings-v2-base-code 1,236 1,258 1,258 1,251
bge-m3 2,500 2,559 2,491 2,517

以上のことから、all-minilnの処理速度が非常に速いことが明らかになりました。また、「チャンクサイズ」が処理速度に与える影響はほとんど見られないことが、表から読み取ることができました。

処理速度(回答生成)

モデル名 256 512 1025 2048 4096 8192 平均
nomic-embed-text 85 77 71 78
mxbai-embed-large 158 158 158
all-minilm 49 51 50
unclemusclez/jina-embeddings-v2-base-code 5,413 5,320 5,248 5,327
bge-m3 152 145 145 147

やはり、all-minilnの処理速度は非常に優れています。ただし、nomic-embed-textもそれほど大きな差は見受けられませんでした。また、「チャンクサイズ」による処理速度の顕著な違いは表から確認できませんでした。

回答精度

4回の質問の平均です。(小数点あり)

モデル名 256 512 1025 2048 4096 8192 平均
nomic-embed-text 2.3 1.3 2.3 2.3
mxbai-embed-large 2.5 2.5 2.5
all-minilm 2.8 2.0 2.4
unclemusclez/jina-embeddings-v2-base-code 1.5 1.5 1.5 1.5
bge-m3 2.3 2.1 1.9 2.1

最後に「回答精度」の比較を行います。ここでの平均はあまり意味を持たないことに注意が必要です。なぜなら、サンプル数(母数)が異なるからです。そのため、全体的な傾向を確認するだけになりますが、やはり「チャンクサイズ」の違いによる「回答精度」の顕著な差は見受けられませんでした。

おわりに

私の計測方法が不十分だったのか、期待した結果には至りませんでした。うーん…
次に比較すべきは'chunk_overlap'の違いでしょうか。反省点としては、「回答精度」においてRAGが良い結果を出した場合と、RAGにはないがllmに問い合わせた時に良い回答が得られた場合では、意味が異なるということです。したがって、類似性のスコアを確認する方法も考慮すべきだったかもしれません。体力があれば、今度その点にも挑戦してみたいと思います。

Discussion