🐇

JAXでがんばる 微分したい

2024/06/13に公開

前回の記事では、なんとなくJAXの概要が伝えられたかなと思います
今回の記事では、JAXの優れた機能である自動微分をあえて使わずに、自分で逆伝搬を行います。

まずは

そうはいってもまずはJAXの自動微分を使ってみましょう。
例として、入力xを重みwとbiasbを使って線形変換することを考えます。

from jax import numpy as jnp
from jax import random, grad, jit

def linear(x, w, b):
    return jnp.dot(x, w) + b

B, D1, D2 = 16, 32, 64
rng = random.split(random.key(1234), 4)  # rngキーを作り3つに分割
x = random.uniform(rng[0], (B, D1))
w = random.uniform(rng[1], (D1, D2))
b = random.uniform(rng[2], (D2, ))
y = random.uniform(rng[3], (B, D2))

def fn(params, x, y):
    w, b = params
    pred = linear(x, w, b)
    loss = jnp.mean((pred - y)**2)
    return loss

grad_fn = jit(grad(fn))
grads = grad_fn((w, b), x, y)

print(grads[0].shape)
print(grads[1].shape)

なんの変哲もない実装ですが、実際にjax.gradを使って自動微分を行うとこのようになります。
一つだけ注意することがあって, jax.gradは第一引数に関する微分を行います。

jax.vjp

jax.vjpはVector Jacobian Productsの略で、関数fを評価するPython関数が与えられたとき、VJP

(x, v) \to (f(x), v.T \partial f(x))

を評価するPython関数を返します。
要するに逆伝搬に必要なforwardとbackwardをjax.vjpを通して得られます。

jax.custom_vjp

つぎに、この記事のトピックである逆伝搬を実装するを行います。
jax.custom_vjpを使って先程のlinear関数を実装したいと思います。

@custom_vjp
@jit
def custom_linear(x, w, b):
    out = jnp.dot(x, w) + b
    return out

def custom_linear_fwd(x, w, b):
    out = jnp.dot(x, w) + b
    return out, (x, w)

def custom_linear_bwd(res, g):
    x, w = res
    dx = jnp.dot(g, w.T)
    dw = jnp.dot(x.T, g)
    db = jnp.sum(g, axis=0)
    return dx, dw, db

custom_linear.defvjp(custom_linear_fwd, custom_linear_bwd)

ここでは新しくcustom_linearを定義し, forwardとbackwardのそれぞれを実装しています。
その後defvjp関数でそれぞれ指定してあげます。
それでは、custom_linearを使って同じ逆伝搬を計算してみましょう。


def fn_custom(params, x, y):
    w, b = params
    pred = custom_linear(x, w, b)
    loss = jnp.mean((pred - y)**2)
    return loss

grad_fn_custom = jit(grad(fn_custom))
grads_custom = grad_fn_custom((w, b), x, y)

print(grads[0][:4, :4])
print(grads_custom[0][:4, :4])

print(grads[1][:4])
print(grads_custom[1][:4])

出力は

[[0.11254271 0.12456626 0.13236916 0.12915088]
 [0.09029176 0.1007601  0.10745039 0.10439933]
 [0.11102881 0.12185482 0.13367452 0.12996063]
 [0.14152832 0.15898883 0.16999632 0.16602176]]
[[0.11254271 0.12456626 0.13236916 0.12915088]
 [0.09029176 0.1007601  0.10745039 0.10439933]
 [0.11102881 0.12185482 0.13367452 0.12996063]
 [0.14152832 0.15898883 0.16999632 0.16602176]]
[0.23434031 0.2604819  0.27758297 0.27279755]
[0.23434031 0.2604819  0.27758297 0.27279755]

となり一致することが確認できました。
実装としてはそこまで難しくないのですが、厄介なのはより複雑な計算を手で微分しVJPを計算することにあります。

次回は、FlashAttentionでも利用されているOpenAIが開発したTritonとそのJAXラッパーであるPallasの紹介をしたいと思います。

Discussion