🙆

カスタム埋め込みモデルDatabricks上でサービングする

2024/01/30に公開

先日、こちらの記事で埋め込みモデルのファインチューニング方法について記載をしました。

https://zenn.dev/hiouchiy/articles/e608379945ec47

今回は、チューニング済みのモデルをサービングエンドポイントとしてデプロイする方法に焦点を当てます。過去にはOpenAIのモデルやカスタム文章生成LLMのサービング方法についても触れましたが、埋め込みモデルについてはまだ扱っていなかったため、その方法について今回詳しく解説します。

なぜカスタム埋め込みモデルをサービングするか?

RAGやセマンティック検索などのアプリケーション構築には、文章データや商品データなどの自然言語データをベクトル化する必要があります。これを実現するためには埋め込みモデルの使用が不可欠です。さらに、さまざまなアプリケーションとの統合を容易にするためには、これらのモデルをRESTエンドポイントとしてデプロイすることが望まれます。
ご存知の通り、Databricksでは複数のSaaS型LLMが「基盤モデル」として提供されており、トークン課金で利用可能です。この中には「BGE Large (English)」という埋め込みモデルも含まれていますが、その名の通り、日本語への対応がまだ十分ではありません。この点は、以下のブログで確認できます。

https://www.netone.co.jp/media/detail/20240122-01/

したがって、高精度な日本語対応の埋め込みモデルを使用するには、基盤モデル以外の選択肢を検討する必要があります。特にセキュリティ面やレスポンスタイムを重視する場合、セルフホスティングが最適な選択肢であると考えられます。

余談ですが、最新の埋め込みモデルの性能については、HuggingFaceの以下のサイトで詳しく紹介されています。

https://huggingface.co/spaces/mteb/leaderboard

また、日本語に特化した埋め込みモデルに関しては、個人的に以下のサイトを参考にしています。

https://github.com/oshizo/JapaneseEmbeddingEval

具体的な手順

0. 環境

1. 必要なライブラリーのインストール

%pip install mlflow==2.9.0 langchain==0.0.344 databricks-vectorsearch==0.22 databricks-sdk==0.12.0 mlflow[databricks]
dbutils.library.restartPython()

2. 埋め込みモデルをダウンロード

今回は日本語対応埋め込みとして評価が高いintfloat/multilingual-e5-largeを例に取ります。

from sentence_transformers import SentenceTransformer
model_name = "intfloat/multilingual-e5-large"

model = SentenceTransformer(model_name)

3. モデルをMLFlow Model Trackingに登録

ここが最大のポイントかもしれませんが、MLFlowのSentence Transformerフレーバーを使ってモデルを登録します。

import mlflow
import pandas as pd

# 入出力スキーマの定義
sentences = ["これは例文です", "各文章は変換されます"]
signature = mlflow.models.infer_signature(
    sentences,
    model.encode(sentences),
)

# MLFlowのSentence Transformerフレーバーを使って登録
with mlflow.start_run() as run:  
    mlflow.sentence_transformers.log_model(
      model, 
      "multilingual-e5-large-embedding", 
      signature=signature,
      input_example=sentences)

4. モデルをMLFlow Model Registryに登録

Unity Catalog上のモデルレジストリーへ登録します。その後、モデルにアライアスをつけます。

# Unityカタログにモデルを登録するためにMLflow Pythonクライアントを設定する
import mlflow
mlflow.set_registry_uri("databricks-uc")

# Unityカタログへのモデル登録
registered_name = "hiroshi.models.multilingual-e5-large" # UCモデル名は<カタログ名>.<スキーマ名>.<モデル名>のパターンに従っており、カタログ名、スキーマ名、登録モデル名に対応していることに注意してください。
result = mlflow.register_model(
    "runs:/"+run.info.run_id+"/multilingual-e5-large-embedding",
    registered_name,
)

from mlflow import MlflowClient
client = MlflowClient()

# 上記のセルに登録されている正しいモデルバージョンを選択
client.set_registered_model_alias(name=registered_name, alias="Champion", version=result.version)

5. エンドポイントのデプロイ

from databricks.sdk import WorkspaceClient
from databricks.sdk.service.serving import EndpointCoreConfigInput, ServedModelInput

# サービングエンドポイントの名前を指定
endpoint_name = 'multilingual-e5-large-embedding-endpoint'

# テスト用なので、テンポラリーなトークンを利用
databricks_url = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiUrl().getOrElse(None)
token = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().getOrElse(None)

# サービングエンドポイントの作成または更新
model_version = result  # mlflow.register_modelの返された結果

serving_endpoint_name = endpoint_name
latest_model_version = model_version.version
model_name = model_version.name

w = WorkspaceClient()
endpoint_config = EndpointCoreConfigInput(
    name=serving_endpoint_name,
    served_models=[
        ServedModelInput(
            model_name=model_name,
            model_version=latest_model_version,
            workload_type="GPU_SMALL",
            workload_size="Small",
            scale_to_zero_enabled=False
        )
    ]
)

existing_endpoint = next(
    (e for e in w.serving_endpoints.list() if e.name == serving_endpoint_name), None
)
serving_endpoint_url = f"{databricks_url}/ml/endpoints/{serving_endpoint_name}"
if existing_endpoint == None:
    print(f"Creating the endpoint {serving_endpoint_url}, this will take a few minutes to package and deploy the endpoint...")
    w.serving_endpoints.create_and_wait(name=serving_endpoint_name, config=endpoint_config)
else:
    print(f"Updating the endpoint {serving_endpoint_url} to version {latest_model_version}, this will take a few minutes to package and deploy the endpoint...")
    w.serving_endpoints.update_config_and_wait(served_models=endpoint_config.served_models, name=serving_endpoint_name)
    
displayHTML(f'Your Model Endpoint Serving is now available. Open the <a href="/ml/endpoints/{serving_endpoint_name}">Model Serving Endpoint page</a> for more details.')

6. エンドポイントをテスト

最近は、Databricks SDKのWorkspaceClientクラスを使うと簡単にエンドポイントを叩けます。

endpoint_response = w.serving_endpoints.query(
    name=endpoint_name, 
    dataframe_records=['こんにちは', 'おはようございます'])

print(endpoint_response)

まとめ

本記事では埋め込みモデルをサービングエンドポイントとしてDatabricks上にデプロイする方法を記載しました。皆様のご参考になれば幸いです。

Discussion