Pytorchで途中のlayerの入出力と勾配を保存する
目的:Deep Learningで途中のlayerの入出力を保存する
背景:Deep Learningの性質を調べるときに途中のlayerの入出力(weightではない)を保存したいことがある
前方伝搬をノードを出力する簡単なサンプル
利用するモデル
Sequential(
(0): Linear(in_features=10, out_features=5, bias=True)
(1): Linear(in_features=5, out_features=2, bias=True)
)
import torch
import torch.nn as nn
model = nn.Sequential(nn.Linear(10, 5), nn.Linear(5, 2))
def forward_hook(model, inputs, output):
# inputsは入力が複数があるのでtuple
print('input:', inputs[0].shape, 'output', output.shape)
for name, layer in model.named_children():
print(f'hook onto {name}')
layer.register_forward_hook(forward_hook)
x = torch.rand(5, 10)
model(x).mean()
出力
input: torch.Size([5, 10]) output torch.Size([5, 5])
input: torch.Size([5, 5]) output torch.Size([5, 2])
前方伝搬するときに、forward_hookが呼ばれ、inputsとoutputそれぞれに入出力のtensorが入る。
※inputsは入力組のtuple, outputsはtensorになる
逆誤差伝搬の場合
import torch
import torch.nn as nn
model = nn.Sequential(nn.Linear(10, 5), nn.Linear(5, 2))
def backward_hook(module, grad_input, grad_output):
# 入力層でのinput gradはNoneとなる
if len(grad_input) >=2:
for i, grad_input_ in enumerate(grad_input):
if grad_input_ is not None:
print('grad_in', i, grad_input_.shape)
print('grad_output:', grad_output[0].shape)
for name, layer in model.named_children():
print(f'hook onto {name}')
layer.register_backward_hook(backward_hook)
#layer.register_full_backward_hook(backward_hook)
x = torch.rand(3, 10)
model(x).mean().backward()
出力
grad_in 0 torch.Size([2])
grad_in 1 torch.Size([3, 5])
grad_in 2 torch.Size([5, 2])
grad_output: torch.Size([3, 2])
grad_in 0 torch.Size([5])
grad_in 2 torch.Size([10, 5])
grad_output: torch.Size([3, 5])
grad normのinとoutの意味は
x(3, 10)->fc1(10,5)->h1(3, 5)->fc2(5,2)->h2(3, 2)->mean->1
f1とf2の前後の微分値を返している
つまりf2のoutはlossを h2(3, 2) で微分した値
つまりf2のinはlossを h1(3, 5) で微分した値
今回だとf2のinputとf1のoutputは一致する
f1のinはlossを x(3, 10) で微分した値になるはずだがNoneになる
ノードの出力を用いたlossの微分値をしりたいなら、grad_output[0]だけを監視してればいい
※ register_backward_hookとは古く、register_full_backward_hookを使うべきだが、Relu(inplace)がある場合動かない[1][2]ため、今回はあえてregister_backward_hookを使っている
register_backward_hookと register_full_backward_hookを使った場合で、grad_input中身が変わる。
register_full_backward_hookの場合
[0] is the derivative of loss wrt layer input
[1] is the derivative of loss wrt layer output (before activation)
[2] is the derivative of loss wrt layer weights
register_backward_hookの場合fcとconvで異なる
fcの場合
[0] shape [10] - Bias values.
[1] shape [64, 84] - Data. The first value is the 64 batches, 84 inputs from the previous layer.
[2] shape [84, 10] - Layer weights. Each node in the fully connected layer receives the 84 outputs from the previous layer. There are 10 nodes.
convの場合
[0] shape [64, 16, 16, 16] - This is the input data. 64 batches, 16 feature maps deep, 16 width, 16 height.
[1] shape [32, 16, 3, 3] - This is the kernel weight data. 32 kernels with 16 depth (to match number if input feature maps), and 3x3 height/width.
[2] shape [32] - This is the bias for each kernel
保存するためのクラス
hook関数を使って、ノードの値をprintするだけでなく、保存する必要があるのでそれ用のクラス。
前方伝搬が行われるたびに記録して、平均した値を返す
class SaveActive(object):
def __init__(self, model):
self.model = model
self.fw_output = {}
self.fw_input = {}
self.bw_output = {}
self.bw_input = {}
self.fw_hook_lst = []
self.bw_hook_lst = []
self.clear_buffer()
self.__registor_model(model)
def __enter__(self):
return self
def __call__(self, model):
self.__init__(model)
def __exit__(self, exc_type, exc_value, traceback):
self.remove_hook()
self.clear_buffer()
def clear_buffer(self):
for name, layer in self.model.named_modules():
if len(list(layer.named_children())) == 0:
self.fw_input[name] = []
self.fw_output[name] = []
self.bw_input[name] = []
self.bw_output[name] = []
def __registor_model(self, model):
for name, layer in model.named_modules():
if len(list(layer.named_children())) == 0:
# print(f'hook in {name}')
fw_handle = layer.register_forward_hook(self.fw_save(name))
self.fw_hook_lst.append(fw_handle)
# except inplace for https://github.com/pytorch/pytorch/issues/61519
# layer.register_full_backward_hook(self.bw_save(name))
bw_handle = layer.register_backward_hook(self.bw_save(name))
self.bw_hook_lst.append(bw_handle)
def remove_hook(self):
for fw_handle in self.fw_hook_lst:
fw_handle.remove()
for bw_handle in self.bw_hook_lst:
bw_handle.remove()
def fw_save(self, name):
def forward_hook(model, inputs, output):
tmp1 = inputs[0].detach().clone().cpu().to(torch.float32)
if tmp1.dim() == 0:
tmp1 = tmp1.unsqueeze(0)
self.fw_input[name].append(tmp1)
if output is not None:
tmp2 = output.detach().clone().cpu().to(torch.float32)
if tmp2.dim() == 0:
tmp2 = tmp2.unsqueeze(0) # dim!=0 for torch.concat
self.fw_output[name].append(tmp2)
return forward_hook
def bw_save(self, name):
def backward_hook(module, grad_input, grad_output):
if len(grad_input) >= 2:
if grad_input[1] is not None:
tmp1 = grad_input[1].detach().clone().cpu().to(torch.float32)
if tmp1.dim() == 0:
tmp1 = tmp1.unsqueeze(0)
self.bw_input[name].append(tmp1)
tmp2 = grad_output[0].detach().clone().cpu().to(torch.float32)
if tmp2.dim() == 0:
tmp2 = tmp2.unsqueeze(0)
self.bw_output[name].append(tmp2)
return backward_hook
def get_fw_input_mean_norm(self):
if self.__is_null(self.fw_input):
print('Error: Please try foward prop')
return {}
means = {}
for key in self.fw_input:
if len(self.fw_input[key]) != 0:
means[key] = torch.cat(self.fw_input[key], dim=0).mean(0).norm()
n_data = len(torch.cat(self.fw_input[key], dim=0))
# rint(f'mean norm by n_sameples: {n_data}')
return means
def get_fw_output_mean_norm(self):
if self.__is_null(self.fw_output):
print('Error: Please try foward prop')
return {}
means = {}
for key in self.fw_output:
if len(self.fw_output[key]) != 0:
means[key] = torch.cat(self.fw_output[key], dim=0).mean(0).norm()
n_data = len(torch.cat(self.fw_output[key], dim=0))
# print(f'mean norm by n_sameples: {n_data}')
return means
def get_bw_input_mean_norm(self):
if self.__is_null(self.bw_input):
print('Error: Please try backward prop')
return {}
means = {}
for key in self.bw_input:
if len(self.bw_input[key]) != 0:
means[key] = torch.cat(self.bw_input[key], dim=0).mean(0).norm()
n_data = len(torch.cat(self.bw_input[key], dim=0))
# print(f'mean norm by n_sameples: {n_data}')
return means
def get_bw_output_mean_norm(self):
if self.__is_null(self.bw_output):
print('Error: Please try backward prop')
return {}
means = {}
for key in self.bw_output:
if len(self.bw_output[key]) != 0:
means[key] = torch.cat(self.bw_output[key], dim=0).mean(0).norm()
n_data = len(torch.cat(self.bw_output[key], dim=0))
# print(f'mean norm by n_sameples: {n_data}')
return means
def __is_null(self, dat):
n_data = 0
for key in dat:
n_data += len(dat[key])
return n_data == 0
使い方
model = nn.Sequential(nn.Linear(10, 5), nn.Linear(5, 2))
# 登録するだけ with scopeに入ってるときだけ関数hookがかかる
with SaveActive(model) as sa:
x = torch.rand(6, 10)
model(x).mean().backward()
print(sa.get_fw_input_mean_norm())
print(sa.get_fw_output_mean_norm())
print(sa.get_bw_input_mean_norm())
print(sa.get_bw_output_mean_norm())
出力
mean norm by n_sameples: 6
{'0': tensor(1.5072), '1': tensor(0.5847)}
mean norm by n_sameples: 6
{'0': tensor(0.5847), '1': tensor(0.4009)}
mean norm by n_sameples: 6
{'1': tensor(0.1083)}
mean norm by n_sameples: 6
{'0': tensor(0.1083), '1': tensor(0.1179)}
前方伝搬だけすると、fwだけが記録され、逆誤差伝搬までするとbwにも値が入る
すべての入力に対してbwしたときの勾配の平均がほしいなら、batchごとにbackward()を呼ぶ必要がある。
平均に使われたn_sample数が出力されるので、思った通りの動作をしているかはそこでわかる。
今は出力を楽にするためにnormを返しているが、.mean(0).norm()→.mean(0)に修正すればtensor自体がreturnされる。
with構文を用いて、withの中にあるときだけ記録するようにする
つまり
with SaveActive(model) as sa:
model(x).mean()
model(x).mean()
model(x).mean().backward()
とすると
前方伝搬はサンプル数 18の平均になり(batch size 6 * 3回)
逆誤差伝搬はサンプル数6の平均になる
ResNetに使ってみる
resblockも x+h1 → h2 となり 1つ入力すると1つ出力するので問題いない
relu inplaceのみ気をつける必要がある
利用するモデル
ResNet(
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(layer1): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(1): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer2): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer3): Sequential(
(0): BasicBlock(
(conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer4): Sequential(
(0): BasicBlock(
(conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
(fc): Linear(in_features=512, out_features=1000, bias=True)
)
import torchvision
model = torchvision.models.resnet18(pretrained=True)
x = torch.rand(2, 3, 32, 32)
with SaveActive(model) as sa:
model(x).mean()
model(x).mean().backward()
print(sa.get_fw_input_mean_norm())
print(sa.get_fw_output_mean_norm())
print(sa.get_bw_input_mean_norm())
print(sa.get_bw_output_mean_norm())
出力
mean norm by n_sameples: 4
{'conv1': tensor(30.0120), 'bn1': tensor(88.0899), 'relu': tensor(46.0366), 'maxpool': tensor(46.0366), 'layer1.0.conv1': tensor(38.5812), 'layer1.0.bn1': tensor(110.7326), 'layer1.0.relu': tensor(26.8837), 'layer1.0.conv2': tensor(14.6708), 'layer1.0.bn2': tensor(21.9908), 'layer1.1.conv1': tensor(44.4266), 'layer1.1.bn1': tensor(87.3894), 'layer1.1.relu': tensor(28.5649), 'layer1.1.conv2': tensor(11.4086), 'layer1.1.bn2': tensor(17.4100), 'layer2.0.conv1': tensor(50.4790), 'layer2.0.bn1': tensor(41.1456), 'layer2.0.relu': tensor(8.3203), 'layer2.0.conv2': tensor(7.5420), 'layer2.0.bn2': tensor(13.2688), 'layer2.0.downsample.0': tensor(50.4790), 'layer2.0.downsample.1': tensor(21.5922), 'layer2.1.conv1': tensor(11.9498), 'layer2.1.bn1': tensor(19.3483), 'layer2.1.relu': tensor(7.3593), 'layer2.1.conv2': tensor(4.8510), 'layer2.1.bn2': tensor(7.4920), 'layer3.0.conv1': tensor(12.2743), 'layer3.0.bn1': tensor(12.0483), 'layer3.0.relu': tensor(4.1795), 'layer3.0.conv2': tensor(4.0709), 'layer3.0.bn2': tensor(4.8925), 'layer3.0.downsample.0': tensor(12.2743), 'layer3.0.downsample.1': tensor(4.2953), 'layer3.1.conv1': tensor(5.5685), 'layer3.1.bn1': tensor(9.7999), 'layer3.1.relu': tensor(3.5061), 'layer3.1.conv2': tensor(2.0692), 'layer3.1.bn2': tensor(3.0278), 'layer4.0.conv1': tensor(5.9326), 'layer4.0.bn1': tensor(5.6728), 'layer4.0.relu': tensor(1.1065), 'layer4.0.conv2': tensor(0.6796), 'layer4.0.bn2': tensor(0.7235), 'layer4.0.downsample.0': tensor(5.9326), 'layer4.0.downsample.1': tensor(3.8254), 'layer4.1.conv1': tensor(1.8447), 'layer4.1.bn1': tensor(1.8723), 'layer4.1.relu': tensor(10.7230), 'layer4.1.conv2': tensor(0.7485), 'layer4.1.bn2': tensor(0.4364), 'avgpool': tensor(20.9373), 'fc': tensor(20.9373)}
mean norm by n_sameples: 4
{'conv1': tensor(88.0899), 'bn1': tensor(53.5855), 'relu': tensor(46.0366), 'maxpool': tensor(38.5812), 'layer1.0.conv1': tensor(110.7326), 'layer1.0.bn1': tensor(25.1572), 'layer1.0.relu': tensor(26.8837), 'layer1.0.conv2': tensor(21.9908), 'layer1.0.bn2': tensor(22.6131), 'layer1.1.conv1': tensor(87.3894), 'layer1.1.bn1': tensor(21.3544), 'layer1.1.relu': tensor(28.5649), 'layer1.1.conv2': tensor(17.4100), 'layer1.1.bn2': tensor(22.5308), 'layer2.0.conv1': tensor(41.1456), 'layer2.0.bn1': tensor(12.3804), 'layer2.0.relu': tensor(8.3203), 'layer2.0.conv2': tensor(13.2688), 'layer2.0.bn2': tensor(12.8707), 'layer2.0.downsample.0': tensor(21.5922), 'layer2.0.downsample.1': tensor(8.5293), 'layer2.1.conv1': tensor(19.3483), 'layer2.1.bn1': tensor(15.4010), 'layer2.1.relu': tensor(7.3593), 'layer2.1.conv2': tensor(7.4920), 'layer2.1.bn2': tensor(13.8359), 'layer3.0.conv1': tensor(12.0483), 'layer3.0.bn1': tensor(8.4056), 'layer3.0.relu': tensor(4.1795), 'layer3.0.conv2': tensor(4.8925), 'layer3.0.bn2': tensor(7.1288), 'layer3.0.downsample.0': tensor(4.2953), 'layer3.0.downsample.1': tensor(3.7063), 'layer3.1.conv1': tensor(9.7999), 'layer3.1.bn1': tensor(10.3078), 'layer3.1.relu': tensor(3.5061), 'layer3.1.conv2': tensor(3.0278), 'layer3.1.bn2': tensor(9.1545), 'layer4.0.conv1': tensor(5.6728), 'layer4.0.bn1': tensor(5.4074), 'layer4.0.relu': tensor(1.1065), 'layer4.0.conv2': tensor(0.7235), 'layer4.0.bn2': tensor(4.7105), 'layer4.0.downsample.0': tensor(3.8254), 'layer4.0.downsample.1': tensor(4.7105), 'layer4.1.conv1': tensor(1.8723), 'layer4.1.bn1': tensor(5.9726), 'layer4.1.relu': tensor(10.7230), 'layer4.1.conv2': tensor(0.4364), 'layer4.1.bn2': tensor(6.4667), 'avgpool': tensor(20.9373), 'fc': tensor(26.2633)}
mean norm by n_sameples: 2
{'conv1': tensor(9.8868e-06), 'bn1': tensor(1.3888e-06), 'layer1.0.conv1': tensor(9.0516e-06), 'layer1.0.bn1': tensor(1.0074e-07), 'layer1.0.conv2': tensor(6.2388e-06), 'layer1.0.bn2': tensor(1.0365e-07), 'layer1.1.conv1': tensor(4.8091e-06), 'layer1.1.bn1': tensor(1.0275e-07), 'layer1.1.conv2': tensor(4.9961e-06), 'layer1.1.bn2': tensor(1.1384e-07), 'layer2.0.conv1': tensor(5.5658e-06), 'layer2.0.bn1': tensor(2.1229e-07), 'layer2.0.conv2': tensor(5.4098e-06), 'layer2.0.bn2': tensor(4.3700e-07), 'layer2.0.downsample.0': tensor(1.4817e-06), 'layer2.0.downsample.1': tensor(7.2040e-07), 'layer2.1.conv1': tensor(4.3899e-06), 'layer2.1.bn1': tensor(1.8118e-07), 'layer2.1.conv2': tensor(4.0736e-06), 'layer2.1.bn2': tensor(1.9022e-07), 'layer3.0.conv1': tensor(1.9207e-06), 'layer3.0.bn1': tensor(4.6934e-08), 'layer3.0.conv2': tensor(1.5649e-06), 'layer3.0.bn2': tensor(9.5632e-08), 'layer3.0.downsample.0': tensor(6.0146e-07), 'layer3.0.downsample.1': tensor(2.4248e-08), 'layer3.1.conv1': tensor(4.3780e-06), 'layer3.1.bn1': tensor(4.8085e-08), 'layer3.1.conv2': tensor(1.3691e-06), 'layer3.1.bn2': tensor(2.0095e-07), 'layer4.0.conv1': tensor(6.0880e-07), 'layer4.0.bn1': tensor(1.8307e-07), 'layer4.0.conv2': tensor(6.9853e-07), 'layer4.0.bn2': tensor(6.8290e-09), 'layer4.0.downsample.0': tensor(2.0392e-07), 'layer4.0.downsample.1': tensor(1.7276e-08), 'layer4.1.conv1': tensor(9.9483e-08), 'layer4.1.bn1': tensor(8.7614e-08), 'layer4.1.conv2': tensor(2.3291e-07), 'layer4.1.bn2': tensor(2.2704e-08), 'fc': tensor(7.4263e-07)}
mean norm by n_sameples: 2
{'conv1': tensor(7.1883e-05), 'bn1': tensor(0.0001), 'relu': tensor(0.0001), 'maxpool': tensor(0.0001), 'layer1.0.conv1': tensor(3.6323e-05), 'layer1.0.bn1': tensor(5.3908e-05), 'layer1.0.relu': tensor(3.9725e-05), 'layer1.0.conv2': tensor(3.9918e-05), 'layer1.0.bn2': tensor(3.2062e-05), 'layer1.1.conv1': tensor(1.3335e-05), 'layer1.1.bn1': tensor(3.4516e-05), 'layer1.1.relu': tensor(2.7420e-05), 'layer1.1.conv2': tensor(3.2259e-05), 'layer1.1.bn2': tensor(2.1437e-05), 'layer2.0.conv1': tensor(1.7720e-05), 'layer2.0.bn1': tensor(5.2720e-05), 'layer2.0.relu': tensor(4.3331e-05), 'layer2.0.conv2': tensor(5.6932e-05), 'layer2.0.bn2': tensor(3.4321e-05), 'layer2.0.downsample.0': tensor(2.0645e-05), 'layer2.0.downsample.1': tensor(3.4321e-05), 'layer2.1.conv1': tensor(3.0945e-05), 'layer2.1.bn1': tensor(3.5875e-05), 'layer2.1.relu': tensor(3.2991e-05), 'layer2.1.conv2': tensor(5.2012e-05), 'layer2.1.bn2': tensor(2.3032e-05), 'layer3.0.conv1': tensor(2.8443e-05), 'layer3.0.bn1': tensor(2.7309e-05), 'layer3.0.relu': tensor(1.9914e-05), 'layer3.0.conv2': tensor(4.2691e-05), 'layer3.0.bn2': tensor(2.1144e-05), 'layer3.0.downsample.0': tensor(1.4941e-05), 'layer3.0.downsample.1': tensor(2.1144e-05), 'layer3.1.conv1': tensor(3.4979e-05), 'layer3.1.bn1': tensor(2.2853e-05), 'layer3.1.relu': tensor(1.3364e-05), 'layer3.1.conv2': tensor(4.1077e-05), 'layer3.1.bn2': tensor(9.4847e-06), 'layer4.0.conv1': tensor(1.9133e-10), 'layer4.0.bn1': tensor(3.0961e-05), 'layer4.0.relu': tensor(1.9245e-07), 'layer4.0.conv2': tensor(1.8833e-11), 'layer4.0.bn2': tensor(5.1805e-06), 'layer4.0.downsample.0': tensor(3.0601e-11), 'layer4.0.downsample.1': tensor(5.1805e-06), 'layer4.1.conv1': tensor(1.9200e-11), 'layer4.1.bn1': tensor(7.6009e-06), 'layer4.1.relu': tensor(3.7132e-07), 'layer4.1.conv2': tensor(7.4483e-12), 'layer4.1.bn2': tensor(3.8490e-07), 'avgpool': tensor(7.4263e-07), 'fc': tensor(0.0158)}
Discussion