PyTorchオンリーで実装するGrad-CAM++
TL;DR
- Grad-CAM++の実装
def process_grad_cam(feat, pred, v=0):
eps = 0.000001
pred[0, v].backward(retain_graph=True)
g1 = feat.grad[0]
g2 = g1 ** 2
g3 = g1 ** 3
sumfeat = torch.sum(feat[0], dim=(1, 2))
aij = g2 / (2 * g2 + sumfeat[:, None, None] * g3 + eps)
aij = torch.where(g1 != 0, aij, torch.zeros(1))
weight = torch.maximum(g1, torch.zeros(1)) * aij
weight = torch.sum(weight, dim=(1, 2))
cam = F.relu(torch.sum(feat[0] * weight.view(-1,1,1), dim=0))
return cam
- ViTのfeature mapの可視化
register_forward_hook
をViTの最後のFFNへ登録し、その空間方向を含む特徴を用いてGrad-CAMを作れば良い。
ViTなどの画像モーダルTransformer系モデルを使う時に、使う出力はだいたい[cls]tokenだが、これをGrad-CAMで可視化しようとすると実は問題があるという説。
具体的には、CNNでは線形結合層(出力層)への入力は画像の空間方向の情報を含む特徴マップをGAPによってベクトル量へ圧縮し、それを予測に用いるのだが、ViTでは線形結合層への入力が[cls]tokenであり、これは画像に対応したtoken列からスライスによって切り取られたベクトルであるため、Grad-CAMのために特徴マップを取ってきても直接出力層との関連を見ることができないのだ。
この解決のため、FFN層の入力を取ってきて、画像を含んだ出力層を作る必要がある。
通常のGrad-CAM
まず、通常のCNNにおけるGrad-CAMを以下に書く。
以下では、ConvNeXtを例にGrad-CAMを出力する。なお、batch sizeは1である前提で話を進める。
import torch
from torch import nn
from torch.nn import functional as F
import timm
model = timm.create_model("convnext_tiny.fb_in1k")
extractor = nn.Sequential(
model.stem,
model.stages
).eval()
head = nn.Sequential(
model.head
).eval()
H, W = 224, 224
C = 768
img = torch.ones(1, 3, H, W) # 画像
# extractorの勾配を切ってheadの勾配だけ見れるようにする
feat = extractor(img)
feat = feat.clone().detach().requires_grad_(True)
pred = head(feat)
# Grad-CAM本体 ------------------------------------------------
def process_grad_cam(feat, pred, v=0): # v: 見たいクラス
# headでの計算で発生する勾配を計算
pred[0, v].backward(retain_graph=True)
# 空間方向の平均を取り特徴次元のみにする
weight = torch.mean(feat.grad[0], axis=(1,2))
# 勾配の大きさで特徴次元に重み付け
cam = F.relu(torch.sum(feat[0].view(C,H,W) * weight.view(-1,1,1), dim=0))
return cam
Grad-CAM++の実装
pytorch-grad-camを参考にtorchのみで書き直した。
def process_grad_cam(feat, pred, v=0):
eps = 0.000001
# headでの計算で発生する勾配を計算
pred[0, v].backward(retain_graph=True)
# 勾配の累乗項を計算
g1 = feat.grad[0]
g2 = g1 ** 2
g3 = g1 ** 3
# 特徴次元に掛ける重みの式を計算
sumfeat = torch.sum(feat[0], dim=(1, 2))
aij = g2 / (2 * g2 + sumfeat[:, None, None] * g3 + eps)
aij = torch.where(g1 != 0, aij, torch.zeros(1))
weight = torch.maximum(g1, torch.zeros(1)) * aij
weight = torch.sum(weight, dim=(1, 2))
# 勾配の大きさで特徴次元に重み付け
cam = F.relu(torch.sum(feat[0] * weight.view(-1,1,1), dim=0))
return cam
ViTをGrad-CAMで可視化する場合の注意点
このプログラムをViTに適用しようとすると、話が変わってくる。
img # [3, H, W]
model = timm.create_model("vit_tiny_patch16_224.augreg_in21k")
feat = extractor(img) # こいつは[197, 192] = [(H/32 * W/32 + 1), 192]
feat = feat.clone().detach().requires_grad_(True)
pred = head(feat) # ここで[1, 192]となるベクトルのみがスライスで取り出されて推論される
この中で何が起こっているかというと、[cls]token以外は全て捨てられており、出力層を通っているのは[cls]tokenである[1, 192]だけが推論に影響を与えているのだ!
Grad-CAMはheadがどの入力画像の空間領域で効いているかを可視化するため、画像と関係ないベクトルにくっついた勾配を見ていても埒が明かない。しかしとんでもないことに、このベクトルは画像に相当するtokenからdetachされずにスライスで切り取られているので、[cls]tokenへ与えた勾配(つまりAttentionの類似度?)の値が残っているのか謎の画像は得られてしまう。
これを改善し、画像に相当するtokenを含んだGrad-CAMを見るには、ViTの最後の層のFFNの入力を取ってくるのが良いと考えた。これであればHeadは3層の線形結合層のSequentialと考えることができるため、勾配の情報が適切にGrad-CAMとして出力されるだろう。
このためにregister_forward_hook
を用いてFFNから特徴をグローバル変数へ保存し、それを用いてViTのGrad-CAMを出力する。
def store_feature(self, model, input, output):
global FEATUER_STORED
FEATUER_STORED = output.clone().detach().requires_grad_(True)
return FEATUER_STORED
# register_forward_hook により特徴保存関数 store_feature を登録
handle = model.blocks[-1].norm2.register_forward_hook(store_feature)
feat = extractor(img)
feat = feat.clone().detach().requires_grad_(True)
pred = head(feat)
pred[0, v].backward(retain_graph=True)
alpha = torch.mean(feat.grad[0], axis=(1,2))
gradcam = F.relu(torch.sum(feat[0].view(C,H,W) * alpha.view(-1,1,1), dim=0))
# register_forward_hook を解除
handle.remove()
これによりViTでも画像に相当するtokenを用いてGrad-CAMを見ることができる。
ただ、これを行わなくても一応Grad-CAMらしい画像を見ることはできるので、画像が出力層を通っている必要があるかについて、もう少し真面目に考えて結論を出す必要がある。
ヒートマップの可視化
ヒートマップを可視化する時にtorchvisionのテンソルとして保存できると、concatenateで比較表を作りやすい。
def make_heatmap(img, size=(224,224), color="hot"):
buf = io.BytesIO()
plt.figure(figsize=size, dpi=1)
plt.gca().axis("off")
plt.subplots_adjust(left=0, right=1, bottom=0, top=1)
plt.imshow(torch.sum(img, dim=0), cmap=color)
plt.savefig(buf, format="png")
plt.clf()
plt.close()
buf = torchvision.io.decode_png(torch.frombuffer(buf.getvalue(), dtype=torch.uint8))
return buf[0:3]
また、Grad-CAMにReLUを掛けず負の要素を残して可視化する場合、ヒートマップの0に相当する色を固定したい気持ちになる。この場合、matplotlib.colors.TwoSlopeNorm
を用いて基準点を定めることで解決する。
from matplotlib.colors import TwoSlopeNorm
norm = TwoSlopeNorm(vcenter=0.0)
plt.imshow(img, cmap="bwr", norm=norm)
これにより0が白、負が青、正が赤のヒートマップを得ることができる。
Discussion