🐇
JAXでがんばる 微分したい
前回の記事では、なんとなくJAXの概要が伝えられたかなと思います
今回の記事では、JAXの優れた機能である自動微分をあえて使わずに、自分で逆伝搬を行います。
まずは
そうはいってもまずはJAXの自動微分を使ってみましょう。
例として、入力x
を重みw
とbiasb
を使って線形変換することを考えます。
from jax import numpy as jnp
from jax import random, grad, jit
def linear(x, w, b):
return jnp.dot(x, w) + b
B, D1, D2 = 16, 32, 64
rng = random.split(random.key(1234), 4) # rngキーを作り3つに分割
x = random.uniform(rng[0], (B, D1))
w = random.uniform(rng[1], (D1, D2))
b = random.uniform(rng[2], (D2, ))
y = random.uniform(rng[3], (B, D2))
def fn(params, x, y):
w, b = params
pred = linear(x, w, b)
loss = jnp.mean((pred - y)**2)
return loss
grad_fn = jit(grad(fn))
grads = grad_fn((w, b), x, y)
print(grads[0].shape)
print(grads[1].shape)
なんの変哲もない実装ですが、実際にjax.grad
を使って自動微分を行うとこのようになります。
一つだけ注意することがあって, jax.grad
は第一引数に関する微分を行います。
jax.vjp
jax.vjp
はVector Jacobian Productsの略で、関数fを評価するPython関数が与えられたとき、VJP
を評価するPython関数を返します。
要するに逆伝搬に必要なforwardとbackwardを
jax.vjp
を通して得られます。
jax.custom_vjp
つぎに、この記事のトピックである逆伝搬を実装するを行います。
jax.custom_vjp
を使って先程のlinear
関数を実装したいと思います。
@custom_vjp
@jit
def custom_linear(x, w, b):
out = jnp.dot(x, w) + b
return out
def custom_linear_fwd(x, w, b):
out = jnp.dot(x, w) + b
return out, (x, w)
def custom_linear_bwd(res, g):
x, w = res
dx = jnp.dot(g, w.T)
dw = jnp.dot(x.T, g)
db = jnp.sum(g, axis=0)
return dx, dw, db
custom_linear.defvjp(custom_linear_fwd, custom_linear_bwd)
ここでは新しくcustom_linear
を定義し, forwardとbackwardのそれぞれを実装しています。
その後defvjp
関数でそれぞれ指定してあげます。
それでは、custom_linear
を使って同じ逆伝搬を計算してみましょう。
def fn_custom(params, x, y):
w, b = params
pred = custom_linear(x, w, b)
loss = jnp.mean((pred - y)**2)
return loss
grad_fn_custom = jit(grad(fn_custom))
grads_custom = grad_fn_custom((w, b), x, y)
print(grads[0][:4, :4])
print(grads_custom[0][:4, :4])
print(grads[1][:4])
print(grads_custom[1][:4])
出力は
[[0.11254271 0.12456626 0.13236916 0.12915088]
[0.09029176 0.1007601 0.10745039 0.10439933]
[0.11102881 0.12185482 0.13367452 0.12996063]
[0.14152832 0.15898883 0.16999632 0.16602176]]
[[0.11254271 0.12456626 0.13236916 0.12915088]
[0.09029176 0.1007601 0.10745039 0.10439933]
[0.11102881 0.12185482 0.13367452 0.12996063]
[0.14152832 0.15898883 0.16999632 0.16602176]]
[0.23434031 0.2604819 0.27758297 0.27279755]
[0.23434031 0.2604819 0.27758297 0.27279755]
となり一致することが確認できました。
実装としてはそこまで難しくないのですが、厄介なのはより複雑な計算を手で微分しVJPを計算することにあります。
次回は、FlashAttentionでも利用されているOpenAIが開発したTritonとそのJAXラッパーであるPallasの紹介をしたいと思います。
Discussion