Embeddingsを使ってローカルでテキストをクラスタリングする(Multilingual-E5)
EmbeddingsとSentence Transformers
Sentence Transformersは、テキストをEmbeddings(埋め込み)と呼ばれるベクトル表現に変換するためのライブラリです。OpenAIの "text-embedding-ada002" も、Embeddingsを生成するモデルです。 テキストの意味をベクトルで表現すると、コサイン類似度などで意味の類似度が簡単に計算できるため、下記のようなタスクが容易になります。
- テキストの類似度算出
- 分類(Classifying)
- クラスタリング
- セマンティック検索(意味に基づいた検索)
今回は、ローカルで動作させることができる "Multilingual-E5" というモデルを使って、短いテキストを分類してみます。
このモデルは、Leaderboradでも好成績を収めています。 largeモデルは、"text-embedding-ada002"よりハイスコアです。ただし扱えるシーケンス長は512トークンまでで、ada(8192)より短くなります。また、多言語モデルなので、言語が違っても、翻訳することなく意味の類似度などを計算できます。環境
Windows 11, RTX3060 12GB
Python 3.10.11
sentence-transformers 2.2.2
torch 2.0.1+cu118
事前準備
パッケージのインストール
GPUを使う場合は事前にPytorchをインストールします。
# pip install torch --index-url https://download.pytorch.org/whl/cu118
pip install sentence_transformers
ダミーデータの作成
今回はPCに関するクレームを短文で表したものを用意します。GPT-4に作成してもらいました。
(n = 75)
# data/claim.txt
ディスプレイがつかない。
画面が真っ暗です。
モニターのライトが点かない。
PCの画面が起動しない。
スクリーンが点灯しません。
ディスプレイに何も映らない。
モニターが真っ黒のままです。
画面が反応しません。
電源は入るが、画面だけ映らない。
スクリーンが暗いままです。
...
Embeddingsの作成
データの読み込み
import pandas as pd
df = pd.read_table("data/claim.txt", header=None, names=["claim"])
df.head()
claim | |
---|---|
0 | ディスプレイがつかない。 |
1 | 画面が真っ暗です。 |
2 | モニターのライトが点かない。 |
3 | PCの画面が起動しない。 |
4 | スクリーンが点灯しません。 |
. | ... |
モデルの指定
PCのスペックやデータ数に応じてモデルを選択します。small, base, largeの3つのモデルが提供されています。Embeddingsの処理は軽量なので、smallなら低スペックのノートPCでも高速に動作します。GPUがあるならlargeでも余裕で動くと思います。
# model_name = "intfloat/multilingual-e5-small"
# model_name = "intfloat/multilingual-e5-base"
model_name = "intfloat/multilingual-e5-large"
# モデルの読み込み
from sentence_transformers import SentenceTransformer
model = SentenceTransformer(model_name)
参考
今回の環境では、smallとbaseでベクトル化の時間に差はありませんでした。
モデルはパフォーマンス要件と性能のトレードオフで選択すると良いでしょう。
モデル | モデルサイズ | 次元 | ベクトル化にかかった時間※ |
---|---|---|---|
small | 471MB | 384 | 8.5s |
base | 1.11GB | 768 | 8.5s |
large | 2.24GB | 1024 | 16.0s |
※環境
GPU: RTX3060 12GB
データ数: 1000
ベクトル化
テキストをmodel.encode()
に渡すだけでベクトル化できます。超簡単。
df["vector"] = df["claim"].apply(model.encode)
df.head()
claim | vector | |
---|---|---|
0 | ディスプレイがつかない。 | [0.030703284, 0.0027665994, -0.02013304, -0.06... |
1 | 画面が真っ暗です。 | [0.030521175, 0.017222878, -0.012971448, -0.05... |
2 | モニターのライトが点かない。 | [0.029229984, 0.0129130315, -0.013014274, -0.0... |
3 | PCの画面が起動しない。 | [0.040510327, 0.003165508, -0.021918021, -0.05... |
4 | スクリーンが点灯しません。 | [0.03320634, -0.0026000645, -0.025136635, -0.0... |
. | ... | ... |
# ベクトルの次元数
print(len(df.at[0, "vector"]))
# 1024
1024次元のベクトルデータが生成されていることがわかります。
クラスタリング
今回は、一般的なK-Means法を使います。
シルエット分析
最適なクラスター数を探る方法はいくつかあるようです。今回はシルエット分析を使いました。
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
vectors = df["vector"].tolist()
# 探索するクラスタ数の範囲を設定
range_n_clusters = range(2, 30) # 例: 2から30まで
best_score = -1
best_k = None
scores = [] # シルエットスコアを保存するためのリスト
for k in range_n_clusters:
kmeans = KMeans(n_clusters=k, n_init="auto", random_state=0)
labels = kmeans.fit_predict(vectors)
# シルエットスコアを計算
score = silhouette_score(vectors, labels)
scores.append(score)
# ベストスコアの更新
if score > best_score:
best_score = score
best_k = k
print("Best k:", best_k)
print("Best silhouette score:", best_score)
# Best k: 22
# Best silhouette score: 0.077055216
# シルエットスコアのプロット
import matplotlib.pyplot as plt
plt.figure(figsize=(10,6))
plt.plot(range_n_clusters, scores, marker='o')
plt.xlabel('Number of clusters')
plt.ylabel('Silhouette Score')
plt.title('Silhouette Score for Different Number of Clusters')
plt.show()
クラスタリング
n = 75 なので、k = 22だとクラスタリングする意味が薄いです。シルエットスコアを見ながらクラスター数を探り、今回は10に決定しました。答えはないと思うので、データの性質や業務要件によって選択するしかないでしょう。
色々試してみた感想では、人間の感覚の倍くらいのクラスター数にするといい感じになる気がします。それより少ないと、意味の違う文が混ざってしまう印象でした。
# k値の指定
k = 10
# クラスタリング
kmeans = KMeans(n_clusters=k, n_init="auto", random_state=0)
df["label"] = kmeans.fit_predict(vectors)
# 保存
df.sort_values("label").drop(columns="vector").to_csv("data/output.csv", index=None)
claim | label |
---|---|
ビデオ再生がカクカクする。 | 0 |
インターネットのアイコンにエクスクラメーションマークがついてる。 | 0 |
ソフトがスムーズに動作しない。 | 1 |
アプリケーションが遅く開く。 | 1 |
処理速度がかなり落ちてきた。 | 1 |
突然の遅延が頻繁に発生。 | 1 |
タスクの完了が遅くなった。 | 1 |
起動に時間がかかる。 | 1 |
キーボードエラーが表示される。 | 2 |
接続エラーが出る。 | 2 |
DNSエラーが頻繁に発生。 | 2 |
ネットワークトラブルのポップアップが出る。 | 2 |
ノートPCのキーボードが故障してる? | 3 |
キータイプしても文字が入力されない。 | 3 |
キーの一部が使えない。 | 3 |
キーボードがフリーズしている。 | 3 |
文字入力ができない状態。 | 3 |
キーボードからの入力が受け付けられない。 | 3 |
特定のキーだけ動かない。 | 3 |
USBキーボードも認識しない。 | 3 |
キーボードのライトは点いてるけど動かない。 | 3 |
キータッチ音はするのに、文字が出ない。 | 3 |
キーボードのドライバトラブルか、反応が悪い。 | 3 |
キーボードが動作しない。 | 3 |
少し怪しいところもありますが、それなりにうまく行っています。
推論
sklearn.cluster.KMeans.predict()
を使うと、新しいテキストデータのラベルを推論できます。
new_claims = ["動画再生がカクカクする。",
"PCの起動が遅い",
"キーボードが効かない"]
new_vectors = list(map(model.encode, new_claims))
print(kmeans.predict(new_vectors))
# [0 1 3]
色々応用ができそうですね!
Discussion