【論文読解めも】Reformer: The Efficient Transformer
Transformer, BERTあたりまではちゃんと追っていたけれど、しばらくNLP分野は放置していたのでキャッチアップのためにいくつか論文読んでみる。
手始めにReformer。
論文を読む前に、日本語解説記事を眺める。
- https://www.infoq.com/jp/news/2020/04/google-reformer-deep-learning/
- https://recruit.gmo.jp/engineer/jisedai/blog/reformer/
以下の2つがメモリを大幅に削減できる工夫っぽい。
- Local-Sensitive-Hashing (LSH)
- Reversible Residual layers
Reformerで導入しているのは、以下の3つ。
- Reversible layersによって、レイヤー数
に比例して増えるactivationを保存しておくためのメモリを削減。N - Feed Forward部分の次元数
はAttentionのactivationの次元数d_{ff} よりもずっと大きいのでメモリを食うが、Feed Forward内のactivationを位置に基づくチャンクに分割して処理するとd_{model} によるメモリ増加を防げる。d_{ff} - LSHによって長い入力シーケンスをクラス分けし、シーケンス長
に対して必要な計算量を、L から\mathcal{O}(L^2) まで削減。\mathcal{O}(L \log L)
図はTransformer論文より。
Attention周りの計算量
通常のDot-product Attentionは以下のように書くことができる。
これを実現するためにLSHを用いたLSH Attentionを提案している。
通常、
実際には、
ハッシュ関数
LSH Attention
通常のAttentionを以下のように書き換える。
さらに、
LSH Attentionの目的は、Attentionを適用する集合
さらに、
この流れは、図にするとスッキリとわかる。
- 1-2段目:まず、入力されたシーケンス
は、ハッシュ関数によってバケットに割り当てられる。q_i=k_i - 2-3段目:バケット順ごとにソートし、バケット内も元のインデックス順にソートされている。
- 3-4段目:サイズ
のチャンクを作成する。図ではm 。m=4 - 4-5段目:Attentionは、元々の位置より手前側(マスク関数
)、同一のバケット内(m(j, \mathcal{P}_i) )、1つ前までのチャンク(\mathcal{P}_i )という3つの条件を満たす場合のみ適用される。\widetilde{\mathcal{P}}_{i}
ハッシュ関数は1つではなく、複数使用する。たった1つのハッシュ関数では拾いきれないqueryとkeyの類似性を抽出することができる。この場合の
Decoderでは、未生成の部分にAttentionを当てるわけには行けないので、マスク関数
Reversible Residual Network
通常、勾配の計算のために、各レイヤーのactivationは保存しておく必要があり、それがメモリを圧迫する。Reversible Residual Networkは、出力のactivationを保存しておけば、入力のactivationを簡単な計算で復元できる、ResNetに類似した構造である。
Gomez, Aidan N., et al. "The reversible residual network: Backpropagation without storing activations." Advances in neural information processing systems. 2017.
Reversible Transformer
Reformerでは、このような構造を、以下の図におけるMulti-head AttentionとFeedForwardの組に対して適用する(ただし、「Add + Norm」のNormは取り除くする必要がある)。これによって、メモリを節約する。
🤔Decoder側での実装はどうなるのだろうか?
Feed Forwardのチャンク化
Feed Forward層の次元数は
レイヤー数に比例して増えるパラメータ数への対処
Reversible Transformerとチャンク化によって、メモリ使用量はレイヤー数によって増加しないが、パラメータそのものは当然増えるのでそれによってメモリが圧迫される。
Reformerではあるレイヤーの計算中には、別のレイヤーのパラメータをCPUに退避させることでこの問題を回避している。この方法は、通常のTransformerでは退避に使用する時間が問題になるが、Reformerでは入力を長いシーケンスにできるため、あるレイヤーの計算中に別のレイヤーのパラメータを退避させても時間的な余裕がある。
だいたい把握したので、おしまい。