👀

【Python】TensorBoardで学習状況を可視化してみよう

2022/03/13に公開約2,300字

はじめに

この記事を読むとaccuracylossなどの値を可視化して学習状況を理解しやすくなります.
尚本記事では,こちらの学習コードに追記していく形になるのでご了承ください↓

環境

PC MacBook Pro (16-inch, 2019)
OS Monterey
CPU 2.3 GHz 8コアIntel Core i9
メモリ 16GB
Python 3.9

使用するライブラリ

本記事で用いるライブラリとバージョンをまとめます.

ライブラリ バージョン
numpy 1.21.2
tensorboard 2.8.0
torch 1.10.1
torchvision 0.11.2
terminal
pip install numpy tensorboard torch torchvision

で問題ないかと思います.

本編

いよいよlogを可視化する記事の本編です.

コード

実際にlogを可視化するためのコードを追記していきます.

main.py
import torch
from torch import nn, optim
+ from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

from dataset.mnist import get_dataloader
from model.resnet import get_resnet

def train(total_epoch: int=100):
+   writer = SummaryWriter(log_dir="log")
    dataloader = get_dataloader(root="data", batch_size=64)

    model = get_resnet(pretrained=True)
    optimizer = optim.SGD(
        params=model.parameters(),
        lr=1e-3
    )
    scheduler = optim.lr_scheduler.OneCycleLR(
        optimizer=optimizer,
        max_lr=1e-3,
        total_steps=len(dataloader),
    )
    criterion = nn.CrossEntropyLoss()

    model.train()
    for epoch in range(total_epoch):
        accuracy, train_loss = 0.0, 0.0
        for images, labels in tqdm(dataloader):
            optimizer.zero_grad()
            
            out = model(images)
            loss = criterion(out, labels)

            loss.backward()
            optimizer.step()

            # 推測値
            preds = out.argmax(axis=1)

            # lossの算出
            train_loss += loss.item()
            accuracy += torch.sum(preds == labels).item() / len(labels)

        scheduler.step()

        # logの記録
+       writer.add_scalar("loss", train_loss / len(dataloader), epoch)
+       writer.add_scalar("accuracy", accuracy / len(dataloader), epoch)     
+       writer.add_scalar("lr", scheduler.get_lr()[0], epoch)   

        print(f"epoch: {epoch + 1}")
        print(f"loss: {train_loss / len(dataloader)}")
        print(f"accuracy: {accuracy / len(dataloader)}")


if __name__ == "__main__":
    train()

出力を見てみる

以下のコマンドを実行して出力を確認してみましょう.

terminal
tensorboard --logdir log

acc
loss
lr

GitHubで編集を提案

Discussion

ログインするとコメントできます