📝
OmegaConfを使った機械学習用ScriptのExample
メモです。
特にcliからの読み込みとデフォルト値の合体、OmegaconfDict -> Dataclassへの変換はよく忘れるので自分で見直すために書いておく。
from dataclasses import dataclass
import os
from pathlib import Path
from omegaconf import OmegaConf, MISSING
@dataclass
class TrainingConfig:
lr: float = 1e-2
batch_size: int = 32
input_size: int = 256
max_epochs: int = 100
seed: int = 42
@dataclass
class LossWeights:
alpha: float = 1.0
beta: float = 1e-2
@dataclass
class Config:
image_root_dir: str = MISSING
output_root_dir: str = MISSING
experiment_name: str = MISSING
num_workers: int = 0
training: TrainingConfig = TrainingConfig()
loss_weights: LossWeights = LossWeights()
def main():
confing_cli = OmegaConf.from_cli()
config_default = OmegaConf.structured(Config)
config: Config = OmegaConf.to_object(OmegaConf.merge(config_default, confing_cli))
if config.num_workers == 0:
config.num_workers = os.cpu_count()
output_root_dir = Path(config.output_root_dir)
output_root_dir.mkdir(exist_ok=True)
(output_root_dir / config.experiment_name).mkdir(exist_ok=True)
output_config_path = output_root_dir / config.experiment_name / "config.yaml"
OmegaConf.save(config, output_config_path)
print("==== config ====")
print(OmegaConf.to_yaml(config))
Discussion