🫥
Databricks|Service Principalを使ってトークンを生成しようとしたら詰まった
起きていた問題
Databricks の API (Databricks SDK for Python) を利用してアクセス用のトークンを取得し、そのトークンを用いて MLflow のアーティファクトをダウンロードする処理を実装していました。しかし、以下のコードでは正常に動作しませんでした。
Databricks へのアクセス方法には、ユーザ認証 と サービスプリンシパル (Service Principal) の 2 種類があります。今回のケースでは、特定のユーザが存在しない環境での自動化を想定していたため、サービスプリンシパル (Service Principal) を利用しています。
download_artifact.py
import os
import mlflow
from databricks.sdk import WorkspaceClient
DATABRICKS_HOST = os.getenv("DATABRICKS_HOST")
DATABRICKS_CLIENT_ID = os.getenv("DATABRICKS_CLIENT_ID")
DATABRICKS_CLIENT_SECRET = os.getenv("DATABRICKS_CLIENT_SECRET")
client = WorkspaceClient(
host=DATABRICKS_HOST,
client_id=DATABRICKS_CLIENT_ID,
client_secret=DATABRICKS_CLIENT_SECRET,
auth_type="oauth",
)
token = client.token_management.create_obo_token(
DATABRICKS_CLIENT_ID, lifetime_seconds=3600
).token_value
os.environ["DATABRICKS_TOKEN"] = token
mlflow.set_tracking_uri("databricks")
run_id = "<RUN_ID>"
artifact_path = "model"
mlflow.artifacts.download_artifacts(
run_id=run_id,
artifact_path=artifact_path,
dst_path="tmp/"
)
発生したエラー
-
WorkspaceClient
のauth_type="oauth"
の指定が問題で、次のエラーが発生:
ValueError: cannot configure default credentials, please check https://docs.databricks.com/en/dev-tools/auth.html#databricks-client-unified-authentication to configure credentials for your preferred authentication method.
-
client.token_management.create_obo_token
を利用すると、以下のエラーが発生:
databricks.sdk.errors.platform.PermissionDenied: Only Admins can access token management APIs.
環境
- Python 3.12
- databricks-sdk==0.44.1
解消方法
問題を解消するために、以下の2点を修正しました。
-
WorkspaceClient
のauth_type
指定を削除- 指定しなかった場合の
auth_type
については補足にて説明
- 指定しなかった場合の
-
client.token_management.create_obo_token
の代わりにclient.tokens.create
を使用
[修正後] download_artifact.py
import os
import mlflow
from databricks.sdk import WorkspaceClient
DATABRICKS_HOST = os.getenv("DATABRICKS_HOST")
DATABRICKS_CLIENT_ID = os.getenv("DATABRICKS_CLIENT_ID")
DATABRICKS_CLIENT_SECRET = os.getenv("DATABRICKS_CLIENT_SECRET")
# WorkspaceClient の初期化(auth_type の指定なし)
client = WorkspaceClient(
host=DATABRICKS_HOST,
client_id=DATABRICKS_CLIENT_ID,
client_secret=DATABRICKS_CLIENT_SECRET,
)
token_lifetime_seconds = 3600
# `create_obo_token` の代わりに `tokens.create` を使用
token_response = client.tokens.create(
lifetime_seconds=token_lifetime_seconds,
comment="Service principal token for automation",
)
token = token_response.token_value
# 取得したトークンを環境変数に設定
os.environ["DATABRICKS_TOKEN"] = token
# MLflow のトラッキング URI を Databricks に設定
mlflow.set_tracking_uri("databricks")
# ダウンロード対象の run_id とアーティファクトパスを指定
run_id = "<RUN_ID>"
artifact_path = "model"
# アーティファクトをダウンロード
mlflow.artifacts.download_artifacts(run_id=run_id, artifact_path=artifact_path, dst_path="tmp/")
まとめ
具体的な原因の究明まではできていませんが、Databricks の Service Principal 認証を用いた MLflow アーティファクトの取得において、
-
auth_type
の指定を削除する -
create_obo_token
ではなくtokens.create
を使う
という修正を行うことで、問題を解決しました。
auth_type
未指定の場合の値
補足:client = WorkspaceClient(
host=os.environ["DATABRICKS_HOST"],
client_id=os.environ["DATABRICKS_CLIENT_ID"],
client_secret=os.environ["DATABRICKS_CLIENT_SECRET"],
)
print(vars(client))
を実行すると
{'_config': <Config: host=https://XXXX.cloud.databricks.com, client_id=XXXX, client_secret=***, auth_type=oauth-m2m. ...
と出力され、 auth_type=oauth-m2m
となっていることがわかりました。
Discussion