JAXでがんばる 速く計算したい
前回の記事ではjax.custom_vjp
を使って自動微分を使わずに逆伝搬を実装しました。
今回は、OpenAIが開発したTritonとそのJAXラッパーであるPallasの解説をします。
一見繋がりが見えない気がしますが, 次回の記事でPallasを使って逆伝搬を実装したいと思います。
Tritonとは?
Tritonは、OpenAIが開発した高性能な機械学習向けのプログラミングライブラリです。機械学習のトレーニングや推論のためのカスタムGPUカーネルを簡単に作成できるように設計されています。Tritonを使用すると、既存のディープラーニングフレームワーク(例えば、PyTorchやTensorFlow)のボトルネックとなる部分を効率的に最適化できます。
Tritonの主要な特徴
a. 高パフォーマンス
Tritonは、カスタムGPUカーネルを自動的に最適化し、高パフォーマンスを実現します。これにより、従来のフレームワークでボトルネックとなっていた部分を高速化できます。
使いやすさ
TritonはPythonで書かれており、使いやすさに重点を置いています。開発者は複雑なCUDAコードを書く必要がなく、簡潔で直感的なコードでGPUカーネルを記述できます。
柔軟性
Tritonは、さまざまな種類の機械学習タスクに対応できる柔軟な設計を持っています。カスタムカーネルの作成や既存のカーネルの最適化が容易に行えます。
Tritonの技術的な優位性
自動チューニング
Tritonは、カーネルの実行パラメータ(例えば、スレッド数やメモリレイアウト)を自動的にチューニングします。これにより、開発者が最適なパフォーマンスを得るための試行錯誤を大幅に削減します。
import triton
import triton.language as tl
@triton.jit
def matmul_kernel(a_ptr, b_ptr, c_ptr, M, N, K):
m = tl.program_id(0)
n = tl.program_id(1)
a = tl.load(a_ptr + m * K + tl.arange(0, K))
b = tl.load(b_ptr + n * K + tl.arange(0, K))
c = tl.dot(a, b)
tl.store(c_ptr + m * N + n, c)
def matmul(a, b):
M, K = a.shape
K, N = b.shape
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
grid = (M, N)
matmul_kernel[grid](a, b, c, M, N, K)
return c
この例では、行列積のカーネルを定義し、Tritonの自動チューニング機能を活用して高パフォーマンスを実現しています。
メモリ最適化
Tritonは、メモリの使用効率を最大化するための最適化を行います。これにより、GPUのメモリ帯域幅を最大限に活用し、計算速度を向上させます。
@triton.jit
def vector_add_kernel(x_ptr, y_ptr, output_ptr, size):
idx = tl.program_id(0)
x = tl.load(x_ptr + idx)
y = tl.load(y_ptr + idx)
tl.store(output_ptr + idx, x + y)
def vector_add(x, y):
size = x.size(0)
output = torch.empty_like(x)
grid = (size,)
vector_add_kernel[grid](x, y, output, size)
return output
この例では、ベクトル加算のカーネルを定義し、効率的なメモリアクセスを行うことで高速な計算を実現しています。
簡潔なAPI
Tritonは、直感的で簡潔なAPIを提供し、開発者が迅速にカーネルを開発できるようにしています。複雑なCUDAの知識がなくても、高性能なGPU計算を実現できます。
@triton.jit
def relu_kernel(input_ptr, output_ptr, size):
idx = tl.program_id(0)
x = tl.load(input_ptr + idx)
tl.store(output_ptr + idx, tl.max(0, x))
def relu(x):
size = x.size(0)
output = torch.empty_like(x)
grid = (size,)
relu_kernel[grid](x, output, size)
return output
この例では、ReLU関数を実装し、シンプルなコードで高効率なGPUカーネルを作成しています。
Pallas
PallasはTritonのJAXラッパーで, tritonよりシンプルに実装できるように工夫されています。
基本的な考え方として, 行列をblockに切り分けて, GPUのshared memoryへとload/storeを抑えつつ効率的に計算を行います。
以下の図は公式の解説ページからです、
このようにそれぞれの
それを異なるGPUのSMで行うことですべての成分を計算します.
以下では, 実際に前回の記事でも扱ったlinear
関数をPallasで実装したいと思います。
まずは
まずは実装を先に示します。
from jax import numpy as jnp
from jax import random, ShapeDtypeStruct, jit
from jax.experimental import pallas as pl
def linear(x, w, b):
return jnp.dot(x, w) + b
B, D1, D2 = 32, 32, 64
rng = random.split(random.key(1234), 3) # rngキーを作り3つに分割
x = random.uniform(rng[0], (B, D1))
w = random.uniform(rng[1], (D1, D2))
b = random.uniform(rng[2], (D2, ))
pred = linear(x, w, b)
def pallas_linear_kernel(x_ref, w_ref, b_ref, out_ref):
out_ref[...] = pl.dot(x_ref[...], w_ref[...]) + b_ref[...]
@jit
def pallas_linear(x, w, b):
block_size = 16
B, D1 = x.shape
D2 = w.shape[-1]
in_specs = [pl.BlockSpec(lambda i, j: (i, 0), (block_size, D1)), # x
pl.BlockSpec(lambda i, j: (0, j), (D1, block_size)), # w
pl.BlockSpec(lambda i, j: (j, ), (block_size, )), # b
]
out_spec = pl.BlockSpec(lambda i, j: (i, j), (block_size, block_size))
grid_size = (pl.cdiv(B, block_size), pl.cdiv(D2, block_size))
out_shape = ShapeDtypeStruct(shape=(B, D2), dtype=x.dtype)
out = pl.pallas_call(
kernel=pallas_linear_kernel,
compiler_params=dict(triton=dict(num_warps=8, num_stages=2)),
grid=grid_size,
in_specs=in_specs,
out_specs=out_spec,
out_shape=out_shape,
interpret=False,
name='pallas_linear',
)(x, w, b)
return out
pred_pallas = pallas_linear(x, w, b)
print(pred[:4, :4])
print(pred_pallas[:4, :4])
Pallasでは、pl.pallas_call
methodで実際にSMで行う計算を呼び出します(kernel).
-
pallas_call
の引数は, -
kernel
: 計算内容 -
grid
: 全体の切り分け方 -
in_specs/out_specs
: 入出力をgrid
の添字からどのように分割するか -
out_shape
: 出力のshape全体
となっています.
上記のpallas_linear
では, 添字(i, j)
をそれぞれバッチ軸(B
)と出力軸(D2
)で分割するようにしています。
pallas_linear_kernel
では, 入出力arrayのreferenceが与えられます.
この関数では, [...]
を使うことでポインタ全体の添字を取得しています.
次回の例では引き続きPallasでの実装例を、少し複雑な例を使って実装したいと思います.
Discussion