🦜

Transformer に触れてみる (2) — ViT もどき

に公開

目的

のを何とかしたい。ということで、Transformer に触れてみる (1) の続き。

再び Vibe coding

またもや GPT-4.1 にお願いして学習素材を作ってもらうが、ViT は Transformer エンコーダに MLP を繋げる形なので、前回の Transformer に触れてみる (1) MiniFormer が再利用できるような気がする。よって、前回のソースコードをコンテキストとして与えて作成してもらった。

作ったもの

  • タスクの内容: 32x32 の画像のデータセットを input_seq = ["[CLS]", "p0", "1", ..., "p63"] で、パッチサイズ 4x4 で 8x8 個のパッチに分割して Transformer エンコーダに入力して特徴量を抽出し、MLP で 4 クラス分類するタスクを考慮。
  • データセットには CIFAR-10 を用い、['cat', 'airplane', 'frog', 'truck'] のサブセットを用いた。
  • An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale のアーキテクチャを大幅に簡易化した MiniViT を用いる。

実装

全部を理解できていないので、単純に貼り付ける。Google Colab 上で実行する。NVIDIA T4 で実行できる。訓練は 7 分程度なのですぐ完了する。

準備やデータセット

以下のような形で準備する。

!pip install -qU bertviz
from __future__ import annotations

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader, Subset
import numpy as np
import random
import json
import matplotlib.pyplot as plt

# クラス設定
selected_classes = ['cat', 'airplane', 'frog', 'truck']
class_to_idx = {
    'airplane': 0, 'automobile': 1, 'bird': 2, 'cat': 3, 'deer': 4,
    'dog': 5, 'frog': 6, 'horse': 7, 'ship': 8, 'truck': 9
}
selected_labels = [class_to_idx[c] for c in selected_classes]

# ラベルリマップ: 元ラベル -> 0~3
label_map = {orig: i for i, orig in enumerate(selected_labels)}

# データセット取得
transform = transforms.ToTensor()
dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)

# サブセット抽出
indices = [i for i, label in enumerate(dataset.targets) if label in selected_labels]
subset = Subset(dataset, indices)

# リマップ用ラッパー
class RemapLabels(torch.utils.data.Dataset):
    def __init__(self, subset, label_map):
        self.subset = subset
        self.label_map = label_map
    def __getitem__(self, idx):
        img, label = self.subset[idx]
        return img, self.label_map[label]
    def __len__(self):
        return len(self.subset)

filtered_dataset = RemapLabels(subset, label_map)

# train/test分割
indices = list(range(len(filtered_dataset)))
random.seed(42)
random.shuffle(indices)
split = int(len(indices) * 0.8)
train_idx, test_idx = indices[:split], indices[split:]
train_ds = torch.utils.data.Subset(filtered_dataset, train_idx)
test_ds = torch.utils.data.Subset(filtered_dataset, test_idx)

batch_size = 64
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False)

データセットを見てみる。

# 可視化: 1バッチ分表示
images, labels = next(iter(test_loader))
class_names = selected_classes

viz_batch_size = 8

fig, axes = plt.subplots(1, viz_batch_size, figsize=(2.5*viz_batch_size, 3))
for i in range(viz_batch_size):
    img = images[i].permute(1, 2, 0).numpy()
    axes[i].imshow(img)
    axes[i].set_title(class_names[labels[i]])
    axes[i].axis('off')
plt.tight_layout()
plt.show()

MiniViT

# 2. パッチ分割ユーティリティ
def img_to_patches(img, patch_size=4):
    # img: [C, H, W]
    C, H, W = img.shape
    assert H % patch_size == 0 and W % patch_size == 0
    patches = img.unfold(1, patch_size, patch_size).unfold(2, patch_size, patch_size)
    patches = patches.contiguous().view(C, -1, patch_size, patch_size)
    patches = patches.permute(1, 0, 2, 3).contiguous().view(-1, C * patch_size * patch_size)
    return patches # [num_patches, patch_dim]

# 3. 位置エンコーディング
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-np.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        if d_model > 1:
            pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)
    def forward(self, x):
        # x: [B, N, d]
        return x + self.pe[:x.size(1)]

# 4. MiniViT (MiniFormerベース)
class MiniViT(nn.Module):
    def __init__(self,
                 patch_dim,    # 3*patch*patch
                 num_patches,
                 num_classes=4,
                 d_model=64):
        super().__init__()
        self.patch_embed = nn.Linear(patch_dim, d_model)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))
        self.pos_enc = PositionalEncoding(d_model, num_patches+1)
        # Self-Attention (single head, 1層)
        self.q_linear = nn.Linear(d_model, d_model)
        self.k_linear = nn.Linear(d_model, d_model)
        self.v_linear = nn.Linear(d_model, d_model)
        self.attn_out = nn.Linear(d_model, d_model)
        # FFN
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.ReLU(),
            nn.Linear(d_model, d_model)
        )
        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)
        self.fc_out = nn.Linear(d_model, num_classes)
        self.attn_weights = None  # for visualization

    def forward(self, img, return_attn=False):
        # img: [B, 3, H, W]
        B = img.size(0)
        patches = torch.stack([img_to_patches(im) for im in img]) # [B, N, patch_dim]
        x = self.patch_embed(patches) # [B, N, d_model]
        # cls token
        cls_token = self.cls_token.expand(B, -1, -1) # [B, 1, d_model]
        x = torch.cat([cls_token, x], dim=1) # [B, N+1, d_model]
        x = self.pos_enc(x)
        # Attention
        Q = self.q_linear(x)
        K = self.k_linear(x)
        V = self.v_linear(x)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(Q.size(-1))
        attn = torch.softmax(scores, dim=-1)
        attn_out = torch.matmul(attn, V)
        attn_out = self.attn_out(attn_out)
        x1 = self.ln1(x + attn_out)
        x2 = self.ln2(x1 + self.ffn(x1))
        cls_out = x2[:, 0] # [B, d_model]
        logits = self.fc_out(cls_out)
        if return_attn:
            self.attn_weights = attn.detach().cpu().numpy()
            return logits, attn
        return logits

# 5. BertViz互換のAttention保存
def save_attention(attn_matrix, input_tokens, filename="attn_weights.json"):
    data = {
        "tokens": input_tokens,
        "attentions": attn_matrix.tolist()
    }
    with open(filename, "w") as f:
        json.dump(data, f, indent=2)

訓練ループ

# 6. 学習・評価
def train(model, loader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    for x, y in loader:
        x = x.to(device)
        y = y.to(device)
        optimizer.zero_grad()
        logits = model(x)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

def evaluate(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device)
            logits = model(x)
            loss = criterion(logits, y)
            total_loss += loss.item()
            pred = logits.argmax(dim=1)
            correct += (pred == y).sum().item()
            total += y.size(0)
    return total_loss / len(loader), correct / total

学習

%%time

device = "cuda" if torch.cuda.is_available() else "cpu"
patch_size = 4
num_patches = (32 // patch_size) * (32 // patch_size)
patch_dim = 3 * patch_size * patch_size
d_model = 64
n_epochs = 100

model = MiniViT(patch_dim, num_patches, num_classes=4, d_model=d_model).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

for epoch in range(1, n_epochs+1):
    train_loss = train(model, train_loader, optimizer, criterion, device)
    test_loss, test_acc = evaluate(model, test_loader, criterion, device)
    print(f"Epoch {epoch:2d}: train loss={train_loss:.4f}, test loss={test_loss:.4f}, test acc={test_acc:.4f}")

# 推論+Attention保存例
model.eval()
x, y = next(iter(test_loader))
x = x.to(device)
with torch.no_grad():
    logits, attn = model(x[:1], return_attn=True)
    pred = logits.argmax(dim=1).cpu().item()
    print(f"True: {y[0].item()}, Pred: {pred}")
    # 可視化用token名 (patch番号)
    input_tokens = ["[CLS]"] + [f"p{i}" for i in range(num_patches)]
    save_attention(attn[0], input_tokens, filename="attn_weights.json")
    print("Saved attention weights to attn_weights.json. Visualize with BertViz or similar tools.")

Epoch 1: train loss=1.1329, test loss=1.0079, test acc=0.5683
Epoch 2: train loss=0.9938, test loss=0.9656, test acc=0.6025
Epoch 3: train loss=0.9543, test loss=0.9706, test acc=0.5925
Epoch 4: train loss=0.9334, test loss=0.9198, test acc=0.6110
Epoch 5: train loss=0.9158, test loss=0.8760, test acc=0.6400
...
Epoch 96: train loss=0.7041, test loss=0.7867, test acc=0.6730
Epoch 97: train loss=0.7026, test loss=0.7763, test acc=0.6817
Epoch 98: train loss=0.7061, test loss=0.7667, test acc=0.6827
Epoch 99: train loss=0.7046, test loss=0.7796, test acc=0.6770
Epoch 100: train loss=0.7053, test loss=0.7757, test acc=0.6817
True: 0, Pred: 2
Saved attention weights to attn_weights.json. Visualize with BertViz or similar tools.
CPU times: user 7min 11s, sys: 1.43 s, total: 7min 12s
Wall time: 7min 19s

当てずっぽうなら test acc=0.25 であろうから、0.68 に到達しているならそれなりに特徴をとらえたと言えるだろう。

セルフアテンション(1 ヘッド)の可視化

import json
import numpy as np
import torch
from bertviz import head_view

with open('attn_weights.json') as f:
    data = json.load(f)
attn = np.array(data['attentions'])  # (65, 65)
tokens = data['tokens']

# (1, 1, 65, 65) に変換し、リストで包む(レイヤー数=1のViTの場合)
attention = [torch.tensor(attn).unsqueeze(0).unsqueeze(0)]  # [ (1, 1, 65, 65) ]

print(f"attn.shape: {attn.shape}, tokens: {tokens}")
head_view(attention=attention, tokens=tokens)
全体 「p30」の部分

アテンションは全体に薄く広がっているようだ。

推論

推論して GT ラベルと推論したラベルを比較する。実際には訓練の具合にもよるのだが、学習時のログが示すように、概ね 6 割以上は正解する感じである。

import matplotlib.pyplot as plt
import torch

# モデル, test_loader, device, class_names などは既に用意されているものとします

def visualize_predictions(model, test_loader, class_names, device, num_images=8):
    model.eval()
    images_shown = 0
    fig, axes = plt.subplots(1, num_images, figsize=(num_images*2, 2.5))
    axes = axes if num_images > 1 else [axes]

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, preds = outputs.max(1)
            for i in range(images.size(0)):
                if images_shown >= num_images:
                    break
                ax = axes[images_shown]
                img = images[i].cpu()
                # 画像変換(例:0-1スケールに戻す)
                img = img.permute(1, 2, 0).numpy()
                img = (img - img.min()) / (img.max() - img.min())
                ax.imshow(img)
                ax.axis('off')
                true_label = class_names[labels[i].item()]
                pred_label = class_names[preds[i].item()]
                color = 'green' if true_label == pred_label else 'red'
                ax.set_title(f"T:{true_label}\nP:{pred_label}", color=color, fontsize=10)
                images_shown += 1
            if images_shown >= num_images:
                break
    plt.tight_layout()
    plt.show()

# 例:class_names = ['cat', 'airplane', 'frog', 'truck']
visualize_predictions(model, test_loader, class_names, device, num_images=8)

アテンションマップ

あまり良く知らないのだが、先頭にクラストークン ([CLS]) を言うのを置いて、この [CLS] からのアテンションを可視化するらしい [1]ImageNet について考える (3) — Tiny ImageNet の分類の説明可能性とモデル圧縮 で CNN ベースの画像分類モデルに対して行った Grad-CAM みたいなものと思っている。

import torch
import matplotlib.pyplot as plt
import numpy as np

def show_attention_on_image(img, attn_map, patch_size=4):
    # img: [3, 32, 32] (CIFAR-10)
    # attn_map: [n_patches] (64次元)
    img = img.cpu()
    attn_map = attn_map.cpu().numpy()
    # パッチを画像サイズに展開
    n_patches = attn_map.shape[0]
    grid_size = int(np.sqrt(n_patches))  # 8
    attn_map_2d = attn_map.reshape(grid_size, grid_size)
    # パッチごとに拡大
    attn_map_up = np.kron(attn_map_2d, np.ones((patch_size, patch_size)))
    # 端数処理
    attn_map_up = attn_map_up[:img.shape[1], :img.shape[2]]
    # 画像のスケール調整
    img_np = img.permute(1,2,0).numpy()
    img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min())
    plt.imshow(img_np)
    plt.imshow(attn_map_up, cmap='jet', alpha=0.5)
    plt.axis('off')
    plt.title("Attention Heatmap")
    plt.show()

# 推論・可視化例
idx = 0
model.eval()
x, y = next(iter(test_loader))
x = x.to(device)
with torch.no_grad():
    logits, attn = model(x[idx:idx+1], return_attn=True)
    pred = logits.argmax(dim=1).item()
    # [CLS]トークンから各パッチへのアテンション(1ヘッド分)
    #attn_map = torch.from_numpy(attn[0, 0, 1:])  # 64次元
    attn_map = attn[0, 0, 1:]  # 64次元Tensor
    show_attention_on_image(x[idx], attn_map, patch_size=4)
    print(f"True label: {y[0].item()}, Pred label: {pred}")

テストローダの先頭の 8 枚に対するアテンションマップの出力を、先ほどの GT とのラベル比較の図と並べると以下のようになる。大体が背景よりもターゲットオブジェクトに注目しているようである。

0 1 2 3 4 5 6 7

因みに学習前のモデルでアテンションマップを可視化すると、先頭の画像に対する結果は以下のようであり、まったくもって無茶苦茶である。

まとめ

ViT をフルスクラッチで実装することも、簡易化することもできないので、今回も生成 AI に実装してもらった。思ったよりもなるほどなというアテンションマップになったので良かった。

補足

パッチの大きさを固定する場合を考える: パッチ分割が n \times m 個だとする場合、画像サイズを 2 倍の 64x64 にすると 2n \times 2m = 4 nm、画像サイズを 3 倍の 64x64 にすると 9 nm と、スケールの 2 乗のオーダーで増えるので結構計算時間が増えそうである。このため 64x64 の Tiny ImageNet はやる前から断念し CIFAR-10 にしたので、あまり綺麗な結果ではないかもしれない。

参考文献

脚注
  1. An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale では Figure 1 の caption に In order to perform classification, we use the standard approach of adding an extra learnable “classification token” to the sequence. と書かれている。 ↩︎

GitHubで編集を提案

Discussion