🐿️

JAXでがんばる 速く計算したい

2024/06/14に公開

前回の記事ではjax.custom_vjpを使って自動微分を使わずに逆伝搬を実装しました。
今回は、OpenAIが開発したTritonとそのJAXラッパーであるPallasの解説をします。
一見繋がりが見えない気がしますが, 次回の記事でPallasを使って逆伝搬を実装したいと思います。

Tritonとは?

Tritonは、OpenAIが開発した高性能な機械学習向けのプログラミングライブラリです。機械学習のトレーニングや推論のためのカスタムGPUカーネルを簡単に作成できるように設計されています。Tritonを使用すると、既存のディープラーニングフレームワーク(例えば、PyTorchやTensorFlow)のボトルネックとなる部分を効率的に最適化できます。

Tritonの主要な特徴

a. 高パフォーマンス
Tritonは、カスタムGPUカーネルを自動的に最適化し、高パフォーマンスを実現します。これにより、従来のフレームワークでボトルネックとなっていた部分を高速化できます。

使いやすさ

TritonはPythonで書かれており、使いやすさに重点を置いています。開発者は複雑なCUDAコードを書く必要がなく、簡潔で直感的なコードでGPUカーネルを記述できます。

柔軟性

Tritonは、さまざまな種類の機械学習タスクに対応できる柔軟な設計を持っています。カスタムカーネルの作成や既存のカーネルの最適化が容易に行えます。

Tritonの技術的な優位性

自動チューニング

Tritonは、カーネルの実行パラメータ(例えば、スレッド数やメモリレイアウト)を自動的にチューニングします。これにより、開発者が最適なパフォーマンスを得るための試行錯誤を大幅に削減します。

import triton
import triton.language as tl

@triton.jit
def matmul_kernel(a_ptr, b_ptr, c_ptr, M, N, K):
    m = tl.program_id(0)
    n = tl.program_id(1)
    a = tl.load(a_ptr + m * K + tl.arange(0, K))
    b = tl.load(b_ptr + n * K + tl.arange(0, K))
    c = tl.dot(a, b)
    tl.store(c_ptr + m * N + n, c)

def matmul(a, b):
    M, K = a.shape
    K, N = b.shape
    c = torch.empty((M, N), device=a.device, dtype=a.dtype)
    grid = (M, N)
    matmul_kernel[grid](a, b, c, M, N, K)
    return c

この例では、行列積のカーネルを定義し、Tritonの自動チューニング機能を活用して高パフォーマンスを実現しています。

メモリ最適化

Tritonは、メモリの使用効率を最大化するための最適化を行います。これにより、GPUのメモリ帯域幅を最大限に活用し、計算速度を向上させます。

@triton.jit
def vector_add_kernel(x_ptr, y_ptr, output_ptr, size):
    idx = tl.program_id(0)
    x = tl.load(x_ptr + idx)
    y = tl.load(y_ptr + idx)
    tl.store(output_ptr + idx, x + y)

def vector_add(x, y):
    size = x.size(0)
    output = torch.empty_like(x)
    grid = (size,)
    vector_add_kernel[grid](x, y, output, size)
    return output

この例では、ベクトル加算のカーネルを定義し、効率的なメモリアクセスを行うことで高速な計算を実現しています。

簡潔なAPI

Tritonは、直感的で簡潔なAPIを提供し、開発者が迅速にカーネルを開発できるようにしています。複雑なCUDAの知識がなくても、高性能なGPU計算を実現できます。

@triton.jit
def relu_kernel(input_ptr, output_ptr, size):
    idx = tl.program_id(0)
    x = tl.load(input_ptr + idx)
    tl.store(output_ptr + idx, tl.max(0, x))

def relu(x):
    size = x.size(0)
    output = torch.empty_like(x)
    grid = (size,)
    relu_kernel[grid](x, output, size)
    return output

この例では、ReLU関数を実装し、シンプルなコードで高効率なGPUカーネルを作成しています。

Pallas

PallasはTritonのJAXラッパーで, tritonよりシンプルに実装できるように工夫されています。
基本的な考え方として, 行列をblockに切り分けて, GPUのshared memoryへとload/storeを抑えつつ効率的に計算を行います。
以下の図は公式の解説ページからです、

X = \begin{bmatrix} X_0 \\ X_1 \end{bmatrix}
Y = \begin{bmatrix} Y_0 & Y_1 \end{bmatrix}
Z = \begin{bmatrix} X_0 \\ X_1 \end{bmatrix} \begin{bmatrix} Y_0 & Y_1 \end{bmatrix} = \begin{bmatrix} X_0 Y_0 & X_0 Y_1 \\ X_1 Y_0 & X_1 Y_1 \end{bmatrix}

このようにそれぞれのi行とj列を切り出していき, (i, j) 成分を計算します。
それを異なるGPUのSMで行うことですべての成分を計算します.

以下では, 実際に前回の記事でも扱ったlinear関数をPallasで実装したいと思います。

まずは

まずは実装を先に示します。

from jax import numpy as jnp
from jax import random, ShapeDtypeStruct, jit
from jax.experimental import pallas as pl

def linear(x, w, b):
    return jnp.dot(x, w) + b

B, D1, D2 = 32, 32, 64
rng = random.split(random.key(1234), 3)  # rngキーを作り3つに分割
x = random.uniform(rng[0], (B, D1))
w = random.uniform(rng[1], (D1, D2))
b = random.uniform(rng[2], (D2, ))

pred = linear(x, w, b)

def pallas_linear_kernel(x_ref, w_ref, b_ref, out_ref):
    out_ref[...] = pl.dot(x_ref[...], w_ref[...]) + b_ref[...]

@jit
def pallas_linear(x, w, b):
    block_size = 16
    B, D1 = x.shape
    D2 = w.shape[-1]
    
    in_specs = [pl.BlockSpec(lambda i, j: (i, 0), (block_size, D1)),  # x
                pl.BlockSpec(lambda i, j: (0, j), (D1, block_size)),  # w
                pl.BlockSpec(lambda i, j: (j, ), (block_size, )),  # b
                ]
    out_spec = pl.BlockSpec(lambda i, j: (i, j), (block_size, block_size))
    grid_size = (pl.cdiv(B, block_size), pl.cdiv(D2, block_size))
    out_shape = ShapeDtypeStruct(shape=(B, D2), dtype=x.dtype)
    
    out = pl.pallas_call(
        kernel=pallas_linear_kernel,
        compiler_params=dict(triton=dict(num_warps=8, num_stages=2)),
        grid=grid_size,
        in_specs=in_specs,
        out_specs=out_spec,
        out_shape=out_shape,
        interpret=False,
        name='pallas_linear',
    )(x, w, b)
    return out

pred_pallas = pallas_linear(x, w, b)

print(pred[:4, :4])
print(pred_pallas[:4, :4])

Pallasでは、pl.pallas_call methodで実際にSMで行う計算を呼び出します(kernel).

  • pallas_callの引数は,
  • kernel : 計算内容
  • grid : 全体の切り分け方
  • in_specs/out_specs : 入出力をgridの添字からどのように分割するか
  • out_shape : 出力のshape全体

となっています.
上記のpallas_linearでは, 添字(i, j)をそれぞれバッチ軸(B)と出力軸(D2)で分割するようにしています。

pallas_linear_kernelでは, 入出力arrayのreferenceが与えられます.
この関数では, [...]を使うことでポインタ全体の添字を取得しています.

次回の例では引き続きPallasでの実装例を、少し複雑な例を使って実装したいと思います.

Discussion