【PyTorch】SimSiam で対照学習を実装してみた
はじめに
はじめに,SimSiam についての軽い説明と筆者が用いている環境について話します.
本記事の詳細については以下の GitHub にコードを載せているのでよければ見てください.Star や Pull Request 等頂けるとやる気が出ます↓
この GitHub では,SimSiam
を使ってCIFAR10
で学習させるサンプルコードも置いているのでお力になれれば幸いです.
SimSiam とは
論文はこちら↓
SimSiam
とは「教師なし学習における Contrastive Learning」の手法で,1つの画像に対してランダムな変形を2回行い,同時に入力します.そして,特徴空間において類似した画像が互いに近くにあり,異なる画像が遠く離れているような特徴表現を学習します.
例えばクラス分類が目的なら
- サンプルペアが同じクラスに属するならば特徴ベクトルを近づける
- サンプルペアが異なるクラスに属するならば特徴ベクトルを遠ざける
様に学習をさせるために
上記の Loss を低くなる様にステップさせて行きます.
環境
PC | MacBook Pro (16-inch, 2019) |
---|---|
OS | Monterey |
CPU | 2.3 GHz 8コアIntel Core i9 |
メモリ | 16GB |
Python | 3.9 |
ライブラリ
今回の記事で用いるライブラリとバージョンをまとめますが,特に気にせず
pip install numpy pillow torch torchvision
で問題ないかと思います.
以下,使用するライブラリのバージョンです.
ライブラリ | バージョン |
---|---|
numpy | 1.21.2 |
pillow | 9.0.1 |
torch | 1.10.1 |
torchvision | 0.11.2 |
階層構造
simsiam
├── loss
│ ├── __init__.py
│ └── negative_cosine_similarity.py
├── model
│ ├── __init__.py
│ ├── backbone
│ │ ├── __init__.py
│ │ └── resnet18.py
│ ├── modules
│ │ ├── __init__.py
│ │ ├── encoder.py
│ │ └── predictor.py
│ └── simsiam.py
├── trainer.py
└── transforms
├── __init__.py
└── transforms.py
実装
以降では,実際にSimSiam
のコードを書いていきます.
transforms
まず,1枚の入力画像に対して,複数の変形をおこなった2枚の画像を得るためのtransforms
を記述します.
import random
from typing import List
import torch
from PIL import ImageFilter
from torchvision import transforms
class TwoCropsTransform:
"""Take two random crops of one image as the query and key."""
def __init__(self, base_transform):
self.base_transform = base_transform
def __call__(self, x) -> List[torch.Tensor]:
q = self.base_transform(x)
k = self.base_transform(x)
return [q, k]
class GaussianBlur:
def __init__(self, sigma=[0.1, 2.0]):
self.sigma = sigma
def __call__(self, x):
sigma = random.uniform(self.sigma[0], self.sigma[1])
x = x.filter(ImageFilter.GaussianBlur(radius=sigma))
return x
def get_transforms(mode: str):
base_transform = {
"train": transforms.Compose(
[
# transforms.ToPILImage(),
transforms.RandomResizedCrop(size=(512, 512), scale=(0.2, 1.0)),
transforms.RandomGrayscale(p=0.2),
transforms.RandomApply([GaussianBlur([0.1, 2.0])], p=0.5),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
),
"valid": transforms.Compose(
[
# transforms.ToPILImage(),
transforms.RandomResizedCrop(size=(512, 512), scale=(0.2, 1.0)),
transforms.RandomGrayscale(p=0.2),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
),
}
return TwoCropsTransform(base_transform[mode])
Encoder
次にEncoder
部分のコードです.
import torch
from torch import nn
class Encoder(nn.Module):
def __init__(self, dim: int = 2048) -> None:
super().__init__()
self.layer = nn.Sequential(
nn.Linear(dim, dim, bias=False),
nn.BatchNorm1d(dim),
nn.ReLU(inplace=True), # first layer
nn.Linear(dim, dim, bias=False),
nn.BatchNorm1d(dim),
nn.ReLU(inplace=True), # second layer
nn.BatchNorm1d(dim, affine=False),
)
def forward(self, x) -> torch.Tensor:
return self.layer(x)
Predictor
次にPredictor
部分のコードです.
import torch
from torch import nn
class Predictor(nn.Module):
def __init__(self, dim: int = 2048, pred_dim: int = 512) -> None:
super().__init__()
self.layer = nn.Sequential(
nn.Linear(dim, pred_dim, bias=False),
nn.BatchNorm1d(pred_dim),
nn.ReLU(inplace=True), # hidden layer
nn.Linear(pred_dim, dim), # output layer
)
def forward(self, x) -> torch.Tensor:
return self.layer(x)
SimSiam
いよいよSimSiam
のコードに入ります.
import torch
from torch import nn
from .modules import Encoder, Predictor
class SimSiam(nn.Module):
def __init__(
self,
backbone: nn.Module,
dim: int = 2048,
pred_dim: int = 512,
) -> None:
super().__init__()
self.dim = dim
self.pred_dim = pred_dim
self.backbone = backbone
self.encoder = Encoder(dim=dim)
self.predictor = Predictor(dim=dim, pred_dim=pred_dim)
def forward(self, x0: torch.Tensor, x1: torch.Tensor):
f0 = self.backbone(x0).flatten(start_dim=1)
f1 = self.backbone(x1).flatten(start_dim=1)
z0 = self.encoder(f0)
z1 = self.encoder(f1)
p0 = self.predictor(z0)
p1 = self.predictor(z1)
return (p0, z0.detach()), (p1, z1.detach())
ポイント
ここでのポイントは.detach
です.
結論から言うと,.detach
が使われることによって,勾配計算から切り離すことができます.つまり,写真で言うところのstop-grad
の役割を担っていると言うことです.
validation
におけるtorch.no_grad
と同じ役割をしています.
criterion
最後に,損失関数のコードです.
損失関数については,小さくなるように学習させるため,-1
をかけています.
import torch
from torch.nn.functional import cosine_similarity
class NegativeCosineSimilarity(torch.nn.Module):
def __init__(self, dim: int = 1, eps: float = 1e-8) -> None:
super().__init__()
self.dim = dim
self.eps = eps
def forward(self, x0: torch.Tensor, x1: torch.Tensor) -> torch.Tensor:
return -cosine_similarity(x0, x1, self.dim, self.eps).mean()
おわりに
大学での研究でSimSiam
に触れる機会があったので,Pytorch
を使って実際に自分でコードを書いてみました.
論文を全て追って実装した経験があまりなかったので,とても良い経験となりました.
また (気が向いたら) 自分が追ったことのある論文について,実装して記事にできたら良いなと考えています.
間違ったことなどありましたら,優しく教えていただけると幸いです.
Discussion