🦉

Pytorch lightning のフックとコールバックについて簡単まとめ

に公開

はじめに:Pytorch lightningについて

PyTorch Lightning は、PyTorch の柔軟性をそのままに、学習ループや分散処理、ロギング、チェックポイント保存といった煩雑な実装を自動化してくれるラッパーフレームワークです。
「モデル定義」と「訓練の仕組み(Trainer)」を明確に分離することで、可読性と再利用性の高いコードを書くことができ、研究・開発の両面で多くの場面で活用されています。

つまり、「研究のロジックに集中し、繰り返し書く訓練コードから解放される」のが PyTorch Lightning の最大のメリットです。

訓練のカスタマイズ:フックとコールバックについて

Pytorch lightning を使って機械学習を行う上で、最初につまずくところが、どうやってカスタマイズすれば思い通りの学習プロセスを実装できるのか?だと思います。
そんな時によく利用するのが、特定のタイミングで処理を差し込む仕組みである、「フック(Hook)」と「コールバック(Callback)」です。
PyTorch Lightning における「フック(Hook)」と「コールバック(Callback)」は似ているようで明確な違いがあります。どちらも「特定のタイミングで処理を差し込むための仕組み」ですが、目的・使い方・拡張性の観点で違いがあります。

以下にそれぞれの違いを分かりやすく整理します。

違いまとめ表

観点 フック(Hooks) コールバック(Callbacks)
定義場所 LightningModule Callback クラスとして別定義
使い方 メソッドをオーバーライド Trainer(callbacks=[...]) で渡す
主な用途 学習処理そのものの制御 ロギング、EarlyStopping、保存など外部的な制御
対象範囲 モデル単位で限定的 モデルに依存せず再利用可能
ユースケース training_step, on_train_epoch_start などで内部処理を記述 汎用的な処理(例:ModelCheckpoint, LearningRateMonitor
再利用性 モデルに依存しやすい モジュール化されていて使い回しやすい

違いがわかったところで、具体的にどのように実装するかを以下にまとめています。

具体例:Hook の使用

以下のように、モデル内で、オーバーライドする形で記述します。

class MyModel(pl.LightningModule):
    def on_train_epoch_start(self):
        print("エポック開始!")

    def training_step(self, batch, batch_idx):
        ...

具体例:Callback の使用

一方で、Callbackは、pl.Callbackクラスを継承し、trainerに渡す形で記述します。

class PrintEpochEndCallback(pl.Callback):
    def on_train_epoch_end(self, trainer, pl_module):
        print(f\"Epoch {trainer.current_epoch} が終了しました\")

trainer = pl.Trainer(callbacks=[PrintEpochEndCallback()])

使い分けの目安

やりたいこと 選ぶべきもの
モデル内部の挙動を変えたい(学習ステップ・勾配制御など) フック(Hook)
学習進行に合わせて外部から監視・制御したい(EarlyStopping、ログなど) コールバック(Callback)
複数モデルで同じ処理を再利用したい コールバック(Callback)
モデル依存の処理をきめ細かくカスタムしたい フック(Hook)

補足

実際には、Trainer は on_train_start() などの共通のイベント名(=フック名)で両方を呼び出しており、「内部に書くか(hook)/外から渡すか(callback)」が最大の違いです。


フックの呼び出し順

ドキュメント読んだだけじゃわかりにくかったので、呼び出し順をまとめると以下のような形になります。
これらのフックのうち、適切なタイミングのものをオーバーライドして実装することで、学習をカスタムすることができます。
存在するフックはpytorch_lightningのソースコードを見ると早いので、参考に。

fit() 全体の処理とフック呼び出し順

  1. configure_callbacks()
  2. prepare_data()(※ local_rank == 0 のときのみ実行)
  3. setup(stage="fit")
  4. configure_model()
  5. configure_optimizers()
  6. on_fit_start()
    • Sanity check (trainer.validate(...))
  7. on_train_start()

fit_loop()(エポックごとの処理)

  1. on_train_epoch_start()

  2. 各トレーニングバッチに対して:

    1. on_train_batch_start()
    2. on_before_batch_transfer()
    3. transfer_batch_to_device()
    4. on_after_batch_transfer()
    5. training_step()
    6. on_before_zero_grad()
    7. optimizer_zero_grad()(通常は内部で optimizer.zero_grad()
    8. on_before_backward()
    9. backward()
    10. on_after_backward()
    11. on_before_optimizer_step()
    12. configure_gradient_clipping()(任意)
    13. optimizer_step()
    14. on_train_batch_end()
    • 条件付きで val_loop() 実行:if should_check_val
  3. on_train_epoch_end()


val_loop()(検証ループ)

  1. on_validation_model_eval()(内部で model.eval()

  2. on_validation_start()

  3. on_validation_epoch_start()

  4. 各検証バッチに対して:

    1. on_validation_batch_start(batch, batch_idx)
    2. on_before_batch_transfer(batch)
    3. transfer_batch_to_device(batch)
    4. on_after_batch_transfer(batch)
    5. validation_step(batch, batch_idx)
    6. on_validation_batch_end(outputs, batch, batch_idx)
  5. on_validation_epoch_end()

  6. on_validation_end()

  7. on_validation_model_train()(→ model.train() に戻す)

  8. torch.set_grad_enabled(True)


後処理

  • on_train_end()
  • on_fit_end()
  • teardown(stage="fit")

test_loop()(テストループ)

  1. on_test_model_eval()model.eval() を呼び出す)

  2. on_test_start()

  3. on_test_epoch_start()

  4. 各テストバッチに対して:

    a. on_test_batch_start(batch, batch_idx, dataloader_idx)
    b. on_before_batch_transfer(batch, dataloader_idx)
    c. transfer_batch_to_device(batch, device, dataloader_idx)
    d. on_after_batch_transfer(batch, dataloader_idx)
    e. test_step(batch, batch_idx, dataloader_idx)
    f. on_test_batch_end(outputs, batch, batch_idx, dataloader_idx)

  5. on_test_epoch_end()

  6. on_test_end()

  7. on_test_model_train()(→ model.train() に戻す)


predict_loop()(予測ループ)

  1. on_predict_model_eval()model.eval() を呼び出す)

  2. on_predict_start()

  3. on_predict_epoch_start()

  4. 各予測バッチに対して:

    a. on_predict_batch_start(batch, batch_idx, dataloader_idx)
    b. on_before_batch_transfer(batch, dataloader_idx)
    c. transfer_batch_to_device(batch, device, dataloader_idx)
    d. on_after_batch_transfer(batch, dataloader_idx)
    e. predict_step(batch, batch_idx, dataloader_idx)
    f. on_predict_batch_end(outputs, batch, batch_idx, dataloader_idx)

  5. on_predict_epoch_end(results)

  6. on_predict_end()


まとめ

以上が、Pytorch lightning を利用しつつ、学習プロセスをカスタマイズする方法のまとめでした。
この記事を参考に、フックとコールバックを駆使して今後のタスクに活用できるよう精進していきましょう。

参考リンク

Fusic 技術ブログ

Discussion