🦆

その他言語モデル 論文解説②「RWKV v5 と v6」

に公開

論文

Eagle and Finch: RWKV with Matrix-Valued States and Dynamic Recurrence

論文: https://arxiv.org/pdf/2404.05892
GitHub: https://github.com/BlinkDL/RWKV-LM
wiki: https://wiki.rwkv.com/basic/architecture.html

なんと本論文では v5 と v6 が同じ paper になっており、本記事もそれに倣って同時に解説します。

RWKV ではバージョンに応じてコードネームが付いています。v5 -> Eagle, v6 -> Finch です。v6 は v5 の上位互換でありますが、発想のステップが異なるため、段階的に解説できればと思います。

概要

  • RWKV-4 の改良モデルある、Eagle(RWKV-5) と Finch(RWKV-6) を提示する
  • Time mixing の \vec{k} \odot \vec{v}の箇所が外積になって行列になった
  • つまり、内部状態の次元が D \rightarrow (D/h) \times (D/h) \times h という行列 x ヘッド数に拡張された

前提知識

RWKV-v4

前回記事を参照してください。

取り扱う次元

  • d: 入力\bold{x} の特徴量次元. n_embd
  • D: \bold{W}_{\square} によって射影される次元. dim_att
  • h: ヘッド数. ヘッド次元は D/h

linear interpolation ( lerp )

新たに \text{lerp} が定義されています。

\text{lerp}_{\square}(\bold{a},\bold{b}) = \bold{a} + (\bold{b} - \bold{a}) \odot \pmb{\mu}_{\square}

\square には r ,k, v, g などが入ります。これは Token Shift で使われ、例えば以下のように変形すると、おなじみの形式になります。

\begin{align*} \text{lerp}(\bold{x}_{t-1},\bold{x}_t) &= \bold{x}_{t-1} + (\bold{x}_t - \bold{x}_{t-1}) \odot \pmb{\mu} \\ &= \bold{x}_{t-1} + \pmb{\mu} \odot \bold{x}_t - \pmb{\mu} \odot \bold{x}_{t-1} \\ &= \pmb{\mu} \odot \bold{x}_t + ( 1 - \pmb{\mu}) \odot \bold{x}_{t-1} \\ \end{align*} \\

ただ今回から、\text{lerp}(\bold{x}_{t},\bold{x}_{t-1}) 形式になっているのでご注意ください。

v4 から変わっていない箇所 ( v5 )

Channel Mixing

次元数だけ変わっているが、それ以外は同じ

v4 から変わった箇所 ( v5 )

基本的なブロック構造や、Time mixingChannel mixing が繰り返される構造は同じです。そのため、変わった箇所にフォーカスして解説します。

  • Multi head になった
  • lerp の作用で、\bold{g} が増えた. \bold{g} は v4 の \bold{r} と似たような Gating の役割をしている.
  • \bold{r} は Attention で言う KVキャッシュに対する Query に似た役割になった
  • 内部次元が ベクトル D から行列 (D/h) \times (D/h) \times h になった

Time Mixing ( Token Shift )

https://github.com/BlinkDL/RWKV-LM/blob/b3b4d056d0907bb9e9619be3de077f051091287c/RWKV-v5/src/model.py#L241

>>> args = SimpleNamespace(n_embd=192, dim_att=128, head_size_a=64, n_layer=4)
>>> args.head_size_divisor = args.head_size_a ** (1/2)
>>> self = RWKV_Tmix_x052(args, 1)
>>> self
RWKV_Tmix_x052(
  (time_shift): ZeroPad2d((0, 0, 1, -1))
  (receptance): Linear(in_features=192, out_features=128, bias=False)
  (key): Linear(in_features=192, out_features=128, bias=False)
  (value): Linear(in_features=192, out_features=128, bias=False)
  (output): Linear(in_features=128, out_features=192, bias=False)
  (gate): Linear(in_features=192, out_features=128, bias=False)
  (ln_x): GroupNorm(2, 128, eps=1e-05, affine=True)
)

まず、従来の Token Shift が行われます。v4 では r, k, v でしたが、今回は r, k, v, g と、ひとつ増えています。\bold{x}_t, \pmb{\mu}_{r,k,v,g} \in \mathbb{R}^{d}, \bold{W}_{r,k,v,g} \in \mathbb{R}^{d \times D}

\begin{align*} \bold{r}_t &= \text{lerp}_{r}(\bold{x}_t, \bold{x}_{t-1}) \bold{W}_r \in \mathbb{R}^D \\ \bold{k}_t &= \text{lerp}_{k}(\bold{x}_t, \bold{x}_{t-1}) \bold{W}_k \in \mathbb{R}^D \\ \bold{v}_t &= \text{lerp}_{v}(\bold{x}_t, \bold{x}_{t-1}) \bold{W}_v \in \mathbb{R}^D \\ \bold{g}_t &= \text{lerp}_{g}(\bold{x}_t, \bold{x}_{t-1}) \bold{W}_g \in \mathbb{R}^D \\ \end{align*}

(\bold{x}_t, \bold{x}_{t-1}) ではなく ( \bold{x}_{t-1}, \bold{x}_{t}) と思いますが、変更したんですかね。細かいので無視します。

Time Mixing ( WKV term )

ここはだいぶ変わっています。v4 では AFT をベースに組まれていましたが、今回は割と関係なくなっています。t-1までの減衰項 \bold{w}t の係数 u については、似たような概念で今回も含まれています。

そして今回から Multi Head の概念が登場します。ヘッド数h, ヘッド次元D/h です。

\begin{align*} \bold{w} &= \exp( - \exp (\pmb{\omega}) ) \ \ \ \in \mathbb{R}^{h \times (D/h)} \\ \bold{u} & \in \mathbb{R}^{h \times (D/h)} \\ \end{align*}

この \exp( - \exp (.)) は、0<w<1 を保証するものです。

uとwのパラメータ初期値
>>> self.time_faaaa # u
Parameter containing:
tensor([[ 3.3333e-01,  4.3071e-01,  2.2808e-01,  3.2546e-01,  4.2283e-01,
          2.2021e-01,  3.1759e-01,  4.1496e-01,  2.1234e-01,  3.0971e-01,
          4.0709e-01,  2.0446e-01,  3.0184e-01,  3.9921e-01,  1.9659e-01,
          2.9396e-01,  3.9134e-01,  1.8871e-01,  2.8609e-01,  3.8346e-01,
          1.8084e-01,  2.7822e-01,  3.7559e-01,  1.7297e-01,  2.7034e-01,
          3.6772e-01,  1.6509e-01,  2.6247e-01,  3.5984e-01,  1.5722e-01,
          2.5459e-01,  3.5197e-01,  1.4934e-01,  2.4672e-01,  3.4409e-01,
          1.4147e-01,  2.3885e-01,  3.3622e-01,  1.3360e-01,  2.3097e-01,
          3.2835e-01,  1.2572e-01,  2.2310e-01,  3.2047e-01,  1.1785e-01,
          2.1522e-01,  3.1260e-01,  1.0997e-01,  2.0735e-01,  3.0472e-01,
          1.0210e-01,  1.9948e-01,  2.9685e-01,  9.4226e-02,  1.9160e-01,
          2.8898e-01,  8.6352e-02,  1.8373e-01,  2.8110e-01,  7.8478e-02,
          1.7585e-01,  2.7323e-01,  7.0604e-02,  1.6798e-01],
        [ 2.6535e-01,  6.2730e-02,  1.6010e-01,  2.5748e-01,  5.4856e-02,
          1.5223e-01,  2.4961e-01,  4.6982e-02,  1.4436e-01,  2.4173e-01,
          3.9108e-02,  1.3648e-01,  2.3386e-01,  3.1234e-02,  1.2861e-01,
          2.2598e-01,  2.3360e-02,  1.2073e-01,  2.1811e-01,  1.5486e-02,
          1.1286e-01,  2.1024e-01,  7.6115e-03,  1.0499e-01,  2.0236e-01,
         -2.6247e-04,  9.7113e-02,  1.9449e-01, -8.1365e-03,  8.9239e-02,
          1.8661e-01, -1.6010e-02,  8.1365e-02,  1.7874e-01, -2.3885e-02,
          7.3491e-02,  1.7087e-01, -3.1759e-02,  6.5617e-02,  1.6299e-01,
         -3.9633e-02,  5.7743e-02,  1.5512e-01, -4.7507e-02,  4.9869e-02,
          1.4724e-01, -5.5381e-02,  4.1995e-02,  1.3937e-01, -6.3255e-02,
          3.4121e-02,  1.3150e-01, -7.1129e-02,  2.6247e-02,  1.2362e-01,
         -7.9003e-02,  1.8373e-02,  1.1575e-01, -8.6877e-02,  1.0499e-02,
          1.0787e-01, -9.4751e-02,  2.6247e-03,  1.0000e-01]],
       requires_grad=True)
>>> self.time_faaaa.shape
torch.Size([2, 64])
>>> self.time_decay # omega
Parameter containing:
tensor([[-6.0000, -5.9794, -5.9547, -5.9283, -5.9007, -5.8721, -5.8428, -5.8127,
         -5.7821, -5.7510, -5.7195, -5.6875, -5.6551, -5.6223, -5.5892, -5.5558,
         -5.5221, -5.4881, -5.4539, -5.4194, -5.3846, -5.3496, -5.3144, -5.2790,
         -5.2433, -5.2075, -5.1715, -5.1353, -5.0989, -5.0623, -5.0256, -4.9887,
         -4.9517, -4.9145, -4.8771, -4.8396, -4.8020, -4.7642, -4.7263, -4.6882,
         -4.6500, -4.6117, -4.5733, -4.5347, -4.4960, -4.4572, -4.4183, -4.3793,
         -4.3402, -4.3009, -4.2616, -4.2221, -4.1825, -4.1429, -4.1031, -4.0633,
         -4.0233, -3.9833, -3.9431, -3.9029, -3.8625, -3.8221, -3.7816, -3.7410],
        [-3.7003, -3.6596, -3.6187, -3.5778, -3.5368, -3.4957, -3.4545, -3.4133,
         -3.3719, -3.3305, -3.2890, -3.2475, -3.2059, -3.1642, -3.1224, -3.0805,
         -3.0386, -2.9966, -2.9546, -2.9124, -2.8703, -2.8280, -2.7857, -2.7433,
         -2.7008, -2.6583, -2.6157, -2.5731, -2.5304, -2.4876, -2.4447, -2.4019,
         -2.3589, -2.3159, -2.2728, -2.2297, -2.1865, -2.1432, -2.0999, -2.0566,
         -2.0131, -1.9697, -1.9261, -1.8826, -1.8389, -1.7952, -1.7515, -1.7077,
         -1.6638, -1.6199, -1.5760, -1.5320, -1.4879, -1.4438, -1.3996, -1.3554,
         -1.3112, -1.2669, -1.2225, -1.1781, -1.1336, -1.0891, -1.0446, -1.0000]],
       requires_grad=True)
>>> self.time_decay.shape
torch.Size([2, 64])

WKV の計算は、ヘッド毎で行われます。まずは \bold{r}_t,\bold{k}_t,\bold{v}_t,\bold{g}_t をヘッド単位に分割します。j番目をヘッドj として、以下のように記述し、\text{wkv}_t^{(j)} を次のように計算します。

\bold{r}_t^{(j)}, \bold{k}_t^{(j)}, \bold{v}_t^{(j)}, \bold{g}_t^{(j)}, \bold{w}^{(j)}, \bold{u}^{(j)} \ \ \ \in \mathbb{R}^{D/h}
\text{wkv}_t^{(j)} = \text{diag}(\bold{u}^{(j)}) {\bold{k}_t^{(j)}}^T \bold{v}_t^{(j)} + \sum_{i=1}^{t-1} \text{diag}(\bold{w}^{(j)})^{t-1-i} {\bold{k}_i^{(j)}}^T \bold{v}_i^{(j)} \in \mathbb{R}^{(D/h) \times (D/h)}

{\bold{k}_i^{(j)}}^T \bold{v}_i^{(j)} は外積を計算しており、\mathbb{R}^{(D/h) \times (D/h)} の行列になります。そこに、対角行列の \text{diag}(\bold{u}^{(j)}) \in \mathbb{R}^{(D/h) \times (D/h)} (もしくは \text{diag}(\bold{w}^{(j)})^{t-1-i})をかけます。

\bold{w} の項は v4 同様に減衰しており、書き下してみると

\begin{align*} i=1, & \ \ \ \text{diag}(\bold{w}^{(j)})^{t-2} \\ i=2, & \ \ \ \text{diag}(\bold{w}^{(j)})^{t-3} \\ \vdots & \\ i=t-2, & \ \ \ \text{diag}(\bold{w}^{(j)})^{1} \\ i=t-1, & \ \ \ \text{diag}(\bold{w}^{(j)})^{0} = \text{diag}(\bold{1}) \\ \end{align*}

です。v4 時に疑問視していた \bold{w} の発散は無くなっています。

そして、最終的な出力は

\bold{o}_t = \text{concat}_j \left( \text{SiLU}( \bold{g}_t^{(j)} ) \odot \text{LayerNorm}(\bold{r}_t^{(j)} \sdot \text{wkv}_t^{(j)}) \right) \ \ \ \in \mathbb{R}^{D}

\bold{r}_t^{(j)} \sdot \text{wkv}_t^{(j)}\mathbb{R}^{1 \times (D/h)} \times \mathbb{R}^{(D/h) \times (D/h)} = \mathbb{R}^{1 \times (D/h)} の次元です。

Time Mixing ( RNN 形式 )

v4 でも states があって、RNN 形式で表記できたように、 v5 でも同様に記述できます。

\begin{align*} \text{wkv}_t^{(j)} &= \bold{s}_{t-1}^{(j)} + \text{diag}(\bold{u}^{(j)}) \sdot {\bold{k}_t^{(j)}}^T \bold{v}_t^{(j)} \\ \bold{s}_t^{(j)} &= \text{diag}(\bold{w}^{(j)}) \sdot \bold{s}_{t-1}^{(j)} + {\bold{k}_t^{(j)}}^T \bold{v}_t^{(j)} \\ \bold{s}_t^{(j)} & \in \mathbb{R}^{(D/h) \times (D/h)} \end{align*}

v5 から変わっていない箇所 ( v6 )

  • Channel Mixing
  • Time Mixing ( RNN 形式 )は \bold{w} 以外は変わっていない

v5 から変わった箇所 ( v6 )

  • lerp をやめて、ddlerp を使うようにした
  • \bold{w}入力依存 (つまり時間依存) \bold{w}_t に変更
  • \text{diag}(\bold{w})^{t-1-i} のような時間減衰をやめて、\bold{w}_t の掛け合わせによって減衰を表現

Time Mixing ( Token Shift )

https://github.com/BlinkDL/RWKV-LM/blob/b3b4d056d0907bb9e9619be3de077f051091287c/RWKV-v5/src/model.py#L325

Token Shift が進化して(ややこしくなって)います。今までは、\bold{x}_t\bold{x}_{t-1} の混合割合を示す \pmb{\mu} は、\bold{x}_t に依存しない学習パラメータでした。その混合割合を、入力依存にするよう改良されています。

それが、ddlerp ( data-dependent linear interpolation ) です。

\text{ddlerp}_{\square}(\bold{a},\bold{b},\bold{c}) = \bold{a} + (\bold{b} - \bold{a}) \odot \pmb{\mu}_{\square}'(\bold{c})

\pmb{\mu}\pmb{\mu}'(\bold{c}) のように別の変数に依存できる関数となりました。
そして \pmb{\mu}'(.)\text{lora}(.) という関数で定義されます。

\pmb{\mu}_{\square}'(\bold{x}) = \text{lora}_{\square}(\bold{x}) = \pmb{\lambda}_{\square} + \text{tanh}(\bold{x}\bold{A}_{\square})\bold{B}_{\square}

\pmb{\lambda}_{\square} \in \mathbb{R}^{d} はベクトル, \bold{A} \in \mathbb{R}^{d \times r(=32)}, \bold{B} \in \mathbb{R}^{r(=32) \times d} は行列です。\square には r ,k, v, g などが入ります。

つまり、入力 \bold{x}\bold{A} で低次元 r に落として、再度 \bold{B}d 次元まで戻します。

一般系として \text{ddlerp}_{\square}(\bold{a},\bold{b},\bold{c}) としましたが、実は \bold{c}(\bold{a}, \bold{b}, \pmb{\mu}_x) といった関数にです。そして実際には、以下が関数として定義されています。

\text{ddlerp}_{\square}(\bold{a},\bold{b},\pmb{\mu}_x) = \bold{a} + (\bold{b} - \bold{a}) \odot \text{lora}_{\square}(\bold{a} + (\bold{b} - \bold{a}) \odot \pmb{\mu}_x)

なんかちょっと頭がこんがらがりますね...。そして何故この数式になったのかも分かりません。

少しコードベースに整理します。

まず、従来の Time Shift のような計算をします。

https://github.com/BlinkDL/RWKV-LM/blob/b3b4d056d0907bb9e9619be3de077f051091287c/RWKV-v5/src/model.py#L384-L386

\begin{align*} \bold{x}' &= \bold{x}_{t-1} - \bold{x}_t \\ \bold{x}'' &= \bold{x}_t + (\bold{x}_{t-1} - \bold{x}_t) \odot \pmb{\mu}_x \\ \end{align*}

\bold{x}''\text{lora}(.) への入力です。次のコードは \text{ddlerp}_{\square}(\bold{x}_t, \bold{x}_{t-1}) を計算しています。

https://github.com/BlinkDL/RWKV-LM/blob/b3b4d056d0907bb9e9619be3de077f051091287c/RWKV-v5/src/model.py#L387-L395

\bold{x}_{\square}''' = \text{tanh}(\bold{x}''\bold{A}_{\square})\bold{B}_{\square} \\ \text{ddlerp}_{\square}(\bold{x}_t, \bold{x}_{t-1}) = \bold{x}_t + (\bold{x}_{t-1} - \bold{x}_t) \odot (\pmb{\lambda}_{\square} + \bold{x}_{\square}''')

\bold{A}_{\square}, \bold{B}_{\square}r ,k, v, g, w 分用意されているため、予め \bold{A}_{\square} \in \mathbb{R}^{5 \times d \times r(=32)} のように定義しておきます。そして一度に行列計算した後に、mw, mk, mv, mr, mg = xxx.unbind(dim=0) で分割しています。

\begin{align*} \bold{r}_t' &= \text{ddlerp}_r (\bold{x}_t, \bold{x}_{t-1}) \ \ \ \in \mathbb{R}^{d} \\ \bold{k}_t' &= \text{ddlerp}_k (\bold{x}_t, \bold{x}_{t-1}) \ \ \ \in \mathbb{R}^{d} \\ \bold{v}_t' &= \text{ddlerp}_v (\bold{x}_t, \bold{x}_{t-1}) \ \ \ \in \mathbb{R}^{d} \\ \bold{g}_t' &= \text{ddlerp}_g (\bold{x}_t, \bold{x}_{t-1}) \ \ \ \in \mathbb{R}^{d} \\ \bold{w}_t' &= \text{ddlerp}_w (\bold{x}_t, \bold{x}_{t-1}) \ \ \ \in \mathbb{R}^{d} \\ \end{align*}

次は nn.Linear層(と SiLU)です。

https://github.com/BlinkDL/RWKV-LM/blob/b3b4d056d0907bb9e9619be3de077f051091287c/RWKV-v5/src/model.py#L397-L400

\begin{align*} \bold{r}_t &= \bold{r}_t' \sdot \bold{W}_r \ \ \ \in \mathbb{R}^{D} \\ \bold{k}_t &= \bold{k}_t' \sdot \bold{W}_k \ \ \ \in \mathbb{R}^{D} \\ \bold{v}_t &= \bold{v}_t' \sdot \bold{W}_v \ \ \ \in \mathbb{R}^{D} \\ \bold{g}_t &= \text{SiLU}( \bold{r}_t' \sdot \bold{W}_g ) \ \ \ \in \mathbb{R}^{D} \\ \end{align*}

そして w のファクターは、さらに lora(.) に入力します。

https://github.com/BlinkDL/RWKV-LM/blob/b3b4d056d0907bb9e9619be3de077f051091287c/RWKV-v5/src/model.py#L402-L403

\begin{align*} \bold{d}_t &= \pmb{\lambda}_d + \text{tanh}( \bold{w}_t' \bold{A}_d) \bold{B}_d \in \mathbb{R}^{D} \\ \bold{w}_t &= \exp( - \exp (\bold{d}_t) ) \ \ \ \in \mathbb{R}^{D} \\ \end{align*}

この lora だけ次元が少し違っていて \pmb{\lambda}_d \in \mathbb{R}^{D}, \bold{A}_d \in \mathbb{R}^{d \times r'(=64)}, \bold{B}_d \in \mathbb{R}^{r' \times D} となっています。そして同様に 0 < \bold{w}_t < 1 となっています。

Time Mixing ( WKV term )

さて、上の章で以下の変数がそろいました。しれっと書いていますが、すでに以下では \bold{w} は入力依存 \bold{w}_t' になっています。

https://github.com/BlinkDL/RWKV-LM/blob/b3b4d056d0907bb9e9619be3de077f051091287c/RWKV-v5/src/model.py#L405

\text{wkv}_t^{(j)} = \text{diag}(\bold{u}^{(j)}) {\bold{k}_t^{(j)}}^T \bold{v}_t^{(j)} + \sum_{i=1}^{t-1} \text{diag} \left( \bigodot_{k=i+1}^{t-1}\bold{w}_k^{(j)} \right) {\bold{k}_i^{(j)}}^T \bold{v}_i^{(j)} \in \mathbb{R}^{(D/h) \times (D/h)}

さて、時間減衰の箇所が変わっています。書き下してみると

\begin{align*} i=1, & \ \ \ \text{diag}(\bold{w}_{2}^{(j)} \odot \bold{w}_{3}^{(j)} ... \odot \bold{w}_{t-1}^{(j)}) \\ i=2, & \ \ \ \text{diag}(\bold{w}_{3}^{(j)} \odot \bold{w}_{4}^{(j)} ... \odot \bold{w}_{t-1}^{(j)}) \\ \vdots & \\ i=t-2, & \ \ \ \text{diag}(\bold{w}_{t-1}^{(j)}) \\ i=t-1, & \ \ \ \text{diag}(\bold{1}) \\ \end{align*}

そして最終的なアウトプットは以下です。(どうやら、ヘッドの concat を先に行ってそうな感じです)

https://github.com/BlinkDL/RWKV-LM/blob/b3b4d056d0907bb9e9619be3de077f051091287c/RWKV-v5/src/model.py#L407-L409

\bold{o}_t = \bold{g}_t \odot \text{LayerNorm} \left( \text{concat}_j ( \bold{r}_t^{(j)} \sdot \text{wkv}_t^{(j)} ) \right) \ \ \ \in \mathbb{R}^{D}

Time Mixing ( RNN 形式 )

v5 では \bold{w} \rightarrow \bold{w}_t の箇所だけ変更になります。

\begin{align*} \text{wkv}_t^{(j)} &= \bold{s}_{t-1}^{(j)} + \text{diag}(\bold{u}^{(j)}) \sdot {\bold{k}_t^{(j)}}^T \bold{v}_t^{(j)} \\ \bold{s}_t^{(j)} &= \text{diag}(\bold{w}_t^{(j)}) \sdot \bold{s}_{t-1}^{(j)} + {\bold{k}_t^{(j)}}^T \bold{v}_t^{(j)} \\ \bold{s}_t^{(j)} & \in \mathbb{R}^{(D/h) \times (D/h)} \end{align*}

結果と所感

なんかちょっと面白いですね。Finch (v6) は Eagle (v5) の上位互換なので、性能が良いことが分かります。そして、Multilingual では他のモデルを圧倒しているのが興味深いです。

lora あたりの記述は何であの数式を使っているのか、不思議です。v6 への進化として、入力依存になっているあたりは、S4D -> Mamba の流れに似ていますね。内部状態wkv を行列で表現しているのも、内部状態の拡張として、Mambaとの似た流れを感じます。

Discussion