Lightningにおけるmetricロギングの調査
PyTorchで記述したモデル開発に関連するコードを簡潔に記述できるLightningのmetric計算について調べたのでまとめておく.
Lightningにおけるmetric計算
Lightningではlightning.LightningModule
を継承したクラスにPyTorchの文法で記述したモデルを学習(training),検証(validation),テスト(test),推論(prediction)に関する情報と一緒に記述する.モデル学習時のlossの計算やモデル検証時のmetricの計算に関しては,それぞれtraining_step
,validation_step
というメソッドに実装する.それらのメソッドの中でlog
というメソッドを呼び出すことで計算したlossやmetricの保存が行える.例えばvalidation_step
は公式ドキュメントには以下のように記述されている.
class LitAutoEncoder(pl.LightningModule):
def training_step(self, batch, batch_idx):
...
def validation_step(self, batch, batch_idx):
# this is the validation loop
x, y = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
val_loss = F.mse_loss(x_hat, x)
self.log("val_loss", val_loss)
この実装を見て疑問に思う点としては1つのバッチに対して計算したlossがloggingされている点である.実際に検証時に確認したいlossは検証用データセット全体に対するlossのはずで,バッチに対して計算されたlossではない.log
メソッドはon_epochという引数を取り,epoch終了時に計算を行ってくれるオプションが存在するが,上記実装を素直に読み解くとバッチごとに計算したlossの平均値を出すのでは?と考えてしまう.これだと実際に計算したいlossやmetricとの乖離が発生してしまう.実際,いくつかのドキュメントではvalidation_step
では推論実行のみを行い.その後のmetric計算およびロギングをvalidation_epoch_end
というメソッドで実行している記述も存在した[1].そのあたりに関する疑問を払拭するために今回log
メソッドがどのように動作しているかについて調査する.
結論
- 平均値,合計値,最大/最小値を計算するmetricは
log
メソッド側で適切にまとめてくれる - 複雑な集約処理が必要なmetric計算はTorchMetricsを利用すると楽に実装できる
- 上記に当てはまらない場合はTorchMetricsの
Metric
クラスを継承して実装するとよい
log
メソッドの挙動
lightning.LightningModuleのlightning.LightningModuleのlog
メソッドから登録されたmetricは_ResultMetric
というクラスで管理される.この_ResultMetric
はtorchmetrics.Metric
クラスを継承しており,update
メソッドでmetric計算に必要な情報を更新し,compute
メソッドで実際にmetric計算を行う._ResultMetric
のupdate
メソッドではreduction functionに合わせてmetric計算に必要な情報をストアする.このreduction functionとしてLightningがネイティブにサポートしているのは平均,合計,最大/最小である.つまり,lossのようなmetricの場合,特に意識しなくてもLightning側で検証用データセット全体の平均値として登録される.
複雑な集約処理が必要なmetricの場合
それでは,平均,合計,最大/最小以外のreduction functionが必要な場合はどうすれば良いだろうか.例えば,ROCAUCやランキングを評価するようなmetricの場合,検証用データセット全体の推論が完了してから出ないとmetricが計算できない場合がある.そのような場合はTorchMetricsを使用することがLightningのドキュメントでも奨励されている.
TorchMetricsで実装されているmetricを使用する場合
TorchMetricsで実装されているmetricの場合,必要なmetricを計算するclassを使用する.
例えばROCAUCを計算する場合は以下のようにvalidation_step
を実装する.
import lightning.pytorch as pl
import torchmetrics
class BinaryModel(pl.LightningModule):
def __init__(self):
...
self.rocauc = torchmetrics.classification.BinaryAUROC()
def validation_step(self, batch, batch_idx):
# this is the validation loop
x, y = batch
y_hat = self.forward(x)
self.rocauc(y_hat, y)
self.log("val_rocauc", self.rocauc, on_step=False, on_epoch=True)
TorchMetricsで実装されていないmetricを使用する場合
どうしても独自で実装したmetric計算が必要な場合でもTorchMetricsが活用できる.TorchMetricsはMetric
というクラスを用意しており,TorchMetricsで実装されているmetricはこのclassを継承して実装されている.つまり,Metric
を継承したmetric計算クラスを実装すれば独自のmetricもLightningで計算できる.具体的な実装方法はTorchMetricsのドキュメントに記載されているのでそちらを参照.また例として,huggingfaceのevaluateをTorchMetricsで書き直した実装を載せておく.
import evaluate
import torchmetrics
class EvaluateROCAUC(torchmetrics.Metric):
def __init__(self):
super().__init__()
self.evaluate_roc_auc = evaluate.load('roc_auc')
self.add_state("preds", default=torch.tensor([]), dist_reduce_fx="cat")
self.add_state("target", default=torch.tensor([]), dist_reduce_fx="cat")
def update(self, preds: torch.Tensor, target: torch.Tensor):
assert preds.shape == target.shape
self.preds = torch.cat([self.preds, preds])
self.target = torch.cat([self.target, target])
def compute(self):
roc_auc = self.evaluate_roc_auc.compute(references=self.target, prediction_scores=self.preds)['roc_auc']
return torch.tensor(roc_auc)
おわりに
log
メソッドは初見だと少し挙動に戸惑うが,動作がわかればかなり実装がわかりやすくなると思う.lightning.LightningModule
に関しても全てをこのモジュールに実装すると複雑な実装となってしまうが,モデル定義はtorch.nn.Module
を継承したもの,データ周りはtorch.utils.data.Dataset
の実装,metric計算はTorchMetricsのように責務を分解して,それを取りまとめる立ち位置としてlightning.LightningModule
を導入するとかなり見通しの良いコードになると思う.
Discussion