【Optimization Method】Optuna Tutorial part3
This is part 3 of the Optuna tutorial series.
・Part 1
・Part 2
Official Page:
Official Tutorial:3.0 Review the basic code
import optuna
def objective(trial):
x = trial.suggest_float("x", -10, 10)
return (x - 2) ** 2
study = optuna.create_study()
study.optimize(objective, n_trials=100)
best_params = study.best_params
found_x = best_params["x"]
print("Found x: {}, (x - 2)^2: {}".format(found_x, (found_x - 2) ** 2))
3.1 Sampling Algorithms
Optuna provides many algorithms for parameter optimization, each method has each feature so we have to choose appropriately. (But you can try all of them if you have certain test times)
More detailed explanation of how samplers suggest parameters is in BaseSampler.
Optuna provides the following sampling algorithms:
・Grid Search implemented in GridSampler
・Random Search implemented in RandomSampler
・Tree-structured Parzen Estimator algorithm implemented in TPESampler
・CMA-ES based algorithm implemented in CmaEsSampler
・Gaussian process-based algorithm implemented in GPSampler
・Algorithm to enable partial fixed parameters implemented in PartialFixedSampler
・Nondominated Sorting Genetic Algorithm II implemented in NSGAIISampler
・A Quasi Monte Carlo sampling algorithm implemented in QMCSampler
The default sampler is TPESampler.
3.1.1 Switching Sampler
・Switching Sampler
import optuna
study = optuna.create_study() # default
print(f"Sampler is {study.sampler.__class__.__name__}")
study = optuna.create_study(sampler=optuna.samplers.RandomSampler())
print(f"Sampler is {study.sampler.__class__.__name__}")
study = optuna.create_study(sampler=optuna.samplers.CmaEsSampler())
print(f"Sampler is {study.sampler.__class__.__name__}")
・Result
[I 2024-09-26 03:04:53,090] A new study created in memory with name: no-name
[I 2024-09-26 03:04:53,091] A new study created in memory with name: no-name
[I 2024-09-26 03:04:53,091] A new study created in memory with name: no-name
Sampler is TPESampler
Sampler is RandomSampler
Sampler is CmaEsSampler
3.2 Pruning Algorithms
Pruners automatically stop unpromising trials at the early stages of the training (a.k.a., automated early-stopping). Currently pruners module is expected to be used only for single-objective optimization.
Optuna provides the following pruning algorithms:
・Median pruning algorithm implemented in MedianPruner
・Non-pruning algorithm implemented in NopPruner
・Algorithm to operate pruner with tolerance implemented in PatientPruner
・Algorithm to prune specified percentile of trials implemented in PercentilePruner
・Asynchronous Successive Halving algorithm implemented in SuccessiveHalvingPruner
・Hyperband algorithm implemented in HyperbandPruner
・Threshold pruning algorithm implemented in ThresholdPruner
・A pruning algorithm based on Wilcoxon signed-rank test implemented in WilcoxonPruner
The Optuna example codes using MedianPruner in most cases. Basically, it is outperformed by SuccessiveHalvingPruner and HyperbandPruner as in this benchmark result.
3.2.1 Activating Pruners
To turn on the pruning feature, you need to call report() and should_prune() after each step of the iterative training. report() periodically monitors the intermediate objective values. should_prune() decides termination of the trial that does not meet a predefined condition.
・Example Prunner Code
import logging
import sys
import sklearn.datasets
import sklearn.linear_model
import sklearn.model_selection
def objective(trial):
iris = sklearn.datasets.load_iris()
classes = list(set(iris.target))
train_x, valid_x, train_y, valid_y = sklearn.model_selection.train_test_split(
iris.data, iris.target, test_size=0.25, random_state=0
)
alpha = trial.suggest_float("alpha", 1e-5, 1e-1, log=True)
clf = sklearn.linear_model.SGDClassifier(alpha=alpha)
for step in range(100):
clf.partial_fit(train_x, train_y, classes=classes)
# Report intermediate objective value.
intermediate_value = 1.0 - clf.score(valid_x, valid_y)
trial.report(intermediate_value, step)
# Handle pruning based on the intermediate value.
if trial.should_prune():
raise optuna.TrialPruned()
return 1.0 - clf.score(valid_x, valid_y)
# Add stream handler of stdout to show the messages
optuna.logging.get_logger("optuna").addHandler(logging.StreamHandler(sys.stdout))
study = optuna.create_study(pruner=optuna.pruners.MedianPruner())
study.optimize(objective, n_trials=20)
・Result
[I 2024-09-26 10:47:43,878] A new study created in memory with name: no-name
[I 2024-09-26 10:47:43,961] Trial 0 finished with value: 0.3421052631578947 and parameters: {'alpha': 0.0004612327672225534}. Best is trial 0 with value: 0.3421052631578947.
[I 2024-09-26 10:47:44,041] Trial 1 finished with value: 0.052631578947368474 and parameters: {'alpha': 4.0761044211755606e-05}. Best is trial 1 with value: 0.052631578947368474.
[I 2024-09-26 10:47:44,114] Trial 2 finished with value: 0.1842105263157895 and parameters: {'alpha': 0.002963224994149077}. Best is trial 1 with value: 0.052631578947368474.
[I 2024-09-26 10:47:44,188] Trial 3 finished with value: 0.3157894736842105 and parameters: {'alpha': 0.00029443834676040354}. Best is trial 1 with value: 0.052631578947368474.
[I 2024-09-26 10:47:44,261] Trial 4 finished with value: 0.368421052631579 and parameters: {'alpha': 0.0005968789848913645}. Best is trial 1 with value: 0.052631578947368474.
[I 2024-09-26 10:47:44,322] Trial 5 pruned.
[I 2024-09-26 10:47:44,403] Trial 6 finished with value: 0.10526315789473684 and parameters: {'alpha': 0.0002549952615034128}. Best is trial 1 with value: 0.052631578947368474.
[I 2024-09-26 10:47:44,407] Trial 7 pruned.
[I 2024-09-26 10:47:44,411] Trial 8 pruned.
[I 2024-09-26 10:47:44,414] Trial 9 pruned.
[I 2024-09-26 10:47:44,420] Trial 10 pruned.
[I 2024-09-26 10:47:44,425] Trial 11 pruned.
[I 2024-09-26 10:47:44,509] Trial 12 finished with value: 0.42105263157894735 and parameters: {'alpha': 9.034472208694208e-05}. Best is trial 1 with value: 0.052631578947368474.
[I 2024-09-26 10:47:44,592] Trial 13 finished with value: 0.1842105263157895 and parameters: {'alpha': 7.208212462168948e-05}. Best is trial 1 with value: 0.052631578947368474.
[I 2024-09-26 10:47:44,600] Trial 14 pruned.
[I 2024-09-26 10:47:44,609] Trial 15 pruned.
[I 2024-09-26 10:47:44,614] Trial 16 pruned.
[I 2024-09-26 10:47:44,698] Trial 17 finished with value: 0.13157894736842102 and parameters: {'alpha': 0.00024609691733570106}. Best is trial 1 with value: 0.052631578947368474.
[I 2024-09-26 10:47:44,705] Trial 18 pruned.
[I 2024-09-26 10:47:44,788] Trial 19 finished with value: 0.07894736842105265 and parameters: {'alpha': 0.0008168248767813228}. Best is trial 1 with value: 0.052631578947368474.
Like this, some unpromising processes are pruncated in the early step for efficiency.
3.3 Which Sampler and Pruner Should be Used?
・Not deep learning tasks
From this benchmark:
・For RandomSampler
, MedianPruner
is the best.
・For TPESampler
, HyperbandPruner
is the best.
・Deep learning tasks
From this book:
Parallel Compute Resource | Categorical/Conditional Hyperparameters | Recommended Algorithms |
---|---|---|
Limited | No | TPE, GP-EI if search space is low-dimensional and continuous. |
Limited | Yes | TPE, GP-EI if search space is low-dimensional and continuous |
Sufficient | No | CMA-ES, Random Search |
Sufficient | Yes | Random Search or Genetic Algorithm |
3.3 Integration
To implement pruning mechanism in much simpler forms, Optuna provides integration modules for the following libraries.
Integration | Dependencies |
---|---|
AllenNLP | allennlp, torch, psutil, jsonnet |
BoTorch | botorch, gpytorch, torch |
CatBoost | catboost |
ChainerMN | chainermn |
Chainer | chainer |
pycma | cma |
Dask | distributed |
FastAI | fastai |
Keras | keras |
LightGBMTuner | lightgbm, scikit-learn |
LightGBMPruningCallback | lightgbm |
MLflow | mlflow |
MXNet | mxnet |
PyTorch Distributed | torch |
PyTorch (Ignite) | pytorch-ignite |
PyTorch (Lightning) | pytorch-lightning |
SHAP | scikit-learn, shap |
Scikit-learn | pandas, scipy, scikit-learn |
SKorch | skorch |
TensorBoard | tensorboard, tensorflow |
TensorFlow | tensorflow, tensorflow-estimator |
TensorFlow + Keras | tensorflow |
Weights & Biases | wandb |
XGBoost | xgboost |
We can check each implementation at the optuna example codes.
For example, the lightgbm version is here:
import optuna.integration
pruning_callback = optuna.integration.LightGBMPruningCallback(trial, 'validation-error')
gbm = lgb.train(param, dtrain, valid_sets=[dvalid], callbacks=[pruning_callback])
3.4 Summary
This time, I explained the basics and implementation of the below.
・The algorithms
・The pruner
Please choose an appropriate method for your task.
Discussion