🪞

【PyTorch】SimSiam で対照学習を実装してみた

2022/06/23に公開

はじめに

はじめに,SimSiam についての軽い説明と筆者が用いている環境について話します.

本記事の詳細については以下の GitHub にコードを載せているのでよければ見てください.Star や Pull Request 等頂けるとやる気が出ます↓

この GitHub では,SimSiamを使ってCIFAR10で学習させるサンプルコードも置いているのでお力になれれば幸いです.

SimSiam とは

論文はこちら↓

SimSiamとは「教師なし学習における Contrastive Learning」の手法で,1つの画像に対してランダムな変形を2回行い,同時に入力します.そして,特徴空間において類似した画像が互いに近くにあり,異なる画像が遠く離れているような特徴表現を学習します.

simsiam

例えばクラス分類が目的なら

  • サンプルペアが同じクラスに属するならば特徴ベクトルを近づける
  • サンプルペアが異なるクラスに属するならば特徴ベクトルを遠ざける

様に学習をさせるために

Loss = - \frac{\bm{Q} \cdot \bm{D}}{\displaystyle \left\Vert \bm{Q} \right\Vert \left\Vert \bm{D} \right\Vert}

上記の Loss を低くなる様にステップさせて行きます.

環境

PC MacBook Pro (16-inch, 2019)
OS Monterey
CPU 2.3 GHz 8コアIntel Core i9
メモリ 16GB
Python 3.9

ライブラリ

今回の記事で用いるライブラリとバージョンをまとめますが,特に気にせず

terminal
pip install numpy pillow torch torchvision

で問題ないかと思います.

以下,使用するライブラリのバージョンです.

ライブラリ バージョン
numpy 1.21.2
pillow 9.0.1
torch 1.10.1
torchvision 0.11.2

階層構造

simsiam
├── loss
│   ├── __init__.py
│   └── negative_cosine_similarity.py
├── model
│   ├── __init__.py
│   ├── backbone
│   │   ├── __init__.py
│   │   └── resnet18.py
│   ├── modules
│   │   ├── __init__.py
│   │   ├── encoder.py
│   │   └── predictor.py
│   └── simsiam.py
├── trainer.py
└── transforms
    ├── __init__.py
    └── transforms.py

実装

以降では,実際にSimSiamのコードを書いていきます.

transforms

まず,1枚の入力画像に対して,複数の変形をおこなった2枚の画像を得るためのtransformsを記述します.

transforms.py
import random
from typing import List

import torch
from PIL import ImageFilter
from torchvision import transforms


class TwoCropsTransform:
    """Take two random crops of one image as the query and key."""

    def __init__(self, base_transform):
        self.base_transform = base_transform

    def __call__(self, x) -> List[torch.Tensor]:
        q = self.base_transform(x)
        k = self.base_transform(x)
        return [q, k]


class GaussianBlur:
    def __init__(self, sigma=[0.1, 2.0]):
        self.sigma = sigma

    def __call__(self, x):
        sigma = random.uniform(self.sigma[0], self.sigma[1])
        x = x.filter(ImageFilter.GaussianBlur(radius=sigma))
        return x


def get_transforms(mode: str):
    base_transform = {
        "train": transforms.Compose(
            [
                # transforms.ToPILImage(),
                transforms.RandomResizedCrop(size=(512, 512), scale=(0.2, 1.0)),
                transforms.RandomGrayscale(p=0.2),
                transforms.RandomApply([GaussianBlur([0.1, 2.0])], p=0.5),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ]
        ),
        "valid": transforms.Compose(
            [
                # transforms.ToPILImage(),
                transforms.RandomResizedCrop(size=(512, 512), scale=(0.2, 1.0)),
                transforms.RandomGrayscale(p=0.2),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ]
        ),
    }

    return TwoCropsTransform(base_transform[mode])

Encoder

次にEncoder部分のコードです.

encoder.py
import torch
from torch import nn


class Encoder(nn.Module):
    def __init__(self, dim: int = 2048) -> None:
        super().__init__()

        self.layer = nn.Sequential(
            nn.Linear(dim, dim, bias=False),
            nn.BatchNorm1d(dim),
            nn.ReLU(inplace=True),  # first layer
            nn.Linear(dim, dim, bias=False),
            nn.BatchNorm1d(dim),
            nn.ReLU(inplace=True),  # second layer
            nn.BatchNorm1d(dim, affine=False),
        )

    def forward(self, x) -> torch.Tensor:
        return self.layer(x)

Predictor

次にPredictor部分のコードです.

predictor.py
import torch
from torch import nn


class Predictor(nn.Module):
    def __init__(self, dim: int = 2048, pred_dim: int = 512) -> None:
        super().__init__()

        self.layer = nn.Sequential(
            nn.Linear(dim, pred_dim, bias=False),
            nn.BatchNorm1d(pred_dim),
            nn.ReLU(inplace=True),  # hidden layer
            nn.Linear(pred_dim, dim),  # output layer
        )

    def forward(self, x) -> torch.Tensor:
        return self.layer(x)

SimSiam

いよいよSimSiamのコードに入ります.

simsiam.py
import torch
from torch import nn

from .modules import Encoder, Predictor


class SimSiam(nn.Module):
    def __init__(
        self,
        backbone: nn.Module,
        dim: int = 2048,
        pred_dim: int = 512,
    ) -> None:

        super().__init__()

        self.dim = dim
        self.pred_dim = pred_dim

        self.backbone = backbone
        self.encoder = Encoder(dim=dim)
        self.predictor = Predictor(dim=dim, pred_dim=pred_dim)

    def forward(self, x0: torch.Tensor, x1: torch.Tensor):
        f0 = self.backbone(x0).flatten(start_dim=1)
        f1 = self.backbone(x1).flatten(start_dim=1)

        z0 = self.encoder(f0)
        z1 = self.encoder(f1)

        p0 = self.predictor(z0)
        p1 = self.predictor(z1)

        return (p0, z0.detach()), (p1, z1.detach())

ポイント

ここでのポイントは.detachです.
結論から言うと,.detachが使われることによって,勾配計算から切り離すことができます.つまり,写真で言うところのstop-gradの役割を担っていると言うことです.
validationにおけるtorch.no_gradと同じ役割をしています.

criterion

最後に,損失関数のコードです.
損失関数については,小さくなるように学習させるため,-1をかけています.

negative_cosine_similarity.py
import torch
from torch.nn.functional import cosine_similarity


class NegativeCosineSimilarity(torch.nn.Module):
    def __init__(self, dim: int = 1, eps: float = 1e-8) -> None:
        super().__init__()
        self.dim = dim
        self.eps = eps

    def forward(self, x0: torch.Tensor, x1: torch.Tensor) -> torch.Tensor:
        return -cosine_similarity(x0, x1, self.dim, self.eps).mean()

おわりに

大学での研究でSimSiamに触れる機会があったので,Pytorchを使って実際に自分でコードを書いてみました.

論文を全て追って実装した経験があまりなかったので,とても良い経験となりました.
また (気が向いたら) 自分が追ったことのある論文について,実装して記事にできたら良いなと考えています.

間違ったことなどありましたら,優しく教えていただけると幸いです.

GitHubで編集を提案

Discussion