📺

Transformerよりもシンプル?「MLP-Mixer」爆誕(6日目最終日) ~Source Code編~

2021/06/02に公開

ニツオです。TwitterでAIやMLについて関連する話題を紹介してます。お気軽にフォローやご質問ください。

さて、2021年5月にMLP-Mixerというモデルが爆誕しました。本日はその解説シリーズ6日目です。

  • 1日目: Abstract / Introduction
  • 2日目: Mixer Architecture
  • 3日目: Experiments 1
  • 4日目: Experiments 2
  • 5日目: Related Work / Conclusion
  • 6日目: Source Code

「MLP-Mixer: An all-MLP Architecture for Vision」の原文はこちらです。2021年5月4日にGoogle ResearchとGoogle Brainの混合チームから発表され、関係者のTwitterでもかなり話題になっています。

シリーズ関連記事は一番下にリンク貼ってます。
早速みていきましょう。

Source Code

https://github.com/google-research/vision_transformer/blob/linen/vit_jax/models_mixer.py
より抜粋。

Copyright

# Copyright 2021 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Library

まずは必要なライブラリのimportを行います。本編とはあまり関係ありませんが、簡単ですが省略せずに説明いれていきます。

from typing import Any
import einops
import flax.linen as nn
import jax.numpy as jnp

typing

まず、typingはPythonの標準ライブラリで、型のアノテーション(注釈)をつけることができます。今回はAnyしか使ってないようですが、このAnyはどんな型で来た場合も許容します。標準ライブラリなので下記パスにありました。

型についてはあまり立ち入らないので、もっと知りたい方は下記参照ください。

https://docs.python.org/ja/3/library/typing.html
https://www.python.jp/pages/python3.9.html
https://qiita.com/papi_tokei/items/2a309d313bc6fc5661c3

einops

次に、einopsは、Alex Rogozhnikovが開発した、Machine Learning向けの便利なライブラリです(後で1回だけ出てくる)。リンク先の動画がイメージつきやすいです。
http://arogozhnikov.github.io/images/einops/einops_video.mp4
https://pypi.org/project/einops/
https://cpp-learning.com/einops/

事前にインストールしておく必要はあります。Google Colabで実行するとここに入ります。pipでインストールした格納先は特に気にすることなくimport出来ます。

jax and flax

jaxとflaxは、ともにDeep Learningのためのライブラリで、Googleが開発、オープンソース化したものです。詳しくは公式ドキュメントや参考記事をご覧ください。pipでインストール( pip install flax )すると下記に入りました。


https://opensource.google/projects/jax
https://github.com/google/jax
https://flax.readthedocs.io/en/latest/flax.linen.html

MlpBlockクラスの実装

MLPブロックのクラスのコードです。MLPブロックとは下記の図の黄色で囲った部分のことを指します。

class MlpBlock(nn.Module):
  mlp_dim: int

  @nn.compact
  def __call__(self, x):
    y = nn.Dense(self.mlp_dim)(x)
    y = nn.gelu(y)
    return nn.Dense(x.shape[-1])(y)

MLPブロックのクラスとして、class MlpBlock(nn.Module): でMlpBlock()が定義されています。flax.linen.Moduleを継承しているので、()内に nn.Module が入っています。この nn.Module は、flaxの中でレイヤーやモデルを記述する際に継承するべきベースとなるクラス、と説明されています。今回登場する3つのクラス、すべてにおいて継承されています。

https://flax.readthedocs.io/en/latest/flax.linen.html

mlp_dim: int の部分でmlp_dimという変数がint型で宣言されてます。これは下記のMLPブロックの図の黄色にハイライトした最初のFully-connectedレイヤーの重み \mathrm{W_1} をかけられた後の出力の次元数が mlp_dim です。

この \mathrm{W_1} の形状は、入力 \mathrm{x} の形状によって異なるので、定義されるのは型のみです。どんな入力が来ても型は変わらないので、クラスの直下に書きます。今回、入力 \mathrm{x} の形状は 1 \times S の行ベクトル、S はパッチ数で本論文が定義した変数なので、mlp_dimを D_{mlp} とすると、重み \mathrm{W_1} の形状は、S \times D_{mlp} です。

MlpBlockクラスのメソッド

次に、このクラス=MLPブロックの変換関数を def __call__() で定義していきます。変数は入力 \mathrm{x} とこのクラス MlpBlock自身を指す \mathrm{self} なので、def __call__(self, x): とします。

また、@nn.compact の@はデコレータで、関数 nn.compact で直後の def __call__(self, x): で定義したメソッドをラップしています。

flax.line.compactの公式ドキュメントはこちら。
https://flax.readthedocs.io/en/latest/_modules/flax/linen/module.html#compact

class MlpBlock(nn.Module):
  mlp_dim: int

  @nn.compact
  def __call__(self, x):
    y = nn.Dense(self.mlp_dim)(x)
    y = nn.gelu(y)
    return nn.Dense(x.shape[-1])(y)

def __call__(self, x): 内では、MLPブロックの図のように3層の変換を表現している。Fully-connectedレイヤー、GELU関数による活性化、再度Fully-connectedレイヤーという形だ。

MlpBlock内の1つ目のFully-connectedレイヤー

まず入力 \mathrm{x} の行ベクトルと重み \mathrm{W_1} を内積します。nn.Dense は公式ドキュメントによると下記ソースである。

flax.linen.Denseの公式ドキュメントはこちら。
https://flax.readthedocs.io/en/latest/_modules/flax/linen/linear.html#Dense

class Dense(Module):
  features: int
  use_bias: bool = True
  dtype: Any = jnp.float32
  precision: Any = None
  kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init
  bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = zeros

  @compact
  def __call__(self, inputs: Array) -> Array:
    inputs = jnp.asarray(inputs, self.dtype)
    kernel = self.param('kernel',
                        self.kernel_init,
                        (inputs.shape[-1], self.features))
    kernel = jnp.asarray(kernel, self.dtype)
    y = lax.dot_general(inputs, kernel,
                        (((inputs.ndim - 1,), (0,)), ((), ())),
                        precision=self.precision)
    if self.use_bias:
      bias = self.param('bias', self.bias_init, (self.features,))
      bias = jnp.asarray(bias, self.dtype)
      y = y + bias
    return y

DenseはFully-connectedレイヤーの変換を表現する関数で、nn.Dense を呼び出す際、クラス Dense 内の feature にあたるのが self.mlp_dim であり、クラス内に定義されたメソッドの引数が入力 \mathrm{x} なので、記載は、 y = nn.Dense(self.mlp_dim)(x) と()が2つになる。

MlpBlock内のGELUレイヤー

その \mathrm{y} をGELU関数で活性化するので、 y = nn.gelu(y) となる。ちなみにGELU関数は、RELU関数とよく似ていますが、こういう形状のグラフです。

GELU関数の公式ドキュメントはこちら。
https://flax.readthedocs.io/en/latest/_autosummary/flax.linen.gelu.html
https://flax.readthedocs.io/en/latest/_modules/jax/_src/nn/functions.html#gelu

MlpBlock内の2つ目のFully-connectedレイヤー

残すは2つ目のFully-connectedレイヤーです。

class MlpBlock(nn.Module):
  mlp_dim: int

  @nn.compact
  def __call__(self, x):
    y = nn.Dense(self.mlp_dim)(x)
    y = nn.gelu(y)
    return nn.Dense(x.shape[-1])(y)

2つ目のFully-connectedレイヤーの入力は \mathrm{y} です。この時、MLPブロックは元の入力の次元数と同じ次元数の配列(今回で言うと行ベクトル)を返す定義です。

Fully-connectedレイヤーにおける重み \mathrm{W_2} は、形状が 1 \times D_{mlp} である入力 \mathrm{y} に対して右からかけられます(Denseのソースのdot_generalのところ)。

入力 \mathrm{x} の形状は、1 \times S = (1, S) ですので、MLPブロックの出力を入力 \mathrm{x} の形状に戻すための次元数は S です。したがって、x.shape[-1] がfeatureとして使用されます。x.shape[-1] は入力 \mathrm{x} の行列の形状を配列型で返し、その要素の右から1つ目を取得したものです。

だとすると、論文内での数式における重みも右側に書いてほしい、と思うのですが、実際に論文著者のLucas Beyerに質問したところ、数式における左側・右側の表現は、慣習的なものであり、厳密ではないそうです。やや納得できなかったのですが、そうなのであれば、実装通りに右側に書いてほしいところです。

MixerBlockクラスの実装

次にMixerBlockの実装です。図の部分になります。

同様に、クラスの直下に変数 tokens_mlp_dimchannels_mlp_dim の型を int で宣言します。

class MixerBlock(nn.Module):
  """Mixer block layer."""
  tokens_mlp_dim: int
  channels_mlp_dim: int

  @nn.compact
  def __call__(self, x):
    y = nn.LayerNorm()(x)
    y = jnp.swapaxes(y, 1, 2)
    y = MlpBlock(self.tokens_mlp_dim, name='token_mixing')(y)
    y = jnp.swapaxes(y, 1, 2)
    x = x + y
    y = nn.LayerNorm()(x)
    return x + MlpBlock(self.channels_mlp_dim, name='channel_mixing')(y)

Mixer Block Layerの図のように、

  1. LayerNormで1回目の標準化
  2. 1回目の転置
  3. 1回目のMLPブロックによる変換
  4. 2回目の転置で元に戻す
  5. 1回目のスキップ結合
  6. LayerNormで2回目の標準化
  7. 2回目のMLPブロックによる変換と2回目のスキップ結合

で実装されています。

MixerBlock: 1回目の標準化

LayerNormは正規化と混同されやすいですが、実装を見ると正しくは平均が0、標準偏差が1になるようにスケーリングする標準化です。以下のLayerNormの公式ドキュメントにあるように、LayerNormクラス直下の変数はデフォルト通りなので、y = nn.LayerNorm()(x) と2つ()があり、1つ目は指定なし、2つ目は入力引数です。

class LayerNorm(Module):
  epsilon: float = 1e-6
  dtype: Any = jnp.float32
  use_bias: bool = True
  use_scale: bool = True
  bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.zeros
  scale_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.ones

  @compact
  def __call__(self, x):
    x = jnp.asarray(x, jnp.float32)
    features = x.shape[-1]
    mean = jnp.mean(x, axis=-1, keepdims=True)
    mean2 = jnp.mean(lax.square(x), axis=-1, keepdims=True)
    var = mean2 - lax.square(mean)
    mul = lax.rsqrt(var + self.epsilon)
    if self.use_scale:
      mul = mul * jnp.asarray(
          self.param('scale', self.scale_init, (features,)),
          self.dtype)
    y = (x - mean) * mul
    if self.use_bias:
      y = y + jnp.asarray(
          self.param('bias', self.bias_init, (features,)),
          self.dtype)
    return jnp.asarray(y, self.dtype)

flax.linen.LayerNormの公式ドキュメントはこちら。
https://flax.readthedocs.io/en/latest/_autosummary/flax.linen.LayerNorm.html
https://flax.readthedocs.io/en/latest/_autosummary/flax.linen.LayerNorm.html

MixerBlock: 1回目の転置

次に、y = jnp.swapaxes(y, 1, 2) で転置します。

class MixerBlock(nn.Module):
  """Mixer block layer."""
  tokens_mlp_dim: int
  channels_mlp_dim: int

  @nn.compact
  def __call__(self, x):
    y = nn.LayerNorm()(x)
    y = jnp.swapaxes(y, 1, 2)
    y = MlpBlock(self.tokens_mlp_dim, name='token_mixing')(y)
    y = jnp.swapaxes(y, 1, 2)
    x = x + y
    y = nn.LayerNorm()(x)
    return x + MlpBlock(self.channels_mlp_dim, name='channel_mixing')(y)

2番目と3番目の引数に与えられた軸を入れ替える、つまり転置します。2次元配列の場合は軸の数え方は縦が0、横が1なので、ここでは3次元配列上の縦と横、つまり S \times C のPatch軸とChannel軸を入れ替えます。軸の図のイメージはこちらの記事を参照。



https://deepage.net/features/numpy-axis.html

元々の論文ではこう書かれていた。

Figure 1 summarizes the architecture. Mixer takes as input a sequence of S non-overlapping image patches, each one projected to a desired hidden dimension C. This results in a two-dimensional real-valued input table, \mathrm{X} ∈ \mathbb{R}^{S×C}.

なので、入力 \mathrm{X} は2次元配列であり、実際、論文の図では0軸目が描かれていない、省略されているため、違和感はなかったが、実際のデータは3次元配列であることが実装上からわかった。

例えば、通常RGB画像のデータは、縦ピクセル×横ピクセル×色チャネル=縦×横×3の3次元配列である。これを S 個の小さなパッチに分解して、その分解された1つ1つのパッチを 1 \times C の行ベクトルに射影し、それを S 回分行方向にくっつけたものが今回の入力である。ここで C はチャネルの C だと思われるが、その値は3色の3ではなく、主に512である(512は慣習的に使用される値で、他にも実験的に値を変える)のは、論文に記載の通りです。なので、軸0はこのソースだけ見てもわからないが、データ数だと捉えて進みます。

1回目のMlpBlockによる変換

3つ目の変換 y = MlpBlock(self.tokens_mlp_dim, name='token_mixing')(y) にきました。

class MixerBlock(nn.Module):
  """Mixer block layer."""
  tokens_mlp_dim: int
  channels_mlp_dim: int

  @nn.compact
  def __call__(self, x):
    y = nn.LayerNorm()(x)
    y = jnp.swapaxes(y, 1, 2)
    y = MlpBlock(self.tokens_mlp_dim, name='token_mixing')(y)
    y = jnp.swapaxes(y, 1, 2)
    x = x + y
    y = nn.LayerNorm()(x)
    return x + MlpBlock(self.channels_mlp_dim, name='channel_mixing')(y)

ここで、さきほど定義したクラス MlpBlock 内で定義したメソッドを使います。MlpBlockクラスは flax.linen.Module を継承したクラスであり、Moduleでは name という属性値をセットできます。このnameをセットすることで、2回出てくるMlpBlockを区別します。同様に、元々 mlp_dim であった部分も、1回目用の tokens_mlp_dim に書き換えられています。

https://flax.readthedocs.io/en/latest/flax.linen.html

2回目の転置以降の処理

あとはこれまで出てきた内容の繰り返しになります。これで、MixerBlockが読み終わりました。

MlpMixerクラスの実装

最後のクラスがMlpMixerです。図のこの部分です。

ソースコードは下記のような感じです。

class MlpMixer(nn.Module):
  """Mixer architecture."""
  patches: Any
  num_classes: int
  num_blocks: int
  hidden_dim: int
  tokens_mlp_dim: int
  channels_mlp_dim: int

  @nn.compact
  def __call__(self, inputs, *, train):
    del train
    x = nn.Conv(self.hidden_dim, self.patches.size,
                strides=self.patches.size, name='stem')(inputs)
    x = einops.rearrange(x, 'n h w c -> n (h w) c')
    for _ in range(self.num_blocks):
      x = MixerBlock(self.tokens_mlp_dim, self.channels_mlp_dim)(x)
    x = nn.LayerNorm(name='pre_head_layer_norm')(x)
    x = jnp.mean(x, axis=1)
    return nn.Dense(self.num_classes, kernel_init=nn.initializers.zeros,
                    name='head')(x)

クラス内に変数のアノテーションが複数記載されました。ほとんどは int ですが、patchesだけ Any です。これは後で出てきますが、パッチサイズのことではなく、パッチサイズを含む変数のようです。

MlpMixerクラスのメソッド

def __call__(self, inputs, *, train): で定義されています。引数に * と train もありますが、これは論文に記載されたコードとは異なる点ですね。その後の del train 含めて意図がわからないので説明割愛します。

入力画像データをパッチに変換

x = nn.Conv(self.hidden_dim, self.patches.size, strides=self.patches.size, name='stem')(inputs) の部分です。入力画像データをパッチに変換しています。下図の部分の変換です。

これは畳み込み処理における、カーネル=フィルターサイズはパッチサイズ P、strides(どれだけカーネルをずらしていくか)もパッチサイズ P、という風にみなして、実装しています。上述のとおり、パッチサイズは、self.patches.size で取得してますね。name=stem は茎(くき)などの意味です。

flax.linen.Convの公式ドキュメントはこちら。
https://flax.readthedocs.io/en/latest/_autosummary/flax.linen.Conv.html
https://flax.readthedocs.io/en/latest/_modules/flax/linen/linear.html#Conv

パッチデータの次元を削減

x = einops.rearrange(x, 'n h w c -> n (h w) c') の部分です。図で言うとこちらです。

これは最初にインポートしたモジュール、einopsの中にあるrearrageというサブモジュールを使って、次元を減らします。減らすのは、\mathrm{h} \times \mathrm{w} = 縦×横の次元を1つの次元にまとめます。この実装によって、論文の図のように入力データは S \times C のテーブルデータになることが明確になりました。

einops.rearrageの公式ドキュメントはこちら。今回で言うと、()で囲まれた部分を掛け算することで次元が減った形状に変化します。
https://cgarciae.github.io/einops/api/rearrange/

N_x 回のMixer Block Layer

for _ in range(self.num_blocks): とそれに続く x = MixerBlock(self.tokens_mlp_dim, self.channels_mlp_dim)(x) のfor文です。Mixer Block Layerが N_x 回繰り返されることが図で示されているので、その部分の実装です。図だけではわかりづらいですが、単純に縦に繰り返されます。

また、この構文 for _ in range(self.num_blocks): 内の、「_」は、ij が使われることが多いですが、ループ内の処理にこの ij が出てこない場合は、アンスコで代替するのが慣習のようです。

https://blog.pyq.jp/entry/Python_kaiketsu_180420
https://qiita.com/jamjamjam/items/eba5096a201745740dfa

Global Average Poolingレイヤー

実装は、x = nn.LayerNorm(name='pre_head_layer_norm')(x)x = jnp.mean(x, axis=1) の部分です。図でいうとこの部分です。

x = jnp.mean(x, axis=1) では、1番目の軸方向の要素を単純平均して、次元を減らします。1番目の軸方向は、パッチ、つまり1つの画像に対する縦軸ですね。

最後のFully-Connectedレイヤー

return nn.Dense(self.num_classes, kernel_init=nn.initializers.zeros, name='head')(x) の部分です。図で言うと、この部分です。

この処理で最終的にClassを推定します。

おわり

「MLP-Mixer」を解説するシリーズすべて終了です。しかし書いてる間にもどんどん出てきますね。

感想や要望・指摘等は、お気軽に本記事へのコメントや、TwitterのリプライやDMに頂ければ幸いです。おしまい。

シリーズ関連記事はこちら
https://zenn.dev/attentionplease/articles/532a3de6308f57
https://zenn.dev/attentionplease/articles/7a11a56d767280
https://zenn.dev/attentionplease/articles/df6170f8581b71
https://zenn.dev/attentionplease/articles/7a3e74ad1bc9bf
https://zenn.dev/attentionplease/articles/a0d88939f9ceed
https://zenn.dev/attentionplease/articles/719580daf5a2d1

【2023年5月追記】
また、Slack版ChatGPT「Q」というサービスを開発・運営しています。
こちらもぜひお試しください。
https://q-bot.suchica.com/

Discussion