🐉
argparseの代わりにhydraを使ってみる (1)
数理最適化ソルバを使うときに,いくつかパラメータを振りたいときがあります.こういうときにパラメータをargparseで書いていると死んでしまう(可読性も低いです)ので,機械学習分野で成熟している気がする技術を調査しています.今回はまずhydraを使ってargparseを置き換えてみました.
- レポジトリです: https://github.com/cocomoff/ortools_with_hydra
- 参考にしたサイトです
問題設定
- google OR-toolsでCVRPを解きます
- パラメータとして「drop penalty」を「#vehicles」を想定します
- 実装を見てください: https://github.com/cocomoff/ortools_with_hydra/blob/main/solver.py
- 計算した結果,dropされたユーザの集合,総移動距離,総load数を返す関数にしています
- これを外から管理したいです
hydra example
簡単な例
- ディレクトリ構造です
.
├── conf
│ └── default.yaml
├── experiment.py
└── solver.py
- パラメータの初期値はGoogle OR-Toolsのコード例では1000と4が設定されています.
conf/default.yaml
solver:
penalty: 1000
num_vehicle: 4
- 具体的にPythonで呼び出す例です.
experiment.py
import os
import hydra
from solver import main
from omegaconf import OmegaConf
@hydra.main(config_name="conf/default.yaml")
def run_experiment(cfg):
print(OmegaConf.to_yaml(cfg))
ds, td, tl = main(cfg.solver.penalty, cfg.solver.num_vehicle)
print("Dropped nodes :", ds)
print("Total distance :", td)
print("Total load :", tl)
if __name__ == '__main__':
run_experiment()
実行
デフォルトパラメータで呼び出してみます.
$ python experiment.py
solver:
penalty: 1000
num_vehicle: 4
Dropped nodes : {16, 15}
Total distance : 5548
Total load : 54
パラメータを上書きして呼び出してみます.
$ python experiment.py solver.penalty=500
solver:
penalty: 500
num_vehicle: 4
Dropped nodes : {8, 16, 14, 15}
Total distance : 4452
Total load : 40
multirunの使用例
hydraにまかせて見ます.コマンド呼び出しは以下の形です.
$ python experiment.py -m solver.penalty=100,250,500,1000,2000
結果の出力される様子です.今回はloggerを特に設定していないので,ターミナルに出力された結果をコピーしてきました.出力が長くなるので,cfgを出力する部分はコメントアウトしてから実行しました.
[2020-12-29 13:43:24,342][HYDRA] Launching 5 jobs locally
[2020-12-29 13:43:24,342][HYDRA] #0 : solver.penalty=100
Dropped nodes : {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
Total distance : 0
Total load : 0
[2020-12-29 13:43:25,593][HYDRA] #1 : solver.penalty=250
Dropped nodes : {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
Total distance : 0
Total load : 0
[2020-12-29 13:43:26,846][HYDRA] #2 : solver.penalty=500
Dropped nodes : {8, 16, 14, 15}
Total distance : 4452
Total load : 40
[2020-12-29 13:43:28,087][HYDRA] #3 : solver.penalty=1000
Dropped nodes : {16, 15}
Total distance : 5548
Total load : 54
[2020-12-29 13:43:29,348][HYDRA] #4 : solver.penalty=2000
Dropped nodes : {16, 15}
Total distance : 5548
Total load : 54
この記事で特に言及していない点
- multirunのときのlog管理
- デフォルトで作成されるディレクトリのlogファイルに出力する方法
- 見やすく出力する方法 (他フレームワークとの連携など)
Discussion