🌟

Pytorch Lightningに入門してみた

に公開

今回はpytorchのエコシステムの一つであるLightningに入門してみました。

Lightningとは?

PyTorch Lightningとは、PyTorchのエコシステムの一つであり、モデルの開発から本番運用まで広く対応したフレームワークです。最小限の労力で最大限の柔軟性を与えることができ、パフォーマンスを落とすことなくモデル開発を行うことができます。PyTorchのモデル学習では必ずしもLightningのようなフレームワークは必要なく、PyTorchのみでも十分モデル学習はできます。一方、実験管理などをしようとお思うと独自に実装する必要があったりで結構面倒な作業が多いのも事実です。その代わりにLightningを利用することで、モデルの学習コードが冗長になることを防ぎつつ柔軟に調整できます。

https://lightning.ai/docs/pytorch/stable/

実際に使ってみる!

今回は公式が提供しているLightning in 15 minutesというチュートリアルをやってみます!内容としては、MNISTデータに対してLinearレイヤのみを使ってAutoEncoderを学習させるというものです。

環境構築

まずは環境構築から!uvを使って環境構築します。

uv init lightning-tutorial -p 3.12
cd lightning-tutorial
uv add torch torchvision lightning tensorboard

コード実装

チュートリアルで提供されているコードは以下になります。

train.py
import torch
import os
from torch import optim, nn, utils, Tensor
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import lightning as L

# define any number of nn.Modules (or use your current ones)
encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))


# define the LightningModule
class LitAutoEncoder(L.LightningModule):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        # it is independent of forward
        x, _ = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = nn.functional.mse_loss(x_hat, x)
        # Logging to TensorBoard (if installed) by default
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer


# init the autoencoder
autoencoder = LitAutoEncoder(encoder, decoder)

# setup data
dataset = MNIST(os.getcwd(), download=True, transform=ToTensor())
train_loader = utils.data.DataLoader(dataset)

# train the model (hint: here are some helpful Trainer arguments for rapid idea iteration)
trainer = L.Trainer(limit_train_batches=100, max_epochs=10)
trainer.fit(model=autoencoder, train_dataloaders=train_loader)

# load checkpoint
checkpoint = "./lightning_logs/version_0/checkpoints/epoch=9-step=1000.ckpt"
autoencoder = LitAutoEncoder.load_from_checkpoint(checkpoint, encoder=encoder, decoder=decoder)

# choose your trained nn.Module
encoder = autoencoder.encoder
encoder.eval()

# embed 4 fake images!
fake_image_batch = torch.rand(4, 28 * 28, device=autoencoder.device)
embeddings = encoder(fake_image_batch)
print("⚡" * 20, "\nPredictions (4 image embeddings):\n", embeddings, "\n", "⚡" * 20)

まずはAutoEncoderを実装するにあたりエンコーダとデコーダを定義します。今回はCNNを用いずにLinearだけを用いることになっています。MNISTデータは縦28ピクセル、横28ピクセルのデータなので、Encoderは入力として、Decoderは出力として28*28の一次元データを返すようになっています。

# define any number of nn.Modules (or use your current ones)
encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))

次にLightningを用いる上で一番重要なLightningModuleの実装になります。特に重要なのがtraining_step関数になります。training_step関数はいわゆるあるデータに対するモデルの更新の1ステップの内容となっており、バッチ(学習に利用されるデータの塊)とそのインデックスを受け取り、計算されたlossを返す関数となっています。Lightningを用いずに学習ステップを実装すると以下の折りたたみないのコードのようにfor分でデータローダーから取得されたデータを順番に利用しますが、そのforの中身を切り出したような感じですね。

# define the LightningModule
class LitAutoEncoder(L.LightningModule):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        # it is independent of forward
        x, _ = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = nn.functional.mse_loss(x_hat, x)
        # Logging to TensorBoard (if installed) by default
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer
Lightningを用いない学習コード例
for _ in range(epoch):
    for x, _ in dataloader:
        # ここがtraining_stepのなかみ
        x, _ = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = nn.functional.mse_loss(x_hat, x)
        # Logging to TensorBoard (if installed) by default
        self.log("train_loss", loss)

LightningModuleの準備ができたので、あとはモデルを学習するだけです。LitAutoEncoderにEncoderとDecoderを与えた上で、L.Trainer.fitにデータローダーと合わせて提供することであとは自動的に学習が進行します。データセットについてはtorchvisionがMNISTデータを提供するAPIを用意しているのでそれを利用します。

# init the autoencoder
autoencoder = LitAutoEncoder(encoder, decoder)

# setup data
dataset = MNIST(os.getcwd(), download=True, transform=ToTensor())
train_loader = utils.data.DataLoader(dataset)

# train the model (hint: here are some helpful Trainer arguments for rapid idea iteration)
trainer = L.Trainer(limit_train_batches=100, max_epochs=10)
trainer.fit(model=autoencoder, train_dataloaders=train_loader)

モデルの学習が完了すると、モデル情報は自動でlightning_logsフォルダ以下に格納されます。推論を行うときはそこから任意のチェックポイントを参照して推論させることができます。

# load checkpoint
checkpoint = "./lightning_logs/version_0/checkpoints/epoch=9-step=1000.ckpt"
autoencoder = LitAutoEncoder.load_from_checkpoint(checkpoint, encoder=encoder, decoder=decoder)

# choose your trained nn.Module
encoder = autoencoder.encoder
encoder.eval()

# embed 4 fake images!
fake_image_batch = torch.rand(4, 28 * 28, device=autoencoder.device)
embeddings = encoder(fake_image_batch)
print("⚡" * 20, "\nPredictions (4 image embeddings):\n", embeddings, "\n", "⚡" * 20)

学習を実行してみる

それでは先ほどのコードを実行してみましょう。実行すると、以下のgifのようなログが表示されます。モデルのパラメータや実行モードなどの情報を表示しつつ、エポックごとの情報なども適切に表示されています。

uv run train.py

結果はlightning_logsに保存されており、そこを参照することでtensorboardで結果が確認できます。

uv run tensorboard --logdir .

ちなみにテストステップはというと、、、

先ほど作ったモデルのテストをしたい場合、以下のようなコードを入れるとテストができます!学習時はtrain_stepを実装していましたが、テストはtest_stepを実装します。

class LitAutoEncoder(L.LightningModule):
    def test_step(self, batch, batch_idx):
        # this is the test loop
        x, _ = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        test_loss = F.mse_loss(x_hat, x)
        self.log("test_loss", test_loss)

テストの実行はtrainer.test(model, dataloaders=DataLoader(test_data))のように実装することで対応できます。

そのほかの機能について

Lightningでは様々な機能が提供されているので、一部ご紹介します。

  • Hooks: 様々なフックを実装可能
    • たとえばbackwardステップの前/後段階で特定の処理をしたいとかができるようになります
  • Trainer flags: フラグによる動作制御
    • たとえば学習デバイス(CPUやGPUなど)の変更やデバイス数の設定などができます

まとめ

今回はPytorchのエコシステムの一つであるLightningを利用しました。学習コードが最低限に収まるだけでなく、様々なフックなども用意されていることから、Pytorchを利用して独自モデルを学習される方はぜひ使ってみてください。

Discussion