wandbとOptunaでタイタニックをチューンする:現代的MLスタックの実践
はじめに
機械学習プロジェクトにおいて、モデルの精度を極限まで高めるハイパーパラメータチューニングは避けて通れない工程です。しかし、闇雲にパラメータを探るだけでは時間と計算リソースを浪費してしまいます。
本記事では、Kaggleの入門タスクとして名高い「タイタニック号の生存予測」を題材に、以下の現代的なPythonデータサイエンススタックを組み合わせた、効率的かつ可視化されたチューニングパイプラインの構築方法を解説します。
- Polars: 高速データフレームライブラリによる前処理
- LightGBM: 高速・高精度な勾配ブースティングモデル
- Optuna: 効率的なハイパーパラメータ自動最適化フレームワーク
- Weights & Biases (WandB): 実験管理と可視化プラットフォーム
単にスコアを出すだけでなく、Optunaの探索プロセスをWandBで可視化し、効率的な意思決定を行うフローを目指します。
実装の全体像
コード全体をいくつかのステップに分けて解説します。
1. ライブラリのインポートと設定
まずは必要なライブラリをインポートします。今回はOptunaとWandBの連携機能である WeightsAndBiasesCallback を利用するのがポイントです。
import matplotlib.pyplot as plt
import numpy as np
import polars as pl
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score
import optuna
import lightgbm as lgb
import wandb
from optuna.integration.wandb import WeightsAndBiasesCallback
from optuna.integration.lightgbm import LightGBMPruningCallback
from optuna.pruners import SuccessiveHalvingPruner
# WandBのプロジェクト名設定
wandb_kwargs = {"project": "case-titanic"}
2. Polarsによるデータ前処理
データ操作には pandas ではなく、Rust製で高速な polars を使用します。記述がメソッドチェーンで完結しやすく、可読性が高いのが特徴です。
# データのロード
# ※ 手元の環境に合わせてパスを変更してください
train = pl.read_csv("sample/data/train.csv")
# 前処理パイプライン
train_processed = train.with_columns(
pl.col("Age").fill_null(pl.col("Age").mean()),
pl.col("Embarked").fill_null("S"),
pl.col("Fare").fill_null(pl.col("Fare").mean()),
# 文字列のカテゴリカル変数を数値に変換
pl.col("Sex").map_elements(lambda x: 1 if x == "female" else 0, return_dtype=pl.Int8),
)
# 特徴量とターゲットの分離
X = train_processed.select(
["Pclass", "Sex", "Age", "SibSp", "Parch", "Fare"]
).fill_nan(0).to_numpy()
y = train_processed.select("Survived").to_numpy().flatten()
# 訓練データと検証データの分割
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)
print(f"Train shape: {X_train.shape}, Val shape: {X_val.shape}")
3. Optunaの目的関数とWandBコールバック
ここが本記事の核となる部分です。
Optunaの objective 関数内でLightGBMの学習を行いますが、以下の2つのコールバックを設定することで、効率的な探索と記録を実現します。
-
LightGBMPruningCallback: 学習曲線が芳しくないトライアル(試行)を早期に打ち切る(Pruning)ことで、計算時間を節約します。 -
WeightsAndBiasesCallback: Optunaの各トライアルの結果(パラメータとスコア)を自動的にWandBに記録します。
# OptunaとWandBをつなぐコールバックの初期化
wandb_callback = WeightsAndBiasesCallback(metric_name="accuracy", wandb_kwargs=wandb_kwargs)
def objective(trial):
# 探索空間の定義
param = {
"objective": "binary",
"metric": "auc", # 学習時の評価指標
"verbosity": -1,
"boosting_type": "gbdt",
"lambda_l1": trial.suggest_float("lambda_l1", 1e-8, 10.0, log=True),
"lambda_l2": trial.suggest_float("lambda_l2", 1e-8, 10.0, log=True),
"num_leaves": trial.suggest_int("num_leaves", 2, 256),
"feature_fraction": trial.suggest_float("feature_fraction", 0.4, 1.0),
"bagging_fraction": trial.suggest_float("bagging_fraction", 0.4, 1.0),
"bagging_freq": trial.suggest_int("bagging_freq", 1, 7),
"min_child_samples": trial.suggest_int("min_child_samples", 5, 100),
"learning_rate": trial.suggest_float("learning_rate", 0.01, 0.3, log=True),
}
dtrain = lgb.Dataset(X_train, label=y_train)
dvalid = lgb.Dataset(X_val, label=y_val)
# 枝刈り用コールバック: 検証データのAUCを監視
pruning_callback = LightGBMPruningCallback(trial, metric="auc", valid_name="valid_0")
# モデルの学習
gbm = lgb.train(
param,
dtrain,
valid_sets=[dvalid],
num_boost_round=1000, # 早期打ち切りを前提に多めに設定
callbacks=[pruning_callback], # ここで枝刈りを有効化
)
# 推論と評価
preds = gbm.predict(X_val)
pred_labels = np.rint(preds)
accuracy = np.mean(pred_labels == y_val)
auc = roc_auc_score(y_val, preds)
# 独自のメトリクスをWandBに追加記録
wandb.log({"accuracy": accuracy, "auc": auc})
# Optunaはaccuracyを最大化するように動作させる
return accuracy
実装のポイント:評価指標と打ち切り(Pruning)
上記のコードでは、LightGBMの学習中の評価(Early StoppingやPruningの基準)には AUC を用い、Optunaの最終的な最適化対象としては Accuracy を返しています。
- 評価と打ち切りの整合性: 一般的に、評価指標(Metric)と打ち切りの判定基準は方向性を揃えます(どちらも「大きい方が良い」など)。今回はAUCもAccuracyも「大きい方が良い」指標であるため、整合性が取れています。
-
最低保証回数: Prunerの設定(後述)で、初期の数epoch(
min_resource)は無条件に学習させることで、学習不足による誤った枝刈りを防ぎます。
4. 最適化の実行
SuccessiveHalvingPruner を用いて最適化を実行します。これは、有望でないトライアルを積極的に打ち切るアルゴリズムです。
# 実験の開始
study = optuna.create_study(
direction="maximize", # Accuracyの最大化を目指す
pruner=SuccessiveHalvingPruner(
min_resource=10, # 最低でも10ラウンドは学習する
reduction_factor=3 # 1/3ずつ候補を絞っていく
)
)
# callbacksにwandb_callbackを指定することで自動記録される
study.optimize(objective, n_trials=100, timeout=600, callbacks=[wandb_callback])
# WandBのrunを終了
wandb.finish()
# 結果の表示
print("Best trial:")
trial = study.best_trial
print(f" Value (Accuracy): {trial.value}")
print(" Params: ")
for key, value in trial.params.items():
print(f" {key}: {value}")
WandBでの可視化
このコードを実行すると、WandBのダッシュボードでグラフが自動生成されます。
MLOpsの視点: WandB か MLflow か
実験管理ツールの選定において、WandBとMLflowはよく比較対象になります。
-
Weights & Biases (WandB):
- Pros: クラウドベースでセットアップが圧倒的に容易。UIが洗練されており、可視化機能(特にハイパーパラメータ探索の分析)が強力です。
- Use Case: 個人の実験、小規模チームでの探索フェーズ、初期のモデル開発に最適です。サーバー構築不要ですぐに始められます。
-
MLflow:
- Pros: オープンソースであり、完全なオンプレミス環境や自社VPC内での構築が可能。モデルレジストリやデプロイ機能まで含めたライフサイクル全般の管理に長けています。
- Use Case: 本番環境へのデプロイを見据えた厳格な管理、外部へデータを出せないセキュリティ要件がある場合、CI/CDパイプラインへの統合。
結論としてのハイブリッドアプローチ:
初期の探索やパラメータチューニングには、セットアップが楽で分析機能が強い WandB を利用し、モデルが固まり本番運用のフェーズに入ったら、アーティファクト管理やデプロイパイプラインとして MLflow にバトンタッチするという使い分けも有効な戦略です。
おわりに
OptunaとWandBを組み合わせることで、「どのパラメータが効いているか」を視覚的に理解しながら、効率的にモデルを改善できます。タイタニックのようなシンプルなタスクでこのパイプラインに慣れておけば、より複雑なコンペティションや実務データにもスムーズに応用できるはずです。
Discussion