📝

BP-Transformer: 二分木構造を用いた効率的なAttention

に公開

はじめに

この記事は論文 [1]"BP-Transformer: Modelling Long-Range Context via Binary Partitioning" の解説です。要旨を一言で言うと、Transformer の自己注意機構において、入力系列を二分木構造に基づいた階層的スパン(multi-scale spans)に分割することで、計算量を O(n^2) から O(k \cdot n\log(n/k)) に削減しつつ、長距離依存関係のモデリング性能を維持または向上させうる、という手法になります。

論文内の手法の解説

本論文の手法は、図 3 に要約されます。以下ではこれの図の意味を説明します。

考える設定と既存手法の問題点

Transformer は自然言語処理の多くのタスクで優れた性能を示していますが、自己注意機構(self-attention)の計算量が入力系列長 n に対して O(n^2) であるため、長いテキストへの適用が困難という問題があります。

従来の Transformer の計算時間、パラメータ数

長さ n の入力系列に対して、通常の Transformer の自己注意機構は全てのトークンペアについて注意重みを計算するため、計算量は O(d \cdot n^2) となります(d は隠れ層の次元数)。また、メモリ使用量も O(n^2) となります。

本論文の手法:BP-Transformer

BP-Transformer (BPT) は、グラフニューラルネットワークの観点から Transformer を再定式化し、効率的なグラフ構造を構築することで計算量を削減します。

Transformer をグラフニューラルネットワークとして捉える

まず、Transformer の自己注意機構をグラフニューラルネットワーク (GNN) として解釈します。入力トークンをノード、注意機構による情報伝達をエッジとみなすと、通常の Transformer は完全連結グラフ上でのメッセージパッシングと考えられます。

グラフ \mathcal{G} 上で、ノード u の表現を近傍ノードの情報を集約して更新する操作を Graph Self-Attention (GSA) と呼びます。ノード u の近傍ノード集合を \mathcal{A}(u) とすると、GSA は以下のように定式化されます。

注記(表記について): 本稿では実装や一般的な表記に合わせ、モデル次元を d(= d_{model})、各ヘッドの出力次元を d_k と表記します。つまりトークン次元は d、ヘッド数を h とすると d_{model}=d(しばしば d_{model} と書くこともあります)で、マルチヘッド結合後の次元は h d_k になります。論文中では一部で表記が異なる箇所(例えば d をヘッド次元として使っている式)が見られますが、本稿では上記の方針に統一して説明します。

\begin{gathered} \mathbf{A}^u = \text{concat}(\{\mathbf{h}_v \mid v \in \mathcal{A}(u)\}), \\ \mathbf{Q}^u_i = \mathbf{h}_u \mathbf{W}_i^Q,\quad \mathbf{K}^u_i = \mathbf{A}^u \mathbf{W}_i^K,\quad \mathbf{V}^u_i = \mathbf{A}^u \mathbf{W}_i^V, \\ \mathrm{head}_i^u = \text{softmax}\left(\dfrac{\mathbf{Q}^u_i {\mathbf{K}^u_i}^T}{\sqrt{d_k}}\right) \mathbf{V}^u_i, \\ \text{GSA}(\mathcal{G}, \mathbf{h}_u) = [\mathrm{head}^u_1,\cdots,\mathrm{head}^u_h] \mathbf{W}^O \end{gathered}
  • \mathcal{G}: グラフ(ノードと有向エッジの集合)。ノードにはトークンノードとスパンノードが含まれる。
  • u, v: グラフのノード。通常はノード u の更新を考える時、v\in\mathcal{A}(u) は近傍ノードを意味する。
  • \mathcal{A}(u): ノード u の近傍集合(incoming predecessors)。GSA ではこの集合だけに attention を向ける。
  • \mathbf{h}_u,\,\mathbf{h}_v\in\mathbb{R}^{d}: ノード u,v の表現ベクトル(層入力)。ここで d はモデル(トークン)次元。ト
  • \mathbf{A}^u: 近傍ノードの表現を縦に並べた行列(形状: |\mathcal{A}(u)|\times d)。
  • h: attention ヘッドの数(head 数)。式中の添字 i はヘッドを表す(i=1,\dots,h)。
  • d_k: 各ヘッドの出力次元(per-head dimension)。
  • \mathbf{W}^Q_i,\mathbf{W}^K_i,\mathbf{W}^V_i \in\mathbb{R}^{d\times d_k}: ヘッド i の線形投影行列(入力はモデル次元 d、出力はヘッド次元 d_k)。
  • \mathbf{Q}^u_i=\mathbf{h}_u\mathbf{W}^Q_i\in\mathbb{R}^{d_k}: ノード u のヘッド i によるクエリベクトル。
  • \mathbf{K}^u_i=\mathbf{A}^u\mathbf{W}^K_i\in\mathbb{R}^{|\mathcal{A}(u)|\times d_k}: 近傍ノード群のヘッド i によるキー行列(行ごとに各近傍ノードのキー)。
  • \mathbf{V}^u_i=\mathbf{A}^u\mathbf{W}^V_i\in\mathbb{R}^{|\mathcal{A}(u)|\times d_k}: 近傍ノード群のヘッド i によるバリュー行列。
  • \mathrm{head}^u_i=\mathrm{softmax}\left(\dfrac{\mathbf{Q}^u_i(\mathbf{K}^u_i)^T}{\sqrt{d_k}}\right)\mathbf{V}^u_i\in\mathbb{R}^{d_k}: ヘッド i の出力(|\mathcal{A}(u)| 個の注意重みを用いて加重平均したベクトル)。
  • \mathbf{W}^O\in\mathbb{R}^{h d_k\times d}: 複数ヘッドを結合した後に使う線形射影行列(concat(heads) をモデル次元 d に戻す)。

通常の Transformer では、全てのトークンが互いに接続された完全グラフを使用しています。つまり、通常の Transformer では、A(u) が全てのトークンであり、|A(u)| = n となります(n は入力トークン長)。また、全入力トークンに対しての attention vector を計算するので、上式での u の範囲は 1\leq i \leq n です。

BPT の核心は、この完全グラフをより疎なグラフ構造に置き換えることです。具体的には、適当な定数kに対して、典型的には|A(u)| = O(k \log(n/k))となるようにすることです。また、元のnトークン以外にもそれらを粗視化したスパントークンと呼ばれるn-1個のトークンも追加で考えるため、上式での u の範囲は 1\leq u \leq2n-1となります。

この構造の詳細を以下で説明します。

グラフの構築:Binary Partitioning (二分分割)

BPT の鍵となるアイデアは、細かい粒度から粗い粒度へ(fine-to-coarse) という帰納的バイアスを導入することです。つまり、あるトークンから見て、近い文脈は細かい粒度で、遠い文脈は粗い粒度で表現します。

これを実現するために、入力系列を Binary Partitioning(二分分割) によって階層的なスパンに分割します。

ノードの構築

長さ n の入力系列に対して、以下の手順で二分木を構築します:

  1. 系列を再帰的に 2 つに分割していきます
  2. 各分割が単一のトークンになるまで続けます
  3. 結果として、n 個のトークンノード(葉ノード)と n-1 個のスパンノード(内部ノード)が生成されます

形式的には、ノード u_{l,m} をレベル lm 番目のノードとします。レベルは下から上に増加し、トークンノードのレベルは l=0、スパンノードのレベルは l \geq 1 です。スパンノード u_{l,m} は、トークンノード u_{0,2^{l}\cdot m+1}, \ldots, u_{0,2^{l}\cdot(m+1)} を含む範囲を表現します。

BPT では、2 種類のノードを区別します:

  • トークンノード (Token nodes): 二分木の葉ノード。入力系列の各トークンに対応
  • スパンノード (Span nodes): 二分木の内部ノード。複数のトークンをまとめた範囲を表現

エッジの構築

二分木構造をそのまま使用すると、遠いトークン間の情報伝達パスが長くなってしまいます。そこで、BPT では 2 種類のエッジを導入します:

1. Affiliated Edges (所属エッジ)

スパンノード u_{l,m} に対して、そのスパンに含まれる全てのトークンノード u_{0,2^{l}\cdot m+i} (1 \leq i \leq 2^l) からの有向エッジを追加します。これを所属エッジと呼びます。

所属エッジにより、スパンノードの表現は、そのスパンに含まれるトークンノードの情報を直接集約して計算されます。

2. Contextual Edges (文脈エッジ)

トークンノードu_{0,i}に対して、左右の文脈情報を階層的に接続するような文脈エッジと呼ばれる有向エッジを接続します。有向エッジは、任意のトークンノードに対して、接続する有向エッジの数が各レベル毎に左右それぞれk本となるように構築します。ここで、ハイパーパラメータ k は接続の密度を決めるハイパーパラメータであり、k\geq N において通常の Attention と同様に全トークンノードとの接続を持ち、追加で一切トークンノード → スパンノードへの接続のみ持ち、スパンノード → トークンノードのエッジは持たない構造となリます。そのため、k\geq N においてトークンノード部分は実効的に通常の Attention = 全トークン同士の密結合と等価になります。

トークンノード u_{0,i} の右側の文脈エッジは、以下のそれぞれのノードからu_{0,i}への有向エッジとして構築されます(ここで、iは論文に倣い 1-indexed です):

  • レベル 0(トークンレベル): u_{0,p_0}, \ldots, u_{0,p_0+k-1}
  • レベル 1(小スパン): u_{1,p_1}, \ldots, u_{1,p_1+k-1}
  • レベル 2(中スパン): u_{2,p_2}, \ldots, u_{2,p_2+k-1}
  • ...以降同様

ここで、レベル l でのスタートインデックス p_l は以下のように再帰的に計算されます:

  • p_0 = i+1
  • p_l = \mathrm{parent}(p_{l-1} + k)

左側の文脈も同様に構築します。ここで、\mathrm{parent}(p)は論文中では未定義のようですが、文字通り親ノード、つまり1つ上のレベルのノードの index を指しています。つまり \mathrm{parent}(j) = \lceil j/2 \rceil です。
この構造により、各レベルで(ほぼ※) k 個のノードを接続していくことで、近い文脈は詳細に、遠い文脈は粗い粒度で捉えることができます。
(※)実装上の効率のため(?)、もし p_l +k -1 が奇数であれば,同じレイヤの次のノードp_l +kまで文脈ノードとして追加し、結果として次レベルの開始インデックスは p_{l+1} = \mathrm{parent}(p_l +k +1) としているようです。そのため、kの取り方によっては、各レベルからあるトークンノードに向けてk+1個の文脈エッジがある場合があります。

計算量とメモリ使用量

長さ n の系列に対して、BPT は以下の特性を持ちます:

  • ノード数: O(2n)n 個のトークンノード + n-1 個のスパンノード)
  • エッジ数: O(k \cdot n \log(n/k))
  • 計算量: O(d \cdot k \cdot n \log(n/k))
  • メモリ使用量: O(k \cdot n \log(n/k))

例えば、計算量について比較すると、通常の Transformer の O(d \cdot n^2) と比較して、BP-transformer では k が定数の場合に O(d n \log n) となり、大幅な削減が実現されます。同様に、メモリ使用量についても通常の Transformer の O(n^2) と比較して、BP-transformer では k が定数の場合に O(n \log n) となり、大幅な削減が実現されます。

重要な性質として、BPT のグラフでは 任意の 2 つのトークンノード間の距離が最大でも 2 です。なぜなら、トークンノードが直接結ばれる場合(距離k以内のトークンの場合)は距離1で結ばれ、そうでない場合は任意のトークンノードの親スパンノードのいずれかから、別な任意のトークンノードへの文脈エッジが必ず存在し、トークンノード → その親スパンノード → 別なトークンノード で接続されるためです。これは、遠いトークン間でも長距離依存関係を効率的に学習できることを意味します。

グラフの更新

グラフ構造を構築した後、全てのノードの表現を Graph Self-Attention を用いて更新します。

有向グラフ \mathcal{G} において、ノード u の近傍 \mathcal{A}(u) は、その先行ノード(自分に向かうエッジの送信元)全てとして定義します:

  • トークンノードの先行ノードは、それに注意を向ける各層のノード(トークンノードもスパンノードも含まれる)
  • スパンノードの先行ノードは、それに含まれる全てのトークンノード

スパンノードの表現は初期値としてゼロベクトルで初期化され、トークンノードの表現は対応する単語埋め込みで初期化されます。

通常の Transformer と同様に、複数のグラフ層をスタックできます。各層で、以下の更新を行います:

\begin{aligned} \mathbf{Z}^t &= \text{norm}(\mathbf{H}^{t-1} + \text{GSA}^{(t)}(\mathcal{G}, \mathbf{H}^{t-1})), \\ \mathbf{H}^t &= \text{norm}(\mathbf{Z}^{t} + \text{FFN}^{(t)}(\mathbf{Z}^{t})) \end{aligned}
  • \text{GSA}^{(t)}(\mathcal{G}, \cdot): グラフ\mathcal{G} の下での graph self-attention の出力
  • \text{FFN}^{(t)}(\cdot): feed foward net の出力
  • \mathbf{H}^{t-1}: 層tにおける attention 層への入力(\mathbf{H}^{0}は初期トークン埋め込み)
  • \mathbf{Z}^{t}: 層tにおける feed foward 層への入力

(本記事の著者注)
つまり、Post‑Norm による学習(残差和のあとに LayerNorm を適用する)を採用していまるようです。
この論文出版時点と同時期〜それ以降頃に、 Pre-Norm ( LayerNorm を先に適用してからサブ層=Attention を通過する )ような学習法が学習安定性の観点から好ましいことが知られており、ここは1つ改善点になるかと考えています。

木構造上の相対位置エンコーディング

通常の Transformer では絶対位置エンコーディングや系列上の相対位置エンコーディングが使用されますが、BPT では木構造に基づく相対位置エンコーディングを導入します。

ノード u とその近傍ノード v について、木構造上の相対的な関係に基づいて位置表現 r_{v,u} を割り当てます:

  • r_{v,u} = r^{\text{self}}v = u のとき(自己ループ)
  • r_{v,u} = r^{\text{left}}_{j,i} または r^{\text{right}}_{j,i}vu の左/右文脈の第 j レベルで i 番目に追加されたノードのとき
  • r_{v,u} = r^{\text{anc}}_juv のレベル j での祖先ノードのとき

これらの位置表現を用いて、attention の計算を以下のように修正します:

\begin{gathered} \mathbf{R}^u = \text{concat}(\{r_{v,u} \mid v \in \mathcal{A}(u)\}), \\ \mathrm{head}_i^u = \text{softmax}\left(\dfrac{\mathbf{Q}^u_i \left(\mathbf{K}^u_i + \mathbf{R}^u \right)^T}{\sqrt{d}}\right) \mathbf{V}^u_i \end{gathered}

全ての r は学習可能なパラメータであり、各層で独自のパラメータセットを持ちます。注意ヘッド間では共有されます。

(本記事の著者注)
相対位置埋め込みについても、上記の方針ではノード u とその近傍ノード v の位置関係毎に学習可能なパラメータとして相対位置埋めこみを行っており、相対的な位置関係毎に独立した学習可能ベクトルを加算しています。
この埋め込み手法は比較的シンプルなタスクでは精度が良い一方、埋め込みの相対距離に対しての連続性が担保されないことや、埋め込みを行う相対距離を伸ばした際の計算量の問題などがあります。
近年では、より効率的な相対位置埋め込みとして本論文出版後に RoPE (Rotary Positional Embeddings) のような手法が提案され、実務でもよく使われています。標準的な RoPE は sin/cos による回転を用するため追加の学習パラメータを持たず、計算コストが低く、訓練時より長い系列長への外挿(extrapolation)にも比較的強いという利点があります。一方で RoPE は連続的な位置情報を扱う設計であり、BP のような木構造に固有の「左/右」や「どの階層で追加されたか」といったカテゴリ的な関係をそのまま明示的に表現するわけではありません。
簡潔に利点・注意点をまとめると:

  • RoPE の利点:追加学習パラメータが不要で外挿に強く、トークン単位の位置処理に低コストで適用可能。
  • 学習ベクトル r の利点:木構造固有のカテゴリ情報(階層性・左右性)を明示的に表現できる。

そのため、BP-Transformer の実装では次のようなハイブリッド設計が現実的で有用となる可能性が考えられます。すなわち、トークン間の相対位置情報には RoPE を適用して連続的な距離情報と外挿性を確保し、スパン間やレベル間の構造的な関係(r^{self}, r^{left}_{j,i}, r^{right}_{j,i}, r^{anc}_j など)は小さな学習可能ベクトルとして残しておく、という方法です。

k の役割

ハイパーパラメータ k は、各レベルで接続するノード数を制御します:

  • k が小さい:疎なグラフ、計算量削減、長距離文脈は粗い粒度のみ
  • k が大きい:密なグラフ、計算量増加、長距離文脈も細かい粒度で捉える
  • k を極端に大きくする(k\to N)とほぼ全トークンを参照するため計算量・メモリ使用量は通常の Transformer に近づき、k>=Nにおいては実効的に通常の Transformer と一致します。(ただしスパンノードや所属エッジ自体は残るため、不要な計算が追加された状態になります)

実験では、単語レベルのタスクでは k=4 程度、文字レベルのタスクでは k=64 程度が最適であることが示されています。

実用例:テキスト分類タスク

IMDB データセット(平均長 294 単語のレビュー文)でのテキスト分類を例に説明します。

  1. 各レビュー文を二分分割により階層的スパンに分割
  2. k=4 とすると、各トークンは左右それぞれ約 4 \log_2 294 \approx 32 個のノードに接続
  3. 通常の Transformer では約 294 個のノードに接続するのと比べ、大幅に疎なグラフ
  4. 最終層のルートノード(全系列を表すスパン)の表現を分類に使用

この設定で、BPT は通常の Transformer を精度で上回りつつ、メモリ使用量と計算時間を大幅に削減します。

結果

論文では、文レベルと文書レベルの様々なタスクで実験が行われています。以下に主要な結果を要約します。

テキスト分類

SST-5 データセット(短文)

平均長 19 単語の感情分類タスク(5 クラス)です。

  • BPT (k=2): 52.71% (標準偏差 0.32)
  • Star Transformer: 52.9%
  • Transformer: 50.4%
  • Bi-LSTM: 49.8%

BPT は短文でも通常の Transformer を上回る性能を示し、構造的な帰納的バイアスが有効であることを示しています。

IMDB データセット(長文)

平均長 294 単語のレビュー文の感情分類タスク(2 クラス)です。

  • BPT (k=4): 92.12% (標準偏差 0.11)
  • BCN+Char+CoVe(追加の事前学習利用): 91.8%
  • QRNN: 91.4%
  • Star Transformer: 90.50%
  • Transformer: 89.24%

BPT は長文において特に大きな改善を示し、通常の Transformer を 2.88 ポイント、Star Transformer を 1.62 ポイント上回りました。

k の影響

IMDB データセットで k \in \{1,2,4,8,16,32,64\} を試した結果、k=4 で最高性能を達成しました。k を大きくしてもグラフを密にするだけで、性能向上にはつながりませんでした。これは、適度な疎性が正則化として機能している可能性を示唆しています。

系列シフトへの頑健性

二分分割の性質上、系列の先頭に要素を挿入するとグラフ構造が変わります。しかし、SST-5 で先頭に 0〜7 個のパディングを追加する実験を行った結果、精度の変化は小さく(最大で 1.2 ポイント程度)、BPT がシフトに対して頑健であることが示されました。

言語モデリング

文字レベルの言語モデリングデータセット Enwik8 と Text8 で評価しました(BPC: bits-per-character で評価、低い方が良い)。

Enwik8

  • BPT (k=64, 文脈長 8192): 1.02 BPC
  • Adaptive Span Transformer: 1.02 BPC
  • Transformer-XL: 1.06 BPC
  • Transformer: 1.11 BPC
  • mLSTM: 1.24 BPC

Text8

  • BPT (k=64, 文脈長 8192): 1.11 BPC
  • Adaptive Span Transformer: 1.11 BPC
  • Recurrent Highway Networks: 1.27 BPC
  • HM-LSTM: 1.29 BPC

BPT は、パラメータ数 38M で最先端の性能を達成しました。これは他の手法と同等かより少ないパラメータ数です。

文脈長の影響

k=64 に固定して文脈長を \{512, 1024, 2048, 4096, 8192\} と変化させた結果、文脈長が長いほど性能が向上しました:

  • Enwik8: 1.07 (512) → 1.02 (8192)
  • Text8: 1.16 (512) → 1.11 (8192)

これは、BPT が長距離依存関係を効果的に利用できることを示しています。

Attention Degree(注意密度)の比較

文脈長 512 に固定し、異なるスパースアテンション手法を比較しました:

  • BPT: 様々な k 値で最も効率的な性能
  • Restricted Attention(ウィンドウサイズ制限): 同じ attention degree で BPT より低性能
  • Sparse Transformer: 同じ attention degree で BPT より低性能

これは、BPT の fine-to-coarse 構造が他のスパース化手法より効果的であることを示しています。

機械翻訳

文書レベル機械翻訳(IWSLT 2015 中英翻訳)

平均文書長 120 文に対して評価しました(BLEU スコアで評価)。

  • BPT (k=4, 文脈長 64): 19.84 BLEU
  • BPT (k=4, 単一文): 19.19 BLEU
  • Transformer(単一文、著者実装): 18.91 BLEU
  • HAN-NMT: 17.78 BLEU
  • Transformer+cache: 17.32 BLEU

BPT は文書レベルの文脈を利用することで、単一文のみのモデルより 0.65 ポイント改善しました。

文脈長の影響

文脈長を増やしても、BPT は性能が安定していました:

  • Transformer: 18.85 (文脈 0) → 15.55 (文脈 128) と大幅に悪化
  • BPT (k=4): 19.19 (文脈 0) → 19.84 (文脈 128) と改善を維持

これは、BPT の帰納的バイアスが過学習を防ぎ、長文脈でも効果的に学習できることを示しています。

文レベル機械翻訳(WMT14 英独翻訳)

大規模データセット(4.5M 文対)での評価です。

  • BPT (k=4): 27.6 BLEU
  • Transformer: 27.3 BLEU
  • ConvS2S: 25.16 BLEU
  • GNMT+RL: 24.6 BLEU

BPT は同じパラメータ数の通常の Transformer を上回り、k=4 が単語レベルタスクで一般的に良い設定であることが確認されました。

スループットと GPU メモリ使用量

BPT では、#計算量とメモリ使用量 に記載したように、計算量及びメモリ使用量の面で通常の Transformer より優れています。
そこで、疎なアテンションのための一連の CUDA カーネルを設計し、言語モデリングの設定(6 層、d=512d_{ff}=2048、8 ヘッド)で、推論時のメモリとスループットを測定しました。

GPU メモリ使用量

論文 Figure.6 に示されるように、系列長が増加するにつれて:

  • Transformer: 急激に増加(8192 で約 10.5GB)
  • BPT (k=1): 緩やかに増加(8192 で約 4GB)
  • BPT (k=4): 緩やかに増加(8192 で約 5.5GB)
  • BPT (k=64): やや増加(8192 で約 8GB)

BPT は常に Transformer より少ないメモリで動作します。

スループット(トークン/秒)

論文 Figure.7 に示されるように、

  • 短い系列(〜1024): Transformer の方が速い(ノード数が 2 倍になるオーバーヘッドのため)
  • 長い系列(2048〜): BPT の方が速く、差は系列長が長いほど顕著
  • 系列長 8192: BPT (k=4)は約 50,000 トークン/秒、Transformer は処理困難

ことがわかります。すなわち、BPT は Transformer が処理困難な長い系列でも、速度とメモリ効率の両面で実用的な性能を示します。

まとめ

BP-Transformer の主要な成果:

  1. 計算量削減: O(n^2) から O(k \cdot n \log(n/k)) へ削減
  2. 長距離依存: 任意のトークンペア間の距離が最大 2、効率的な長距離依存学習
  3. 帰納的バイアス: fine-to-coarse 構造により、短文でも過学習を防ぎ、長文で効果的
  4. 汎用性: テキスト分類、言語モデリング、機械翻訳で一貫して高性能
  5. 効率性: 実装レベルでも、長文で大幅なメモリ削減と速度向上を実現

k=4 が単語レベルタスクで、k=64 が文字レベルタスクで一般的に良い設定であることが、複数のタスクで確認されました。

参考文献

[1]: BP-Transformer: Modelling Long-Range Context via Binary Partitioning(https://arxiv.org/abs/1911.04070)

Discussion