PyTorch LightningでTrain中に特徴量とモデルを保存する

2025/02/11に公開

TL;DR

LightningModuleの中に特徴保存用のプロパティを作成し、on_validation_endでそれらを保存すれば良い。
Trainerの最初のsanity checkに注意。


Lightningを使っている場合に、Train中のバリデーションスコアが良かったEpochの特徴量とモデルを保存したい時の方法を以下に書きます。

想定されるシチュエーションとしては、分類タスクにおいてクロスバリデーションを行い、各FOLDで最も良いGPA入出力特徴と推論された予測ラベル、対応する教師ラベルを1つのpklファイルにまとめて保存する感じです

実行環境 Docker

(備忘録) Dockerfileで書いたImageからContainerを作成する

ARG PYTORCH="2.6.0"
ARG CUDA="11.8"
ARG CUDNN="9"
FROM pytorch/pytorch:${PYTORCH}-cuda${CUDA}-cudnn${CUDNN}-devel

ENV TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0 7.5 8.0 8.6+PTX" \
    TORCH_NVCC_FLAGS="-Xfatbin -compress-all" \
    CMAKE_PREFIX_PATH="$(dirname $(which conda))/../" \
    FORCE_CUDA="1" \
    DEBIAN_FRONTEND=noninteractive

RUN apt-get update
RUN apt-get install -y ffmpeg git \
    && apt-get clean \
    && rm -rf /var/lib/apt/lists/*

# torch torchvision
RUN pip install ipython jupyter \
    && pip install numpy scikit-learn pillow pandas polars matplotlib seaborn \ 
    && pip install torchinfo torchmetrics lightning timm transformers
CMD ["/bin/bash"]
docker image build -t <container name>:<tag name> .
docker run --name <container name> --gpus all --shm-size 8gb -v </C/mount/dir>:/workspace -i -t <image name>:<tag name>

ライブラリ

import multiprocessing

import pickle as pkl
from sklearn import metrics

import torch
from torch import nn
import torchvision
from torchvision.transforms import v2 as transforms

import timm
from timm.layers import SelectAdaptivePool2d

import lightning
from lightning.pytorch.loggers import CSVLogger
from lightning.pytorch.callbacks import LearningRateMonitor
import torchmetrics as tm

データセットは以下の通り。
画像が格納されているディレクトリからpng/jpg画像を読み込み、教師ラベルとともに出力するデータセットを定義し、LightningDataModuleでtrain/validation/testのそれらを用意する。
学習データはオーグメンテーションを掛けている。

class ClassificationDataset(torch.utils.data.Dataset):
    def __init__(self, path_list, transforms=transforms.ToImage()):
        self.path_list = path_list
        self.label_list = _path_to_label(path_list)
        self.transforms = transforms 
        return None
    
    def __len__(self):
        return len(self.path_list)  

    def __getitem__(self, idx):
        img = torchvision.io.read_image(self.path_list[idx])
        img = self.transforms(img)
        target = torch.zeros(NUM_CLASSES)
        target[self.label_list[idx]] = 1
        return img.to(torch.float), target.to(torch.float)
    

class ClassificationDataModule(lightning.LightningDataModule):
    def __init__(self, img_size=(224,224), fold=0, batch_size=64, num_workers=4):
        super().__init__()  

        self.batch_size = batch_size
        self.num_workers = num_workers
        self.fold = fold
        
        self.train_transforms = transforms.Compose([
            transforms.Resize(img_size[0]),
            transforms.CenterCrop(img_size),
            transforms.RandomRotation((-15, 15)),
            transforms.RandomVerticalFlip(),
            transforms.RandomHorizontalFlip(),
            transforms.RandomGrayscale(p=0.10),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
        self.test_transforms = transforms.Compose([
            transforms.Resize(img_size[0]),
            transforms.CenterCrop(img_size),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
        return None

    def prepare_data(self):
        return None

    def setup(self, stage=0):
        dataset_path_list = _load_train_test(self.fold)

        self.train_dataset = ClassificationDataset(
            dataset_path_list["train"], 
            transforms=self.train_transforms
            )
        self.val_dataset = ClassificationDataset(
            dataset_path_list["test"], 
            transforms=self.test_transforms
            )
        self.test_dataset = ClassificationDataset(
            dataset_path_list["test"], 
            transforms=self.test_transforms
            )
        return None

    def train_dataloader(self):
        train_dataloader = torch.utils.data.DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
        )
        return train_dataloader
    
    def val_dataloader(self):
        test_dataloader = torch.utils.data.DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
        )
        return test_dataloader
    
    def test_dataloader(self):
        test_dataloader = torch.utils.data.DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
        )
        return test_dataloader

モデルは以下の通り。
timmを利用してモデルを持ってきて、出力部分を吹き飛ばすことで特徴マップを得られるようにしている。また、training_stepvalidation_stepで推論した特徴をメモリに保存しておき、validation_stepが終わった時に呼び出されるon_validation_endでそれらをストレージへ保存している。

MODEL_NAME = "resnetrs50.tf_in1k"
DIR_LOGS = "_logs"
NUM_FOLD = 10
NUM_CLASSES = 1000
MAXEPOCH = 100


class ClassificationModel(lightning.LightningModule):
    def __init__(self, num_class=1000, fold=0):
        super().__init__()

        self.ext = timm.create_model(
            MODEL_NAME,
            pretrained=True,
            num_classes=NUM_CLASSES
            )
        self.ext.global_pool = nn.Identity()
        self.ext.fc = nn.Identity()

        self.gap = SelectAdaptivePool2d(pool_type="avg", flatten=nn.Flatten(start_dim=1, end_dim=-1))
        self.head = nn.Linear(in_features=2048, out_features=NUM_CLASSES, bias=True)

        self.criteria = nn.CrossEntropyLoss()
        self.metrics_train = tm.MetricCollection([
            tm.Accuracy(task="multilabel", num_labels=num_class),
            tm.F1Score(task="multilabel", num_labels=num_class),
            ], prefix="train")
        self.metrics = tm.MetricCollection([
            tm.Accuracy(task="multilabel", num_labels=num_class),
            tm.F1Score(task="multilabel", num_labels=num_class),
            ], prefix="test")
        
        self.fold = fold
        self.best_score = 0.0

        self.train_feat_list = []
        self.train_prediction_list = []
        self.train_target_list = []
        self.train_featmap_list = []
        self.test_feat_list = []
        self.test_prediction_list = []
        self.test_target_list = []
        self.test_featmap_list = []

        return None
    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(  
            self.parameters(), 
            lr=0.001,
            weight_decay=0.01)
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 
            optimizer, 
            MAXEPOCH, 
            eta_min=0.000001)
        return {"optimizer": optimizer, "lr_scheduler": lr_scheduler}
    
    def forward(self, x, mask=None):
        fm = self.ext(x)
        f = self.gap(fm)
        o = self.head(f)
        return o, f, fm

    def training_step(self, batch, batch_idx):
        x, target = batch
        out, feat, featmap = self.forward(x)

        loss = self.criteria(out, target)

        acc = self.metrics_train(out, target)
        acc_dict = {
            "trainAcc": acc["trainMultilabelAccuracy"],
            "trainF1": acc["trainMultilabelF1Score"],
            }
        self.log_dict(acc_dict, logger=True, prog_bar=True) 

        self.train_feat_list += [feat.detach().cpu()]
        self.train_featmap_list += [featmap.detach().cpu()]
        self.train_prediction_list += [torch.argmax(out.detach().cpu(), dim=1)]
        self.train_target_list += [torch.argmax(target.detach().cpu(), dim=1)]

        return {"loss": loss} 
    
    def validation_step(self, batch, batch_idx):
        x, target = batch
        out, feat, featmap = self.forward(x)
        feat = feat.detach().cpu()
        acc = self.metrics(out, target)
        acc_dict = {
            "Acc": acc["testMultilabelAccuracy"],
            "F1": acc["testMultilabelF1Score"],
            }
        self.log_dict(acc_dict, logger=True, on_epoch=True, prog_bar=True)

        self.test_feat_list += [feat.detach().cpu()]
        self.test_featmap_list += [featmap.detach().cpu()]
        self.test_prediction_list += [torch.argmax(out.detach().cpu(), dim=1)]
        self.test_target_list += [torch.argmax(target.detach().cpu(), dim=1)]

        return {"metrics": acc}
    
    def test_step(self, batch, batch_idx):
        x, target = batch
        out, _, _ = self.forward(x)
        acc = self.metrics(out, target)
        acc_dict = {
            "Acc": acc["testMultilabelAccuracy"],
            "F1": acc["testMultilabelF1Score"],
            }
        self.log_dict(acc_dict)
        return {"metrics": acc}
    
    def on_validation_end(self):
        try:
            self.train_feat_list = torch.concat(self.train_feat_list).numpy()
            self.train_featmap_list = torch.concat(self.train_featmap_list).numpy()
            self.train_prediction_list = torch.concat(self.train_prediction_list).numpy()
            self.train_target_list = torch.concat(self.train_target_list).numpy()
            self.test_feat_list = torch.concat(self.test_feat_list).numpy()
            self.test_featmap_list = torch.concat(self.test_featmap_list).numpy()
            self.test_prediction_list = torch.concat(self.test_prediction_list).numpy()
            self.test_target_list = torch.concat(self.test_target_list).numpy()
        except RuntimeError:
            print("sanity checked")
            self.train_feat_list = []
            self.train_featmap_list = []
            self.train_prediction_list = []
            self.train_target_list = []
            self.test_feat_list = []
            self.test_featmap_list = []
            self.test_prediction_list = []
            self.test_target_list = []
            return super().on_validation_end()

        score = metrics.accuracy_score(self.test_target_list, self.test_prediction_list)
        if self.best_score < score:
            data = {
                "feat_train": self.train_feat_list,
                "featmap_train": self.train_featmap_list,
                "pred_train": self.train_prediction_list,
                "target_train": self.train_target_list,
                "feat_test": self.test_feat_list,
                "featmap_test": self.test_featmap_list,
                "pred_test": self.test_prediction_list,
                "target_test": self.test_target_list,
                }
            with open(f"_log/feat_{self.fold:03}_best.pkl", "wb") as f:
                pkl.dump(data, f)
                
            model = nn.Sequential(
            self.exter,
            self.gpa,
            self.head
            )
            torch.save(model.state_dict(), f"{DIR_LOGS}/model_{self.fold:03}.pth")

            self.best_score = score
            print(f"------------ BEST score updated: {score} ------------")            
        self.train_feat_list = []
        self.train_featmap_list = []
        self.train_prediction_list = []
        self.train_target_list = []
        self.test_feat_list = []
        self.test_featmap_list = []
        self.test_prediction_list = []
        self.test_target_list = []
        return super().on_validation_end()
    
    def on_test_end(self):
        print(f"------------ BEST score updated: {self.best_score} ------------")
        return super().on_test_end()

ここで厄介なのがTrainerの最初のsanity checkで、呼び出された最初にvalidation_stepを呼び出すので、なにも保存していない特徴を保存するのを妨害しなければならない。上のプログラムではsanity checkで保存されていない特徴にアクセスしようとした時に発されるRuntimeErrorを使ってon_validation_endを終了している。

学習ループは以下のように書けて、これを実行することで全FOLDで特徴が保存できる。

def train(fold):
    datamodule = ClassificationDataModule(fold=fold) 
    model = ClassificationModel(num_class=NUM_CLASSES, fold=fold) 

    trainer = lightning.Trainer(
        logger=CSVLogger(DIR_LOGS, name=MODEL_NAME),
        enable_checkpointing=False,
        check_val_every_n_epoch=1,
        callbacks=LearningRateMonitor(logging_interval="epoch"),
        accelerator="gpu",
        devices=1,
        max_epochs=MAXEPOCH,
        )
    trainer.fit( 
        model, 
        datamodule=datamodule
        )
    result = trainer.test(
        model,
        datamodule=datamodule
        )
    return result

if __name__ == "__main__":
    multiprocessing.freeze_support()
    train(_fold)

おわり

Discussion