🌊

ベクトル検索ライブラリFaissで自然言語の近傍検索を試す

2024/02/12に公開

LangChainの勉強で出てきたfaissについて、理解を深めるため色々試してみました。
https://python.langchain.com/docs/get_started/quickstart

こちらの記事の内容が古くなっていたので、2024年2月時点で動くコードを作成しています
https://note.com/npaka/n/nb766e344a4fc

環境設定

google colaboratoryで動かしています。

  • 2024年2月現在、pythonの最新バージョンは3.12であるが、faissは3.11でしか動作しない。
    そのため、現在は3.11で動作を確認している
  • デフォルトは3.10なので、以下コマンドで3.11をインストールする
!sudo apt install python3.11

!sudo update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.11 1
  • 3.11に設定したらデフォルトのpipが使えなくなるので、改めて現在のpythonバージョンに合うpipをインストールする必要がある
# 現在のバージョンに紐づいているpipをインストール
!curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
!python get-pip.py

openaiのパッケージインストール

# パッケージのインストール
!pip install openai

# クライアントの有効化
from openai import OpenAI

client = OpenAI(
    api_key= "************"
)

Embedding作成を試す

  • openAIを使えば、既に学習されたモデルを用いて、テキストからベクトルを作成できる
texts = ["これはテストです。"]
response = client.embeddings.create(
    input=texts, model="text-embedding-ada-002"
) 
  • responseはCreateEmbeddingResponseオブジェクト
  • データの中身はresponse.dataにList<Embedding>型で格納されている(渡したtextの数だけdataがある。今回は一つだけなので、dataは1つ)
for e in response.data:
  print(e.embedding)

# [-0.0019862281624227762, -0.011034601368010044, -0.0024470852222293615。。。
  • ちなみにembeddingの数は1536個。これはテキストがどんなに長くても短くても変わらない。
  • openai的には、あらゆる文章を1536次元で考えてるってこと。人間だと、(面白さ、明るさ、奇抜さ・・・)みたいな軸で考えるでしょうけど、これがAIだt1535要素ある、みたいなものかな。そう考えるとかなり多いですね。

faissで検索してみる

インデックスの作成

import faiss
import numpy as np

# インデックスの生成
index = faiss.IndexFlatL2(1536)
  • IndexFlatL2は探し方の一つで、総当たり検索ってこと。他にもコサイン類似度や多少の見逃しを許容する代わりに高速化を目指したIndexIVFFlatなど、いくつかある。

対象データを作成

  • テキストを配列で用意する
  • それぞれのテキストをEmbeddingした1536個の数値配列をnumpyの配列にして、indexに放り込む
# 対象テキストの作成
target_texts = [
    "好きな食べ物は何ですか?",
    "どこにお住まいですか?",
    "朝の電車は混みますね",
    "今日は良いお天気ですね",
    "最近景気悪いですね"
]

target_res = client.embeddings.create(
    input=target_texts, model="text-embedding-ada-002"
)

target_embeds = [e.embedding for e in target_res.data]

target_embedds_array = np.array(target_embeds).astype('float32')

# インデックスに追加する
index.add(target_embedds_array)
  • ポイント。np.arrayに渡すデータは、list<list<float>>の二次元配列。
[[ 0.00800769 -0.03133659  0.00396046 ... -0.00316403  0.01176981
  -0.00809446]
 [ 0.01001792  0.00272776 -0.00045281 ... -0.00241392 -0.00255445
  -0.01284093]
 [-0.00152516  0.0020203   0.00113932 ... -0.00600459  0.0035703
  -0.01050556]
 [-0.00107278  0.00933294 -0.0069663  ... -0.00970193  0.00934566
  -0.00994368]
 [-0.03182344 -0.01483458  0.01383242 ...  0.00432902  0.02192463
  -0.00989199]]

クエリの作成

  • 検索用のテキストも、embedした後でnumpy配列にします。こうすることで、faissで近傍検索ができるようになります
input_texts = ["大阪に住んでいます"]

input_res = client.embeddings.create(
    input=input_texts, model="text-embedding-ada-002"
)

input_embeds = [e.embedding for e in input_res.data]

query = np.array(input_embeds).astype('float32')

検索の実行と結果

  • 近傍探索の実行。結果はタプルで返され、ひとつ目はdistanceを、二つ目はIndicesを示す
# 結果は一つだけ返して欲しいので、1を渡している
D, I = index.search(query, 1)

# 確認
print(D)
print(I)
  • 結果としては、以下のようにD(distance)が距離についての配列を、I(Indecies)がベクトルのインデックスの配列となっている。
    (今回は検索用に一つだけのtextを渡したので、DistanceもIndeciesもひとつだけ)
[[0.23695067]]
[[1]]
  • faissから返される結果は"1"という数値のみ。(faissに渡したのはベクトルデータだけで、テキストデータは渡してないから当然です)
    これだとわかりにくいので、1番近かったのはどのテキストかを調べるには、もとのtarget_textsをから引っ張ってくる。
result_text = target_texts(I[0][0])

faissのデータ追加、削除

追加

  • index.addでデータを追加することができます
# 新しいテキストデータ
new_texts = ['ペットの名前は何ですか?']

# 新しいテキストデータの埋め込みを取得
# ここで client は OpenAI API のクライアントインスタンスと仮定
new_target_res = client.embeddings.create(
    input = new_texts,
    model="text-embedding-ada-002"
)

# 新しいテキストデータの埋め込みベクトルをNumPy配列に変換
new_embeddings_array = np.array([e.embedding for e in new_target_res.data]).astype('float32')

# 新しい埋め込みベクトルをfaissインデックスに追加
index.add(new_embeddings_array)

# 追加後の総ベクトル数を表示
print(index.ntotal)

削除はできないので、「無視」

  • faissではいったん与えたデータを物理的に削除することはできない。
  • もととなるnumpy配列から特定のデータを削除し、再度faissのindexを再生成する、、、ことはできるが、たぶん求めている答えではないでしょう
クリックで展開
import faiss
import numpy as np

# 元のデータリストから6番目のデータを除外
# Pythonのインデックスは0から始まるため、5を指定して6番目の要素を削除
filtered_embeddings = np.delete(embeddings_array, 5, axis=0)

# 新しいインデックスを作成
new_index = faiss.IndexFlatL2(1536)

# 修正されたデータリストを新しいインデックスに追加
new_index.add(filtered_embeddings)

# 追加後の総ベクトル数を表示
print(new_index.ntotal)
  • faiss.IDSelectorArrayを使用することで、特定のIDを持つベクトルを選択的に無視することで、あたかも削除したかのように振る舞わせることができる
import faiss
import numpy as np

# インデックスの生成とデータの追加
index = faiss.IndexFlatL2(1536)
index.add(embeddings_array)

# 検索クエリの生成(ここでは簡単のため embeddings_array の最初のベクトルを使用)
query = embeddings_array[0:1]

# IDSelectorArray を使用して特定のID(この例ではID 5)を無視する
ids_to_ignore = np.array([5], dtype=np.int64)  # 6番目のデータのIDを無視
selector = faiss.IDSelectorArray(ids_to_ignore)

# 検索の実行時にselectorを使用
D, I = index.search(query, 1, selector)

# 検索結果の表示
print("Distances:", D)
print("Indices:", I)

もうちょっと完成度を高めてみる

  • 先ほどのコードをひとつのファイルとして実行できるようにしました。
import sqlite3

client = OpenAI(api_key= "*****")

index = faiss.IndexFlatL2(1536)

initial_greetings = [
      "好きな食べ物は何ですか?",
      "どこにお住まいですか?",
      "朝の電車は混みますね",
      "今日は良いお天気ですね",
      "最近景気悪いですね"
    ]

# 0. embed配列を生成するための関数
def create_embeds_array(texts):
    res = client.embeddings.create(
        input=texts, model = "text-embedding-3-small"
    )
    embeds = [e.embedding for e in res.data]
    return np.array(embeds).astype('float32')


def initialize():
    # 1. SQLiteデータベースの準備
    conn = sqlite3.connect("aisatsu.db")
    cursor = conn.cursor()
    cursor.execute("CREATE TABLE IF NOT EXISTS aisatsu (id INTEGER PRIMARY KEY, text TEXT)")
    cursor.execute("DELETE FROM aisatsu")

    # 2. SQLiteに初期データの登録
    cursor.executemany("INSERT INTO aisatsu (text) VALUES (?)", [(g,) for g in initial_greetings])
    conn.commit()

    # 3. faissに初期データを登録
    cursor.execute("SELECT id, text FROM aisatsu")
    greeting_data = cursor.fetchall()
    greeting_ids, greeting_texts = zip(*greeting_data)
    embeds_array = create_embeds_array(greeting_texts)
    index.add(embeds_array)
    
def search_greeting(text):
    embed_array = create_embeds_array([text])
    D,I = index.search(embed_array,2)
    distance = D[0][0]
    id = I[0][0] + 1 # sqliteにidを登録する時は1オリジンなので、index.searchの結果の0オリジンに1を足す必要がある

    conn = sqlite3.connect("aisatsu.db")
    cursor = conn.cursor()
    cursor.execute(f'SELECT text FROM aisatsu WHERE id = {id}')
    result = cursor.fetchone() #fetchoneはクエリ結果の最初の行を返します
    return (result[0], distance) if result else 'データが見つかりませんでした'


initialize()
res = search_greeting('大阪に住んでいます')
print(res)

Discussion