👯♀️
BERTとfaissで類似文検索してみる
めっちゃ簡単。出てくる結果はこんな感じです。
#input
index_text = ["私は犬です。", "私は犬が好きです。", "昨日の犬がかわいかった。", "犬と猫だとどっちが強い?"]
query_text = "犬はとてもかわいい"
#output_ranked
昨日の犬がかわいかった。
犬と猫だとどっちが強い?
私は犬です。
私は犬が好きです。
犬と猫だとどっちが強い?
使ったライブラリとか
- transformers:事前学習済みBERTを利用するためのライブラリ
- faiss:近傍探索ライブラリ
手順
こんな感じです。
- 検索インデックス用の文章の文ベクトルを獲得
- ベクトルをインデックスに追加(検索インデックスの作成)
- クエリとなる文のベクトル表現を獲得
- 検索
順番に説明していきます。
1. 文ベクトルの獲得
モデルは東北大BERTを利用しています。
import torch
from torch.utils.data import DataLoader
from transformers import AutoModel, AutoTokenizer
from transformers.tokenization_utils import BatchEncoding
import numpy as np
import torch.nn.functional as F
from tqdm import tqdm
device="cuda:0"
model_name="cl-tohoku/bert-base-japanese-whole-word-masking"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)
model.eval().to(device)
index_text = ["私は犬です。", "私は犬が好きです。", "昨日の犬がかわいかった。", "犬と猫だとどっちが強い?"]
def collate_fn(batch: List[str]) -> BatchEncoding:
return tokenizer(
batch,
padding=True,
truncation=True,
return_tensors="pt",
max_length=512,
)
dl = DataLoader(index_text, collate_fn=collate_fn, batch_size=16, num_workers=2)
vec_list=[]
for batch in tqdm(dl):
output = model(**batch.to(device))
vec_list.append(output["pooler_output"].cpu())
df_ex=pd.DataFrame({"text":index_text})
df_ex.to_json( "exemplars.jsonl",
orient="records",
force_ascii=False,
lines=True,
default_handler=str
)
vec_array = F.normalize(torch.cat(vec_list, dim=0))
文のベクトル表現はモデルの出力のうち"pooler_output"から獲得しています。
コサイン類似度を利用して類似度を計算したいのですが、faissのIndexFlatIPではベクトル同士の内積を距離関数としているため、計算前にベクトルに対し正規化を行っています。(これをしないとベクトルの大きさに依存して類似度が計算されてしまう)
2. 検索インデックスの作成
faissには様々な検索インデックスが用意されていますが、ここではあまり大きなデータを扱わないことから総当たり的な検索を行うIndexFlatIPというメソッドを利用してみました。
先ほど獲得した文ベクトルたちを突っ込むことで検索インデックスの作成が完了します。
index = faiss.IndexFlatIP(vec_array.shape[1])
index.add(vec_array)
3. クエリとなる文のベクトル表現を獲得
手順1と同じことをしています
emb = tokenizer(
"犬はとてもかわいい",
padding=True,
truncation=True,
return_tensors="pt",
max_length=512,
)
with torch.inference_mode():
output=F.normalize(model(**emb.to(device))["pooler_output"].cpu())
4. 検索
num_of_preds=4
D, I = index.search(output.detach().numpy(), num_of_preds)
preds=I[0]
for id, pred in enumerate(preds):
print(df.to_dict("records")[pred]["text"])
print(D[0][id])
num_of_predsで上位何例を検索するか決定しています。検索結果のテキストの他に類似度Dも表示してみました。
結果
以上の結果がこのようになります。
昨日の犬がかわいかった。
0.767721
犬と猫だとどっちが強い?
0.7617792
私は犬です。
0.60839754
私は犬が好きです。
0.5737625
犬がかわいいという同じ主張をしている文が1位となりました。しかし類似度を見てみると2位の「犬と猫だとどっちが強い?」という全く違う文意のものとかなり近い値となっています。
おわり
今回は簡単に動かせることを伝えるという点に重きを置くため、かなり短く、また数も少ないデータで試しています。他のデータやモデルを利用してみるとまた違った結果が見れて面白そうだなと思っています。
Discussion