PyTorch LightningでTrain中に特徴量とモデルを保存する
TL;DR
LightningModuleの中に特徴保存用のプロパティを作成し、on_validation_end
でそれらを保存すれば良い。
Trainer
の最初のsanity checkに注意。
Lightningを使っている場合に、Train中のバリデーションスコアが良かったEpochの特徴量とモデルを保存したい時の方法を以下に書きます。
想定されるシチュエーションとしては、分類タスクにおいてクロスバリデーションを行い、各FOLDで最も良いGPA入出力特徴と推論された予測ラベル、対応する教師ラベルを1つのpklファイルにまとめて保存する感じです
実行環境 Docker
(備忘録) Dockerfileで書いたImageからContainerを作成する
ARG PYTORCH="2.6.0"
ARG CUDA="11.8"
ARG CUDNN="9"
FROM pytorch/pytorch:${PYTORCH}-cuda${CUDA}-cudnn${CUDNN}-devel
ENV TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0 7.5 8.0 8.6+PTX" \
TORCH_NVCC_FLAGS="-Xfatbin -compress-all" \
CMAKE_PREFIX_PATH="$(dirname $(which conda))/../" \
FORCE_CUDA="1" \
DEBIAN_FRONTEND=noninteractive
RUN apt-get update
RUN apt-get install -y ffmpeg git \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
# torch torchvision
RUN pip install ipython jupyter \
&& pip install numpy scikit-learn pillow pandas polars matplotlib seaborn \
&& pip install torchinfo torchmetrics lightning timm transformers
CMD ["/bin/bash"]
docker image build -t <container name>:<tag name> .
docker run --name <container name> --gpus all --shm-size 8gb -v </C/mount/dir>:/workspace -i -t <image name>:<tag name>
ライブラリ
import multiprocessing
import pickle as pkl
from sklearn import metrics
import torch
from torch import nn
import torchvision
from torchvision.transforms import v2 as transforms
import timm
from timm.layers import SelectAdaptivePool2d
import lightning
from lightning.pytorch.loggers import CSVLogger
from lightning.pytorch.callbacks import LearningRateMonitor
import torchmetrics as tm
データセットは以下の通り。
画像が格納されているディレクトリからpng/jpg画像を読み込み、教師ラベルとともに出力するデータセットを定義し、LightningDataModuleでtrain/validation/testのそれらを用意する。
学習データはオーグメンテーションを掛けている。
class ClassificationDataset(torch.utils.data.Dataset):
def __init__(self, path_list, transforms=transforms.ToImage()):
self.path_list = path_list
self.label_list = _path_to_label(path_list)
self.transforms = transforms
return None
def __len__(self):
return len(self.path_list)
def __getitem__(self, idx):
img = torchvision.io.read_image(self.path_list[idx])
img = self.transforms(img)
target = torch.zeros(NUM_CLASSES)
target[self.label_list[idx]] = 1
return img.to(torch.float), target.to(torch.float)
class ClassificationDataModule(lightning.LightningDataModule):
def __init__(self, img_size=(224,224), fold=0, batch_size=64, num_workers=4):
super().__init__()
self.batch_size = batch_size
self.num_workers = num_workers
self.fold = fold
self.train_transforms = transforms.Compose([
transforms.Resize(img_size[0]),
transforms.CenterCrop(img_size),
transforms.RandomRotation((-15, 15)),
transforms.RandomVerticalFlip(),
transforms.RandomHorizontalFlip(),
transforms.RandomGrayscale(p=0.10),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
self.test_transforms = transforms.Compose([
transforms.Resize(img_size[0]),
transforms.CenterCrop(img_size),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
return None
def prepare_data(self):
return None
def setup(self, stage=0):
dataset_path_list = _load_train_test(self.fold)
self.train_dataset = ClassificationDataset(
dataset_path_list["train"],
transforms=self.train_transforms
)
self.val_dataset = ClassificationDataset(
dataset_path_list["test"],
transforms=self.test_transforms
)
self.test_dataset = ClassificationDataset(
dataset_path_list["test"],
transforms=self.test_transforms
)
return None
def train_dataloader(self):
train_dataloader = torch.utils.data.DataLoader(
self.train_dataset,
batch_size=self.batch_size,
shuffle=True,
num_workers=self.num_workers,
)
return train_dataloader
def val_dataloader(self):
test_dataloader = torch.utils.data.DataLoader(
self.val_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
)
return test_dataloader
def test_dataloader(self):
test_dataloader = torch.utils.data.DataLoader(
self.test_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
)
return test_dataloader
モデルは以下の通り。
timmを利用してモデルを持ってきて、出力部分を吹き飛ばすことで特徴マップを得られるようにしている。また、training_step
とvalidation_step
で推論した特徴をメモリに保存しておき、validation_step
が終わった時に呼び出されるon_validation_end
でそれらをストレージへ保存している。
MODEL_NAME = "resnetrs50.tf_in1k"
DIR_LOGS = "_logs"
NUM_FOLD = 10
NUM_CLASSES = 1000
MAXEPOCH = 100
class ClassificationModel(lightning.LightningModule):
def __init__(self, num_class=1000, fold=0):
super().__init__()
self.ext = timm.create_model(
MODEL_NAME,
pretrained=True,
num_classes=NUM_CLASSES
)
self.ext.global_pool = nn.Identity()
self.ext.fc = nn.Identity()
self.gap = SelectAdaptivePool2d(pool_type="avg", flatten=nn.Flatten(start_dim=1, end_dim=-1))
self.head = nn.Linear(in_features=2048, out_features=NUM_CLASSES, bias=True)
self.criteria = nn.CrossEntropyLoss()
self.metrics_train = tm.MetricCollection([
tm.Accuracy(task="multilabel", num_labels=num_class),
tm.F1Score(task="multilabel", num_labels=num_class),
], prefix="train")
self.metrics = tm.MetricCollection([
tm.Accuracy(task="multilabel", num_labels=num_class),
tm.F1Score(task="multilabel", num_labels=num_class),
], prefix="test")
self.fold = fold
self.best_score = 0.0
self.train_feat_list = []
self.train_prediction_list = []
self.train_target_list = []
self.train_featmap_list = []
self.test_feat_list = []
self.test_prediction_list = []
self.test_target_list = []
self.test_featmap_list = []
return None
def configure_optimizers(self):
optimizer = torch.optim.AdamW(
self.parameters(),
lr=0.001,
weight_decay=0.01)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer,
MAXEPOCH,
eta_min=0.000001)
return {"optimizer": optimizer, "lr_scheduler": lr_scheduler}
def forward(self, x, mask=None):
fm = self.ext(x)
f = self.gap(fm)
o = self.head(f)
return o, f, fm
def training_step(self, batch, batch_idx):
x, target = batch
out, feat, featmap = self.forward(x)
loss = self.criteria(out, target)
acc = self.metrics_train(out, target)
acc_dict = {
"trainAcc": acc["trainMultilabelAccuracy"],
"trainF1": acc["trainMultilabelF1Score"],
}
self.log_dict(acc_dict, logger=True, prog_bar=True)
self.train_feat_list += [feat.detach().cpu()]
self.train_featmap_list += [featmap.detach().cpu()]
self.train_prediction_list += [torch.argmax(out.detach().cpu(), dim=1)]
self.train_target_list += [torch.argmax(target.detach().cpu(), dim=1)]
return {"loss": loss}
def validation_step(self, batch, batch_idx):
x, target = batch
out, feat, featmap = self.forward(x)
feat = feat.detach().cpu()
acc = self.metrics(out, target)
acc_dict = {
"Acc": acc["testMultilabelAccuracy"],
"F1": acc["testMultilabelF1Score"],
}
self.log_dict(acc_dict, logger=True, on_epoch=True, prog_bar=True)
self.test_feat_list += [feat.detach().cpu()]
self.test_featmap_list += [featmap.detach().cpu()]
self.test_prediction_list += [torch.argmax(out.detach().cpu(), dim=1)]
self.test_target_list += [torch.argmax(target.detach().cpu(), dim=1)]
return {"metrics": acc}
def test_step(self, batch, batch_idx):
x, target = batch
out, _, _ = self.forward(x)
acc = self.metrics(out, target)
acc_dict = {
"Acc": acc["testMultilabelAccuracy"],
"F1": acc["testMultilabelF1Score"],
}
self.log_dict(acc_dict)
return {"metrics": acc}
def on_validation_end(self):
try:
self.train_feat_list = torch.concat(self.train_feat_list).numpy()
self.train_featmap_list = torch.concat(self.train_featmap_list).numpy()
self.train_prediction_list = torch.concat(self.train_prediction_list).numpy()
self.train_target_list = torch.concat(self.train_target_list).numpy()
self.test_feat_list = torch.concat(self.test_feat_list).numpy()
self.test_featmap_list = torch.concat(self.test_featmap_list).numpy()
self.test_prediction_list = torch.concat(self.test_prediction_list).numpy()
self.test_target_list = torch.concat(self.test_target_list).numpy()
except RuntimeError:
print("sanity checked")
self.train_feat_list = []
self.train_featmap_list = []
self.train_prediction_list = []
self.train_target_list = []
self.test_feat_list = []
self.test_featmap_list = []
self.test_prediction_list = []
self.test_target_list = []
return super().on_validation_end()
score = metrics.accuracy_score(self.test_target_list, self.test_prediction_list)
if self.best_score < score:
data = {
"feat_train": self.train_feat_list,
"featmap_train": self.train_featmap_list,
"pred_train": self.train_prediction_list,
"target_train": self.train_target_list,
"feat_test": self.test_feat_list,
"featmap_test": self.test_featmap_list,
"pred_test": self.test_prediction_list,
"target_test": self.test_target_list,
}
with open(f"_log/feat_{self.fold:03}_best.pkl", "wb") as f:
pkl.dump(data, f)
model = nn.Sequential(
self.exter,
self.gpa,
self.head
)
torch.save(model.state_dict(), f"{DIR_LOGS}/model_{self.fold:03}.pth")
self.best_score = score
print(f"------------ BEST score updated: {score} ------------")
self.train_feat_list = []
self.train_featmap_list = []
self.train_prediction_list = []
self.train_target_list = []
self.test_feat_list = []
self.test_featmap_list = []
self.test_prediction_list = []
self.test_target_list = []
return super().on_validation_end()
def on_test_end(self):
print(f"------------ BEST score updated: {self.best_score} ------------")
return super().on_test_end()
ここで厄介なのがTrainerの最初のsanity checkで、呼び出された最初にvalidation_step
を呼び出すので、なにも保存していない特徴を保存するのを妨害しなければならない。上のプログラムではsanity checkで保存されていない特徴にアクセスしようとした時に発されるRuntimeError
を使ってon_validation_end
を終了している。
学習ループは以下のように書けて、これを実行することで全FOLDで特徴が保存できる。
def train(fold):
datamodule = ClassificationDataModule(fold=fold)
model = ClassificationModel(num_class=NUM_CLASSES, fold=fold)
trainer = lightning.Trainer(
logger=CSVLogger(DIR_LOGS, name=MODEL_NAME),
enable_checkpointing=False,
check_val_every_n_epoch=1,
callbacks=LearningRateMonitor(logging_interval="epoch"),
accelerator="gpu",
devices=1,
max_epochs=MAXEPOCH,
)
trainer.fit(
model,
datamodule=datamodule
)
result = trainer.test(
model,
datamodule=datamodule
)
return result
if __name__ == "__main__":
multiprocessing.freeze_support()
train(_fold)
おわり
Discussion