↖︎

ついに来た Firebaseでベクトル検索

2024/04/21に公開

はじめに

Firebaseでベクトル検索ができるようになったとのことなので、試してみました!

そもそもベクトル検索って何ってところからGPTに質問しながらなんとかできた感じなので、解釈が間違っていたり説明が不十分な箇所もあるかと思いますが、こんなことができるんだという参考になれば幸いです。

今回やりたいこと

ベクトル検索ができると聞き、真っ先に思い浮かんだのが去年開発したポケモン対戦記録アプリの類似対戦検索機能です。

「過去の対戦から今回の対戦相手のパーティと相手のパーティが似ている順にソートして取得する」です。

理由は説明すると長いので、興味のある人だけ読んでみてください。

なぜ似ている順にソートして取得したいか

ポケモン対戦において最もよく遊ばれているルールの1つに相手と自分のパーティ6匹を見せ合って、実際に戦闘に繰り出す3匹を決めて3vs3で戦う6→3というルールがあります。
この時3匹を選ぶことを選出と言うのですが、ポケモン対戦では構築、選出、プレイングの3大要素の1つに数えられるとても重要な要素です。

選出する際には相手の6匹をみて、その6匹にできるだけ強い3匹を選ぶことになります。
また1匹目に選ぶポケモンはより重要で、選出した段階でほぼ勝負が決まるみたいなこともありえます。

そのためポケモン対戦記録アプリでは対戦相手のパーティを記録しておくことで、次に似たようなパーティとマッチした際に、前回は似たようなパーティにこんな選出をして、勝ったとか、負けたみたいなメモを瞬時に引き出して参考にするために、似ているパーティをFirestoreから引っ張ってきたいです。

しかし、Firestoreのクエリはあまり強くないため、これを今までは実現できませんでした。(少なくとも自分は色々考えたのですが、無理でした。)そのため、過去n件のデータを取得して、それをクライアントでソートすると言うなんとも微妙というか、なんならタブーくらいの実装をしてました。

これがベクトル検索で実現できそうな気がしたので、今回ベクトル検索機能の追加に飛びついているわけです。

似ているパーティとは

始めに似ているパーティについて定義をしておきます。

2つのパーティを比べた時に同じポケモンの数を類似度とします。
パーティには6匹ポケモンがいるため、類似度は0以上6以下の7段階で表せます。
つまり類似度が最大のパーティが似ていて類似度が0のパーティが最も似ていないと言えます。

ベクトル検索とは

データを多次元ベクトル化して、保存しておくことでそのデータの意味的な特徴をベクトル計算を使ってデータ間の類似度を測るもののようです。

例えばテキストデータでは単語ごとに意味を持ったベクトルに変換して似た単語はベクトルも近いのでうまく計算すると似たような単語を引っ張れるので、全文検索などに使われているようです。

が、今回はそこまで難しいことはせずに類似度順にソートして取得するということをしてみます。
そのため、ベクトル検索の真価にはあまり触れられないかもしれませんが、ご了承ください。

Firestoreでベクトル検索をする

https://cloud.google.com/firestore/docs/vector-search?hl=ja
2024/04/21現在Vector検索はプレビュー機能となっており、Cloud Functionsからのみ利用できる形になっています。

Firestoreでベクトル検索をする手順は3ステップです。

  1. ベクトル値を保存する。
  2. KNNベクトル インデックスを作成する。
  3. ベクトル距離関数を用いてKNNクエリを実行する。

1については、全文検索などをする場合は自然言語をベクトルに変換するロジックを用いて適切なベクトルに変換するそうです。今回は簡単なロジックでベクトルを用意します。

今回のベクトル定義

扱うポケモンの種類次元のベクトルを用意して、パーティにおけるポケモンの出現頻度を各次元の要素とします。

ちょっと、何言っているかわからないですね。わかりやすくしましょう。

例えば、簡単にするためにポケモンの種類が9種類だとします。

[フシギダネ,フシギソウ,フシギバナ,ヒトカゲ,リザード,リザードン,ゼニガメ,カメール,カメックス]

その場合用意するベクトルは9次元となります。

次にとあるパーティについて考えます。パーティは6匹なので例えば下記のような6匹について考えます。

[フシギダネ,フシギソウ,フシギバナ,ヒトカゲ,リザード,リザードン]

すると、それぞれの次元は出現頻度なので

[1,1,1,1,1,1,0,0,0]

のようなベクトルで表すことができます。

また、別のパーティも考えてみます。

[フシギバナ,ヒトカゲ,リザード,リザードン,カメール,カメックス]

この場合は

[0,0,1,1,1,1,0,1,1]

と表せます。

これらのベクトルについて何らかの方法でベクトル距離を計算します。
現在Firestoreのベクトル検索で利用できるのはEUCLIDEAN、COSINE、DOT_PRODUCTの3種類のようです。

今回はDOT_PRODUCT(内積)を利用しようと思います。

内積はA・B = |A||B|cosθ = A1B1 + A2B2と習いました。

今回は

A = [1,1,1,1,1,1,0,0,0]
B = [0,0,1,1,1,1,0,1,1]

と考えると

A・B = (1 × 0) + (1 × 0) + (1 × 1) + (1 × 1) + (1 × 1) + (1 × 1) + (0 × 0) +  (0 × 1) +  (0 × 1)

と表せて、

A・B = 0 + 0 + 1 + 1 + 1 + 1 + 0 + 0 + 0 = 4

となります。
数式を見て貰えばわかるように両方のパーティに含まれているポケモンの数だけ加算され、片方にしか含まれないポケモンの数は0をかけて打ち消されています。

また両方のパーティに入っているのはフシギバナ、ヒトカゲ、リザード、リザードンの4匹であることから内積の4という数字と先ほど定義した類似度が同じになりそうなこともわかります。

ベクトル値を計算する

では、実際に前述したベクトルを計算してみようと思います。今回はTypeScriptで実装しました。
ポケモンに1から1025のidを振って、パーティを表現したものがpokemonIdsです。

ベクトルの算出
const convertToVector = (pokemonIds: number[]): number[] => {
    // ポケモンの総数に基づく1025次元のベクトルを作成(初期値は0)
    const vector: number[] = new Array(1025).fill(0);

    // ポケモンIDの配列をループして、対応するインデックスの値を1に設定
    pokemonIds.forEach(id => {
        if (id >= 1 && id <= 1025) {
            vector[id - 1] = 1;  // IDが1から始まるので、インデックスはid - 1とする
        }
    });

    return vector;
}

pokemonIdsは最大でも要素が6の配列なので、forEachは最大でも6周しかしません。
また、ポケモンの出現頻度で表そうと考えていましたが、重複したポケモンを数え上げなくてもいいかなと思ったので、vector[id - 1] = 1;とし、1匹でも含まれていたら、1とするようにしました。
これで、6次元で表せていたパーティを1025次元のベクトルに変換することができました。

ベクトル値を保存する

まず、ベクトル検索を利用できるように@google-cloud/firestoreを最新にします。執筆時点では最新は7.6.0でした。

次に必要なライブラリをimportします。

import * as functions from "firebase-functions";
import * as admin from "firebase-admin";
import { FieldValue } from "@google-cloud/firestore";

今回はFirebaseFunctions経由でFirebaseへの書き込みをするので、functions.https.onCallでfunctionsを作成します。

functions名はsetBattleとしました。ちょっと長いので折りたたんでおきます。

setBattle
battle.ts
export const setBattle = functions.https.onCall(async (data, context) => {
    // 認証チェック
    if (!context.auth || context.auth.uid !== data.userId) {
        throw new functions.https.HttpsError("unauthenticated", "The function must be called while authenticated.");
    }

    // パラメータチェック
    const missingParams = checkBattleParams(data);
    if (missingParams.length > 0) {
        throw new functions.https.HttpsError("invalid-argument", `Missing or invalid parameters: ${missingParams.join(", ")}`);
    }

    const userId = data.userId;
    const opponentPartyIds = data.opponentPartyIds;

    try {
        const db = admin.firestore();
        const battleDoc = db.collection(`user/${userId}/battle`).doc();

        const battleData = {
            userId,
            partyId: data.partyId,
            battleId: battleDoc.id,
            opponentParty: data.opponentParty,
            myParty: data.myParty,
            opponentOrder: data.opponentOrder,
            myOrder: data.myOrder,
            memo: data.memo,
            eachMemo: data.eachMemo,
            result: data.result,
            createdAt: FieldValue.serverTimestamp(),
            embedding_field: FieldValue.vector(convertToVector(opponentPartyIds)),
        };

        await battleDoc.set(battleData);
        return { success: true };
    } catch (error) {
        console.error("Error writing battle:", error);
        throw new functions.https.HttpsError("unknown", "Failed to set battle", error);
    }
});

AdminSDKを用いたFirestoreへの書き込みはFirestoreルールを突破できてしまうため、認証チェックとパラメータのチェックを最初に行っています。
そして、先ほど定義したconvertToVectorを用いてembedding_fieldというパラメータにベクトルを保存しています。この際にfirestoreのFieldValueに含まれている.vectorを利用しています。
そのほかは特に変わったことはしていないです。

次にFlutterアプリからこのメソッドを呼び出します。
呼び出し側は特に何の変哲もありませんが、こんな感じになります。

setBattleの呼び出し
 Future<void> setBattle({
    required String userId,
    required String partyId,
    required List<String> opponentParty,
    required List<String> myParty,
    required List<int> opponentOrder,
    required List<int> myOrder,
    required String memo,
    required Map<String, String> eachMemo,
    required String result,
    required List<int> opponentPartyIds,
  }) async {
    final functions = FirebaseFunctions.instance;
    final callable = functions.httpsCallable('setBattle');
    try {
      final response = await callable.call({
        'userId': userId,
        'partyId': partyId,
        'opponentParty': opponentParty,
        'myParty': myParty,
        'opponentOrder': opponentOrder,
        'myOrder': myOrder,
        'memo': memo,
        'eachMemo': eachMemo,
        'result': result,
        'opponentPartyIds': opponentPartyIds,
      });
      print('Function returned: ${response.data}');
    } catch (e) {
      print('Caught Firebase Functions Exception:');
      print(e);
    }
  }

エラーハンドリングは握り潰しちゃっているので、ちゃんとハンドリングする必要はありますがベクトル計算と書き込みはfunctions側に寄せているので、クライアント側は何にも考えなくていいのが嬉しいですね。

実際にFirebaseEmulatorを利用して書き込んでみるとこんな感じで書き込まれました。
※FirebaseEmulatorだとまだ対応されてないかもと思っていましたが、そんなことはなかったです。

KNNベクトルインデックスを作成する

次にベクトルインデックスを構築する必要があります。
これはGoogleCloudCLIを利用して行います。GoogleCloudCLIのインストールはこちらからします。

もし初めてGoogleCloudCLIを利用する場合はログインやプロジェクトの設定が必要になります。

GoogleCloudCLIの初期設定

まずはGoogleCloudCLIにログインをします。

gcloud auth login

次に自分のプロジェクト一覧を見てみましょう。

gcloud projects list

次にgcloudで利用するプロジェクトを指定します。

gcloud config set project `project-id`

project-idは先ほどのプロジェクト一覧から確認できます。

設定ができたか確認してみます。

gcloud config list

project=project-idとなっていれば設定完了です。

公式ドキュメントを参考にGoogleCloudCLIを利用してインデックスを貼ってあげます。
https://firebase.google.com/docs/firestore/vector-search#create_a_single-field_vector_index

Create a single-field vector index
gcloud alpha firestore indexes composite create \
--collection-group="battle" \
--database="(default)" \
--query-scope=COLLECTION \
--field-config field-path=embedding_field,vector-config='{"dimension":"1025", "flat": "{}"}'

各種パラメータは私の環境での値になっているので下記を参考に適宜読み替えてください。

  • collection-groupは対象となるCollectionのIdを指定します。
  • databaseは基本的には(default)になっていることが多いと思います。
  • field-pathには先ほどvectorを指定したembedding_fieldを指定しました。
  • vector-configには今回扱うベクトルの次元の1025を指定しました。flatが何を意味しているかわかりませんが、ドキュメントにThe index type must be flat.と合ったので素直に従いました。

実際に実行してみるとインデックスが構築され、Firebaseのコンソールから確認することができました!

ベクトル距離関数を用いてKNNクエリを実行する

最後に実際に類似度順に対戦履歴を取得してみましょう。
対戦履歴の保存と同様に取得の際もfunctions経由で行います。
対象のCollectionReferenceに対して、findNearestでKNNクエリを記述します。
実際のクエリの部分はこんな感じです。

battle.ts
const snapshot = await battlesRef.findNearest("embedding_field", FieldValue.vector(convertToVector(data.opponentPartyIds)), {
    limit: 10,
    distanceMeasure: "DOT_PRODUCT"
}).get();

取得の際には、比較したいパーティのベクトルが必要です。そのため、functionsでopponentIdsをパラメータに加えており、先ほど定義したconvertToVectorメソッドでベクトル化しています。
limitは取得するDocumentの数です。
distanceMeasureはKNNクエリで利用するベクトル距離の種類を指定します。今回は前述した通りDOT_PRODUCT(内積)を指定しています。

取得関数の全容
battle.ts
export const fetchBattles = functions.https.onCall(async (data, context) => {
    // 認証チェック
    if (!context.auth) {
        throw new functions.https.HttpsError("unauthenticated", "The function must be called while authenticated.");
    }

    try {

        // パラメータチェック
        const missingParams = checkFetchBattlesParams(data);
        if (missingParams.length > 0) {
            throw new functions.https.HttpsError("invalid-argument", `Missing or invalid parameters: ${missingParams.join(", ")}`);
        }

        const db = admin.firestore();
        const battlesRef = db.collection(`user/${data.userId}/battle`);

        const snapshot = await battlesRef.findNearest("embedding_field", FieldValue.vector(convertToVector(data.opponentPartyIds)), {
            limit: 10,
            distanceMeasure: "COSINE"
        }).get();

        const battles = snapshot.docs.map(doc => ({ id: doc.id, ...doc.data() }));
        console.log("Fetched battles:", battles);
        return { battles };
    } catch (error) {
        console.error("Error fetching battles:", error);
        throw new functions.https.HttpsError("unknown", "Failed to fetch battles", error);
    }
});

次にFlutterアプリからこのメソッドを呼び出します。
こちらも保存の時と同様に特に何の変哲もない感じです。

fetchBattlesの呼び出し
  Future<List<Battle>> fetchBattles(
      String userId, List<int> opponentPartyIds) async {
    final functions = FirebaseFunctions.instance;
    final callable = functions.httpsCallable('fetchBattles');
    try {
      final result = await callable.call(<String, dynamic>{
        'userId': userId,
        'opponentPartyIds': opponentPartyIds
      });

      final List<Battle> battles = (result.data['battles'] as List)
          .map(
              (item) => Battle.fromJson(Map<String, dynamic>.from(item as Map)))
          .toList();

      return battles;
    } catch (e) {
      print('Caught Firebase Functions Exception:');
      print(e);
      return [];
    }
  }

実際に呼び出してみました。
例えば、今回は下記のパーティと対戦するとします。

[ゼニガメ,カメール,カメックス,ヒトカゲ,リザード,リザードン]

このパーティを使ってfetchBattlesを呼んでみると

確かに被っているポケモンが多い順に取得できています。(赤文字のポケモンが被っているポケモン)

これで似ている順に対戦履歴を取得することができました。

やりたかったけどできなかったこと

画像の例で言うと3つ目のパーティはヒトカゲしか被っておらず似ているとは言えなさそうです。
そのため、本当は閾値を決めて4匹以上被っていたら取得するみたいなクエリを書きたかったのですが、現時点ではそのような閾値を設定することはできなさそうでした。
また、startAfterのようなメソッドも使えなさそうなので、追加ローディングやPaginationの実装も現時点では難しいかなと思いました。

とはいえ、パブリックプレビュー段階ですし、今後これらの機能が追加される可能性は大いにあると思うので、追加されるのを願っています。

おわりに

今回はベクトル検索を利用して類似度順にFirestoreからデータを取得すると言うのをやってみました。今回のベクトル定義はかなり単純なものでしたが、実際には自然言語処理などを使ってベクトルを定義することで全文検索やレコメンドシステムなどにも応用ができると思うので、次はその辺りも手を出してみたいです。

2024年は技術発信も頑張ろうと思っているので、記事が参考になった方は記事とGitHubのいいね(スター)とフォローをしていただけると励みになります!
最後まで読んでいただきありがとうございました✨

https://github.com/miyasic/PokeScouter

Discussion