🐙

PyTorchのモデルをtorch.profilerにかける

2021/04/01に公開

なにこれ

時代遅れなtorch.autograd.profiler

環境

  • torch==1.8.1
  • tensorboard==2.4.1
  • torch-tb-profiler==0.1.0

紹介されているコード

pseudo_code
with torch.profiler.profile(
    schedule=torch.profiler.schedule(
        wait=2,
        warmup=2,
        active=6, 
        repeat=1),
    on_trace_ready=tensorboard_trace_handler,
    with_trace=True
) as profiler:
    for step, data in enumerate(trainloader, 0):
        print("step:{}".format(step))
        inputs, labels = data[0].to(device=device), data[1].to(device=device)

        outputs = model(inputs)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        profiler.step()

実際に動かすなら

  • インポート
実際のコード
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
  • 準備
準備
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, ), (0.5, ))
])
dataset = torchvision.datasets.CIFAR10(
    root='data',
    download=True,
    transform=transform
)
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=32,
    shuffle=True,
    num_workers=2)
device = "cuda:0"
model = torchvision.models.resnet18(pretrained=True).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
dir_name = "logdir"
  • 適宜修正
    • tensorboardには自動で書き込まれる
実はそのままだと動かない
with torch.profiler.profile(
    schedule=torch.profiler.schedule(
        wait=2,
        warmup=2,
        active=6,
        repeat=1),
    on_trace_ready=torch.profiler.tensorboard_trace_handler(dir_name),
) as profiler:
    for step, (inputs, labels) in enumerate(dataloader, 0):
        print("step:{}".format(step), end="\r")
        inputs, labels = inputs.to(device=device), labels.to(device=device)

        outputs = model(inputs)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        profiler.step()

Tensorboardで確認する

  • 書き込まれるデータはjsonで、そのままだとchrome://tracingでも、tensorboardでも開けません。torch_tb_profilerをインストールしてtensorboardで開けるようにする必要があります。
tbのプラグインをインストールする
pip install torch_tb_profiler
tensorboardをlaunch
tensorboard --logdir logdir
  • Overview

  • Operator

  • GPU Kernel

  • Trace

最後に

  • GPU Kernelを表示されても正直「完全に理解した」レベルでは手に余る
  • transformをあえて組み込んだので、どれくらいの負荷がかかるのかを確認したかったが、OperatorのViewから該当するものを発見できなかった。たぶんNormalizationそれ自体は引き算割り算で実行しているのでaten::addとaten::divがそれなのか…?

  • 旧バージョンからは、Traceの他のOverview, Operator, GPU Kernelが追加されたことが大きな進化かと思われる
  • なにはともあれとにかく機能盛りだくさんで優秀

Discussion