🗂

MLflow+Hydra 最小構成

2023/05/06に公開
directory
my_project/
├─ configs/
│   └─ config.yaml
├─ my_project/
│   ├─ __init__.py
│   ├─ __main__.py
│   ├─ train.py
│   └─ utils.py
└─ pyproject.toml
pyproject.toml
[tool.poetry]
name = "my_project"
version = "0.1.0"
description = ""
authors = ["Your Name <you@example.com>"]
readme = "README.md"
packages = [{include = "my_project/**/**"}]

[tool.poetry.dependencies]
python = "^3.8"
torch = "^2.0.0"
hydra-core = "^1.3.2"
hydra = "^2.5"


[tool.poetry.group.dev.dependencies]
mlflow = "^2.3.1"
omegaconf = "^2.3.0"

[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"

my_project/configs/config.yaml
# my_project/configs/config.yaml

hydra:
  run:
    dir: outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}_${model.type}

data:
  path: "path/to/your/data"

model:
  type: "example_model"
  hidden_size: 128

training:
  batch_size: 64
  num_epochs: 10

mlflow:
  experiment_name: "my_experiment"
  # tracking_uri: ""

__main__.py
# my_project/__main__.py

import hydra
from omegaconf import OmegaConf, open_dict
from .train import train
from .utils import setup_mlflow

@hydra.main(version_base=None, config_path="../configs", config_name="config")
def main(config):
    with open_dict(config):
        config.mlflow["tracking_uri"] = "file://" + hydra.utils.get_original_cwd() + "/mlruns"
    setup_mlflow(config)
    train(config)

if __name__ == "__main__":
    main()

train.py
# my_project/src/train.py

import mlflow
from omegaconf import DictConfig

def train(config: DictConfig):
    # training process
    with mlflow.start_run():
        mlflow.log_params(config)

        for epoch in range(config.training.num_epochs):
            # training step
            # ...
            train_loss = 0.0
            train_accuracy = 0.0
            val_loss = 0.0
            val_accuracy = 0.0

            # log metrics
            mlflow.log_metric("train_loss", train_loss)
            mlflow.log_metric("train_accuracy", train_accuracy)
            mlflow.log_metric("val_loss", val_loss)
            mlflow.log_metric("val_accuracy", val_accuracy)

        # save models
        # mlflow.pytorch.log_model(model, "models")

util.py
# my_project/src/utils.py

from omegaconf import DictConfig
import mlflow

def setup_mlflow(config: DictConfig):
    mlflow.set_tracking_uri(config.mlflow.tracking_uri)
    mlflow.set_experiment(config.mlflow.experiment_name)

run
poetry run python -m my_project training.batch_size=128

Discussion