🔧
PyTorchのモデルをPruneしてProfileする - 推論の効率化の検証 -
なにこれ
- PyTorchの枝刈り(Pruning)と分析(Profile)を紹介したい
- Pruneしたモデルの効率化具合をProfileする
Prune
-
PRUNING TUTORIAL
- 重みがsparseになって推論処理が軽くなることが期待できる
import torch
import torch.nn.utils.prune as prune
from torchvision import models
resnet18のconv1を使って確認する
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = models.resnet18(pretrained=True).to(device)
module = model.conv1
実際のパラメータの変化の一部を見てみる
print(list(module.named_parameters())[0][1][0][0])
prune.l1_unstructured(module, name="weight", amount=0.3)
prune.remove(module, "weight")
print(list(module.named_parameters())[0][1][0][0])
"""
tensor([[-0.0104, -0.0061, -0.0018, 0.0748, 0.0566, 0.0171, -0.0127],
[ 0.0111, 0.0095, -0.1099, -0.2805, -0.2712, -0.1291, 0.0037],
[-0.0069, 0.0591, 0.2955, 0.5872, 0.5197, 0.2563, 0.0636],
[ 0.0305, -0.0670, -0.2984, -0.4387, -0.2709, -0.0006, 0.0576],
[-0.0275, 0.0160, 0.0726, -0.0541, -0.3328, -0.4206, -0.2578],
[ 0.0306, 0.0410, 0.0628, 0.2390, 0.4138, 0.3936, 0.1661],
[-0.0137, -0.0037, -0.0241, -0.0659, -0.1507, -0.0822, -0.0058]],
grad_fn=<SelectBackward>)
tensor([[-0.0000, -0.0000, -0.0000, 0.0748, 0.0566, 0.0171, -0.0000],
[ 0.0000, 0.0000, -0.1099, -0.2805, -0.2712, -0.1291, 0.0000],
[-0.0000, 0.0591, 0.2955, 0.5872, 0.5197, 0.2563, 0.0636],
[ 0.0305, -0.0670, -0.2984, -0.4387, -0.2709, -0.0000, 0.0576],
[-0.0275, 0.0160, 0.0726, -0.0541, -0.3328, -0.4206, -0.2578],
[ 0.0306, 0.0410, 0.0628, 0.2390, 0.4138, 0.3936, 0.1661],
[-0.0000, -0.0000, -0.0241, -0.0659, -0.1507, -0.0822, -0.0000]],
grad_fn=<SelectBackward>)
"""
torch.nn.Conv2dとtorch.nn.Linearをpruningする
model = models.resnet18(pretrained=True).to(device) # ↑でconv1だけpruneしてあるのでリロード
for name, module in model.named_modules():
if isinstance(module, torch.nn.Conv2d):
prune.l1_unstructured(module, name='weight', amount=0.2)
prune.remove(module, "weight")
elif isinstance(module, torch.nn.Linear):
prune.l1_unstructured(module, name='weight', amount=0.4)
prune.remove(module, "weight")
- 注意:prune.removeしないと、forwardの際にpruneの結果を計算するhookがオーバーヘッドになってむしろ遅くなる場合も。(下記参照)
Profile
- PYTORCH PROFILER(基本的な使用方法)
-
PROFILING YOUR PYTORCH MODULE(改善例)
- torch.floatが必要な処理に対してtorch.doubleからの変換をかませるとメモリ使用量が大きくなってしまう
- CUDAからCPUへのコピーやCUDA上でもできる処理をCPU上でわざわざ行うと処理時間が伸びる
- PyTorch moduleがどれくらいのスピードで処理されるのかを確認できる
import torch.autograd.profiler as profiler
結果をexport_chrome_traceするとchrome
model = models.resnet18(pretrained=True).to(device) # ↑でpruneしたのでリロード
inputs = torch.randn(5, 3, 28, 28).to(device)
model(inputs) # warming up
use_cuda = torch.cuda.is_available()
with profiler.profile(record_shapes=True, profile_memory=True, use_cuda=use_cuda, with_stack=True) as prof:
with profiler.record_function("model_inference"):
model(inputs)
prof.export_chrome_trace("before.json")
- tracing(chrome://tracing)でGUI付きで分析できます
PruneしてProfileするなら
- 複数のexportをまとめる場合、pidを変えれば並べて比較できる
import json
pruneしてexport
model = models.resnet18(pretrained=True).to(device)
for name, module in model.named_modules():
if isinstance(module, torch.nn.Conv2d):
prune.l1_unstructured(module, name='weight', amount=0.2)
prune.remove(module, "weight")
elif isinstance(module, torch.nn.Linear):
prune.l1_unstructured(module, name='weight', amount=0.4)
prune.remove(module, "weight")
use_cuda = torch.cuda.is_available()
with profiler.profile(record_shapes=True, profile_memory=True, use_cuda=use_cuda, with_stack=True) as prof:
with profiler.record_function("model_inference"):
model(inputs)
prof.export_chrome_trace("after.json")
- prune.removeを忘れると遅くなることを確認したい
removeしなかった場合をexport
model = models.resnet18(pretrained=True).to(device)
for name, module in model.named_modules():
if isinstance(module, torch.nn.Conv2d):
prune.l1_unstructured(module, name='weight', amount=0.2)
# prune.remove(module, "weight")
elif isinstance(module, torch.nn.Linear):
prune.l1_unstructured(module, name='weight', amount=0.4)
# prune.remove(module, "weight")
use_cuda = torch.cuda.is_available()
with profiler.profile(record_shapes=True, profile_memory=True, use_cuda=use_cuda, with_stack=True) as prof:
with profiler.record_function("model_inference"):
model(inputs)
prof.export_chrome_trace("with_hooks.json")
exportしたjsonを一つにまとめる
with open("before.json", "r") as f:
before = json.load(f)
with open("with_hooks.json", "r") as f:
with_hooks = json.load(f)
with open("after.json", "r") as f:
after = json.load(f)
concat = []
for prof in before:
if prof["pid"] == "CPU functions":
prof["pid"] = "CPU (before)"
elif prof["pid"] == "CUDA functions":
prof["pid"] = "CUDA (before)"
concat.append(prof)
for prof in with_hooks:
if prof["pid"] == "CPU functions":
prof["pid"] = "CPU (with_hooks)"
elif prof["pid"] == "CUDA functions":
prof["pid"] = "CUDA (with_hooks)"
concat.append(prof)
for prof in after:
if prof["pid"] == "CPU functions":
prof["pid"] = "CPU (after)"
elif prof["pid"] == "CUDA functions":
prof["pid"] = "CUDA (after)"
concat.append(prof)
with open("trace_all.json", "w") as f:
json.dump(concat, f)
-
tracingで表示してみると
-
pruning前に49.601 msかかっていた処理は、
- pruning後に49.485 msにできる。
- ただし、prune.removeを忘れると55.591 msかかる
-
ちなみに、同じpruningを行った場合、resnet152であれば10%近い処理スピードの改善が見られた
- 前:399.683 ms
- 後:367.759 ms
まとめ
- スカスカにすると精度が落ちることは忘れてはならない。早くなったことを喜んでいる場合じゃないかもしれない。
- resnet18やresnet152のような小さなモデルであればこの程度だが、より巨大なモデルでのpruningの効果は大きい。精度を許容範囲内に維持しつつ疎にできるパラメータをどのように発見するかがキモでしょう。
Discussion