😀

Pytorchで途中のlayerの入出力と勾配を保存する

2022/07/07に公開

目的:Deep Learningで途中のlayerの入出力を保存する

背景:Deep Learningの性質を調べるときに途中のlayerの入出力(weightではない)を保存したいことがある

前方伝搬をノードを出力する簡単なサンプル

利用するモデル

Sequential(
  (0): Linear(in_features=10, out_features=5, bias=True)
  (1): Linear(in_features=5, out_features=2, bias=True)
)
import torch
import torch.nn as nn

model = nn.Sequential(nn.Linear(10, 5), nn.Linear(5, 2))

def forward_hook(model, inputs, output):
    # inputsは入力が複数があるのでtuple
    print('input:', inputs[0].shape, 'output', output.shape)    

for name, layer in model.named_children():
    print(f'hook onto {name}')
    layer.register_forward_hook(forward_hook)

x = torch.rand(5, 10)
model(x).mean()

出力

input: torch.Size([5, 10]) output torch.Size([5, 5])
input: torch.Size([5, 5]) output torch.Size([5, 2])

前方伝搬するときに、forward_hookが呼ばれ、inputsとoutputそれぞれに入出力のtensorが入る。
※inputsは入力組のtuple, outputsはtensorになる

逆誤差伝搬の場合

import torch
import torch.nn as nn

model = nn.Sequential(nn.Linear(10, 5), nn.Linear(5, 2))

def backward_hook(module, grad_input, grad_output):
    # 入力層でのinput gradはNoneとなる
    if len(grad_input) >=2: 
        for i, grad_input_ in enumerate(grad_input):
            if grad_input_ is not None:            
                print('grad_in', i,  grad_input_.shape)
    print('grad_output:', grad_output[0].shape)

for name, layer in model.named_children():
    print(f'hook onto {name}')
    layer.register_backward_hook(backward_hook)
    #layer.register_full_backward_hook(backward_hook)
    
x = torch.rand(3, 10)
model(x).mean().backward()

出力

grad_in 0 torch.Size([2])
grad_in 1 torch.Size([3, 5])
grad_in 2 torch.Size([5, 2])
grad_output: torch.Size([3, 2])
grad_in 0 torch.Size([5])
grad_in 2 torch.Size([10, 5])
grad_output: torch.Size([3, 5])

grad normのinとoutの意味は

x(3, 10)->fc1(10,5)->h1(3, 5)->fc2(5,2)->h2(3, 2)->mean->1

f1とf2の前後の微分値を返している
つまりf2のoutはlossを h2(3, 2) で微分した値
つまりf2のinはlossを h1(3, 5) で微分した値

今回だとf2のinputとf1のoutputは一致する
f1のinはlossを x(3, 10) で微分した値になるはずだがNoneになる

ノードの出力を用いたlossの微分値をしりたいなら、grad_output[0]だけを監視してればいい

※ register_backward_hookとは古く、register_full_backward_hookを使うべきだが、Relu(inplace)がある場合動かない[1][2]ため、今回はあえてregister_backward_hookを使っている

register_backward_hookと register_full_backward_hookを使った場合で、grad_input中身が変わる。

register_full_backward_hookの場合

[0] is the derivative of loss wrt layer input
[1] is the derivative of loss wrt layer output (before activation)
[2] is the derivative of loss wrt layer weights

register_backward_hookの場合fcとconvで異なる

fcの場合
[0] shape [10] - Bias values.
[1] shape [64, 84] - Data. The first value is the 64 batches, 84 inputs from the previous layer.
[2] shape [84, 10] - Layer weights. Each node in the fully connected layer receives the 84 outputs from the previous layer. There are 10 nodes.

convの場合
[0] shape [64, 16, 16, 16] - This is the input data. 64 batches, 16 feature maps deep, 16 width, 16 height.
[1] shape [32, 16, 3, 3] - This is the kernel weight data. 32 kernels with 16 depth (to match number if input feature maps), and 3x3 height/width.
[2] shape [32] - This is the bias for each kernel

保存するためのクラス

hook関数を使って、ノードの値をprintするだけでなく、保存する必要があるのでそれ用のクラス。
前方伝搬が行われるたびに記録して、平均した値を返す

class SaveActive(object):
    def __init__(self, model):
        self.model = model
        self.fw_output = {}
        self.fw_input = {}
        self.bw_output = {}
        self.bw_input = {}
        self.fw_hook_lst = []
        self.bw_hook_lst = []
        self.clear_buffer()
        self.__registor_model(model)

    def __enter__(self):
        return self
    
    def __call__(self, model):
        self.__init__(model)
        
    def __exit__(self, exc_type, exc_value, traceback):
        self.remove_hook()
        self.clear_buffer()

    def clear_buffer(self):
        for name, layer in self.model.named_modules():
            if len(list(layer.named_children())) == 0:
                self.fw_input[name] = []
                self.fw_output[name] = []
                self.bw_input[name] = []
                self.bw_output[name] = []

    def __registor_model(self, model):
        for name, layer in model.named_modules():
            if len(list(layer.named_children())) == 0:
                # print(f'hook in {name}')
                fw_handle = layer.register_forward_hook(self.fw_save(name))
                self.fw_hook_lst.append(fw_handle)
                # except inplace for https://github.com/pytorch/pytorch/issues/61519
                # layer.register_full_backward_hook(self.bw_save(name))
                bw_handle = layer.register_backward_hook(self.bw_save(name))
                self.bw_hook_lst.append(bw_handle)

    def remove_hook(self):
        for fw_handle in self.fw_hook_lst:
            fw_handle.remove()
        for bw_handle in self.bw_hook_lst:
            bw_handle.remove()

    def fw_save(self, name):
        def forward_hook(model, inputs, output):
            tmp1 = inputs[0].detach().clone().cpu().to(torch.float32)
            if tmp1.dim() == 0:
                tmp1 = tmp1.unsqueeze(0)
            self.fw_input[name].append(tmp1)
            if output is not None:
                tmp2 = output.detach().clone().cpu().to(torch.float32)
                if tmp2.dim() == 0:
                    tmp2 = tmp2.unsqueeze(0)  # dim!=0 for torch.concat
                self.fw_output[name].append(tmp2)

        return forward_hook

    def bw_save(self, name):
        def backward_hook(module, grad_input, grad_output):
            if len(grad_input) >= 2:
                if grad_input[1] is not None:
                    tmp1 = grad_input[1].detach().clone().cpu().to(torch.float32)
                    if tmp1.dim() == 0:
                        tmp1 = tmp1.unsqueeze(0)
                    self.bw_input[name].append(tmp1)
            tmp2 = grad_output[0].detach().clone().cpu().to(torch.float32)
            if tmp2.dim() == 0:
                tmp2 = tmp2.unsqueeze(0)
            self.bw_output[name].append(tmp2)

        return backward_hook

    def get_fw_input_mean_norm(self):
        if self.__is_null(self.fw_input):
            print('Error: Please try foward prop')
            return {}
        means = {}
        for key in self.fw_input:
            if len(self.fw_input[key]) != 0:
                means[key] = torch.cat(self.fw_input[key], dim=0).mean(0).norm()
                n_data = len(torch.cat(self.fw_input[key], dim=0))
        # rint(f'mean norm by n_sameples: {n_data}')
        return means

    def get_fw_output_mean_norm(self):
        if self.__is_null(self.fw_output):
            print('Error: Please try foward prop')
            return {}
        means = {}
        for key in self.fw_output:
            if len(self.fw_output[key]) != 0:
                means[key] = torch.cat(self.fw_output[key], dim=0).mean(0).norm()
                n_data = len(torch.cat(self.fw_output[key], dim=0))
        # print(f'mean norm by n_sameples: {n_data}')
        return means

    def get_bw_input_mean_norm(self):
        if self.__is_null(self.bw_input):
            print('Error: Please try backward prop')
            return {}
        means = {}
        for key in self.bw_input:
            if len(self.bw_input[key]) != 0:
                means[key] = torch.cat(self.bw_input[key], dim=0).mean(0).norm()
                n_data = len(torch.cat(self.bw_input[key], dim=0))
        # print(f'mean norm by n_sameples: {n_data}')
        return means

    def get_bw_output_mean_norm(self):
        if self.__is_null(self.bw_output):
            print('Error: Please try backward prop')
            return {}
        means = {}
        for key in self.bw_output:
            if len(self.bw_output[key]) != 0:
                means[key] = torch.cat(self.bw_output[key], dim=0).mean(0).norm()
                n_data = len(torch.cat(self.bw_output[key], dim=0))
        # print(f'mean norm by n_sameples: {n_data}')
        return means

    def __is_null(self, dat):
        n_data = 0
        for key in dat:
            n_data += len(dat[key])
        return n_data == 0

使い方

model = nn.Sequential(nn.Linear(10, 5), nn.Linear(5, 2))

# 登録するだけ with scopeに入ってるときだけ関数hookがかかる
with SaveActive(model) as sa:
    x = torch.rand(6, 10)
    model(x).mean().backward()
    print(sa.get_fw_input_mean_norm())
    print(sa.get_fw_output_mean_norm())
    print(sa.get_bw_input_mean_norm())
    print(sa.get_bw_output_mean_norm())

出力

mean norm by n_sameples: 6
{'0': tensor(1.5072), '1': tensor(0.5847)}
mean norm by n_sameples: 6
{'0': tensor(0.5847), '1': tensor(0.4009)}
mean norm by n_sameples: 6
{'1': tensor(0.1083)}
mean norm by n_sameples: 6
{'0': tensor(0.1083), '1': tensor(0.1179)}

前方伝搬だけすると、fwだけが記録され、逆誤差伝搬までするとbwにも値が入る
すべての入力に対してbwしたときの勾配の平均がほしいなら、batchごとにbackward()を呼ぶ必要がある。
平均に使われたn_sample数が出力されるので、思った通りの動作をしているかはそこでわかる。
今は出力を楽にするためにnormを返しているが、.mean(0).norm()→.mean(0)に修正すればtensor自体がreturnされる。
with構文を用いて、withの中にあるときだけ記録するようにする

つまり

with SaveActive(model) as sa:
    model(x).mean()
    model(x).mean()
    model(x).mean().backward()

とすると
前方伝搬はサンプル数 18の平均になり(batch size 6 * 3回)
逆誤差伝搬はサンプル数6の平均になる

ResNetに使ってみる

resblockも x+h1 → h2 となり 1つ入力すると1つ出力するので問題いない
relu inplaceのみ気をつける必要がある

利用するモデル

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer2): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer3): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=512, out_features=1000, bias=True)
)
import torchvision

model = torchvision.models.resnet18(pretrained=True)
x = torch.rand(2, 3, 32, 32)

with SaveActive(model) as sa:
    model(x).mean()
    model(x).mean().backward()
    print(sa.get_fw_input_mean_norm())
    print(sa.get_fw_output_mean_norm())
    print(sa.get_bw_input_mean_norm())
    print(sa.get_bw_output_mean_norm())    

出力

mean norm by n_sameples: 4
{'conv1': tensor(30.0120), 'bn1': tensor(88.0899), 'relu': tensor(46.0366), 'maxpool': tensor(46.0366), 'layer1.0.conv1': tensor(38.5812), 'layer1.0.bn1': tensor(110.7326), 'layer1.0.relu': tensor(26.8837), 'layer1.0.conv2': tensor(14.6708), 'layer1.0.bn2': tensor(21.9908), 'layer1.1.conv1': tensor(44.4266), 'layer1.1.bn1': tensor(87.3894), 'layer1.1.relu': tensor(28.5649), 'layer1.1.conv2': tensor(11.4086), 'layer1.1.bn2': tensor(17.4100), 'layer2.0.conv1': tensor(50.4790), 'layer2.0.bn1': tensor(41.1456), 'layer2.0.relu': tensor(8.3203), 'layer2.0.conv2': tensor(7.5420), 'layer2.0.bn2': tensor(13.2688), 'layer2.0.downsample.0': tensor(50.4790), 'layer2.0.downsample.1': tensor(21.5922), 'layer2.1.conv1': tensor(11.9498), 'layer2.1.bn1': tensor(19.3483), 'layer2.1.relu': tensor(7.3593), 'layer2.1.conv2': tensor(4.8510), 'layer2.1.bn2': tensor(7.4920), 'layer3.0.conv1': tensor(12.2743), 'layer3.0.bn1': tensor(12.0483), 'layer3.0.relu': tensor(4.1795), 'layer3.0.conv2': tensor(4.0709), 'layer3.0.bn2': tensor(4.8925), 'layer3.0.downsample.0': tensor(12.2743), 'layer3.0.downsample.1': tensor(4.2953), 'layer3.1.conv1': tensor(5.5685), 'layer3.1.bn1': tensor(9.7999), 'layer3.1.relu': tensor(3.5061), 'layer3.1.conv2': tensor(2.0692), 'layer3.1.bn2': tensor(3.0278), 'layer4.0.conv1': tensor(5.9326), 'layer4.0.bn1': tensor(5.6728), 'layer4.0.relu': tensor(1.1065), 'layer4.0.conv2': tensor(0.6796), 'layer4.0.bn2': tensor(0.7235), 'layer4.0.downsample.0': tensor(5.9326), 'layer4.0.downsample.1': tensor(3.8254), 'layer4.1.conv1': tensor(1.8447), 'layer4.1.bn1': tensor(1.8723), 'layer4.1.relu': tensor(10.7230), 'layer4.1.conv2': tensor(0.7485), 'layer4.1.bn2': tensor(0.4364), 'avgpool': tensor(20.9373), 'fc': tensor(20.9373)}
mean norm by n_sameples: 4
{'conv1': tensor(88.0899), 'bn1': tensor(53.5855), 'relu': tensor(46.0366), 'maxpool': tensor(38.5812), 'layer1.0.conv1': tensor(110.7326), 'layer1.0.bn1': tensor(25.1572), 'layer1.0.relu': tensor(26.8837), 'layer1.0.conv2': tensor(21.9908), 'layer1.0.bn2': tensor(22.6131), 'layer1.1.conv1': tensor(87.3894), 'layer1.1.bn1': tensor(21.3544), 'layer1.1.relu': tensor(28.5649), 'layer1.1.conv2': tensor(17.4100), 'layer1.1.bn2': tensor(22.5308), 'layer2.0.conv1': tensor(41.1456), 'layer2.0.bn1': tensor(12.3804), 'layer2.0.relu': tensor(8.3203), 'layer2.0.conv2': tensor(13.2688), 'layer2.0.bn2': tensor(12.8707), 'layer2.0.downsample.0': tensor(21.5922), 'layer2.0.downsample.1': tensor(8.5293), 'layer2.1.conv1': tensor(19.3483), 'layer2.1.bn1': tensor(15.4010), 'layer2.1.relu': tensor(7.3593), 'layer2.1.conv2': tensor(7.4920), 'layer2.1.bn2': tensor(13.8359), 'layer3.0.conv1': tensor(12.0483), 'layer3.0.bn1': tensor(8.4056), 'layer3.0.relu': tensor(4.1795), 'layer3.0.conv2': tensor(4.8925), 'layer3.0.bn2': tensor(7.1288), 'layer3.0.downsample.0': tensor(4.2953), 'layer3.0.downsample.1': tensor(3.7063), 'layer3.1.conv1': tensor(9.7999), 'layer3.1.bn1': tensor(10.3078), 'layer3.1.relu': tensor(3.5061), 'layer3.1.conv2': tensor(3.0278), 'layer3.1.bn2': tensor(9.1545), 'layer4.0.conv1': tensor(5.6728), 'layer4.0.bn1': tensor(5.4074), 'layer4.0.relu': tensor(1.1065), 'layer4.0.conv2': tensor(0.7235), 'layer4.0.bn2': tensor(4.7105), 'layer4.0.downsample.0': tensor(3.8254), 'layer4.0.downsample.1': tensor(4.7105), 'layer4.1.conv1': tensor(1.8723), 'layer4.1.bn1': tensor(5.9726), 'layer4.1.relu': tensor(10.7230), 'layer4.1.conv2': tensor(0.4364), 'layer4.1.bn2': tensor(6.4667), 'avgpool': tensor(20.9373), 'fc': tensor(26.2633)}
mean norm by n_sameples: 2
{'conv1': tensor(9.8868e-06), 'bn1': tensor(1.3888e-06), 'layer1.0.conv1': tensor(9.0516e-06), 'layer1.0.bn1': tensor(1.0074e-07), 'layer1.0.conv2': tensor(6.2388e-06), 'layer1.0.bn2': tensor(1.0365e-07), 'layer1.1.conv1': tensor(4.8091e-06), 'layer1.1.bn1': tensor(1.0275e-07), 'layer1.1.conv2': tensor(4.9961e-06), 'layer1.1.bn2': tensor(1.1384e-07), 'layer2.0.conv1': tensor(5.5658e-06), 'layer2.0.bn1': tensor(2.1229e-07), 'layer2.0.conv2': tensor(5.4098e-06), 'layer2.0.bn2': tensor(4.3700e-07), 'layer2.0.downsample.0': tensor(1.4817e-06), 'layer2.0.downsample.1': tensor(7.2040e-07), 'layer2.1.conv1': tensor(4.3899e-06), 'layer2.1.bn1': tensor(1.8118e-07), 'layer2.1.conv2': tensor(4.0736e-06), 'layer2.1.bn2': tensor(1.9022e-07), 'layer3.0.conv1': tensor(1.9207e-06), 'layer3.0.bn1': tensor(4.6934e-08), 'layer3.0.conv2': tensor(1.5649e-06), 'layer3.0.bn2': tensor(9.5632e-08), 'layer3.0.downsample.0': tensor(6.0146e-07), 'layer3.0.downsample.1': tensor(2.4248e-08), 'layer3.1.conv1': tensor(4.3780e-06), 'layer3.1.bn1': tensor(4.8085e-08), 'layer3.1.conv2': tensor(1.3691e-06), 'layer3.1.bn2': tensor(2.0095e-07), 'layer4.0.conv1': tensor(6.0880e-07), 'layer4.0.bn1': tensor(1.8307e-07), 'layer4.0.conv2': tensor(6.9853e-07), 'layer4.0.bn2': tensor(6.8290e-09), 'layer4.0.downsample.0': tensor(2.0392e-07), 'layer4.0.downsample.1': tensor(1.7276e-08), 'layer4.1.conv1': tensor(9.9483e-08), 'layer4.1.bn1': tensor(8.7614e-08), 'layer4.1.conv2': tensor(2.3291e-07), 'layer4.1.bn2': tensor(2.2704e-08), 'fc': tensor(7.4263e-07)}
mean norm by n_sameples: 2
{'conv1': tensor(7.1883e-05), 'bn1': tensor(0.0001), 'relu': tensor(0.0001), 'maxpool': tensor(0.0001), 'layer1.0.conv1': tensor(3.6323e-05), 'layer1.0.bn1': tensor(5.3908e-05), 'layer1.0.relu': tensor(3.9725e-05), 'layer1.0.conv2': tensor(3.9918e-05), 'layer1.0.bn2': tensor(3.2062e-05), 'layer1.1.conv1': tensor(1.3335e-05), 'layer1.1.bn1': tensor(3.4516e-05), 'layer1.1.relu': tensor(2.7420e-05), 'layer1.1.conv2': tensor(3.2259e-05), 'layer1.1.bn2': tensor(2.1437e-05), 'layer2.0.conv1': tensor(1.7720e-05), 'layer2.0.bn1': tensor(5.2720e-05), 'layer2.0.relu': tensor(4.3331e-05), 'layer2.0.conv2': tensor(5.6932e-05), 'layer2.0.bn2': tensor(3.4321e-05), 'layer2.0.downsample.0': tensor(2.0645e-05), 'layer2.0.downsample.1': tensor(3.4321e-05), 'layer2.1.conv1': tensor(3.0945e-05), 'layer2.1.bn1': tensor(3.5875e-05), 'layer2.1.relu': tensor(3.2991e-05), 'layer2.1.conv2': tensor(5.2012e-05), 'layer2.1.bn2': tensor(2.3032e-05), 'layer3.0.conv1': tensor(2.8443e-05), 'layer3.0.bn1': tensor(2.7309e-05), 'layer3.0.relu': tensor(1.9914e-05), 'layer3.0.conv2': tensor(4.2691e-05), 'layer3.0.bn2': tensor(2.1144e-05), 'layer3.0.downsample.0': tensor(1.4941e-05), 'layer3.0.downsample.1': tensor(2.1144e-05), 'layer3.1.conv1': tensor(3.4979e-05), 'layer3.1.bn1': tensor(2.2853e-05), 'layer3.1.relu': tensor(1.3364e-05), 'layer3.1.conv2': tensor(4.1077e-05), 'layer3.1.bn2': tensor(9.4847e-06), 'layer4.0.conv1': tensor(1.9133e-10), 'layer4.0.bn1': tensor(3.0961e-05), 'layer4.0.relu': tensor(1.9245e-07), 'layer4.0.conv2': tensor(1.8833e-11), 'layer4.0.bn2': tensor(5.1805e-06), 'layer4.0.downsample.0': tensor(3.0601e-11), 'layer4.0.downsample.1': tensor(5.1805e-06), 'layer4.1.conv1': tensor(1.9200e-11), 'layer4.1.bn1': tensor(7.6009e-06), 'layer4.1.relu': tensor(3.7132e-07), 'layer4.1.conv2': tensor(7.4483e-12), 'layer4.1.bn2': tensor(3.8490e-07), 'avgpool': tensor(7.4263e-07), 'fc': tensor(0.0158)}

Discussion