😀

# Pytorchで途中のlayerの入出力と勾配を保存する

2022/07/07に公開

## 前方伝搬をノードを出力する簡単なサンプル

``````Sequential(
(0): Linear(in_features=10, out_features=5, bias=True)
(1): Linear(in_features=5, out_features=2, bias=True)
)
``````
``````import torch
import torch.nn as nn

model = nn.Sequential(nn.Linear(10, 5), nn.Linear(5, 2))

def forward_hook(model, inputs, output):
# inputsは入力が複数があるのでtuple
print('input:', inputs[0].shape, 'output', output.shape)

for name, layer in model.named_children():
print(f'hook onto {name}')
layer.register_forward_hook(forward_hook)

x = torch.rand(5, 10)
model(x).mean()
``````

``````input: torch.Size([5, 10]) output torch.Size([5, 5])
input: torch.Size([5, 5]) output torch.Size([5, 2])
``````

※inputsは入力組のtuple, outputsはtensorになる

## 逆誤差伝搬の場合

``````import torch
import torch.nn as nn

model = nn.Sequential(nn.Linear(10, 5), nn.Linear(5, 2))

for name, layer in model.named_children():
print(f'hook onto {name}')
layer.register_backward_hook(backward_hook)
#layer.register_full_backward_hook(backward_hook)

x = torch.rand(3, 10)
model(x).mean().backward()
``````

``````grad_in 0 torch.Size([2])
``````

``````x(3, 10)->fc1(10,5)->h1(3, 5)->fc2(5,2)->h2(3, 2)->mean->1
``````

f1とf2の前後の微分値を返している
つまりf2のoutはlossを h2(3, 2) で微分した値
つまりf2のinはlossを h1(3, 5) で微分した値

f1のinはlossを x(3, 10) で微分した値になるはずだがNoneになる

※ register_backward_hookとは古く、register_full_backward_hookを使うべきだが、Relu(inplace)がある場合動かない[1][2]ため、今回はあえてregister_backward_hookを使っている

register_full_backward_hookの場合

``````[0] is the derivative of loss wrt layer input
[1] is the derivative of loss wrt layer output (before activation)
[2] is the derivative of loss wrt layer weights
``````

register_backward_hookの場合fcとconvで異なる

``````fcの場合
[0] shape [10] - Bias values.
[1] shape [64, 84] - Data. The first value is the 64 batches, 84 inputs from the previous layer.
[2] shape [84, 10] - Layer weights. Each node in the fully connected layer receives the 84 outputs from the previous layer. There are 10 nodes.

convの場合
[0] shape [64, 16, 16, 16] - This is the input data. 64 batches, 16 feature maps deep, 16 width, 16 height.
[1] shape [32, 16, 3, 3] - This is the kernel weight data. 32 kernels with 16 depth (to match number if input feature maps), and 3x3 height/width.
[2] shape [32] - This is the bias for each kernel
``````

## 保存するためのクラス

hook関数を使って、ノードの値をprintするだけでなく、保存する必要があるのでそれ用のクラス。

``````class SaveActive(object):
def __init__(self, model):
self.model = model
self.fw_output = {}
self.fw_input = {}
self.bw_output = {}
self.bw_input = {}
self.fw_hook_lst = []
self.bw_hook_lst = []
self.clear_buffer()
self.__registor_model(model)

def __enter__(self):
return self

def __call__(self, model):
self.__init__(model)

def __exit__(self, exc_type, exc_value, traceback):
self.remove_hook()
self.clear_buffer()

def clear_buffer(self):
for name, layer in self.model.named_modules():
if len(list(layer.named_children())) == 0:
self.fw_input[name] = []
self.fw_output[name] = []
self.bw_input[name] = []
self.bw_output[name] = []

def __registor_model(self, model):
for name, layer in model.named_modules():
if len(list(layer.named_children())) == 0:
# print(f'hook in {name}')
fw_handle = layer.register_forward_hook(self.fw_save(name))
self.fw_hook_lst.append(fw_handle)
# except inplace for https://github.com/pytorch/pytorch/issues/61519
# layer.register_full_backward_hook(self.bw_save(name))
bw_handle = layer.register_backward_hook(self.bw_save(name))
self.bw_hook_lst.append(bw_handle)

def remove_hook(self):
for fw_handle in self.fw_hook_lst:
fw_handle.remove()
for bw_handle in self.bw_hook_lst:
bw_handle.remove()

def fw_save(self, name):
def forward_hook(model, inputs, output):
tmp1 = inputs[0].detach().clone().cpu().to(torch.float32)
if tmp1.dim() == 0:
tmp1 = tmp1.unsqueeze(0)
self.fw_input[name].append(tmp1)
if output is not None:
tmp2 = output.detach().clone().cpu().to(torch.float32)
if tmp2.dim() == 0:
tmp2 = tmp2.unsqueeze(0)  # dim!=0 for torch.concat
self.fw_output[name].append(tmp2)

return forward_hook

def bw_save(self, name):
if tmp1.dim() == 0:
tmp1 = tmp1.unsqueeze(0)
self.bw_input[name].append(tmp1)
if tmp2.dim() == 0:
tmp2 = tmp2.unsqueeze(0)
self.bw_output[name].append(tmp2)

return backward_hook

def get_fw_input_mean_norm(self):
if self.__is_null(self.fw_input):
return {}
means = {}
for key in self.fw_input:
if len(self.fw_input[key]) != 0:
means[key] = torch.cat(self.fw_input[key], dim=0).mean(0).norm()
n_data = len(torch.cat(self.fw_input[key], dim=0))
# rint(f'mean norm by n_sameples: {n_data}')
return means

def get_fw_output_mean_norm(self):
if self.__is_null(self.fw_output):
return {}
means = {}
for key in self.fw_output:
if len(self.fw_output[key]) != 0:
means[key] = torch.cat(self.fw_output[key], dim=0).mean(0).norm()
n_data = len(torch.cat(self.fw_output[key], dim=0))
# print(f'mean norm by n_sameples: {n_data}')
return means

def get_bw_input_mean_norm(self):
if self.__is_null(self.bw_input):
return {}
means = {}
for key in self.bw_input:
if len(self.bw_input[key]) != 0:
means[key] = torch.cat(self.bw_input[key], dim=0).mean(0).norm()
n_data = len(torch.cat(self.bw_input[key], dim=0))
# print(f'mean norm by n_sameples: {n_data}')
return means

def get_bw_output_mean_norm(self):
if self.__is_null(self.bw_output):
return {}
means = {}
for key in self.bw_output:
if len(self.bw_output[key]) != 0:
means[key] = torch.cat(self.bw_output[key], dim=0).mean(0).norm()
n_data = len(torch.cat(self.bw_output[key], dim=0))
# print(f'mean norm by n_sameples: {n_data}')
return means

def __is_null(self, dat):
n_data = 0
for key in dat:
n_data += len(dat[key])
return n_data == 0
``````

``````model = nn.Sequential(nn.Linear(10, 5), nn.Linear(5, 2))

# 登録するだけ with scopeに入ってるときだけ関数hookがかかる
with SaveActive(model) as sa:
x = torch.rand(6, 10)
model(x).mean().backward()
print(sa.get_fw_input_mean_norm())
print(sa.get_fw_output_mean_norm())
print(sa.get_bw_input_mean_norm())
print(sa.get_bw_output_mean_norm())
``````

``````mean norm by n_sameples: 6
{'0': tensor(1.5072), '1': tensor(0.5847)}
mean norm by n_sameples: 6
{'0': tensor(0.5847), '1': tensor(0.4009)}
mean norm by n_sameples: 6
{'1': tensor(0.1083)}
mean norm by n_sameples: 6
{'0': tensor(0.1083), '1': tensor(0.1179)}
``````

すべての入力に対してbwしたときの勾配の平均がほしいなら、batchごとにbackward()を呼ぶ必要がある。

with構文を用いて、withの中にあるときだけ記録するようにする

つまり

``````with SaveActive(model) as sa:
model(x).mean()
model(x).mean()
model(x).mean().backward()
``````

とすると

# ResNetに使ってみる

resblockも x+h1 → h2 となり 1つ入力すると1つ出力するので問題いない
relu inplaceのみ気をつける必要がある

``````ResNet(
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(layer1): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(1): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer2): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer3): Sequential(
(0): BasicBlock(
(conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer4): Sequential(
(0): BasicBlock(
(conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(fc): Linear(in_features=512, out_features=1000, bias=True)
)
``````
``````import torchvision

model = torchvision.models.resnet18(pretrained=True)
x = torch.rand(2, 3, 32, 32)

with SaveActive(model) as sa:
model(x).mean()
model(x).mean().backward()
print(sa.get_fw_input_mean_norm())
print(sa.get_fw_output_mean_norm())
print(sa.get_bw_input_mean_norm())
print(sa.get_bw_output_mean_norm())
``````

``````mean norm by n_sameples: 4
{'conv1': tensor(30.0120), 'bn1': tensor(88.0899), 'relu': tensor(46.0366), 'maxpool': tensor(46.0366), 'layer1.0.conv1': tensor(38.5812), 'layer1.0.bn1': tensor(110.7326), 'layer1.0.relu': tensor(26.8837), 'layer1.0.conv2': tensor(14.6708), 'layer1.0.bn2': tensor(21.9908), 'layer1.1.conv1': tensor(44.4266), 'layer1.1.bn1': tensor(87.3894), 'layer1.1.relu': tensor(28.5649), 'layer1.1.conv2': tensor(11.4086), 'layer1.1.bn2': tensor(17.4100), 'layer2.0.conv1': tensor(50.4790), 'layer2.0.bn1': tensor(41.1456), 'layer2.0.relu': tensor(8.3203), 'layer2.0.conv2': tensor(7.5420), 'layer2.0.bn2': tensor(13.2688), 'layer2.0.downsample.0': tensor(50.4790), 'layer2.0.downsample.1': tensor(21.5922), 'layer2.1.conv1': tensor(11.9498), 'layer2.1.bn1': tensor(19.3483), 'layer2.1.relu': tensor(7.3593), 'layer2.1.conv2': tensor(4.8510), 'layer2.1.bn2': tensor(7.4920), 'layer3.0.conv1': tensor(12.2743), 'layer3.0.bn1': tensor(12.0483), 'layer3.0.relu': tensor(4.1795), 'layer3.0.conv2': tensor(4.0709), 'layer3.0.bn2': tensor(4.8925), 'layer3.0.downsample.0': tensor(12.2743), 'layer3.0.downsample.1': tensor(4.2953), 'layer3.1.conv1': tensor(5.5685), 'layer3.1.bn1': tensor(9.7999), 'layer3.1.relu': tensor(3.5061), 'layer3.1.conv2': tensor(2.0692), 'layer3.1.bn2': tensor(3.0278), 'layer4.0.conv1': tensor(5.9326), 'layer4.0.bn1': tensor(5.6728), 'layer4.0.relu': tensor(1.1065), 'layer4.0.conv2': tensor(0.6796), 'layer4.0.bn2': tensor(0.7235), 'layer4.0.downsample.0': tensor(5.9326), 'layer4.0.downsample.1': tensor(3.8254), 'layer4.1.conv1': tensor(1.8447), 'layer4.1.bn1': tensor(1.8723), 'layer4.1.relu': tensor(10.7230), 'layer4.1.conv2': tensor(0.7485), 'layer4.1.bn2': tensor(0.4364), 'avgpool': tensor(20.9373), 'fc': tensor(20.9373)}
mean norm by n_sameples: 4
{'conv1': tensor(88.0899), 'bn1': tensor(53.5855), 'relu': tensor(46.0366), 'maxpool': tensor(38.5812), 'layer1.0.conv1': tensor(110.7326), 'layer1.0.bn1': tensor(25.1572), 'layer1.0.relu': tensor(26.8837), 'layer1.0.conv2': tensor(21.9908), 'layer1.0.bn2': tensor(22.6131), 'layer1.1.conv1': tensor(87.3894), 'layer1.1.bn1': tensor(21.3544), 'layer1.1.relu': tensor(28.5649), 'layer1.1.conv2': tensor(17.4100), 'layer1.1.bn2': tensor(22.5308), 'layer2.0.conv1': tensor(41.1456), 'layer2.0.bn1': tensor(12.3804), 'layer2.0.relu': tensor(8.3203), 'layer2.0.conv2': tensor(13.2688), 'layer2.0.bn2': tensor(12.8707), 'layer2.0.downsample.0': tensor(21.5922), 'layer2.0.downsample.1': tensor(8.5293), 'layer2.1.conv1': tensor(19.3483), 'layer2.1.bn1': tensor(15.4010), 'layer2.1.relu': tensor(7.3593), 'layer2.1.conv2': tensor(7.4920), 'layer2.1.bn2': tensor(13.8359), 'layer3.0.conv1': tensor(12.0483), 'layer3.0.bn1': tensor(8.4056), 'layer3.0.relu': tensor(4.1795), 'layer3.0.conv2': tensor(4.8925), 'layer3.0.bn2': tensor(7.1288), 'layer3.0.downsample.0': tensor(4.2953), 'layer3.0.downsample.1': tensor(3.7063), 'layer3.1.conv1': tensor(9.7999), 'layer3.1.bn1': tensor(10.3078), 'layer3.1.relu': tensor(3.5061), 'layer3.1.conv2': tensor(3.0278), 'layer3.1.bn2': tensor(9.1545), 'layer4.0.conv1': tensor(5.6728), 'layer4.0.bn1': tensor(5.4074), 'layer4.0.relu': tensor(1.1065), 'layer4.0.conv2': tensor(0.7235), 'layer4.0.bn2': tensor(4.7105), 'layer4.0.downsample.0': tensor(3.8254), 'layer4.0.downsample.1': tensor(4.7105), 'layer4.1.conv1': tensor(1.8723), 'layer4.1.bn1': tensor(5.9726), 'layer4.1.relu': tensor(10.7230), 'layer4.1.conv2': tensor(0.4364), 'layer4.1.bn2': tensor(6.4667), 'avgpool': tensor(20.9373), 'fc': tensor(26.2633)}
mean norm by n_sameples: 2
{'conv1': tensor(9.8868e-06), 'bn1': tensor(1.3888e-06), 'layer1.0.conv1': tensor(9.0516e-06), 'layer1.0.bn1': tensor(1.0074e-07), 'layer1.0.conv2': tensor(6.2388e-06), 'layer1.0.bn2': tensor(1.0365e-07), 'layer1.1.conv1': tensor(4.8091e-06), 'layer1.1.bn1': tensor(1.0275e-07), 'layer1.1.conv2': tensor(4.9961e-06), 'layer1.1.bn2': tensor(1.1384e-07), 'layer2.0.conv1': tensor(5.5658e-06), 'layer2.0.bn1': tensor(2.1229e-07), 'layer2.0.conv2': tensor(5.4098e-06), 'layer2.0.bn2': tensor(4.3700e-07), 'layer2.0.downsample.0': tensor(1.4817e-06), 'layer2.0.downsample.1': tensor(7.2040e-07), 'layer2.1.conv1': tensor(4.3899e-06), 'layer2.1.bn1': tensor(1.8118e-07), 'layer2.1.conv2': tensor(4.0736e-06), 'layer2.1.bn2': tensor(1.9022e-07), 'layer3.0.conv1': tensor(1.9207e-06), 'layer3.0.bn1': tensor(4.6934e-08), 'layer3.0.conv2': tensor(1.5649e-06), 'layer3.0.bn2': tensor(9.5632e-08), 'layer3.0.downsample.0': tensor(6.0146e-07), 'layer3.0.downsample.1': tensor(2.4248e-08), 'layer3.1.conv1': tensor(4.3780e-06), 'layer3.1.bn1': tensor(4.8085e-08), 'layer3.1.conv2': tensor(1.3691e-06), 'layer3.1.bn2': tensor(2.0095e-07), 'layer4.0.conv1': tensor(6.0880e-07), 'layer4.0.bn1': tensor(1.8307e-07), 'layer4.0.conv2': tensor(6.9853e-07), 'layer4.0.bn2': tensor(6.8290e-09), 'layer4.0.downsample.0': tensor(2.0392e-07), 'layer4.0.downsample.1': tensor(1.7276e-08), 'layer4.1.conv1': tensor(9.9483e-08), 'layer4.1.bn1': tensor(8.7614e-08), 'layer4.1.conv2': tensor(2.3291e-07), 'layer4.1.bn2': tensor(2.2704e-08), 'fc': tensor(7.4263e-07)}
mean norm by n_sameples: 2
{'conv1': tensor(7.1883e-05), 'bn1': tensor(0.0001), 'relu': tensor(0.0001), 'maxpool': tensor(0.0001), 'layer1.0.conv1': tensor(3.6323e-05), 'layer1.0.bn1': tensor(5.3908e-05), 'layer1.0.relu': tensor(3.9725e-05), 'layer1.0.conv2': tensor(3.9918e-05), 'layer1.0.bn2': tensor(3.2062e-05), 'layer1.1.conv1': tensor(1.3335e-05), 'layer1.1.bn1': tensor(3.4516e-05), 'layer1.1.relu': tensor(2.7420e-05), 'layer1.1.conv2': tensor(3.2259e-05), 'layer1.1.bn2': tensor(2.1437e-05), 'layer2.0.conv1': tensor(1.7720e-05), 'layer2.0.bn1': tensor(5.2720e-05), 'layer2.0.relu': tensor(4.3331e-05), 'layer2.0.conv2': tensor(5.6932e-05), 'layer2.0.bn2': tensor(3.4321e-05), 'layer2.0.downsample.0': tensor(2.0645e-05), 'layer2.0.downsample.1': tensor(3.4321e-05), 'layer2.1.conv1': tensor(3.0945e-05), 'layer2.1.bn1': tensor(3.5875e-05), 'layer2.1.relu': tensor(3.2991e-05), 'layer2.1.conv2': tensor(5.2012e-05), 'layer2.1.bn2': tensor(2.3032e-05), 'layer3.0.conv1': tensor(2.8443e-05), 'layer3.0.bn1': tensor(2.7309e-05), 'layer3.0.relu': tensor(1.9914e-05), 'layer3.0.conv2': tensor(4.2691e-05), 'layer3.0.bn2': tensor(2.1144e-05), 'layer3.0.downsample.0': tensor(1.4941e-05), 'layer3.0.downsample.1': tensor(2.1144e-05), 'layer3.1.conv1': tensor(3.4979e-05), 'layer3.1.bn1': tensor(2.2853e-05), 'layer3.1.relu': tensor(1.3364e-05), 'layer3.1.conv2': tensor(4.1077e-05), 'layer3.1.bn2': tensor(9.4847e-06), 'layer4.0.conv1': tensor(1.9133e-10), 'layer4.0.bn1': tensor(3.0961e-05), 'layer4.0.relu': tensor(1.9245e-07), 'layer4.0.conv2': tensor(1.8833e-11), 'layer4.0.bn2': tensor(5.1805e-06), 'layer4.0.downsample.0': tensor(3.0601e-11), 'layer4.0.downsample.1': tensor(5.1805e-06), 'layer4.1.conv1': tensor(1.9200e-11), 'layer4.1.bn1': tensor(7.6009e-06), 'layer4.1.relu': tensor(3.7132e-07), 'layer4.1.conv2': tensor(7.4483e-12), 'layer4.1.bn2': tensor(3.8490e-07), 'avgpool': tensor(7.4263e-07), 'fc': tensor(0.0158)}
``````

ログインするとコメントできます