Transformer に触れてみる (2) — ViT もどき
目的
- Vision Transformer入門 をパラパラめくってもさっぱり理解できる気がしない。
- arXiv:2010.11929 An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale のいつものパッチ並べを見るのはもう飽きた。
のを何とかしたい。ということで、Transformer に触れてみる (1) の続き。
再び Vibe coding
またもや GPT-4.1 にお願いして学習素材を作ってもらうが、ViT は Transformer エンコーダに MLP を繋げる形なので、前回の Transformer に触れてみる (1) の MiniFormer
が再利用できるような気がする。よって、前回のソースコードをコンテキストとして与えて作成してもらった。
作ったもの
- タスクの内容: 32x32 の画像のデータセットを
input_seq = ["[CLS]", "p0", "1", ..., "p63"]
で、パッチサイズ 4x4 で 8x8 個のパッチに分割して Transformer エンコーダに入力して特徴量を抽出し、MLP で 4 クラス分類するタスクを考慮。 - データセットには CIFAR-10 を用い、
['cat', 'airplane', 'frog', 'truck']
のサブセットを用いた。 -
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale のアーキテクチャを大幅に簡易化した
MiniViT
を用いる。
実装
全部を理解できていないので、単純に貼り付ける。Google Colab 上で実行する。NVIDIA T4 で実行できる。訓練は 7 分程度なのですぐ完了する。
準備やデータセット
以下のような形で準備する。
!pip install -qU bertviz
from __future__ import annotations
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader, Subset
import numpy as np
import random
import json
import matplotlib.pyplot as plt
# クラス設定
selected_classes = ['cat', 'airplane', 'frog', 'truck']
class_to_idx = {
'airplane': 0, 'automobile': 1, 'bird': 2, 'cat': 3, 'deer': 4,
'dog': 5, 'frog': 6, 'horse': 7, 'ship': 8, 'truck': 9
}
selected_labels = [class_to_idx[c] for c in selected_classes]
# ラベルリマップ: 元ラベル -> 0~3
label_map = {orig: i for i, orig in enumerate(selected_labels)}
# データセット取得
transform = transforms.ToTensor()
dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
# サブセット抽出
indices = [i for i, label in enumerate(dataset.targets) if label in selected_labels]
subset = Subset(dataset, indices)
# リマップ用ラッパー
class RemapLabels(torch.utils.data.Dataset):
def __init__(self, subset, label_map):
self.subset = subset
self.label_map = label_map
def __getitem__(self, idx):
img, label = self.subset[idx]
return img, self.label_map[label]
def __len__(self):
return len(self.subset)
filtered_dataset = RemapLabels(subset, label_map)
# train/test分割
indices = list(range(len(filtered_dataset)))
random.seed(42)
random.shuffle(indices)
split = int(len(indices) * 0.8)
train_idx, test_idx = indices[:split], indices[split:]
train_ds = torch.utils.data.Subset(filtered_dataset, train_idx)
test_ds = torch.utils.data.Subset(filtered_dataset, test_idx)
batch_size = 64
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False)
データセットを見てみる。
# 可視化: 1バッチ分表示
images, labels = next(iter(test_loader))
class_names = selected_classes
viz_batch_size = 8
fig, axes = plt.subplots(1, viz_batch_size, figsize=(2.5*viz_batch_size, 3))
for i in range(viz_batch_size):
img = images[i].permute(1, 2, 0).numpy()
axes[i].imshow(img)
axes[i].set_title(class_names[labels[i]])
axes[i].axis('off')
plt.tight_layout()
plt.show()
MiniViT
# 2. パッチ分割ユーティリティ
def img_to_patches(img, patch_size=4):
# img: [C, H, W]
C, H, W = img.shape
assert H % patch_size == 0 and W % patch_size == 0
patches = img.unfold(1, patch_size, patch_size).unfold(2, patch_size, patch_size)
patches = patches.contiguous().view(C, -1, patch_size, patch_size)
patches = patches.permute(1, 0, 2, 3).contiguous().view(-1, C * patch_size * patch_size)
return patches # [num_patches, patch_dim]
# 3. 位置エンコーディング
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len):
super().__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-np.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
if d_model > 1:
pe[:, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe)
def forward(self, x):
# x: [B, N, d]
return x + self.pe[:x.size(1)]
# 4. MiniViT (MiniFormerベース)
class MiniViT(nn.Module):
def __init__(self,
patch_dim, # 3*patch*patch
num_patches,
num_classes=4,
d_model=64):
super().__init__()
self.patch_embed = nn.Linear(patch_dim, d_model)
self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))
self.pos_enc = PositionalEncoding(d_model, num_patches+1)
# Self-Attention (single head, 1層)
self.q_linear = nn.Linear(d_model, d_model)
self.k_linear = nn.Linear(d_model, d_model)
self.v_linear = nn.Linear(d_model, d_model)
self.attn_out = nn.Linear(d_model, d_model)
# FFN
self.ffn = nn.Sequential(
nn.Linear(d_model, d_model),
nn.ReLU(),
nn.Linear(d_model, d_model)
)
self.ln1 = nn.LayerNorm(d_model)
self.ln2 = nn.LayerNorm(d_model)
self.fc_out = nn.Linear(d_model, num_classes)
self.attn_weights = None # for visualization
def forward(self, img, return_attn=False):
# img: [B, 3, H, W]
B = img.size(0)
patches = torch.stack([img_to_patches(im) for im in img]) # [B, N, patch_dim]
x = self.patch_embed(patches) # [B, N, d_model]
# cls token
cls_token = self.cls_token.expand(B, -1, -1) # [B, 1, d_model]
x = torch.cat([cls_token, x], dim=1) # [B, N+1, d_model]
x = self.pos_enc(x)
# Attention
Q = self.q_linear(x)
K = self.k_linear(x)
V = self.v_linear(x)
scores = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(Q.size(-1))
attn = torch.softmax(scores, dim=-1)
attn_out = torch.matmul(attn, V)
attn_out = self.attn_out(attn_out)
x1 = self.ln1(x + attn_out)
x2 = self.ln2(x1 + self.ffn(x1))
cls_out = x2[:, 0] # [B, d_model]
logits = self.fc_out(cls_out)
if return_attn:
self.attn_weights = attn.detach().cpu().numpy()
return logits, attn
return logits
# 5. BertViz互換のAttention保存
def save_attention(attn_matrix, input_tokens, filename="attn_weights.json"):
data = {
"tokens": input_tokens,
"attentions": attn_matrix.tolist()
}
with open(filename, "w") as f:
json.dump(data, f, indent=2)
訓練ループ
# 6. 学習・評価
def train(model, loader, optimizer, criterion, device):
model.train()
total_loss = 0
for x, y in loader:
x = x.to(device)
y = y.to(device)
optimizer.zero_grad()
logits = model(x)
loss = criterion(logits, y)
loss.backward()
optimizer.step()
total_loss += loss.item()
return total_loss / len(loader)
def evaluate(model, loader, criterion, device):
model.eval()
total_loss = 0
correct = 0
total = 0
with torch.no_grad():
for x, y in loader:
x = x.to(device)
y = y.to(device)
logits = model(x)
loss = criterion(logits, y)
total_loss += loss.item()
pred = logits.argmax(dim=1)
correct += (pred == y).sum().item()
total += y.size(0)
return total_loss / len(loader), correct / total
学習
%%time
device = "cuda" if torch.cuda.is_available() else "cpu"
patch_size = 4
num_patches = (32 // patch_size) * (32 // patch_size)
patch_dim = 3 * patch_size * patch_size
d_model = 64
n_epochs = 100
model = MiniViT(patch_dim, num_patches, num_classes=4, d_model=d_model).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
for epoch in range(1, n_epochs+1):
train_loss = train(model, train_loader, optimizer, criterion, device)
test_loss, test_acc = evaluate(model, test_loader, criterion, device)
print(f"Epoch {epoch:2d}: train loss={train_loss:.4f}, test loss={test_loss:.4f}, test acc={test_acc:.4f}")
# 推論+Attention保存例
model.eval()
x, y = next(iter(test_loader))
x = x.to(device)
with torch.no_grad():
logits, attn = model(x[:1], return_attn=True)
pred = logits.argmax(dim=1).cpu().item()
print(f"True: {y[0].item()}, Pred: {pred}")
# 可視化用token名 (patch番号)
input_tokens = ["[CLS]"] + [f"p{i}" for i in range(num_patches)]
save_attention(attn[0], input_tokens, filename="attn_weights.json")
print("Saved attention weights to attn_weights.json. Visualize with BertViz or similar tools.")
Epoch 1: train loss=1.1329, test loss=1.0079, test acc=0.5683
Epoch 2: train loss=0.9938, test loss=0.9656, test acc=0.6025
Epoch 3: train loss=0.9543, test loss=0.9706, test acc=0.5925
Epoch 4: train loss=0.9334, test loss=0.9198, test acc=0.6110
Epoch 5: train loss=0.9158, test loss=0.8760, test acc=0.6400
...
Epoch 96: train loss=0.7041, test loss=0.7867, test acc=0.6730
Epoch 97: train loss=0.7026, test loss=0.7763, test acc=0.6817
Epoch 98: train loss=0.7061, test loss=0.7667, test acc=0.6827
Epoch 99: train loss=0.7046, test loss=0.7796, test acc=0.6770
Epoch 100: train loss=0.7053, test loss=0.7757, test acc=0.6817
True: 0, Pred: 2
Saved attention weights to attn_weights.json. Visualize with BertViz or similar tools.
CPU times: user 7min 11s, sys: 1.43 s, total: 7min 12s
Wall time: 7min 19s
当てずっぽうなら test acc=0.25 であろうから、0.68 に到達しているならそれなりに特徴をとらえたと言えるだろう。
セルフアテンション(1 ヘッド)の可視化
import json
import numpy as np
import torch
from bertviz import head_view
with open('attn_weights.json') as f:
data = json.load(f)
attn = np.array(data['attentions']) # (65, 65)
tokens = data['tokens']
# (1, 1, 65, 65) に変換し、リストで包む(レイヤー数=1のViTの場合)
attention = [torch.tensor(attn).unsqueeze(0).unsqueeze(0)] # [ (1, 1, 65, 65) ]
print(f"attn.shape: {attn.shape}, tokens: {tokens}")
head_view(attention=attention, tokens=tokens)
全体 | 「p30」の部分 |
---|---|
![]() |
![]() |
アテンションは全体に薄く広がっているようだ。
推論
推論して GT ラベルと推論したラベルを比較する。実際には訓練の具合にもよるのだが、学習時のログが示すように、概ね 6 割以上は正解する感じである。
import matplotlib.pyplot as plt
import torch
# モデル, test_loader, device, class_names などは既に用意されているものとします
def visualize_predictions(model, test_loader, class_names, device, num_images=8):
model.eval()
images_shown = 0
fig, axes = plt.subplots(1, num_images, figsize=(num_images*2, 2.5))
axes = axes if num_images > 1 else [axes]
with torch.no_grad():
for images, labels in test_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, preds = outputs.max(1)
for i in range(images.size(0)):
if images_shown >= num_images:
break
ax = axes[images_shown]
img = images[i].cpu()
# 画像変換(例:0-1スケールに戻す)
img = img.permute(1, 2, 0).numpy()
img = (img - img.min()) / (img.max() - img.min())
ax.imshow(img)
ax.axis('off')
true_label = class_names[labels[i].item()]
pred_label = class_names[preds[i].item()]
color = 'green' if true_label == pred_label else 'red'
ax.set_title(f"T:{true_label}\nP:{pred_label}", color=color, fontsize=10)
images_shown += 1
if images_shown >= num_images:
break
plt.tight_layout()
plt.show()
# 例:class_names = ['cat', 'airplane', 'frog', 'truck']
visualize_predictions(model, test_loader, class_names, device, num_images=8)
アテンションマップ
あまり良く知らないのだが、先頭にクラストークン ([CLS]
) を言うのを置いて、この [CLS]
からのアテンションを可視化するらしい [1]。ImageNet について考える (3) — Tiny ImageNet の分類の説明可能性とモデル圧縮 で CNN ベースの画像分類モデルに対して行った Grad-CAM みたいなものと思っている。
import torch
import matplotlib.pyplot as plt
import numpy as np
def show_attention_on_image(img, attn_map, patch_size=4):
# img: [3, 32, 32] (CIFAR-10)
# attn_map: [n_patches] (64次元)
img = img.cpu()
attn_map = attn_map.cpu().numpy()
# パッチを画像サイズに展開
n_patches = attn_map.shape[0]
grid_size = int(np.sqrt(n_patches)) # 8
attn_map_2d = attn_map.reshape(grid_size, grid_size)
# パッチごとに拡大
attn_map_up = np.kron(attn_map_2d, np.ones((patch_size, patch_size)))
# 端数処理
attn_map_up = attn_map_up[:img.shape[1], :img.shape[2]]
# 画像のスケール調整
img_np = img.permute(1,2,0).numpy()
img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min())
plt.imshow(img_np)
plt.imshow(attn_map_up, cmap='jet', alpha=0.5)
plt.axis('off')
plt.title("Attention Heatmap")
plt.show()
# 推論・可視化例
idx = 0
model.eval()
x, y = next(iter(test_loader))
x = x.to(device)
with torch.no_grad():
logits, attn = model(x[idx:idx+1], return_attn=True)
pred = logits.argmax(dim=1).item()
# [CLS]トークンから各パッチへのアテンション(1ヘッド分)
#attn_map = torch.from_numpy(attn[0, 0, 1:]) # 64次元
attn_map = attn[0, 0, 1:] # 64次元Tensor
show_attention_on_image(x[idx], attn_map, patch_size=4)
print(f"True label: {y[0].item()}, Pred label: {pred}")
テストローダの先頭の 8 枚に対するアテンションマップの出力を、先ほどの GT とのラベル比較の図と並べると以下のようになる。大体が背景よりもターゲットオブジェクトに注目しているようである。
0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 |
---|---|---|---|---|---|---|---|
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
因みに学習前のモデルでアテンションマップを可視化すると、先頭の画像に対する結果は以下のようであり、まったくもって無茶苦茶である。
まとめ
ViT をフルスクラッチで実装することも、簡易化することもできないので、今回も生成 AI に実装してもらった。思ったよりもなるほどなというアテンションマップになったので良かった。
補足
パッチの大きさを固定する場合を考える: パッチ分割が
参考文献
- Vision Transformer入門
- arXiv:2010.11929 An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale
- bertviz
-
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale では Figure 1 の caption に In order to perform classification, we use the standard approach of adding an extra learnable “classification token” to the sequence. と書かれている。 ↩︎
Discussion