カスタム埋め込みモデルDatabricks上でサービングする
先日、こちらの記事で埋め込みモデルのファインチューニング方法について記載をしました。
今回は、チューニング済みのモデルをサービングエンドポイントとしてデプロイする方法に焦点を当てます。過去にはOpenAIのモデルやカスタム文章生成LLMのサービング方法についても触れましたが、埋め込みモデルについてはまだ扱っていなかったため、その方法について今回詳しく解説します。
なぜカスタム埋め込みモデルをサービングするか?
RAGやセマンティック検索などのアプリケーション構築には、文章データや商品データなどの自然言語データをベクトル化する必要があります。これを実現するためには埋め込みモデルの使用が不可欠です。さらに、さまざまなアプリケーションとの統合を容易にするためには、これらのモデルをRESTエンドポイントとしてデプロイすることが望まれます。
ご存知の通り、Databricksでは複数のSaaS型LLMが「基盤モデル」として提供されており、トークン課金で利用可能です。この中には「BGE Large (English)」という埋め込みモデルも含まれていますが、その名の通り、日本語への対応がまだ十分ではありません。この点は、以下のブログで確認できます。
したがって、高精度な日本語対応の埋め込みモデルを使用するには、基盤モデル以外の選択肢を検討する必要があります。特にセキュリティ面やレスポンスタイムを重視する場合、セルフホスティングが最適な選択肢であると考えられます。
余談ですが、最新の埋め込みモデルの性能については、HuggingFaceの以下のサイトで詳しく紹介されています。
また、日本語に特化した埋め込みモデルに関しては、個人的に以下のサイトを参考にしています。
具体的な手順
0. 環境
- Databricks Runtime: 14.2 ML GPU
- ノードタイプ: g4dn.xlarge (シングルノード)
- ソースコード: https://github.com/hiouchiy/databricks-ml-examples/tree/master/llm-models/embedding/e5/multilingual-e5-large
1. 必要なライブラリーのインストール
%pip install mlflow==2.9.0 langchain==0.0.344 databricks-vectorsearch==0.22 databricks-sdk==0.12.0 mlflow[databricks]
dbutils.library.restartPython()
2. 埋め込みモデルをダウンロード
今回は日本語対応埋め込みとして評価が高いintfloat/multilingual-e5-large
を例に取ります。
from sentence_transformers import SentenceTransformer
model_name = "intfloat/multilingual-e5-large"
model = SentenceTransformer(model_name)
3. モデルをMLFlow Model Trackingに登録
ここが最大のポイントかもしれませんが、MLFlowのSentence Transformerフレーバーを使ってモデルを登録します。
import mlflow
import pandas as pd
# 入出力スキーマの定義
sentences = ["これは例文です", "各文章は変換されます"]
signature = mlflow.models.infer_signature(
sentences,
model.encode(sentences),
)
# MLFlowのSentence Transformerフレーバーを使って登録
with mlflow.start_run() as run:
mlflow.sentence_transformers.log_model(
model,
"multilingual-e5-large-embedding",
signature=signature,
input_example=sentences)
4. モデルをMLFlow Model Registryに登録
Unity Catalog上のモデルレジストリーへ登録します。その後、モデルにアライアスをつけます。
# Unityカタログにモデルを登録するためにMLflow Pythonクライアントを設定する
import mlflow
mlflow.set_registry_uri("databricks-uc")
# Unityカタログへのモデル登録
registered_name = "hiroshi.models.multilingual-e5-large" # UCモデル名は<カタログ名>.<スキーマ名>.<モデル名>のパターンに従っており、カタログ名、スキーマ名、登録モデル名に対応していることに注意してください。
result = mlflow.register_model(
"runs:/"+run.info.run_id+"/multilingual-e5-large-embedding",
registered_name,
)
from mlflow import MlflowClient
client = MlflowClient()
# 上記のセルに登録されている正しいモデルバージョンを選択
client.set_registered_model_alias(name=registered_name, alias="Champion", version=result.version)
5. エンドポイントのデプロイ
from databricks.sdk import WorkspaceClient
from databricks.sdk.service.serving import EndpointCoreConfigInput, ServedModelInput
# サービングエンドポイントの名前を指定
endpoint_name = 'multilingual-e5-large-embedding-endpoint'
# テスト用なので、テンポラリーなトークンを利用
databricks_url = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiUrl().getOrElse(None)
token = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().getOrElse(None)
# サービングエンドポイントの作成または更新
model_version = result # mlflow.register_modelの返された結果
serving_endpoint_name = endpoint_name
latest_model_version = model_version.version
model_name = model_version.name
w = WorkspaceClient()
endpoint_config = EndpointCoreConfigInput(
name=serving_endpoint_name,
served_models=[
ServedModelInput(
model_name=model_name,
model_version=latest_model_version,
workload_type="GPU_SMALL",
workload_size="Small",
scale_to_zero_enabled=False
)
]
)
existing_endpoint = next(
(e for e in w.serving_endpoints.list() if e.name == serving_endpoint_name), None
)
serving_endpoint_url = f"{databricks_url}/ml/endpoints/{serving_endpoint_name}"
if existing_endpoint == None:
print(f"Creating the endpoint {serving_endpoint_url}, this will take a few minutes to package and deploy the endpoint...")
w.serving_endpoints.create_and_wait(name=serving_endpoint_name, config=endpoint_config)
else:
print(f"Updating the endpoint {serving_endpoint_url} to version {latest_model_version}, this will take a few minutes to package and deploy the endpoint...")
w.serving_endpoints.update_config_and_wait(served_models=endpoint_config.served_models, name=serving_endpoint_name)
displayHTML(f'Your Model Endpoint Serving is now available. Open the <a href="/ml/endpoints/{serving_endpoint_name}">Model Serving Endpoint page</a> for more details.')
6. エンドポイントをテスト
最近は、Databricks SDKのWorkspaceClientクラスを使うと簡単にエンドポイントを叩けます。
endpoint_response = w.serving_endpoints.query(
name=endpoint_name,
dataframe_records=['こんにちは', 'おはようございます'])
print(endpoint_response)
まとめ
本記事では埋め込みモデルをサービングエンドポイントとしてDatabricks上にデプロイする方法を記載しました。皆様のご参考になれば幸いです。
Discussion