🤖

FNet: Mixing Tokens with Fourier Transformsを軽く紹介

1 min read

はじめに

結構面白そうな論文があったので紹介します.
その名も「FNet」!!
タイトルにもあるように「F」はFourierの「F」です.
簡単に言えば,SelfAttentionを2DのFourier Transformで置き換えてもそれほど性能を落とさずにパラメータ数及び学習速度を改善できるという大胆な論文です.
本記事では,この「FNet」を軽く紹介します.

https://arxiv.org/abs/2105.03824

本記事の画像はすべて論文より引用しています.

モチベーション

NLPを筆頭に,最近ではCVでもTransformerが用いられています.
しかし,通常のSelf-Attentionの計算量は2乗オーダーであり,メモリも多く使います.
そこで,最近の研究ではその計算量やメモリ消費量を抑える手法が数多く提案されてきました.
著者らも改善するために,sequenceの次元とhidden channelの次元,それぞれの方向での線形変換で近似することを考えたようですが,その流れでそれぞれの方向にFourier変換を適用する事により学習するパラメータがないにも関わらず,タスクによってはほぼ同等のスコアを出せることを見つけました.

FNetアーキテクチャ

非常にシンプルで,以下の図のように表されます.

ただ,BERTにおけるSelf-AttentionをFourierで置き換えるだけです.

なお,Fourier Layerは出力をyとすると以下の式で表されます.

y = \mathfrak{R}(\mathcal{F}_{seq}(\mathcal{F}_{h}(x)))

ここで,\mathcal{F}_{h}(\cdot)はhidden channel方向のFourier変換を表し,\mathcal{F}_{seq}(\cdot)はsequence方向のFourier変換を表します.
また,\mathfrak{R}(\cdot)は複素数の入力に対して実部を返す操作です.
つまり,入力の次元を(Batch, Seq, Hidden)とすればSeqとHiddenに対して2DのFourier変換を行っているわけです.

また,Fourier変換の双対性より,層を重ねることでTimeとFrequencyを交互に行き来して処理を行っているとも見ることができます.
このことから,FeedForward LayerはFrequencyでは畳み込みと見ることができます.

実験結果

実験結果として,GLUEでの結果を示します.

平均のスコアは流石にBERTには及びませんが,パラメータが無い割にはそこそこの結果が出ています.

次に,操作の計算量(左)とバッチサイズ64にした場合の事前学習で1stepにかかる時間(ms)です.

通常のBERTに比べるとFNetは1.8倍の学習速度向上が見られます.

感想

本記事では,FNetについて軽く紹介しました.
Self-Attentionは結構メモリを食うので,学習パラメータがいらないFourier変換で置き換えられれば,普通だったらOOMになってしまう場合でも動かせるし,層も大幅に増やせそうだと思いました.
この論文ではNLPでの実験でしたが,他の分野ではどのような結果が出るのか気になりますね.
実装も非常に簡単なので気になった方はぜひ試してみてください.

Discussion

ログインするとコメントできます