👻

Monarch 行列によるNNの効率化

に公開

はじめに

この記事は論文 [1] "Monarch: Expressive Structured Matrices for Efficient and Accurate Training" の解説です。要旨を一言で言うと、Monarch 行列という新しい構造化行列を提案し、これをニューラルネットワークの重み行列として用いることで、計算効率(GPU での実行速度)と表現力(FFT や畳み込みなどを表現可能)を両立させ、End-to-End の学習高速化や、事前学習済みモデルの圧縮・ファインチューニングの効率化を実現した、という手法になります。
具体的には、行列のパラメータ数\mathcal{O}(n^2)\rightarrow \mathcal{O}(n\sqrt{n}), 行列-ベクトル積の演算量\mathcal{O}(n^2)\rightarrow \mathcal{O}(n\sqrt{n})となります。

論文内の手法の解説

考える設定と既存手法の問題点

大規模なニューラルネットワークは高い性能を示しますが、学習やファインチューニングには膨大な計算リソースが必要です。このコストを削減するために、密な重み行列(Dense Matrix)を構造化行列(疎行列、低ランク行列、フーリエ変換など)で置き換えるアプローチがあります。しかし、既存の構造化行列には以下の課題がありました。

  1. End-to-End (E2E) 学習でのトレードオフ:
    • 効率性 (Efficiency): 多くの疎行列手法は FLOPs を削減できても、GPU などのハードウェア上で実際に高速化するとは限りません(ランダムアクセスなどにより逆に遅くなることもあります)。
    • 表現力 (Expressiveness): 効率的な行列(例:対角行列)は表現力が低く、表現力が高い行列(例:畳み込み)は汎用的でなかったり計算コストが高かったりします。
  2. Dense-to-Sparse (D2S) ファインチューニングの困難さ:
    • 事前学習済みの密な重み行列を、特定の構造を持つ行列で近似(射影)する効率的なアルゴリズムが存在しない場合が多く、知識の転移が困難でした。

本論文の手法:Monarch 行列

本論文では、これらの課題を解決するために Monarch 行列 を提案しています。これは、蝶の羽(Monarch butterfly)にちなんで名付けられました。

Monarch 行列の定義

n \times n の Monarch 行列 \mathbf{M} は、2つのブロック対角行列 \mathbf{L}, \mathbf{R} と置換行列 \mathbf{P} を用いて以下のように定義されます(ここで n=m^2 とします)。

\mathbf{M} = \mathbf{P} \mathbf{L} \mathbf{P}^\top \mathbf{R}
  • \mathbf{L}, \mathbf{R}: 各ブロックが m \times m のサイズを持つブロック対角行列。
  • \mathbf{P}: 入力ベクトルを m \times m 行列に変形して転置し、再びベクトルに戻す操作に対応する置換行列。

この構造は、FFT(高速フーリエ変換)のアルゴリズム構造にヒントを得ており、GPU 上で非常に効率的に計算できる「Batch Matrix Multiply (BMM)」を活用できる形になっています。

(本記事の著者による)直感的な理解

この数式 M = P L P^\top R の意図を直感的に理解するために、入力ベクトル(長さ n)を m \times m の2次元グリッド(正方形のマス目)に並べ替えた状態を想像してみます。

  1. R(局所的な混合):

    • グリッドの「縦の列」ごとに、その列に含まれる m 個の要素を使って全結合(密結合)を行います。
    • これにより、同じ列にある要素同士の情報は混ざり合いますが、隣の列の情報はまだ入ってきません。
  2. P^\top(転置):

    • m \times m の 2次元グリッドの行と列を入れ替えます。
  3. L(局所的な混合):

    • 今度は(元のグリッドで言う)「横の行」ごとに、その行に含まれる m 個の要素を混ぜ合わせます。
    • ここでも、同じ行内であれば端から端まで情報が飛びます。
  4. P(元に戻す):

    • 最後に並び順を元に戻します。

ポイント:

  • 直接のつながり: 「同じ行」または「同じ列」にいる要素同士だけが直接のパラメータ(重み)を持ちます。「斜め」の位置にいる要素同士には直接のパスはありません。これによりパラメータ数を O(n^2) から O(n\sqrt{n}) に削減しています。
  • 間接的なつながり: しかし、この2ステップを経ることで、情報は「(自分) \to (同じ列の誰か) \to (その人の行にいる斜めの人)」という経路で伝わります。結果として、たった2回の操作でグリッド上のあらゆる地点からあらゆる地点へ情報が届く(全要素間の相関を持てる)ことになります。

つまり、Monarch行列は「計算コストを抑えつつ(直接のパスは限定的)、大域的な情報伝達(間接的には全員とつながる)を実現する」ための巧妙な構造と言えます。

特徴1: ハードウェア効率性 (Efficiency)

Monarch 行列の積 \mathbf{M}\mathbf{x} は、以下の手順で計算できます。

  1. \mathbf{R}\mathbf{x}: ブロック対角行列との積(BMM で計算可能)。
  2. \mathbf{P}^\top (\cdot): 転置(メモリ上の再配置)。
  3. \mathbf{L} (\cdot): ブロック対角行列との積(BMM で計算可能)。
  4. \mathbf{P} (\cdot): 転置。

これにより、密行列の計算量 O(n^2) に対して O(n\sqrt{n}) の計算量で済み、かつ GPU の特性を活かした実装が可能です。実験では密行列積と比較して最大 2 倍の高速化を達成しています。

(本記事の著者による)補足:なぜMonarch行列はGPUで効率的なのか?

論文中には「Batch Matrix Multiply (BMM) を活用できるため効率的である」という旨の記載があります。これを踏まえ、なぜスパースな構造なのにGPUで高速なのかを補足します。

  1. 一般的なスパース行列の課題:

    • 通常のスパース行列(CSR形式など)の積は、非ゼロ要素の位置に合わせてメモリ上の飛び飛びの場所にアクセスする「ランダムアクセス」が発生します。GPUは並列計算は得意ですが、このような不規則なメモリアクセスは非常に苦手で、計算性能が低下しやすいです。
    • また、NVIDIA GPU の Tensor Core のような、密行列積を爆速で計算する専用回路を利用しにくいという欠点もあります。
  2. Monarch行列の利点:

    • Monarch行列の構成要素である \mathbf{L}, \mathbf{R} は「ブロック対角行列」です。これは計算機上では「小さな密行列の束(バッチ)」として扱えます。
    • これにより、メモリへのアクセスが規則的になり、かつ Tensor Core をフル活用して計算することができます。
    • \mathbf{P}(置換)の計算も、メモリ上の規則的なデータの並び替え(Reshape/Transpose)に相当するため、ランダムアクセスよりも遥かに高速です。

つまり、Monarch行列は「デタラメに0があるスパース行列」ではなく、「GPUが得意な『小さな密行列演算』の集合体として扱える構造化されたスパース行列」であるため、理論的な計算量削減だけでなく、実効速度の向上も達成できていると言えます。

特徴2: 表現力 (Expressiveness)

Monarch 行列(およびその積)は非常に高い表現力を持ちます。

  • Butterfly 行列(FFT などを表現できる行列クラス)を包含しています。
  • 畳み込み、アダマール変換、テプリッツ行列などを表現可能です。
  • Monarch 行列の積(2層)で、フーリエ変換、離散コサイン変換 (DCT)、離散サイン変換 (DST) などを表現可能です。

(本記事の著者による)補足:MonarchとButterflyの関係(Radix-2 vs Radix-\sqrt{n}

「MonarchもButterflyの一種である」という点について、少し補足します。

  1. Radix-2 vs Radix-\sqrt{n}:

    • 一般的なButterfly行列は、サイズ n を 2 ずつ再帰的に分解していくため、深さが \log n になります(Radix-2)。
    • 一方、Monarch行列は、サイズ n\sqrt{n} ずつ分解するため、深さが 2 で済みます(Radix-\sqrt{n})。
  2. 共通点と包含関係:

    • どちらも「分割統治法(Cooley-Tukey法)」に基づくアプローチであり、その意味でMonarchは「GPU向けに調整された(浅くて太い)Butterflyの変種」と言えます。
    • なぜ包含するのか?: 一般的なButterfly行列は、サイズ m のブロックに対してもさらに再帰的に疎な構造(2分割)を作ります。一方、Monarch行列はサイズ m のブロックを「密行列(Dense)」として扱います。
    • 密行列は、そのサイズで表現可能なあらゆる線形変換(当然、さらに再帰したButterfly構造も含む)を表現できるため、Monarch行列はButterfly行列のスーパーセット(包含するクラス)となります。
    • つまり、「Butterflyの再帰を \log m 回繰り返した状態(サイズ m)でストップし、そこを自由な密行列にしたのがMonarch」 と解釈すると分かりやすいでしょう。

特徴3: 射影 (Projection) アルゴリズム

任意の密行列 \mathbf{A} を Monarch 行列 \mathbf{M} で近似する問題(\min_{\mathbf{M} \in \mathcal{M}} \|\mathbf{A} - \mathbf{M}\|_F^2)に対して、解析的な最適解が存在し、効率的に計算できることを示しました。

具体的には、行列 \mathbf{A} を 4 階テンソルに変形し、各ブロックに対して特異値分解 (SVD) を行ってランク 1 近似を行うことで、最適な \mathbf{L}, \mathbf{R} を求めることができます。これにより、事前学習済みのモデルを Monarch 行列に変換してファインチューニングすることが容易になります。

(本記事の著者による)射影アルゴリズムの数学的詳細

射影アルゴリズムの美しさは、一見複雑な「Monarch行列への近似問題」が、実は「たくさんの小さなランク1近似問題(SVDで解ける)」に分解できるという点にあります。以下にその詳細を示します。

1. 問題の定式化

与えられた密行列 \mathbf{A} \in \mathbb{R}^{n \times n} に対して、最も近い Monarch 行列 \mathbf{M} = \mathbf{P}\mathbf{L}\mathbf{P}^\top\mathbf{R} を見つけたいというのが目的です。

\min_{\mathbf{M} \in \mathcal{M}} \|\mathbf{A} - \mathbf{M}\|_F^2

ここで n = m^2 とします。

2. 行列のテンソル化(インデックスの分解)

ここが最大のポイントです。行列の行インデックス r と列インデックス c (1 \le r, c \le n) を、それぞれ m 進法のように2つのインデックスに分解します。

  • r \leftrightarrow (\ell, j) : r = (\ell-1)m + j
  • c \leftrightarrow (k, i) : c = (k-1)m + i

これを使うと、行列 \mathbf{A}m \times m \times m \times m の4階テンソル \tilde{\mathbf{A}} とみなせます。

\tilde{A}_{\ell j k i} = A_{(\ell-1)m+j, (k-1)m+i}

3. Monarch行列の構造解析

Monarch行列 \mathbf{M} = \mathbf{P}\mathbf{L}\mathbf{P}^\top\mathbf{R} の要素 M_{\ell j k i} がどうなるかを計算すると、論文中の式 (2) のように以下のシンプルな形になります。

M_{\ell j k i} = L_{j \ell k} \cdot R_{k j i}

ここで、

  • L_{j \ell k}: 行列 \mathbf{L}j 番目のブロックの (\ell, k) 成分
  • R_{k j i}: 行列 \mathbf{R}k 番目のブロックの (j, i) 成分

です。
元の行列形式

\mathbf{M} = \mathbf{P} \mathbf{L} \mathbf{P}^\top \mathbf{R}

で定義したL,Rに添え字をつける形で書き直すと

\mathbf{M}_{\ell j k i} = \sum_{j', k'} (\mathbf{L}_{(j, \ell), (j', k')} \delta_{j j'}) (\mathbf{R}_{(k', j), (k, i)} \delta_{k' k})

となります。ここで (a, b)m 進法分解されたインデックス (a-1)m + b を表し、\delta はクロネッカーのデルタを表します。つまり、\mathbf{L} は行ブロック j 内での変換、\mathbf{R} は列ブロック k 内での変換に対応していることが分かります。

4. ランク1近似への分解

ここで、インデックス jk を固定して考えてみます。すると、残る変数は \elli です。
このとき、テンソル \tilde{\mathbf{M}} のスライス(m \times m 行列)は以下のようになります。

\tilde{M}_{: j k :} = \mathbf{u}_{jk} \mathbf{v}_{jk}^\top

ただし、ベクトル \mathbf{u}_{jk}, \mathbf{v}_{jk} \in \mathbb{R}^m は以下のように定義されます。

  • \mathbf{u}_{jk} = [L_{j 1 k}, L_{j 2 k}, \dots, L_{j m k}]^\top\mathbf{L} の一部)
  • \mathbf{v}_{jk} = [R_{k j 1}, R_{k j 2}, \dots, R_{k j m}]^\top\mathbf{R} の一部)

これはまさに「ランク1行列」の形です。
つまり、「jk を固定したときの m \times m 部分行列が、すべてランク1であれば、それはMonarch行列である」と言えます。

5. アルゴリズムの手順

以上のことから、射影問題は m^2 個の独立したランク1近似問題に帰着されます。

  1. Reshape: 入力行列 \mathbf{A}m \times m \times m \times m のテンソル \tilde{\mathbf{A}} に変形する。
  2. Loop: j=1 \dots m, k=1 \dots m の各組み合わせについて:
    • 部分行列 \mathbf{B} = \tilde{\mathbf{A}}_{: j k :} (m \times m 行列) を取り出す。
    • \mathbf{B} を SVD(特異値分解) し、最大特異値に対応する左特異ベクトル \mathbf{u} と右特異ベクトル \mathbf{v} を求める(これが \mathbf{B} に最も近いランク1行列 \sigma \mathbf{u} \mathbf{v}^\top を与える)。
    • 求めた \mathbf{u}\mathbf{L} の対応箇所に、\mathbf{v}\mathbf{R} の対応箇所に格納する。
  3. Reconstruct: \mathbf{L}, \mathbf{R} をブロック対角行列として構成して完了。

このように、元の重み行列のスパースなMonarch行列への射影を得るのに、スパースな構造の下での再学習を行わずとも、SVDにより数学的に最適な射影が得られるというのもこの手法の強力な点です。これは、学習済みモデルの重みをMonarch行列に置き換える際に、近似行列を求めるための学習が不要であり、その後のファインチューニングも密行列を学習するよりも短時間で済むため、全体の計算量を大幅に削減できるという利点があります。

結果

論文では、画像認識、言語モデル、科学計算などの多岐にわたるタスクで検証を行っています。

End-to-End 疎学習 (Sparse Training)

最初から Monarch 行列を用いてモデルを学習させる設定です。

  • 画像分類 (ImageNet): ViT や MLP-Mixer の重みを Monarch 行列に置き換えたところ、精度を維持したまま学習速度が 1.7〜2倍 高速化しました。
  • 言語モデリング (WikiText-103): GPT-2 (Small/Medium) において、Perplexity を維持しつつ学習を 約2倍 高速化しました。
  • 科学計算 (PDE, MRI): 偏微分方程式 (PDE) の解法や MRI 画像再構成において、従来のフーリエ変換ベースの手法よりも高い精度(PDE で誤差 40% 減、MRI で pSNR 1.5dB 向上)を達成しました。これは Monarch 行列が学習可能なため、データに適応した変換を獲得できたためと考えられます。

Sparse-to-Dense 学習 (Reverse Sparsification)

学習の大部分(例:90%)を Monarch 行列で行い、最後の少し(例:10%)で密行列に戻して学習する「Reverse Sparsification」という手法を提案しています。

  • GPT-2 事前学習: OpenWebText での事前学習において、最初から密行列で学習する場合と比べて、精度を落とさずに総学習時間を 2倍 短縮しました。
  • BERT 事前学習: NVIDIA の最適化された実装(MLPerf 1.1 記録)と比較しても 23% 高速に学習を完了しました。

Dense-to-Sparse ファインチューニング

事前学習済みの密な BERT モデルを Monarch 行列に射影(近似)してから、GLUE タスクでファインチューニングを行いました。

  • BERT ファインチューニング: パラメータ数を半分に削減しつつ、密モデルと同等の精度を達成しました。また、ファインチューニングの速度も 1.7倍 高速化しました。

まとめ

Monarch 行列の主要な成果:

  1. 効率と表現力の両立: ブロック対角行列と置換の組み合わせにより、GPU で高速に動作しつつ、FFT や畳み込みなどの重要な変換を表現できる。
  2. 射影可能: 密行列からの最適な近似を求めるアルゴリズムが存在し、事前学習済みモデルの活用が可能。
  3. 実用的な高速化: 画像、言語、科学計算の幅広いタスクで、精度を犠牲にすることなく学習時間を大幅(約2倍)に短縮。

特に「Reverse Sparsification」は、大規模言語モデルの事前学習コストを削減する実用的なテクニックとして注目されます。

参考文献

[1]: Monarch: Expressive Structured Matrices for Efficient and Accurate Training (https://arxiv.org/abs/2204.00595)

Discussion