Closed9

BAAI/bge-m3を試す&LlamaIndexでインデックス作成してretrieval

kun432kun432
  • Multi-Linguality: It can support more than 100 working languages.
  • Multi-Granularity: It is able to process inputs of different granularities, spanning from short sentences to long documents of up to 8192 tokens.
kun432kun432

Colaboratoryで。とりあえずランタイムは標準の「CPU」で。

!pip install -U FlagEmbedding

モデルをロードする。

from FlagEmbedding import BGEM3FlagModel

model = BGEM3FlagModel(
    'BAAI/bge-m3',  
    use_fp16=False, # Setting use_fp16 to True speeds up computation with a slight performance degradation
)
/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:88: UserWarning: 
The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
  warnings.warn(
Fetching 19 files: 100%
19/19 [01:01<00:00, 5.51s/it]
config.json: 100%
687/687 [00:00<00:00, 5.18kB/s]
README.md: 100%
11.7k/11.7k [00:00<00:00, 168kB/s]
1_Pooling/config.json: 100%
191/191 [00:00<00:00, 1.45kB/s]
config_sentence_transformers.json: 100%
123/123 [00:00<00:00, 828B/s]
imgs/long.jpg: 100%
218k/218k [00:00<00:00, 1.33MB/s]
imgs/miracl.jpg: 100%
201k/201k [00:00<00:00, 1.07MB/s]
.gitattributes: 100%
1.57k/1.57k [00:00<00:00, 34.6kB/s]
imgs/mkqa.jpg: 100%
326k/326k [00:00<00:00, 2.48MB/s]
sentence_bert_config.json: 100%
54.0/54.0 [00:00<00:00, 665B/s]
colbert_linear.pt: 100%
2.10M/2.10M [00:00<00:00, 7.06MB/s]
modules.json: 100%
349/349 [00:00<00:00, 3.43kB/s]
imgs/nqa.jpg: 100%
97.4k/97.4k [00:00<00:00, 761kB/s]
special_tokens_map.json: 100%
964/964 [00:00<00:00, 9.32kB/s]
tokenizer_config.json: 100%
1.31k/1.31k [00:00<00:00, 15.8kB/s]
model.safetensors: 100%
2.27G/2.27G [00:42<00:00, 75.3MB/s]
pytorch_model.bin: 100%
2.27G/2.27G [00:59<00:00, 54.9MB/s]
sentencepiece.bpe.model: 100%
5.07M/5.07M [00:00<00:00, 22.6MB/s]
sparse_linear.pt: 100%
3.52k/3.52k [00:00<00:00, 26.7kB/s]
tokenizer.json: 100%
17.1M/17.1M [00:00<00:00, 47.3MB/s]
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-2-004240603cb2> in <cell line: 3>()
      1 from FlagEmbedding import BGEM3FlagModel
      2 
----> 3 model = BGEM3FlagModel(
      4     'BAAI/bge-m3',
      5     use_fp16=False, # Setting use_fp16 to True speeds up computation with a slight performance degradation

5 frames
/usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_base.py in _from_pretrained(cls, resolved_vocab_files, pretrained_model_name_or_path, init_configuration, token, cache_dir, local_files_only, _commit_hash, _is_local, *init_inputs, **kwargs)
   2247             if added_tokens_map != {} and init_kwargs[key] is not None:
   2248                 if key != "additional_special_tokens":
-> 2249                     init_kwargs[key] = added_tokens_map.get(init_kwargs[key], init_kwargs[key])
   2250 
   2251         init_kwargs["added_tokens_decoder"] = added_tokens_decoder

TypeError: unhashable type: 'dict'

怒られる。

transformersをアップデートする。

!pip install -U transformers
(snip)
Installing collected packages: transformers
  Attempting uninstall: transformers
    Found existing installation: transformers 4.35.2
    Uninstalling transformers-4.35.2:
      Successfully uninstalled transformers-4.35.2
Successfully installed transformers-4.37.2
WARNING: The following packages were previously imported in this runtime:
  [transformers]
You must restart the runtime in order to use newly installed versions.

ランタイムをリスタートして再度実行するといけた。

ではEmbeddingsを生成してみる。

sentences_1 = [
    "What is BGE M3?",
    "Defination of BM25"
]
sentences_2 = [
    "BGE M3 is an embedding model supporting dense retrieval, lexical matching and multi-vector interaction.", 
    "BM25 is a bag-of-words retrieval function that ranks a set of documents based on the query terms appearing in each document"
]

embeddings_1 = model.encode(
    sentences_1, 
    batch_size=12, 
    max_length=8192, # If you don't need such a long length, you can set a smaller value to speed up the encoding process.
)['dense_vecs']
embeddings_2 = model.encode(sentences_2)['dense_vecs']

print()

print(embeddings_1.shape)
print(embeddings_2.shape)

similarity = embeddings_1 @ embeddings_2.T
print()
print(similarity)
# [[0.6265, 0.3477], [0.3499, 0.678 ]]
/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py:557: UserWarning: This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
  warnings.warn(_create_warning_msg(
encoding: 100%|██████████| 1/1 [00:01<00:00,  1.09s/it]
encoding: 100%|██████████| 1/1 [00:01<00:00,  1.37s/it]
(2, 1024)
(2, 1024)

[[0.62590367 0.3474958 ]
 [0.34986818 0.6782464 ]]
kun432kun432

A100にして再度試してみた。モデルロード直後のVRAM状態。

use_fp16=False

Sun Feb  4 06:26:38 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA A100-SXM4-40GB          Off | 00000000:00:04.0 Off |                    0 |
| N/A   32C    P0              50W / 400W |   2609MiB / 40960MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
                                                                                         
+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
+---------------------------------------------------------------------------------------+

use_fp16=True

Sun Feb  4 06:29:43 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA A100-SXM4-40GB          Off | 00000000:00:04.0 Off |                    0 |
| N/A   32C    P0              54W / 400W |   1513MiB / 40960MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
                                                                                         
+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
+---------------------------------------------------------------------------------------+
kun432kun432

LlamaIndexで。HuggingFaceEmbedding を使えば簡単。
ただし、ColaboratoryだとこちらもTransformerをアップデートしておく必要がある。

from llama_index.embeddings import HuggingFaceEmbedding
from llama_index import ServiceContext, VectorStoreIndex

embed_model = HuggingFaceEmbedding(
    model_name='BAAI/bge-m3',
    embed_batch_size=12,
    max_length=8192
)

service_context = ServiceContext.from_defaults(
    embed_model=embed_model
)

index = VectorStoreIndex(
    nodes,
    service_context=service_context,
    show_progress=True
)

ノードは以下のデータを使った。

https://linecorp.com/ja/csr/newslist/ja/2020/260

ノードの作り方は以下

上記LINEのオープンデータからLlamaIndexのノードを作成する
!wget https://d.line-scdn.net/stf/linecorp/ja/csr/dataset_.zip
!!unzip dataset_.zip
import pandas as pd

df = pd.read_excel("dataset_.xlsx")
df.rename(columns={
    'サンプルID': 'SampleID',
    'サンプル 問い合わせ文': 'Q',
    'サンプル 応答文': 'A',
    'カテゴリ1': 'Category',
    'カテゴリ2': 'Cat2',
    '出典': 'Ref',
    '<参考>UMカテゴリタグ': 'Tag',
    '<参考>UMサービスメニュー\n(標準的な行政サービス名称)': 'Service'
}, inplace=True)
df.drop(columns=["ID", "Cat2","Ref","Tag","Service"], inplace=True)
df["Context"] = "Q: " + df["Q"] + "\nA: " + df["A"]

max_digits = df["SampleID"].astype(str).str.len().max()
df['Q_ID'] = df['SampleID'].apply(lambda x: "Q_" + str(x).zfill(max_digits))
df['A_ID'] = df['SampleID'].apply(lambda x: "A_" + str(x).zfill(max_digits))
from llama_index.schema import TextNode, NodeRelationship, RelatedNodeInfo

nodes = []
for idx, row in df.iterrows():
    node = TextNode(text=row["A"], id_=row["A_ID"])
    nodes.append(node)

retrieverを作成

k=10
text = "住民税の計算方法を教えてください"     # Q_662

for r in index.as_retriever(similarity_top_k=k).retrieve(text):
    print(r.id_)
    print(r.get_score())
    print(r.get_content().replace("\n","")[:100])
    print("-----")

A_662 # ★これが正解
0.7440262592731685
住民税の算出方法についてはこちらをご覧ください。(自治体HP内関連ページのURL)
-----
A_569
0.6745045379175899
住民税は、その年の1月1日に住所(住民票)がある市区町村で課税されますので、証明書もその市区町村で発行されます。2019年度の証明書は、2019年1月1日に住所があった市区町村に申請してください。○○
-----
A_648
0.6306537189675594
○○年度の住民税の申告についてはこちらをご覧ください。(自治体HP内関連ページのURL)
-----
A_659
0.625395760497799
住民税の納税証明書は、(申請場所(自治体の担当課、支所・出張所等))で発行しています。申請方法等、詳しくはこちらをご覧ください。(自治体HP内関連ページのURL)
-----
A_333
0.6116227282864792
(1/1時点で海外に在住・在勤していた場合の保育利用料の取り扱いを記載してください。)例「保育利用料の軽減に関して、1月1日時点で海外に在住・在勤していたため、住民税が課税されていない場合、非課税扱い
-----
A_100
0.5924761112815058
所得証明書(住民税課税(非課税)証明書)は住民税担当課で請求ができます。転出先で必要となる証明書の年度をご確認のうえ、必要年度の前年度の1月1日に在住していた市区町村の住民税担当課で請求してください。
-----
A_236
0.5910326681887337
住民税課税(非課税)証明書は、(申請場所(自治体の担当課、支所・出張所等))で申請できます。申請には本人確認のための身分証明書が必要です。▼詳しくはこちら(自治体HP内関連ページのURL)
-----
A_305
0.5783871298038831
(海外勤務等の場合の取り扱いを記載してください。)例「海外勤務等で住民税が課税されない方の認可保育園保育料は、年間収入申告書(市指定様式)を提出して頂き、保育料を決定します。詳しくは、○○課(電話番号
-----
A_150
0.5716120814942951
幼稚園などで配布している申請書「(補助金申請書のタイトル)」に必要事項を明記し、幼稚園へ提出してください。当該年度に○○市の住民税が課税されていない人は、当該年度の「市区町村民税納税(税額)通知書」ま
-----
A_326
0.5695823549252746
(保育利用料軽減について記載してください。)例「保育利用料軽減金額は、世帯の住民税等により異なるため、申し訳ありませんが、具体的な金額はお答えできません。」
text = "母子手帳を受け取りたいのですが、手続きを教えてください。"   # Q_001
A_002
0.772968359983863
母子手帳は、○○市役所本庁舎△△階××課窓口、◎◎出張所、………(その他の受け取り場所を適宜記載)………で受け取れます。▼詳しくはこちら(自治体HP内関連ページのURL)
-----
A_003
0.7666359479983332
母子手帳は、妊娠届の内容を確認させていただき、その場でお渡しします。▼詳しくはこちら(自治体HP内関連ページのURL)
-----
A_108
0.7535058083471671
母子手帳をなくしたときは、再交付を受けてください。お子さんが出生前の母子手帳については、(再交付を受けられる場所)で再交付を受けられます。お子さんが出生後の母子手帳については、(再交付を受けられる場所
-----
A_450
0.7428489092288214
妊娠したら妊娠届を○○課窓口(または支所・出張所窓口)に提出し、母子手帳を受け取ってください。▼詳しくはこちら(自治体HP内関連ページのURL)
-----
A_001  # ★これが正解(ただし検索上位を見るとこの正解データよりもマッチしている感あり)
0.7213666886120362
窓口で妊娠届をご記入いただき、母子手帳をお渡しします。住民票の世帯が別の方が代理で窓口に来られる場合は、委任状が必要になります。▼詳しくはこちら(自治体HP内関連ページのURL)
-----
A_036
0.714482910149047
産前は母子手帳以外の手続きは特にありません。産後に、出生の届出や出生通知書の提出、(自治体が行う出産助成等)の申請をお願いします。
-----
A_165
0.6914882134271054
母子手帳は住所が変わってもそのままお使いいただけます。再発行等の手続は必要ありません。◆お問い合わせ(自治体の担当課や子育てセンター等の名称)(電話番号)/(開庁時間)
-----
A_274
0.6817114703963848
(手続きの説明を記載してください。)例「保育園での面接、健康診断を受けていただきます。面接時には、「支給認定証」、「母子健康手帳」をお持ちください。」
-----
A_252
0.680508975359388
夜間・休日窓口の場合、母子手帳の証明や届書の受理証明書などの発行、 子どもに関する手当・助成の受付はしていませんので、通常窓口で手続き・申請してください。▼詳しくはこちら(自治体HP内関連ページのUR
-----
A_014
0.6623347290013195
出産後に必要な手続きは出生届・出生通知票の提出、児童手当、子ども医療費助成の申請等があります。▼詳しくはこちら(自治体HP内関連ページのURL)
-----
text = "ひとり親家庭への手当・助成の種類を教えてください。"     # Q_201
A_201  # ★これが正解
0.7730284028275043
ひとり親家庭への手当・助成としては、児童扶養手当、児童育成手当、ひとり親医療費助成制度があります。▼詳しくはこちら(自治体HP内関連ページのURL)
-----
A_466
0.7142369920668331
児童扶養手当などの、ひとり親の手当・助成や、その他支援などが受けられる場合があります。▼詳しくはこちら(自治体HP内関連ページのURL)
-----
A_203
0.6861757430615038
ひとり親家庭への手当の所得制限は、それぞれの制度によって異なります。詳しくは下記をご覧ください。▼各手当所得制限限度額(自治体HP内関連ページのURL)
-----
A_207
0.6757537662516389
ひとり親手当は未婚でも受給できますが、そのほか受給条件があります。詳しくは以下のページをご覧いただくか、または(自治体の担当課等の名称)までお問い合わせください。(自治体HP内関連ページのURL)◆お
-----
A_216
0.6376144770599661
ひとり親手当は子どもと別居していても、必要書類を提出すれば受けられます。ただし、児童扶養手当は父子の別居監護についてはできかねます。
-----
A_206
0.6320580991247859
児童扶養手当(ひとり親手当)の申請は、窓口での面談が必要になります。(自治体の支所・出張所等)では申請できません。(自治体の担当課等の名称)までお越しください。
-----
A_205
0.6294295587089813
児童扶養手当(ひとり親手当)の申請は、窓口での面談が必要になります。(自治体の担当課等の名称)までお越しください。
-----
A_641
0.6154077122764182
児童扶養手当とは、離婚によるひとり親家庭などの生活の安定・自立促進に寄与することにより、その家庭において養育されている子どもの福祉増進のために支給される手当です。▼詳しくはこちら(自治体HP内関連ペー
-----
A_174
0.6145850134957446
(家事支援等を行っている場合は記載してください。)例「中学生以下の児童を扶養しているひとり親家庭を対象に家事援助者派遣事業を実施しています。家事など日常生活に支障があるとき、保護者が在宅している時間帯
-----
A_394
0.6136560845052453
児童扶養手当は、離婚によるひとり親家庭などの生活の安定・自立促進に寄与することにより、その家庭において養育されている子どもの福祉増進のために支給される手当です。▼詳しくはこちら(自治体HP内関連ページ
-----

非常に良い感じではないだろうか。

kun432kun432

余談

全然関係ないけど、LlamaIndexだとHuggingFaceEmbeddingモジュールで使える。ただし、カスタムにEmbeddingモジュールを定義することもできる。以下はFlagEmbeddingを使ったカスタムなBAAI/bge-m3のEmbeddingモデルの例。

from typing import Any, List
from FlagEmbedding import BGEM3FlagModel

from llama_index.bridge.pydantic import PrivateAttr
from llama_index.embeddings.base import BaseEmbedding

class BGE_M3Embeddings(BaseEmbedding):
    _model: BGEM3FlagModel = PrivateAttr()
    _encode_options: dict = PrivateAttr()

    def __init__(
        self,
        model_name: str = 'BAAI/bge-m3',
        use_fp16: bool = True,
        **kwargs: Any,
    ) -> None:
        self._model = BGEM3FlagModel(model_name, use_fp16=use_fp16)
        self._encode_options = kwargs
        super().__init__(**kwargs)

    @classmethod
    def class_name(cls) -> str:
        return "bgem3flag"

    async def _aget_query_embedding(self, query: str) -> List[float]:
        return self._get_query_embedding(query)

    async def _aget_text_embedding(self, text: str) -> List[float]:
        return self._get_text_embedding(text)

    def _get_query_embedding(self, query: str) -> List[float]:
        embeddings = self._model.encode([query])['dense_vecs']
        return embeddings[0].tolist()

    def _get_text_embedding(self, text: str) -> List[float]:
        embeddings = self._model.encode([text], **self._encode_options)['dense_vecs']
        return embeddings[0].tolist()

    def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
        embeddings = self._model.encode(texts, **self._encode_options)['dense_vecs']
        return [embedding.tolist() for embedding in embeddings]

テスト

import unittest
from typing import List


class TestBGE_M3Embeddings(unittest.TestCase):

    def setUp(self):
        self.model = BGE_M3Embeddings(model_name='BAAI/bge-m3', use_fp16=True, batch_size=12, max_length=8192)

    def test_instance_creation(self):
        self.assertIsInstance(self.model, BGE_M3Embeddings)

    def test_query_embedding(self):
        result = self.model._get_query_embedding("test query")
        self.assertIsInstance(result, List)
        self.assertTrue(all(isinstance(x, float) for x in result))

    def test_text_embedding(self):
        result = self.model._get_text_embedding("test text")
        self.assertIsInstance(result, List)
        self.assertTrue(all(isinstance(x, float) for x in result))

if __name__ == '__main__':
    unittest.main(argv=['first-arg-is-ignored'], exit=False)

Fetching 19 files: 100%
19/19 [00:00<00:00, 935.21it/s]
.loading existing colbert_linear and sparse_linear---------
Fetching 19 files: 100%
19/19 [00:00<00:00, 1087.30it/s]
loading existing colbert_linear and sparse_linear---------
encoding: 100%|██████████| 1/1 [00:01<00:00,  1.12s/it]
.
Fetching 19 files: 100%
19/19 [00:00<00:00, 1152.96it/s]
loading existing colbert_linear and sparse_linear---------
encoding: 100%|██████████| 1/1 [00:01<00:00,  1.09s/it]
.
----------------------------------------------------------------------
Ran 3 tests in 8.833s

OK

使うとき

embedding_model = BGE_M3Embeddings(model_name='BAAI/bge-m3', use_fp16=True, batch_size=12, max_length=8192)
kun432kun432

上と同じデータ使って、ranxで評価してみた。

2024/02/04追記: infloat/multilingual-e5-largeを追加。ただしちょっと数字が良くなくて、もしかするとテストの仕方が良くないのかもしれない、ということでmultilingual-e5-largについては参考程度に。

##### @2 #####

#    Model                   Hits@2    MAP@2     MRR@2
---  ----------------------  --------  --------  --------
a    bge-m3                  0.816ᵇᶜᵈ  0.769ᵇᶜᵈ  0.769ᵇᶜᵈ
b    multilingual-e5-large   0.739     0.690     0.690
c    text-embedding-ada-002  0.752     0.697     0.697
d    text-embedding-3-small  0.770     0.731ᶜ    0.731ᶜ
e    text-embedding-3-large  0.834ᵇᶜᵈ  0.785ᵇᶜᵈ  0.785ᵇᶜᵈ

##### @3 #####

#    Model                   Hits@3    MAP@3     MRR@3
---  ----------------------  --------  --------  --------
a    bge-m3                  0.852ᵇᶜ   0.781ᵇᶜᵈ  0.781ᵇᶜᵈ
b    multilingual-e5-large   0.784     0.705     0.705
c    text-embedding-ada-002  0.796     0.712     0.712
d    text-embedding-3-small  0.826ᵇ    0.750ᵇᶜ   0.750ᵇᶜ
e    text-embedding-3-large  0.872ᵇᶜᵈ  0.798ᵇᶜᵈ  0.798ᵇᶜᵈ

##### @5 #####

#    Model                   Hits@5    MAP@5     MRR@5
---  ----------------------  --------  --------  --------
a    bge-m3                  0.881ᵇᶜ   0.787ᵇᶜᵈ  0.787ᵇᶜᵈ
b    multilingual-e5-large   0.838     0.717     0.717
c    text-embedding-ada-002  0.849     0.724     0.724
d    text-embedding-3-small  0.872     0.760ᵇᶜ   0.760ᵇᶜ
e    text-embedding-3-large  0.905ᵇᶜᵈ  0.806ᵇᶜᵈ  0.806ᵇᶜᵈ

##### @10 #####

#    Model                   Hits@10    MAP@10    MRR@10
---  ----------------------  ---------  --------  --------
a    bge-m3                  0.912ᶜ     0.792ᵇᶜᵈ  0.792ᵇᶜᵈ
b    multilingual-e5-large   0.893      0.724     0.724
c    text-embedding-ada-002  0.885      0.729     0.729
d    text-embedding-3-small  0.905      0.765ᵇᶜ   0.765ᵇᶜ
e    text-embedding-3-large  0.932ᵇᶜᵈ   0.809ᵇᶜᵈ  0.809ᵇᶜᵈ

評価指標どれがいいのかさっぱりわからない。。。

あと今気づいたけど、text-embedding-3-largeのdimensionsいじってない、多分デフォルトの1536になってそう。ドキュメント見てみたら、largeは3072、smallは1536がデフォルトみたい。

ただこれだけでも、BAAI/bge-m3かなり良い感じに思える。text-embedding-3-largeが一番強い感はあるけども。

denseでも十分良いけど、sparseとcolbertとのハイブリッド検索でさらに良くなる(ただしcolbertはtoken数分の1024次元ベクトルが必要でデータ量が多い)

ってのも気になる。

評価用のコードは後で書く。

評価用コード

LlamaIndexでインデックス+retrieverを作成して、ranx評価データを作成。

!pip install -U llama-index typing_extensions transformers
from google.colab import drive
drive.mount('/content/drive')

ranx_data_path="/content/drive/MyDrive/ranx"
!wget https://d.line-scdn.net/stf/linecorp/ja/csr/dataset_.zip
!!unzip dataset_.zip
import pandas as pd

df = pd.read_excel("dataset_.xlsx")
df.rename(columns={
    'サンプルID': 'SampleID',
    'サンプル 問い合わせ文': 'Q',
    'サンプル 応答文': 'A',
    'カテゴリ1': 'Category',
    'カテゴリ2': 'Cat2',
    '出典': 'Ref',
    '<参考>UMカテゴリタグ': 'Tag',
    '<参考>UMサービスメニュー\n(標準的な行政サービス名称)': 'Service'
}, inplace=True)
df.drop(columns=["ID", "Cat2","Ref","Tag","Service"], inplace=True)
df["Context"] = "Q: " + df["Q"] + "\nA: " + df["A"]

max_digits = df["SampleID"].astype(str).str.len().max()
df['Q_ID'] = df['SampleID'].apply(lambda x: "Q_" + str(x).zfill(max_digits))
df['A_ID'] = df['SampleID'].apply(lambda x: "A_" + str(x).zfill(max_digits))
from google.colab import userdata
import os

os.environ["OPENAI_API_KEY"] = userdata.get('OPENAI_API_KEY')
from llama_index.schema import TextNode, NodeRelationship, RelatedNodeInfo

nodes = []
qrels = []
for idx, row in df.iterrows():
    node = TextNode(text=row["A"], id_=row["A_ID"])
    qrel = "{} 0 {} 1".format(row["Q_ID"], row["A_ID"])
    nodes.append(node)
    qrels.append(qrel)
with open(f'{ranx_data_path}/qrels.trec', 'w') as file:
    for qrel in qrels:
        file.write(f"{qrel}\n")
from tqdm.auto import tqdm
from llama_index.embeddings import HuggingFaceEmbedding, OpenAIEmbedding
from llama_index import ServiceContext, VectorStoreIndex

k = 100

contexts_indexes = {
    "text-embedding-ada-002": ServiceContext.from_defaults(embed_model=OpenAIEmbedding(model="text-embedding-ada-002")),
    "text-embedding-3-small": ServiceContext.from_defaults(embed_model=OpenAIEmbedding(model="text-embedding-3-small")),
    "text-embedding-3-large": ServiceContext.from_defaults(embed_model=OpenAIEmbedding(model="text-embedding-3-large")),
    "bge-m3": ServiceContext.from_defaults(embed_model=HuggingFaceEmbedding(model_name='BAAI/bge-m3', embed_batch_size=12, max_length=8192)),
    "multilingual-e5-large": ServiceContext.from_defaults(embed_model=HuggingFaceEmbedding(model_name="intfloat/multilingual-e5-large")),
}

indexes = {model_name: VectorStoreIndex(nodes, service_context=context, show_progress=True) for model_name, context in contexts_indexes.items()}

runs = {model_name: [] for model_name in contexts_indexes}

for model_name, index in indexes.items():
    for idx, row in tqdm(df.iterrows(), total=df.shape[0], desc=f"Processing {model_name}"):
        query = row["Q"]
        if model_name == "multilingual-e5-large":
            # multilingual-e5の場合は"query: "プレフィックスを付ける。データ側には"passage: "プレフィックスを付ける
            #  べきだが、自分の検証では"query"プレフィクスあり+passageプレフィクスなしが最もスコアが良かったので
            # クエリだけにしている
            query = "query: " + row["Q"]
        q_id = row["Q_ID"]
        for r_idx, r in enumerate(index.as_retriever(similarity_top_k=k).retrieve(query), start=1):
            a_rank = r_idx
            a_id = r.id_
            a_score = r.get_score()
            run = "{} 0 {} {} {} {}".format(q_id, a_id, a_rank, a_score, model_name)
            runs[model_name].append(run)

各Embeddingモデルごとにranx用テスト結果データを作成

for model_name, model_runs in runs.items():
    file_path = f'{ranx_data_path}/run_{model_name}.trec'
    with open(file_path, 'w') as file:
        for run in model_runs:
            file.write(f"{run}\n")

ranxで評価

!pip install ranx
from google.colab import drive
drive.mount('/content/drive')

ranx_data_path="/content/drive/MyDrive/ranx"
from ranx import Qrels, Run, evaluate, compare

qrels = Qrels.from_file(f"{ranx_data_path}/qrels.trec", kind="trec")
run_bge_m3 = Run.from_file(f"{ranx_data_path}/run_bge-m3.trec", kind="trec")
run_ada002 = Run.from_file(f"{ranx_data_path}/run_text-embedding-ada-002.trec", kind="trec")
run_3lg = Run.from_file(f"{ranx_data_path}/run_text-embedding-3-small.trec", kind="trec")
run_3sm = Run.from_file(f"{ranx_data_path}/run_text-embedding-3-large.trec", kind="trec")
run_me5l = Run.from_file(f"{ranx_data_path}/run_multilingual-e5-large.trec", kind="trec")

for k in [2, 3, 5, 10]:
    print(f"##### @{k} #####")
    print()
    report = compare(
        qrels,
        runs=[run_bge_m3, run_me5l, run_ada002, run_3lg , run_3sm],
        metrics=[f"hits@{k}", f"map@{k}", f"mrr@{k}", f"map@{k}"],
        max_p=0.01,  # P-value threshold
    )
    print(report)
    print()
このスクラップは2024/02/04にクローズされました