Closed11

【論文読解めも】Reformer: The Efficient Transformer

takoroytakoroy

Transformer, BERTあたりまではちゃんと追っていたけれど、しばらくNLP分野は放置していたのでキャッチアップのためにいくつか論文読んでみる。

手始めにReformer。

https://arxiv.org/abs/2001.04451

takoroytakoroy

Reformerで導入しているのは、以下の3つ。

  • Reversible layersによって、レイヤー数Nに比例して増えるactivationを保存しておくためのメモリを削減。
  • Feed Forward部分の次元数d_{ff}はAttentionのactivationの次元数d_{model}よりもずっと大きいのでメモリを食うが、Feed Forward内のactivationを位置に基づくチャンクに分割して処理するとd_{ff}によるメモリ増加を防げる。
  • LSHによって長い入力シーケンスをクラス分けし、シーケンス長Lに対して必要な計算量を、\mathcal{O}(L^2) から \mathcal{O}(L \log L) まで削減。

図はTransformer論文より。

takoroytakoroy

Attention周りの計算量

通常のDot-product Attentionは以下のように書くことができる。

\operatorname{Attention}(Q, K, V)=\operatorname{softmax}\left(\frac{Q K^{T}}{\sqrt{d_{k}}}\right) V

Q, K , Vはそれぞれ (batchsize, length, d_{model}) というサイズ。実験では、length=64kという超長いシーケンスを使う。このとき問題になるのは、Q K^{T}のサイズが(batchsize, length, length)と、非常に大きくなること。これは、小説の最後の1文中のtokenに対するattentionを小説全文から計算するようなもので、無駄が多い。

QK^Tの目的は、全てのKがなくても、i=(1, 2, ..., length)番目に対応する各q_iに近いKのサブセット(32~64)のみでも達成できる。

\operatorname{softmax}\left(\frac{q_{i} K^{T}}{\sqrt{d_{k}}}\right) V

これを実現するためにLSHを用いたLSH Attentionを提案している。

takoroytakoroy

通常、Q, K, Vは全て同一のTensorA(batchsize, length, d_{model})から、異なるパラメータを持つ線形変換によって得られる。これを簡略化したバージョンとして、Q=Kとなるshared-QK Transformerを提案している。このような簡略化によっても、性能への悪影響はほとんどないことが実験で示されている。

実際には、k_{j}=\frac{q_{j}}{\left\|q_{j}\right\|}とするみたい。

takoroytakoroy

ハッシュ関数h(x)は、keyの次元数d_kからb次元のハッシュへと変換される。ランダムな行列R \in \mathbb{R}^{d_k \times b/2}を用いて、ハッシュ関数はh(x)=\arg \max ([x R ;-x R])と定義できる。[u; v]は2つのベクトルの結合を表す。このようなハッシュの作成方法をLSH schemeという。

takoroytakoroy

LSH Attention

通常のAttentionを以下のように書き換える。\mathcal{P}_{i}は、インデックスiがAttentionをかけうるインデックスの集合で、zは分配関数に相当する。簡単のためにスケーリングを表す\sqrt{d_k}は省略している。

o_{i}=\sum_{j \in \mathcal{P}_{i}} \exp \left(q_{i} \cdot k_{j}-z\left(i, \mathcal{P}_{i}\right)\right) v_{j} \quad \text { where } \mathcal{P}_{i}=\{j: i \geq j\}

さらに、\widetilde{\mathcal{P}}_{i}=\{0,1, \ldots, l\} \supseteq \mathcal{P}_{i}とマスク関数m(j, \mathcal{P}_i)を用いて、以下のように書き直せる。マスク関数は、Decoderの場合は必須だが、Encoder側では使用しない、はず。

o_{i}=\sum_{j \in \widetilde{\mathcal{P}}_{i}} \exp \left(q_{i} \cdot k_{j}-m\left(j, \mathcal{P}_{i}\right)-z\left(i, \mathcal{P}_{i}\right)\right) v_{j} \quad \text { where } m\left(j, \mathcal{P}_{i}\right)=\left\{\begin{array}{ll} \infty & \text { if } j \notin \mathcal{P}_{i} \\ 0 & \text { otherwise } \end{array}\right.

LSH Attentionの目的は、Attentionを適用する集合\mathcal{P}_iを別の形にすることで、それはハッシュ関数h(x)を用いて以下のように書くことができる。これにより、クエリq_iとキーk_jが同一のハッシュバケットに属するときのみAttentionの対象とすることになる。

\mathcal{P}_{i}=\left\{j: h\left(q_{i}\right)=h\left(k_{j}\right)\right\}

さらに、\widetilde{\mathcal{P}}_{i}のほうも冗長なので、より本質的な部分集合に絞る。まず、バケット番号順・index順にソートし、i \mapsto s_{i}という並べ替えを得る。この、s_iを用いて、\widetilde{\mathcal{P}}_{i}を以下のように書く。なお、mはチャンクサイズを表すハイパーパラメータで、論文ではm=\frac{2 l}{n_{\text {buckets }}}としている。これによって、同一のチャンク内または一つ前のチャンク内のインデックスのみがAttentionの対象となる。チャンクによって、Attentionがあまりにも遠い位置に当たらないようにする効果が生まれる。

\widetilde{\mathcal{P}}_{i}=\left\{j:\left\lfloor\frac{s_{i}}{m}\right\rfloor-1 \leq\left\lfloor\frac{s_{j}}{m}\right\rfloor \leq\left\lfloor\frac{s_{i}}{m}\right\rfloor\right\}

この流れは、図にするとスッキリとわかる。

  • 1-2段目:まず、入力されたシーケンスq_i=k_iは、ハッシュ関数によってバケットに割り当てられる。
  • 2-3段目:バケット順ごとにソートし、バケット内も元のインデックス順にソートされている。
  • 3-4段目:サイズmのチャンクを作成する。図ではm=4
  • 4-5段目:Attentionは、元々の位置より手前側(マスク関数m(j, \mathcal{P}_i))、同一のバケット内(\mathcal{P}_i)、1つ前までのチャンク(\widetilde{\mathcal{P}}_{i})という3つの条件を満たす場合のみ適用される。

ハッシュ関数は1つではなく、複数使用する。たった1つのハッシュ関数では拾いきれないqueryとkeyの類似性を抽出することができる。この場合の\mathcal{P}_{i}、\widetilde{\mathcal{P}}_{i}を用いた計算は、付録に記載されている。

\mathcal{P}_{i}=\bigcup_{r=1}^{n_{\text {rounds }}} \mathcal{P}_{i}^{(r)} \quad \text { where } \mathcal{P}_{i}^{(r)}=\left\{j: h^{(r)}\left(q_{i}\right)=h^{(r)}\left(q_{j}\right)\right\}

Decoderでは、未生成の部分にAttentionを当てるわけには行けないので、マスク関数m(j, \mathcal{P}_i)が必須。なお、q_i=k_iとしたことにより、自分自身の位置に対するAttentionが極端に大きくなってしまう可能性が高い。これを避けるため、シーケンスの最初の位置以外は、自分自身を含めない形で、自分自身よりも前の位置にのみAttentionを許可するようにマスク関数m(j, \mathcal{P}_i)を修正する。

takoroytakoroy

Reversible Residual Network

通常、勾配の計算のために、各レイヤーのactivationは保存しておく必要があり、それがメモリを圧迫する。Reversible Residual Networkは、出力のactivationを保存しておけば、入力のactivationを簡単な計算で復元できる、ResNetに類似した構造である。

Gomez, Aidan N., et al. "The reversible residual network: Backpropagation without storing activations." Advances in neural information processing systems. 2017.

\begin{array}{l} y_{1}=x_{1}+\mathcal{F}\left(x_{2}\right) \\ y_{2}=x_{2}+\mathcal{G}\left(y_{1}\right) \end{array}

\begin{array}{l} x_{2}=y_{2}-\mathcal{G}\left(y_{1}\right) \\ x_{1}=y_{1}-\mathcal{F}\left(x_{2}\right) \end{array}

Reversible Transformer

Reformerでは、このような構造を、以下の図におけるMulti-head AttentionとFeedForwardの組に対して適用する(ただし、「Add + Norm」のNormは取り除くする必要がある)。これによって、メモリを節約する。

Y_{1}=X_{1}+\text { Attention }\left(X_{2}\right) \quad Y_{2}=X_{2}+\text { FeedForward }\left(Y_{1}\right)

🤔Decoder側での実装はどうなるのだろうか?

takoroytakoroy

Feed Forwardのチャンク化

Feed Forward層の次元数はd_{f f}=4 Kとなり、大きい。Feed Forwardは、位置に依存せずに適用されるので、位置に応じてc個のチャンクに分割して処理することができる。バッチをチャンクごとに構成してあげれば、メモリ消費量を削減できる。また、このテクニックは、語彙数が多い場合は、モデルの出力であるlog probabilisticに対しても適用できる。

Y_{2}=\left[Y_{2}^{(1)} ; \ldots ; Y_{2}^{(c)}\right]=\left[X_{2}^{(1)}+\text { FeedForward }\left(Y_{1}^{(1)}\right) ; \ldots ; X_{2}^{(c)}+\text { FeedForward }\left(Y_{1}^{(c)}\right)\right]
takoroytakoroy

レイヤー数に比例して増えるパラメータ数への対処

Reversible Transformerとチャンク化によって、メモリ使用量はレイヤー数によって増加しないが、パラメータそのものは当然増えるのでそれによってメモリが圧迫される。

Reformerではあるレイヤーの計算中には、別のレイヤーのパラメータをCPUに退避させることでこの問題を回避している。この方法は、通常のTransformerでは退避に使用する時間が問題になるが、Reformerでは入力を長いシーケンスにできるため、あるレイヤーの計算中に別のレイヤーのパラメータを退避させても時間的な余裕がある。

このスクラップは2021/01/14にクローズされました