Lightningにおけるmetricロギングの調査

2023/08/02に公開

PyTorchで記述したモデル開発に関連するコードを簡潔に記述できるLightningのmetric計算について調べたのでまとめておく.

https://www.pytorchlightning.ai/index.html

Lightningにおけるmetric計算

Lightningではlightning.LightningModuleを継承したクラスにPyTorchの文法で記述したモデルを学習(training),検証(validation),テスト(test),推論(prediction)に関する情報と一緒に記述する.モデル学習時のlossの計算やモデル検証時のmetricの計算に関しては,それぞれtraining_stepvalidation_stepというメソッドに実装する.それらのメソッドの中でlogというメソッドを呼び出すことで計算したlossやmetricの保存が行える.例えばvalidation_step公式ドキュメントには以下のように記述されている.

class LitAutoEncoder(pl.LightningModule):
    def training_step(self, batch, batch_idx):
        ...

    def validation_step(self, batch, batch_idx):
        # this is the validation loop
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        val_loss = F.mse_loss(x_hat, x)
        self.log("val_loss", val_loss)

この実装を見て疑問に思う点としては1つのバッチに対して計算したlossがloggingされている点である.実際に検証時に確認したいlossは検証用データセット全体に対するlossのはずで,バッチに対して計算されたlossではない.logメソッドはon_epochという引数を取り,epoch終了時に計算を行ってくれるオプションが存在するが,上記実装を素直に読み解くとバッチごとに計算したlossの平均値を出すのでは?と考えてしまう.これだと実際に計算したいlossやmetricとの乖離が発生してしまう.実際,いくつかのドキュメントではvalidation_stepでは推論実行のみを行い.その後のmetric計算およびロギングをvalidation_epoch_endというメソッドで実行している記述も存在した[1].そのあたりに関する疑問を払拭するために今回logメソッドがどのように動作しているかについて調査する.

結論

  • 平均値,合計値,最大/最小値を計算するmetricはlogメソッド側で適切にまとめてくれる
  • 複雑な集約処理が必要なmetric計算はTorchMetricsを利用すると楽に実装できる
  • 上記に当てはまらない場合はTorchMetricsのMetricクラスを継承して実装するとよい

lightning.LightningModuleのlogメソッドの挙動

lightning.LightningModuleのlogメソッドから登録されたmetricは_ResultMetricというクラスで管理される.この_ResultMetrictorchmetrics.Metricクラスを継承しており,updateメソッドでmetric計算に必要な情報を更新し,computeメソッドで実際にmetric計算を行う._ResultMetricupdateメソッドではreduction functionに合わせてmetric計算に必要な情報をストアする.このreduction functionとしてLightningがネイティブにサポートしているのは平均,合計,最大/最小である.つまり,lossのようなmetricの場合,特に意識しなくてもLightning側で検証用データセット全体の平均値として登録される.

https://github.com/Lightning-AI/lightning/blob/2.0.6/src/lightning/pytorch/trainer/connectors/logger_connector/result.py#L184

複雑な集約処理が必要なmetricの場合

それでは,平均,合計,最大/最小以外のreduction functionが必要な場合はどうすれば良いだろうか.例えば,ROCAUCやランキングを評価するようなmetricの場合,検証用データセット全体の推論が完了してから出ないとmetricが計算できない場合がある.そのような場合はTorchMetricsを使用することがLightningのドキュメントでも奨励されている.

https://lightning.ai/docs/pytorch/stable/extensions/logging.html#logging-from-a-lightningmodule

https://torchmetrics.readthedocs.io/en/stable/

TorchMetricsで実装されているmetricを使用する場合

TorchMetricsで実装されているmetricの場合,必要なmetricを計算するclassを使用する.

例えばROCAUCを計算する場合は以下のようにvalidation_stepを実装する.

import lightning.pytorch as pl
import torchmetrics


class BinaryModel(pl.LightningModule):
    def __init__(self):
	...
	self.rocauc = torchmetrics.classification.BinaryAUROC()

    def validation_step(self, batch, batch_idx):
        # this is the validation loop
        x, y = batch
        y_hat = self.forward(x)
	self.rocauc(y_hat, y)
        self.log("val_rocauc", self.rocauc, on_step=False, on_epoch=True)

TorchMetricsで実装されていないmetricを使用する場合

どうしても独自で実装したmetric計算が必要な場合でもTorchMetricsが活用できる.TorchMetricsはMetricというクラスを用意しており,TorchMetricsで実装されているmetricはこのclassを継承して実装されている.つまり,Metricを継承したmetric計算クラスを実装すれば独自のmetricもLightningで計算できる.具体的な実装方法はTorchMetricsのドキュメントに記載されているのでそちらを参照.また例として,huggingfaceのevaluateをTorchMetricsで書き直した実装を載せておく.

https://torchmetrics.readthedocs.io/en/stable/pages/implement.html

import evaluate
import torchmetrics


class EvaluateROCAUC(torchmetrics.Metric):
    def __init__(self):
        super().__init__()
        self.evaluate_roc_auc = evaluate.load('roc_auc')
        self.add_state("preds", default=torch.tensor([]), dist_reduce_fx="cat")
        self.add_state("target", default=torch.tensor([]), dist_reduce_fx="cat")

    def update(self, preds: torch.Tensor, target: torch.Tensor):
        assert preds.shape == target.shape

        self.preds = torch.cat([self.preds, preds])
        self.target = torch.cat([self.target, target])

    def compute(self):
        roc_auc = self.evaluate_roc_auc.compute(references=self.target, prediction_scores=self.preds)['roc_auc']
        return torch.tensor(roc_auc)

おわりに

logメソッドは初見だと少し挙動に戸惑うが,動作がわかればかなり実装がわかりやすくなると思う.lightning.LightningModuleに関しても全てをこのモジュールに実装すると複雑な実装となってしまうが,モデル定義はtorch.nn.Moduleを継承したもの,データ周りはtorch.utils.data.Datasetの実装,metric計算はTorchMetricsのように責務を分解して,それを取りまとめる立ち位置としてlightning.LightningModuleを導入するとかなり見通しの良いコードになると思う.

脚注
  1. https://lightning.ai/docs/pytorch/stable/notebooks/lightning_examples/text-transformers.html ↩︎

  2. https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.core.hooks.ModelHooks.html ↩︎

Discussion