👀

【Python】tqdmで進捗の可視化

2022/03/12に公開

はじめに

本記事ではPCに長い計算をさせている時の待ち時間を気持ち的に楽にさせる方法を紹介します.

なお本記事では,こちらの記事に出ているコードを変更していきます↓

環境

PC MacBook Pro (16-inch, 2019)
OS Monterey
CPU 2.3 GHz 8コアIntel Core i9
メモリ 16GB
Python 3.9

必要なライブラリ

ライブラリ バージョン
tensorboard 4.63.0

ライブラリのインストール方法は以下です.

pip install tensorboard

本編

実際にコードを書いて結果を確認してみましょう.

main.py
import torch
from torch import nn, optim
+ from tqdm import tqdm
+ from typing import OrderedDict

from dataset.mnist import get_dataloader
from model.resnet import get_resnet

def train(total_epoch: int=100):
    writer = SummaryWriter(log_dir="log")
    dataloader = get_dataloader(root="data", batch_size=64)

    model = get_resnet(pretrained=True)
    optimizer = optim.SGD(
        params=model.parameters(),
        lr=1e-3
    )
    scheduler = optim.lr_scheduler.OneCycleLR(
        optimizer=optimizer,
        max_lr=1e-3,
        total_steps=len(dataloader),
    )
    criterion = nn.CrossEntropyLoss()

    model.train()

    for epoch in range(total_epoch):
        accuracy, train_loss = 0.0, 0.0

+       with tqdm(dataloader) as pbar:
+           pbar.set_description(f'[Epoch {epoch + 1}/{total_epoch}]')

            for images, labels in pbar:
                optimizer.zero_grad()
                
                out = model(images)
                loss = criterion(out, labels)

                loss.backward()
                optimizer.step()

                # 推測値
                preds = out.argmax(axis=1)

                # lossの算出
                train_loss += loss.item()
                accuracy += torch.sum(preds == labels).item() / len(labels)

+               pbar.set_postfix(
+                   OrderedDict(
+                       Loss=loss.item(),
+                       Accuracy=torch.sum(preds == labels).item() / len(labels),
+                   )
+               )

            print(f"epoch: {epoch + 1}")
            print(f"loss: {train_loss / len(dataloader)}")
            print(f"accuracy: {accuracy / len(dataloader)}")

            scheduler.step()


if __name__ == "__main__":
    train()

出力結果

terminal
[Epoch 1/100]:  12%|██           | 117/937 [00:59<07:02,  1.94it/s, Loss=2.16, Accuracy=0.172]

これで,epoch毎の結果が後どのくらいで出力されるのか待ち焦がれる必要はないですね!
しかも,途中経過まで出力してくれるので深層学習をよくさせている人は必須なライブラリです!

ぜひ活用してみましょう!!

GitHubで編集を提案

Discussion