🌊

TransformerをRWKVに変換する手法について

に公開

TransformerをRWKVに変換する技術めも

そもそもRWKVって何?

RWKV(Receptance Weighted Key Value)は、BlinkDL氏を中心とした約9,000人+のコミュニティメンバーで研究・開発されているオープンソースのLLMアーキテクチャです。現在の最新版は「RWKV-7 "Goose"」で、伝統的にアーキテクチャのメジャー更新ごとに鳥の名前が付けられています。
自分は、v4からメンバーに入りました。

従来のTransformerとの違い

一般的なTransformerの処理フローはこのようになっていまして

Embedding → blocks[Attention → MLP] → Head

Transformerのアテンション機構では、Query(Q)、Key(K)、Value(V)を使ってソフトマックスを計算します:

\text{Attention}(Q,K,V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V

一方、RWKVはRNNアーキテクチャによる逐次生成機構を持ち、線形アテンション機構に似た設計になっています。従来の二次アテンションのKVキャッシュの代わりに、固定メモリステートで情報を管理する点が特徴です。

RWKVのコア部分:ステート管理機構

RWKVの最も重要な部分は、このステート更新式です:

S_t = S_{t-1} \odot (diag(w_t) - \kappa_t^T (a_t \odot \kappa_t)) + v_t^T k_t

この式を分解すると:

  • S_t:現在の時刻でのステート(記憶)
  • S_{t-1}:前の時刻でのステート
  • diag(w_t):何をどれだけ忘れるかを制御する対角行列
  • \kappa_t:情報の流れを制御するゲーティングベクトル
  • a_t:アクティベーション
  • v_t, k_t:Value と Key ベクトル

つまり、「前の記憶をいい感じに忘れながら、新しい情報を圧縮して追加する」仕組みです。

RWKVの計算効率

計算量の比較

Transformer(従来)

  • 1トークンあたりの計算量:O(n^2)(nはシーケンス長)
  • メモリ使用量:O(n)(KVキャッシュがシーケンス長に比例)

RWKV

  • 1トークンあたりの計算量:O(n)(線形時間)
  • メモリ使用量:O(1)(固定サイズのステート)
  • 生成可能コンテキスト長:理論上無限(崩壊は考えず)

実際の効果

この効率性により、8Bくらいであれば、24GB GPUで384バッチの同時処理が可能になります。マルチエージェントをバックグラウンドで流すタスクなどに向いているのではと思っています

RWKVの制約

もちろん、RWKVは魔法ではありません:

  • 過去情報の振り返りが弱い:ステートに情報を圧縮するため、古い情報が段階的に揮発
  • 開発者・文献が少ない:まだ新しいアーキテクチャのため、リソースが限定的
  • 推論エンジンの対応が限定的:現時点では llama.cpp でのみ対応

ただし、RWKV v7 "Goose" では状態追跡機構が向上し、NIAHなどの簡単なNeedleタスクで128kトークン程度まで100%の精度を達成できるようになりました。

なぜTransformerをRWKVに変換できるのか?

最近有名になってきている仮説があります:Transformer LLMの知識の大半はMLPレイヤーに蓄積されているということです。

つまり、Attentionブロックだけを RWKVに置き換えても、モデルの本質的な性能を維持できる可能性があります。MLPが知識の中核を担っているなら、アテンション機構の変更は「情報へのアクセス方法」を変えるだけで、知識そのものは保持されるはずです。

アーキテクチャの対応関係

TransformerからRWKVへの変換では、以下のような対応関係になります:

Transformerのアテンション

  • Query(Q)
  • Key(K)
  • Value(V)
  • Output(O)

RWKVの対応要素

  • Query(Q)→ Receptance(R)
  • Key(K)→ Key(K)
  • Value(V)→ Value(V)
  • Output(O)→ Output(O)

ただし、RWKVには忘却度調整w、残差接続などが追加されます。

単純な重み継承では動作しない

ここで注意が必要です。AttentionブロックのWeightをそのままRWKVに移植しても動作しません。

なぜなら、RWKVは「何を覚えて、何を忘れるか」を選択的に制御することで、ステートに情報を圧縮・保持する全く異なる仕組みだからです。同じ入出力形式でも、内部の動作原理が根本的に異なります。

知識蒸留による変換手法

そこで、元のTransformerモデルを「教師」、RWKVモデルを「生徒」とした蒸留学習を行います。

Step 1: レイヤーごとの出力を合わせる

まず、生徒モデル(RWKV)が教師モデル(Transformer)の各層のAttention Blockの出力と同じになるようにトレーニングします。

損失関数:

L_{\text{alignment}} = \sum_{i=1}^{L} \|H_i^{\text{teacher}} - H_i^{\text{student}}\|^2

ここで、L は層数、H_i は第 i 層の隠れ状態です。これにより、レイヤー間の分布のずれを最小化できます。通常、約150Mトークンで収束します。MSEがいい感じに機能します。

Step 2: 最終出力の分布を合わせる

次に、生徒モデルが教師モデルの最終出力と同じになるようにトレーニングします。KL Divergenceを用いて、トークン分布全体を学習させます:

L_{\text{KD}} = \text{KL}(P_{\text{teacher}} \| P_{\text{student}}) = \sum_i P_{\text{teacher}}(i) \log \frac{P_{\text{teacher}}(i)}{P_{\text{student}}(i)}

これにより、蒸留データセットが不十分でも分布学習により、トレーニング効率が大幅に向上します。
場合によっては、Multi Stepped Temperatureを使用すると、より効率的に学習できます。

変換後のメリット

State-tuning:新しいPEFT手法

RWKVで提案されたPEFT手法です。モデルの重みを変更することなく、RWKVステートを直接学習します:

S_{\text{tuned}} = S_{\text{base}} + \Delta S

この手法により、大規模デプロイ時に単一モデルでステートのみを切り替える運用が可能になります。LoRAのように重みを変更する必要がないため、デプロイ側は非常に助かります。
ただし、トークンを生成するほど、揮発していくため、改良は必要です。

超高効率推論

RWKVの線形計算量により、従来ではA100 80GBユーザーしかできなかった大規模バッチ処理が可能になります。

変換のデメリット

性能は確実に劣化します。

教師モデルと完全に同じ性能にはなりません。知識蒸留のレベルにもよりますが、特に振り返りタスクなどではRWKVの本来の制約が現れ、性能劣化は確実に発生します。

これは、圧縮による情報損失と、アーキテクチャの根本的な違いによるものです。

ちょっとはAttentionがいるかも:

個人的の研究では、ハイブリッド構造で効率と性能のバランスを両立する手法が実用的性能の確保に必要だとわかってきました。大半の層をRWKVにしつつ、一部の重要な層を二次アテンションのまま維持する手法です。

このハイブリッド構成により、NIAHタスクを60k程度まで維持できる場合もあります。計算効率の大部分を保ちながら、振り返りタスク能力を大幅に向上できるパターンがありました。

まとめ

TransformerからRWKVへの変換は、計算効率と推論性能のトレードオフを考慮したアプローチです。完璧な変換は難しいものの、特定の用途では大きなメリットを提供できる技術として研究を進めています。

特に、大規模な推論処理や、メモリ制約の厳しい環境(私)では、このような変換技術の価値は非常に高いと信じています。(自分だけかも)

この分野はまだ発展途上ですが、Sliding Attentionハイブリッドアプローチなどの新しい手法も登場しており、面白い分野になってきました。

Discussion