Closed5

Pytorch LightningとMLflowを使ってメトリクスを可視化する

rintaro121rintaro121

まずはmflowのinstall

poetry add mlflow

学習時にはloggerを作成してtrainerに与える

from lightning.pytorch import Trainer
from lightning.pytorch.loggers import MLFlowLogger

logger = MLFlowLogger()
trainer = Trainer(logger=logger)
rintaro121rintaro121

MLFlowLogger()の引数

引数 説明
experiment_name 実験名(プロジェクト名)
run_name プロジェクトにおけるの各実験の名前を指定
tracking_uri tracking内容を送る先のURI(ローカルでもリモートでもいい)。
ローカルの時は、file:///パス名のような感じで指定
tags プロジェクトに対して付与するタグ
save_dir runの実行結果が保存されるディレクトリ。
tracking_uriが指定されない場合、デフォルトは./mlrunsに保存される。tracking_uriを指定した場合は、ここの値を指定しても何も意味なし。
log_model ModelCheckpointによって作られたチェックポイントをMLFlowにおけるartifactsとして保存
log_model==”all”の場合、学習中の各チェックポイントを保存
log_model== Trueの場合、最後の学習のチェックポイントを保存
log_model==Falseの場合、チェックポイントは保存されない、デフォルトはこれ
prefix 指定した文字がメトリクスの説明の前つく
artifact_location 各runで保存されるartifactの保存先
run_id プロジェクトにおけるrunのID(指定しない場合新たなrun_idが生成される)
synchronous runの完了時に全てのログをとるか逐次的にログをとるかを指定

https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.loggers.mlflow.html

rintaro121rintaro121

Lightning Moduleのクラス内でロギングしたい値を指定。

class LitModel(LightningModule):
    def training_step(self, batch, batch_idx):
        ...
        # Logging Value
        loss = ...
        self.log("train_loss", loss)

    def validation_step(self, batch, batch_idx):
        ...
        # Logging Value
        loss = ...
        self.log("valid_loss", loss)

    def on_validation_epoch_end(self):
        ...
        # Logging Value
        accuracy = ...
        self.log("accuracy", accuracy)
rintaro121rintaro121

学習が始まると、mlrunsというディレクトリができる。

├── data
├── mlruns
├── poetry.lock
├── pyproject.toml
└── src

mlrunsが配下にある状態で

poetry run mlflow ui  # or mlflow ui

と打つと、ローカルサーバが立ち上がって実験結果を確認できる。

このスクラップは2025/02/22にクローズされました