Transformerよりもシンプル?「MLP-Mixer」爆誕(6日目最終日) ~Source Code編~
ニツオです。TwitterでAIやMLについて関連する話題を紹介してます。お気軽にフォローやご質問ください。
さて、2021年5月にMLP-Mixerというモデルが爆誕しました。本日はその解説シリーズ6日目です。
- 1日目: Abstract / Introduction
- 2日目: Mixer Architecture
- 3日目: Experiments 1
- 4日目: Experiments 2
- 5日目: Related Work / Conclusion
- 6日目: Source Code
「MLP-Mixer: An all-MLP Architecture for Vision」の原文はこちらです。2021年5月4日にGoogle ResearchとGoogle Brainの混合チームから発表され、関係者のTwitterでもかなり話題になっています。
- 論文の要約はこちら
- 論文のPDFはこちら
- 論文のコードはこちら ※但し、2021年5月16日時点ではMasterブランチにはまだ反映されていませんので、LinenブランチのURLを貼ってます
シリーズ関連記事は一番下にリンク貼ってます。
早速みていきましょう。
Source Code
より抜粋。
Copyright
# Copyright 2021 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Library
まずは必要なライブラリのimportを行います。本編とはあまり関係ありませんが、簡単ですが省略せずに説明いれていきます。
from typing import Any
import einops
import flax.linen as nn
import jax.numpy as jnp
typing
まず、typingはPythonの標準ライブラリで、型のアノテーション(注釈)をつけることができます。今回はAnyしか使ってないようですが、このAnyはどんな型で来た場合も許容します。標準ライブラリなので下記パスにありました。
型についてはあまり立ち入らないので、もっと知りたい方は下記参照ください。
einops
次に、einopsは、Alex Rogozhnikovが開発した、Machine Learning向けの便利なライブラリです(後で1回だけ出てくる)。リンク先の動画がイメージつきやすいです。
事前にインストールしておく必要はあります。Google Colabで実行するとここに入ります。pipでインストールした格納先は特に気にすることなくimport出来ます。
jax and flax
jaxとflaxは、ともにDeep Learningのためのライブラリで、Googleが開発、オープンソース化したものです。詳しくは公式ドキュメントや参考記事をご覧ください。pipでインストール( pip install flax
)すると下記に入りました。
MlpBlockクラスの実装
MLPブロックのクラスのコードです。MLPブロックとは下記の図の黄色で囲った部分のことを指します。
class MlpBlock(nn.Module):
mlp_dim: int
@nn.compact
def __call__(self, x):
y = nn.Dense(self.mlp_dim)(x)
y = nn.gelu(y)
return nn.Dense(x.shape[-1])(y)
MLPブロックのクラスとして、class MlpBlock(nn.Module):
でMlpBlock()が定義されています。flax.linen.Moduleを継承しているので、()内に nn.Module
が入っています。この nn.Module
は、flaxの中でレイヤーやモデルを記述する際に継承するべきベースとなるクラス、と説明されています。今回登場する3つのクラス、すべてにおいて継承されています。
mlp_dim: int
の部分でmlp_dimという変数がint型で宣言されてます。これは下記のMLPブロックの図の黄色にハイライトした最初のFully-connectedレイヤーの重み mlp_dim
です。
この
MlpBlockクラスのメソッド
次に、このクラス=MLPブロックの変換関数を def __call__()
で定義していきます。変数は入力 def __call__(self, x):
とします。
また、@nn.compact
の@はデコレータで、関数 nn.compact
で直後の def __call__(self, x):
で定義したメソッドをラップしています。
flax.line.compactの公式ドキュメントはこちら。
class MlpBlock(nn.Module):
mlp_dim: int
@nn.compact
def __call__(self, x):
y = nn.Dense(self.mlp_dim)(x)
y = nn.gelu(y)
return nn.Dense(x.shape[-1])(y)
def __call__(self, x):
内では、MLPブロックの図のように3層の変換を表現している。Fully-connectedレイヤー、GELU関数による活性化、再度Fully-connectedレイヤーという形だ。
MlpBlock内の1つ目のFully-connectedレイヤー
まず入力 nn.Dense
は公式ドキュメントによると下記ソースである。
flax.linen.Denseの公式ドキュメントはこちら。
class Dense(Module):
features: int
use_bias: bool = True
dtype: Any = jnp.float32
precision: Any = None
kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init
bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = zeros
@compact
def __call__(self, inputs: Array) -> Array:
inputs = jnp.asarray(inputs, self.dtype)
kernel = self.param('kernel',
self.kernel_init,
(inputs.shape[-1], self.features))
kernel = jnp.asarray(kernel, self.dtype)
y = lax.dot_general(inputs, kernel,
(((inputs.ndim - 1,), (0,)), ((), ())),
precision=self.precision)
if self.use_bias:
bias = self.param('bias', self.bias_init, (self.features,))
bias = jnp.asarray(bias, self.dtype)
y = y + bias
return y
DenseはFully-connectedレイヤーの変換を表現する関数で、nn.Dense
を呼び出す際、クラス Dense
内の feature
にあたるのが self.mlp_dim
であり、クラス内に定義されたメソッドの引数が入力 y = nn.Dense(self.mlp_dim)(x)
と()が2つになる。
MlpBlock内のGELUレイヤー
その y = nn.gelu(y)
となる。ちなみにGELU関数は、RELU関数とよく似ていますが、こういう形状のグラフです。
GELU関数の公式ドキュメントはこちら。
MlpBlock内の2つ目のFully-connectedレイヤー
残すは2つ目のFully-connectedレイヤーです。
class MlpBlock(nn.Module):
mlp_dim: int
@nn.compact
def __call__(self, x):
y = nn.Dense(self.mlp_dim)(x)
y = nn.gelu(y)
return nn.Dense(x.shape[-1])(y)
2つ目のFully-connectedレイヤーの入力は
Fully-connectedレイヤーにおける重み
入力 x.shape[-1]
がfeatureとして使用されます。x.shape[-1]
は入力
だとすると、論文内での数式における重みも右側に書いてほしい、と思うのですが、実際に論文著者のLucas Beyerに質問したところ、数式における左側・右側の表現は、慣習的なものであり、厳密ではないそうです。やや納得できなかったのですが、そうなのであれば、実装通りに右側に書いてほしいところです。
MixerBlockクラスの実装
次にMixerBlockの実装です。図の部分になります。
同様に、クラスの直下に変数 tokens_mlp_dim
と channels_mlp_dim
の型を int
で宣言します。
class MixerBlock(nn.Module):
"""Mixer block layer."""
tokens_mlp_dim: int
channels_mlp_dim: int
@nn.compact
def __call__(self, x):
y = nn.LayerNorm()(x)
y = jnp.swapaxes(y, 1, 2)
y = MlpBlock(self.tokens_mlp_dim, name='token_mixing')(y)
y = jnp.swapaxes(y, 1, 2)
x = x + y
y = nn.LayerNorm()(x)
return x + MlpBlock(self.channels_mlp_dim, name='channel_mixing')(y)
Mixer Block Layerの図のように、
- LayerNormで1回目の標準化
- 1回目の転置
- 1回目のMLPブロックによる変換
- 2回目の転置で元に戻す
- 1回目のスキップ結合
- LayerNormで2回目の標準化
- 2回目のMLPブロックによる変換と2回目のスキップ結合
で実装されています。
MixerBlock: 1回目の標準化
LayerNormは正規化と混同されやすいですが、実装を見ると正しくは平均が0、標準偏差が1になるようにスケーリングする標準化です。以下のLayerNormの公式ドキュメントにあるように、LayerNormクラス直下の変数はデフォルト通りなので、y = nn.LayerNorm()(x)
と2つ()があり、1つ目は指定なし、2つ目は入力引数です。
class LayerNorm(Module):
epsilon: float = 1e-6
dtype: Any = jnp.float32
use_bias: bool = True
use_scale: bool = True
bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.zeros
scale_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.ones
@compact
def __call__(self, x):
x = jnp.asarray(x, jnp.float32)
features = x.shape[-1]
mean = jnp.mean(x, axis=-1, keepdims=True)
mean2 = jnp.mean(lax.square(x), axis=-1, keepdims=True)
var = mean2 - lax.square(mean)
mul = lax.rsqrt(var + self.epsilon)
if self.use_scale:
mul = mul * jnp.asarray(
self.param('scale', self.scale_init, (features,)),
self.dtype)
y = (x - mean) * mul
if self.use_bias:
y = y + jnp.asarray(
self.param('bias', self.bias_init, (features,)),
self.dtype)
return jnp.asarray(y, self.dtype)
flax.linen.LayerNormの公式ドキュメントはこちら。
MixerBlock: 1回目の転置
次に、y = jnp.swapaxes(y, 1, 2)
で転置します。
class MixerBlock(nn.Module):
"""Mixer block layer."""
tokens_mlp_dim: int
channels_mlp_dim: int
@nn.compact
def __call__(self, x):
y = nn.LayerNorm()(x)
y = jnp.swapaxes(y, 1, 2)
y = MlpBlock(self.tokens_mlp_dim, name='token_mixing')(y)
y = jnp.swapaxes(y, 1, 2)
x = x + y
y = nn.LayerNorm()(x)
return x + MlpBlock(self.channels_mlp_dim, name='channel_mixing')(y)
2番目と3番目の引数に与えられた軸を入れ替える、つまり転置します。2次元配列の場合は軸の数え方は縦が0、横が1なので、ここでは3次元配列上の縦と横、つまり
元々の論文ではこう書かれていた。
Figure 1 summarizes the architecture. Mixer takes as input a sequence of
non-overlapping image patches, each one projected to a desired hidden dimension S . This results in a two-dimensional real-valued input table, C . \mathrm{X} ∈ \mathbb{R}^{S×C}
なので、入力
例えば、通常RGB画像のデータは、縦ピクセル×横ピクセル×色チャネル=縦×横×3の3次元配列である。これを
1回目のMlpBlockによる変換
3つ目の変換 y = MlpBlock(self.tokens_mlp_dim, name='token_mixing')(y)
にきました。
class MixerBlock(nn.Module):
"""Mixer block layer."""
tokens_mlp_dim: int
channels_mlp_dim: int
@nn.compact
def __call__(self, x):
y = nn.LayerNorm()(x)
y = jnp.swapaxes(y, 1, 2)
y = MlpBlock(self.tokens_mlp_dim, name='token_mixing')(y)
y = jnp.swapaxes(y, 1, 2)
x = x + y
y = nn.LayerNorm()(x)
return x + MlpBlock(self.channels_mlp_dim, name='channel_mixing')(y)
ここで、さきほど定義したクラス MlpBlock 内で定義したメソッドを使います。MlpBlockクラスは flax.linen.Module
を継承したクラスであり、Moduleでは name
という属性値をセットできます。このnameをセットすることで、2回出てくるMlpBlockを区別します。同様に、元々 mlp_dim
であった部分も、1回目用の tokens_mlp_dim
に書き換えられています。
2回目の転置以降の処理
あとはこれまで出てきた内容の繰り返しになります。これで、MixerBlockが読み終わりました。
MlpMixerクラスの実装
最後のクラスがMlpMixerです。図のこの部分です。
ソースコードは下記のような感じです。
class MlpMixer(nn.Module):
"""Mixer architecture."""
patches: Any
num_classes: int
num_blocks: int
hidden_dim: int
tokens_mlp_dim: int
channels_mlp_dim: int
@nn.compact
def __call__(self, inputs, *, train):
del train
x = nn.Conv(self.hidden_dim, self.patches.size,
strides=self.patches.size, name='stem')(inputs)
x = einops.rearrange(x, 'n h w c -> n (h w) c')
for _ in range(self.num_blocks):
x = MixerBlock(self.tokens_mlp_dim, self.channels_mlp_dim)(x)
x = nn.LayerNorm(name='pre_head_layer_norm')(x)
x = jnp.mean(x, axis=1)
return nn.Dense(self.num_classes, kernel_init=nn.initializers.zeros,
name='head')(x)
クラス内に変数のアノテーションが複数記載されました。ほとんどは int
ですが、patchesだけ Any
です。これは後で出てきますが、パッチサイズのことではなく、パッチサイズを含む変数のようです。
MlpMixerクラスのメソッド
def __call__(self, inputs, *, train):
で定義されています。引数に * と train もありますが、これは論文に記載されたコードとは異なる点ですね。その後の del train
含めて意図がわからないので説明割愛します。
入力画像データをパッチに変換
x = nn.Conv(self.hidden_dim, self.patches.size, strides=self.patches.size, name='stem')(inputs)
の部分です。入力画像データをパッチに変換しています。下図の部分の変換です。
これは畳み込み処理における、カーネル=フィルターサイズはパッチサイズ self.patches.size
で取得してますね。name=stem
は茎(くき)などの意味です。
flax.linen.Convの公式ドキュメントはこちら。
パッチデータの次元を削減
x = einops.rearrange(x, 'n h w c -> n (h w) c')
の部分です。図で言うとこちらです。
これは最初にインポートしたモジュール、einopsの中にあるrearrageというサブモジュールを使って、次元を減らします。減らすのは、
einops.rearrageの公式ドキュメントはこちら。今回で言うと、()で囲まれた部分を掛け算することで次元が減った形状に変化します。
N_x 回のMixer Block Layer
for _ in range(self.num_blocks):
とそれに続く x = MixerBlock(self.tokens_mlp_dim, self.channels_mlp_dim)(x)
のfor文です。Mixer Block Layerが
また、この構文 for _ in range(self.num_blocks):
内の、「_」は、
Global Average Poolingレイヤー
実装は、x = nn.LayerNorm(name='pre_head_layer_norm')(x)
と x = jnp.mean(x, axis=1)
の部分です。図でいうとこの部分です。
x = jnp.mean(x, axis=1)
では、1番目の軸方向の要素を単純平均して、次元を減らします。1番目の軸方向は、パッチ、つまり1つの画像に対する縦軸ですね。
最後のFully-Connectedレイヤー
return nn.Dense(self.num_classes, kernel_init=nn.initializers.zeros, name='head')(x)
の部分です。図で言うと、この部分です。
この処理で最終的にClassを推定します。
おわり
「MLP-Mixer」を解説するシリーズすべて終了です。しかし書いてる間にもどんどん出てきますね。
感想や要望・指摘等は、お気軽に本記事へのコメントや、TwitterのリプライやDMに頂ければ幸いです。おしまい。
シリーズ関連記事はこちら
【2023年5月追記】
また、Slack版ChatGPT「Q」というサービスを開発・運営しています。
こちらもぜひお試しください。
Discussion