📝

OmegaConfを使った機械学習用ScriptのExample

2021/07/30に公開

メモです。
特にcliからの読み込みとデフォルト値の合体、OmegaconfDict -> Dataclassへの変換はよく忘れるので自分で見直すために書いておく。

from dataclasses import dataclass
import os
from pathlib import Path
from omegaconf import OmegaConf, MISSING

@dataclass
class TrainingConfig:
    lr: float = 1e-2
    batch_size: int = 32
    input_size: int = 256
    max_epochs: int = 100
    seed: int = 42

@dataclass
class LossWeights:
    alpha: float = 1.0
    beta: float = 1e-2

@dataclass
class Config:
    image_root_dir: str = MISSING
    output_root_dir: str = MISSING
    experiment_name: str = MISSING
    num_workers: int = 0
    training: TrainingConfig = TrainingConfig()
    loss_weights: LossWeights = LossWeights()
        

def main():
    confing_cli = OmegaConf.from_cli()
    config_default = OmegaConf.structured(Config)
    config: Config = OmegaConf.to_object(OmegaConf.merge(config_default, confing_cli))

    if config.num_workers == 0:
        config.num_workers = os.cpu_count()

    output_root_dir = Path(config.output_root_dir)
    output_root_dir.mkdir(exist_ok=True)
    (output_root_dir / config.experiment_name).mkdir(exist_ok=True)
    output_config_path = output_root_dir / config.experiment_name / "config.yaml"
    OmegaConf.save(config, output_config_path)

    print("==== config ====")
    print(OmegaConf.to_yaml(config))

Discussion