MIXI DEVELOPERS
🐲

Hydra + PyTorch Lightningを使ったDeep Learning モデル構築テンプレート紹介

2024/12/08に公開

始めに

これはMIXI DEVELOPERS Advent Calendar 2024の8日目の記事です。

こんにちは、みてねプロダクト開発部Data Engineeringグループの kittchy です。現在、MLエンジニアとしてML解析パイプラインの整備やMLモデルの構築などを担当しています。その中でも特に研究開発の一環として、PyTorchを用いたモデルの構築や学習を行っています。

本記事では、PyTorchによるDeep Learningモデルの学習において、私が最近よく使用している便利なツールであるHydraとPyTorch Lightningの紹介を行います。また、ashleve/lightning-hydra-templateを用いて手軽にDeep Learningモデルを構築する方法についても紹介します。

https://github.com/ashleve/lightning-hydra-template

この記事が向いている方

  • Deep LearningモデルをPyTorchで簡易的な実装、学習をしたことがある。
  • Pure PyTorchの実装が面倒くさいと思っている。
  • PyTorch Lightningをモデル構築経験がない。
  • Hydraを使ったモデル構築経験がない。
  • ashleve/lightning-hydra-templateを知らない。

PyTorchのみを用いて学習する際の面倒くさいところ

Deep Learningモデルを学習する場合、データセットの構築、データセットを読み込むコード作成、モデルコードの作成、学習コードの作成、ハイパーパラメータ調整、モデルの評価など、非常に複雑な処理を行う必要があります。その中でも特に重要かつ面倒くさいのが、学習コードの作成と実験パラメータの管理です。

1つ目の学習コードの作成では、学習のイテレーション処理だけでなく、Early StoppingやGradient Accumulationなど、さまざまな最適化手法があり、非常に複雑な処理を記述する必要があります。これらを1から記述する場合、コードが複雑かつ膨大になるため実装ミスに怯えながら実験することになり、結果的に実験が失敗してしまう場合があります。

2つ目の実験パラメータの管理に関しては、実験の再現性につながります。Deep Learningモデルはハイパーパラメータを調整して何度も学習することでより精度の良いモデルを作ることができます。この膨大なハイパーパラメータを管理し、実験の記録や再現性を確保する必要がありますが、これを手動で行うと、実験結果の記録漏れや、再現性の確保が難しくなります。

面倒くさいを解消するツール

これらの背景から、私は学習コード作成を楽にするために、PyTorch Lightning 、2つ目の実験パラメータの管理を楽にするために Hydra を主に使って、面倒くさいを解消しています。

PyTorch Lightning : 学習コード作成を楽にしてくれる

https://lightning.ai/docs/pytorch/stable/

PyTorch LightningはPyTorchフレームワークを利用して機械学習モデルを簡便かつ効率的に開発するためのライブラリです。コードの構造化、トレーニングの設定行うことで、機械学習の実装を大幅に簡略化します。

私達が学習のために行うことは、Lightning Moduleの作成、Datasets/Dataloaderの作成、Trainerの作成の3つのみです。

  • Lightning Module
    PyTorchの torch.nn.Module のようにDeep Learningモデルを実装しますが、PyTorch Lightning独自のHooksが用意されていて、これらを実装することで Trainer による学習が可能となります。

    以下のコードは、PyTorch Lightningを用いてTransformerモデルを学習する例です。

    • training_step : batchとbatch indexを引数として学習イテレーションのたびに呼び出されます。ここでlossも計算して戻り値とすることで、自動で逆伝搬も計算してくれます。
    • forward : training_stepself(inputs, target) によって呼び出され、順伝搬が実行されます。
    • configure_optimizers:最適化関数や学習率schedulerなどを定義します。
    import lightning as L
    import torch
    
    from lightning.pytorch.demos import Transformer
    
    class LightningTransformer(L.LightningModule):
        def __init__(self, vocab_size):
            super().__init__()
            self.model = Transformer(vocab_size=vocab_size)
    
        def forward(self, inputs, target):
            return self.model(inputs, target)
    
        def training_step(self, batch, batch_idx):
            inputs, target = batch
            output = self(inputs, target) # self.forwardが呼び出される
            loss = torch.nn.functional.nll_loss(output, target.view(-1))
            return loss
    
        def configure_optimizers(self):
            return torch.optim.SGD(self.model.parameters(), lr=0.1)
    model = LightningTransformer(vocab_size=dataset.vocab_size)
    
  • Datasets/Dataloader
    PyTorchのDatasets, Dataloaderを作成します。必要に応じてCustomDatasets を作成する場合もあります。

    from lightning.pytorch.demos import WikiText2
    from torch.utils.data import DataLoader
    
    dataset = WikiText2()
    dataloader = DataLoader(dataset)
    
  • Trainer構築、学習実行
    Trainerを用いて、作成したモデルとdataloaderで学習を行います。ここで、さまざまなサブモジュールによって、学習テクニック(例:Early Stopping, Gradient Accumulationなど)を使うことができます。詳しくはTrainer を御覧ください。

    trainer = L.Trainer(fast_dev_run=100)
    trainer.fit(model=model, train_dataloaders=dataloader)
    

Hydra : 実験パラメータの管理を楽にしてくれる

https://hydra.cc/docs/intro/

HydraはPythonプログラムにおいて設定管理を効率的に行うためのフレームワークです。設定ファイルの階層化や動的な構成変更を容易にし、大規模なプロジェクトや研究開発において非常に有用です。

たとえば、以下のPythonコードのハイパーパラメータを調整することを考えてみます。

# my_app.py
class Optimizer:
    algo: str
    lr: float

    def __init__(self, algo: str, lr: float) -> None:
        self.algo = algo
        self.lr = lr

class Dataset:
    name: str
    path: str

    def __init__(self, name: str, path: str) -> None:
        self.name = name
        self.path = path

class Trainer:
    def __init__(self, optimizer: Optimizer, dataset: Dataset) -> None:
        self.optimizer = optimizer
        self.dataset = dataset
  • 従来までのハイパーパラメータの管理
    Pythonのargparseがパラメータ調整によく用いられていましたが、パラメータが増えるたびに add_argument を追加する必要がありました。

    import argparse
    
    parser = argparse.ArgumentParser(
        prog="DeepLearningModel", description="model parameter"
    )
    parser.add_argument("--algo", default="SGD")
    parser.add_argument("--lr", default=0.01)
    parser.add_argument("--dataset-name", default="Imagenet")
    parser.add_argument("--data-path", default="/datasets/imagenet")
    arg = parser.parse_args()
    
    datasets = Dataset(name=arg.dataset_name, path=arg.data_path)
    optimizer = Optimizer(algo=arg.algo, lr=arg.lr)
    trainer = Trainer(optimizer=optimizer, dataset=datasets)
    
    

    また、モデルの学習を実行する際は、引数を大量につけて実行するため可読性が下がり、さらに実験パラメータの記録を残すためには実行コードをshell scriptファイルとして保存するなどの工夫が必要です。

    python src/train.py --algo SGD --lr 0.01 --dataset-name Imagenet --data-path /datasets/imagenet
    
  • Hydraを使った場合
    一方、Hydraを使うと以下のようなyamlファイルでハイパーパラメータを管理できます。ここで、 _target_ はClassのパスを指定しており、それ以外は、コンストラクタの引数を表しています。たとえば2つ目の my_app.Optimizer./my_app.pyOptimizer classを使うことを示しており、コンストラクタの引数には “SGD” と0.01が入ります。

    # config.yaml
    trainer:
      _target_: my_app.Trainer
      optimizer:
        _target_: my_app.Optimizer
        algo: SGD
        lr: 0.01
      dataset:
        _target_: my_app.Dataset
        name: Imagenet
        path: /datasets/imagenet
    

    あとは以下のようなPythonコードを作成して実行するだけで、instanceが自動で作成されます。

    import hydra
    
    @hydra.main(version_base=None, config_path=".", config_name="config")
    def main(cfg: DictConfig) -> None:
      # cnf には config.yaml の内容が入っている
      # instantinateで自動でインスタンスが作成される
      trainer = hydra.utils.instantiate(cfg.trainer)
    
    if __name__ == "__main__":
        main()
    

lightning-hydra-template の紹介

HydraとPyTorch Lightningを用いた便利なテンプレートリポジトリとしてlightning-hydra-templateが公開されています。これ以降では、実際にashleve/lightning-hydra-templateを使いながらDeep Learningモデルを構築・学習するコードを実際に見ながら学んでいきます。

作成するモデル

今回は、自分が過去に趣味で作った body-height-estimator を使いながら解説していきます。GitHubにarchiveで公開しているため、Git cloneでダウンロードしてくることも可能です。

こちらは、画像に写った写真の身長体重がラベル付されたデータセットのHeights and Weights Datasetを用いて、画像のみから身長を推定するタスクです。

以下の論文を参考に作成しました。

ディレクトリ構造

テンプレートのディレクトリ構造は以下のようになっています。

.
├── Makefile
├── README.md
├── configs/
├── data/
├── environment.yaml
├── logs
├── notebooks
├── pyproject.toml
├── requirements-dev.lock
├── requirements.lock
├── scripts/
├── src
│   ├── __init__.py
│   ├── data/
│   │   ├── __init__.py
│   │   ├── components
│   │   │   └── datasets.py
│   │   └── datamodule.py
│   ├── models/
│   │   ├── __init__.py
│   │   ├── components
│   │   │   ├── __init__.py
│   │   │   ├── resnet.py
│   │   │   └── senet.py
│   │   └── module.py
│   ├── preprocess.py
│   ├── eval.py
│   ├── train.py
│   └── utils/
└── tests/

私達が主に編集すべきディレクトリは以下の3つです。

  • configs : Hydraのyamlファイルを格納する場所
  • scripts: モデル学習のスクリプトを格納する。
  • src: Lightning Module, Datasets, 学習、 テストのエンドポイントなどが格納される

また、以下のディレクトリには実験データやログなどが保存されます。

  • data: 前処理などで作成した画像やマニフェストを格納する。
  • logs: 実験のときに出力されたログが格納される。

実装

1. マニフェスト作成

Datasetsでデータを読み込みやすいようにjsonp形式(jsonやyamlでも可)でマニフェストファイルを作成し、事前にtrain, valid, testセットに分割します。ここでは original_image_path , depth_image_path, mask_image_path , pose_image_path , height を1データとして保持するように作成します。

{"original_image_path":"data/Dump/original/606-370_Vance_L1.jpg", "height":198.12}
{"original_image_path":"data/Dump/original/601-210_cecil_L1.jpg", "height":185.42}
{"original_image_path":"data/Dump/original/505-140_Julie_L1.jpg", "height":160.02}
{"original_image_path":"data/Dump/original/604-120_Thomas_L1.jpg", "height":190.5}
{"original_image_path":"data/Dump/original/603-150_Mdoucet_L1.jpg", "height":190.5}
...

今回は src/preprocess.py にマニフェスト生成コードを作成しました。以下のコマンドを実行して data/manifests に生成します。

2. Datasets/Datamodule(dataloader)作成

PyTorchのDatasetsと、Datasetsの出力をバッチに固める処理を collent_fn に定義します。

src/data/components/datasets.py

import torch
from torch.utils.data import Dataset
import numpy as np
from data.manifest import Manifest
from PIL import Image
from src.data.components.augmentor import Augmentor

# リソースの指定(CPU/GPU)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

class HeightDataset(Dataset):
    def __init__(self, manifests_path: str):
        self.manifests = Manifest.load_jsonp(manifests_path)

    def __len__(self):
        return len(self.manifests)

    def __getitem__(self, index):
        manifest = self.manifests[index]

        original = Image.open(manifest.original_image_path).convert(mode="L")
        original = np.expand_dims(original, 0)
        original = torch.tensor(original, dtype=torch.float).to(device)


        height = torch.tensor(manifest.height).to(device)
        return original, height

def collent_fn(batch):
    original, height = zip(*batch)

    # サイズを揃えるためにpadding
    max_height = max([o.size(1) for o in original])
    max_width = max([o.size(2) for o in original])
    original = [torch.nn.functional.pad(o, (0, max_width - o.size(2), 0, max_height - o.size(1))) for o in original]

    height = torch.stack(height)
    return (
        torch.stack(original),
        height,
    )

作成したDatasetsやcollent_fnを用いて、PyTorch LightningのDataModuleを構築します。これにより、PyTorch Lightningが学習時や評価時に適切なDataloaderを呼び出せます。

src/data/datamodule.py

from typing import Any, Optional

from lightning import LightningDataModule
from torch.utils.data import DataLoader, Dataset, random_split
from data.components.datasets import HeightDataset, collent_fn

class DataModule(LightningDataModule):
    def __init__(
        self,
        train_manifest_file: str,
        valid_manifest_file: str,
        test_manifest_file: str,
        batch_size: int = 64,
        num_workers: int = 0,
        pin_memory: bool = False,
    ) -> None:
        """ """
        super().__init__()

        self.save_hyperparameters(logger=False)

        self.data_train: Optional[Dataset] = None
        self.data_val: Optional[Dataset] = None
        self.data_test: Optional[Dataset] = None

        self.batch_size_per_device = batch_size
        self.train_manifest_file = train_manifest_file
        self.valid_manifest_file = valid_manifest_file
        self.test_manifest_file = test_manifest_file

    def setup(self, stage: Optional[str] = None) -> None:
        if self.trainer is not None:
            if self.hparams.batch_size % self.trainer.world_size != 0:
                raise RuntimeError(
                    f"Batch size ({self.hparams.batch_size}) is not divisible by the number of devices ({self.trainer.world_size})."
                )
            self.batch_size_per_device = (
                self.hparams.batch_size // self.trainer.world_size
            )
        if stage == "fit":
            self.data_train = HeightDataset(self.train_manifest_file)
            self.data_val = HeightDataset(self.valid_manifest_file)
        elif stage == "test":
            self.data_test = HeightDataset(self.test_manifest_file)

    def train_dataloader(self) -> DataLoader[Any]:
        return DataLoader(
            dataset=self.data_train,
            batch_size=self.batch_size_per_device,
            num_workers=self.hparams.num_workers,
            pin_memory=self.hparams.pin_memory,
            shuffle=True,
            collate_fn=collent_fn,
        )

    def val_dataloader(self) -> DataLoader[Any]:
        return DataLoader(
            dataset=self.data_val,
            batch_size=self.batch_size_per_device,
            num_workers=self.hparams.num_workers,
            pin_memory=self.hparams.pin_memory,
            shuffle=False,
            collate_fn=collent_fn,
        )

    def test_dataloader(self) -> DataLoader[Any]:
        return DataLoader(
            dataset=self.data_test,
            batch_size=self.batch_size_per_device,
            num_workers=self.hparams.num_workers,
            pin_memory=self.hparams.pin_memory,
            shuffle=False,
            collate_fn=collent_fn,
        )

3. モデル構築

PyTorch Lightning Moduleを作成します。

前述で紹介したLightning Moduleの training_step forward configure_optimizers のほかにもHooksが増えています。

  • 私が作った独自の関数
    • average_relative_error Loss関数
    • model_step batch取得からloss計算までの処理を共通化
  • PyTorch Lightningが用意したHooks
    • on_{train/validation/test}_start : epochの始まりに実行されます。
    • on_{train/validation/test}_epoch_end :epochの終了後に実行されます。
    • setup training/testが実行される前に実行されます。

src/models/module.py

from typing import Any, Dict, Tuple

import torch
from lightning import LightningModule
from torchmetrics import MaxMetric, MeanMetric, MinMetric
from torchmetrics import MeanAbsoluteError

class BodyHeightModule(LightningModule):
    def __init__(
        self,
        net: torch.nn.Module,
        optimizer: torch.optim.Optimizer,
        scheduler: torch.optim.lr_scheduler,
        compile: bool,
    ) -> None:
        super().__init__()
        self.save_hyperparameters(logger=False)

        self.net = net

        self.train_loss = MeanMetric()
        self.val_loss = MeanMetric()
        self.test_loss = MeanMetric()

    def average_relative_error(
        self, preds: torch.Tensor, targets: torch.Tensor
    ) -> torch.Tensor:
        return torch.mean(torch.abs(preds - targets) / targets)

    def forward(
        self,
        original: torch.Tensor,
    ) -> torch.Tensor:
        x = original
        return self.net(x)

    def on_train_start(self) -> None:
        self.val_loss.reset()

    def model_step(
        self,
        batch: Tuple[torch.Tensor, torch.Tensor],
    ) -> torch.Tensor:
        original, height = batch
        est_height = self.forward(original)
        loss = self.average_relative_error(est_height, height)
        return loss

    def training_step(
        self,
        batch: Tuple[torch.Tensor, torch.Tensor],
        batch_idx: int,
    ) -> torch.Tensor:
        loss = self.model_step(batch)

        self.train_loss(loss)
        self.log("train/loss", self.train_loss, on_step=False, on_epoch=True, prog_bar=True)

        return loss

    def on_train_epoch_end(self) -> None:
          pass

    def validation_step(
        self,
        batch: Tuple[torch.Tensor, torch.Tensor],
        batch_idx: int,
    ) -> None:
        loss = self.model_step(batch)

        # update and log metrics
        self.val_loss(loss)
        self.log("val/loss", self.val_loss, on_step=False, on_epoch=True, prog_bar=True)

    def test_step(
        self,
        batch: Tuple[torch.Tensor, torch.Tensor],
        batch_idx: int,
    ) -> None:
        loss = self.model_step(batch)

        # update and log metrics
        self.test_loss(loss)
        self.log("test/loss", self.test_loss, on_step=False, on_epoch=True, prog_bar=True)

    def setup(self, stage: str) -> None:
        if self.hparams.compile and stage == "fit":
            self.net = torch.compile(self.net)

    def configure_optimizers(self) -> Dict[str, Any]:
        optimizer = self.hparams.optimizer(params=self.trainer.model.parameters())
        if self.hparams.scheduler is not None:
            scheduler = self.hparams.scheduler(optimizer=optimizer)
            return {
                "optimizer": optimizer,
                "lr_scheduler": {
                    "scheduler": scheduler,
                    "monitor": "val/loss",
                    "interval": "epoch",
                    "frequency": 1,
                },
            }
        return {"optimizer": optimizer}

4. 実験ハイパラ定義

最後に configs/ で学習のハイパーパラメータを定義します。変更する必要があるファイルは以下の3つです。

  • configs/data/body_height.yaml
    新しくファイルを作成します。今回は body_height.yaml としました。
    先ほど作成したDataModuleの引数として、マニフェストファイルやbatchサイズなどを指定します。

    _target_: src.data.datamodule.DataModule
    train_manifest_file: ${paths.data_dir}/Manifest/train.jsonp
    valid_manifest_file: ${paths.data_dir}/Manifest/valid.jsonp
    test_manifest_file: ${paths.data_dir}/Manifest/test.jsonp
    batch_size: 32
    num_workers: 0
    pin_memory: False
    
  • configs/model/body_height.yaml
    新しくファイルを作成します。今回は body_height.yaml としました。先ほど作成したLightningModuleの引数のOptimizer, Scheduler, Modelを定義し、ハイパーパラメータを指定します。

    _target_: src.models.module.BodyHeightModule
    
    optimizer:
      _target_: torch.optim.Adam
      _partial_: true
      lr: 0.001
      weight_decay: 0.0
    
    scheduler:
      _target_: torch.optim.lr_scheduler.CosineAnnealingLR
      _partial_: true
      T_max: 10
    
    net:
      _target_: src.models.components.resnet.ResNet
      layers: [2, 2, 2, 2]
      num_classes: 1
      use_senet: true
      ratio: 16
    
    compile: false
    
  • configs/main.yaml
    上記で作成したconfigs/data/body_height.yamlconfigs/model/body_height.yaml をdefaultに定義します。

    defaults:
      - _self_
      - data: body_height # ← configs/model/にあるファイル名に変更
      - model: body_height # ← configs/model/にあるファイル名に変更
      - callbacks: default
      - logger: tensorboard # tensorboard形式でlogが保存される
      - trainer: default
      - paths: default
      - extras: default
      - hydra: default
      - experiment: null
      - hparams_search: null
      - optional local: default
      - debug: null
    
    task_name: "train"
    tags: ["dev"]
    
    train: True
    test: True
    ckpt_path: null
    seed: null
    

5. 学習

pytorch-lightning-templateに予め用意されている train.py コードをそのまま使って実行します。

しばらく待つと学習が始まり、epochごとのlossが出てくると思います。

python src/train.py

終わりに

今回はDeep Learningモデルを学習する際の面倒くさいポイントとそれを解消するPyTorch LightningとHydraを紹介しました。

また、後半ではlightning-hydra-template を用いたモデル構築方法について紹介しました。

PyTorch LightningとHydraを用いた実装はDeep LearningフレームワークのNeMoでも使われており、現在のDeep Learningのデファクトスタンダードとなっています。また、今回紹介しなかったですが、WandBやOptunaといった、ほかのフレームワークも柔軟に組み合わせることが可能なで、非常に便利です。

ぜひこれらを使いこなして、いろいろなモデルを構築してみてください。

MIXI DEVELOPERS
MIXI DEVELOPERS

Discussion