🎃

[paper-reading] Reformer: The Efficient Transformer

2023/09/09に公開

論文リンク

Reformer: The Efficient Transformer

概要

Transformerのcomputing/memory costを簡略化した改良アーキテクチャを提案する。

特徴

  1. LSH(Locality Sensitive Hash) attention
  2. reversible transformer block

LSH attention

計算機コストとメモリコストがトークン長の2乗に比例することがlanguage modelのコンテキスト長を拡大することのボトルネックになっている。提案手法の基本的なコンセプトは「重要でないアテンションの計算を省略することで計算量を削減する」というものである。この要件をナイーブに実現しようとするとembeddingのMIPS(Maximum Inner Product Search)のようなことが必要になるが、提案手法では「embedding空間内で近い位置にあるembeddingに高い確率で同じ値を割り当てる」という性質を持った特殊なハッシュ関数を使ってこの要件を実現する。提案手法により、計算量とメモリ消費のコストは O(N^2) から O(N\log N) に削減される。

なお、画像タスクにおいてTransformerのself-attentionの近似計算によって計算量を削減しようとする試みはこれまでにいくつか行われてきた。代表的なものはShifted-window[1], deformable-attention[2], spatial-reduction attention[3]といったように、データの幾何学的な連結性(=ピクセルの隣接関係)に着目してattentionを作用させる範囲を限定するものだった。これに対し、提案手法はこのようなデータの構造に由来する近接性の情報を利用せずに、embedding空間内の距離のみを用いてattentionを作用させる範囲を限定するという違いがある。

Reversible Transformer

GPU, TPUなどのアクセラレータのメモリ消費を削減する方法として、提案手法ではTransformerブロックに「k層目のレイヤー入力を出力から逆算できる」という性質を課す。これにより、forward/backwardの計算グラフで常に1層分のアクティベーションを保持するだけで良くなる。

提案手法はPyTorchが採用しているようなナイーブなgradient check-pointing(=中間のアクティベーションは一切保持しない)とメモリ削減効果についてはほぼ同じである。一方、計算コストについては、backward pathの計算において、PyTorchのcheck-pointingがk層目の入力のアクティベーションの復元にk-1層分のforwardパスを計算するのに対し、提案手法では出力のactivationから1層分のforwardパスを計算するだけで良いので、提案手法の方が計算効率が良いことが期待できる。

なお、提案手法はネットワークのアークテクチャに可逆的な制約を課すという点ではFlowベースネットワークの一種とも解釈できる。

主要な結果

  • 64Kのコンテクスト長の入力を16GBのメモリ内に収めることができた
  • トークン長の増加に対して計算量があまり増えないモデルを実現した
  • LSH attentionはTransformerのfull-attentionの近似計算を与える。これにより、精度を犠牲にして計算量予算の範囲内で学習できるようになった。

論文内でやれてないこと

  • 著者らは「Transformerとほぼ同等の性能を維持したまま計算量を減らせた」かのように受け取れる主張をしているが、LSH attentionとfull-attentionのBLEUスコアによる比較は掲載されていない。また、ablationの比較結果は3-layerの浅いネットワークを用いた比較のみである。
    したがって、本論文の結果から提案手法が「Transformerの上位互換」と受け取るのは楽観的すぎると言える。両者の性能をきちんと評価するには十分なレイヤー数のネットワークでBLEUスコアで比較するのが妥当だろう。

参考文献

GitHubで編集を提案

Discussion