🐉

argparseの代わりにhydraを使ってみる (1)

2020/12/29に公開

数理最適化ソルバを使うときに,いくつかパラメータを振りたいときがあります.こういうときにパラメータをargparseで書いていると死んでしまう(可読性も低いです)ので,機械学習分野で成熟している気がする技術を調査しています.今回はまずhydraを使ってargparseを置き換えてみました.

問題設定

  • google OR-toolsでCVRPを解きます
  • パラメータとして「drop penalty」を「#vehicles」を想定します
  • これを外から管理したいです

hydra example

簡単な例

  • ディレクトリ構造です
.
├── conf
│   └── default.yaml
├── experiment.py
└── solver.py
  • パラメータの初期値はGoogle OR-Toolsのコード例では1000と4が設定されています.
conf/default.yaml
solver:
  penalty: 1000
  num_vehicle: 4
  • 具体的にPythonで呼び出す例です.
experiment.py
import os
import hydra
from solver import main
from omegaconf import OmegaConf

@hydra.main(config_name="conf/default.yaml")
def run_experiment(cfg):
    print(OmegaConf.to_yaml(cfg))
    ds, td, tl = main(cfg.solver.penalty, cfg.solver.num_vehicle)
    print("Dropped nodes  :", ds)
    print("Total distance :", td)
    print("Total load     :", tl)
    
if __name__ == '__main__':
    run_experiment()

実行

デフォルトパラメータで呼び出してみます.

$ python experiment.py 
solver:
  penalty: 1000
  num_vehicle: 4

Dropped nodes  : {16, 15}
Total distance : 5548
Total load     : 54

パラメータを上書きして呼び出してみます.

$ python experiment.py solver.penalty=500
solver:
  penalty: 500
  num_vehicle: 4

Dropped nodes  : {8, 16, 14, 15}
Total distance : 4452
Total load     : 40

multirunの使用例

hydraにまかせて見ます.コマンド呼び出しは以下の形です.

$ python experiment.py -m solver.penalty=100,250,500,1000,2000

結果の出力される様子です.今回はloggerを特に設定していないので,ターミナルに出力された結果をコピーしてきました.出力が長くなるので,cfgを出力する部分はコメントアウトしてから実行しました.

[2020-12-29 13:43:24,342][HYDRA] Launching 5 jobs locally
[2020-12-29 13:43:24,342][HYDRA]        #0 : solver.penalty=100
Dropped nodes  : {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
Total distance : 0
Total load     : 0
[2020-12-29 13:43:25,593][HYDRA]        #1 : solver.penalty=250
Dropped nodes  : {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
Total distance : 0
Total load     : 0
[2020-12-29 13:43:26,846][HYDRA]        #2 : solver.penalty=500
Dropped nodes  : {8, 16, 14, 15}
Total distance : 4452
Total load     : 40
[2020-12-29 13:43:28,087][HYDRA]        #3 : solver.penalty=1000
Dropped nodes  : {16, 15}
Total distance : 5548
Total load     : 54
[2020-12-29 13:43:29,348][HYDRA]        #4 : solver.penalty=2000
Dropped nodes  : {16, 15}
Total distance : 5548
Total load     : 54

この記事で特に言及していない点

  • multirunのときのlog管理
  • デフォルトで作成されるディレクトリのlogファイルに出力する方法
  • 見やすく出力する方法 (他フレームワークとの連携など)

Discussion