🗂

LLM(GPT)の自前 training 実装のための Backward のメモ

2024/03/27に公開

背景

富岳で C++ LLM pretrain したい...
DePIN で C++ LLM pretrain したい...

pytorch とかめんどすぎ(特にちょっと特殊なことをしたいときとか)
自前実装でコンパクトにしつつ, 大規模トレーニング対応のための hack を入れたりしたい.

まずは backward 計算です.

追記: 俺たちの karpathy 先生が llm.c でやってくれました! llm.c 見ればよいでしょう. https://github.com/karpathy/llm.c

Backward 計算

Embedding

ぺろっと解説してくれているのはあまりない...?

Back propagation in an embedding layer
https://medium.com/@ilyarudyak/back-propagation-in-an-embedding-layer-30382fa7f023

https://info.drobe.co.jp/blog/engineering/pytorch-embedding

forward は実質 look up テーブルである.

weight = embedding[index]

backward 計算は, index を one-hot vector として扱うことで微分計算可能にする.

CPU 実装

https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/Embedding.cpp

および llm.c 参照.

Linear, Matmul

これは機械学習の教科書でよく解説されているであろう.

https://web.eecs.umich.edu/~justincj/teaching/eecs442/notes/linear-backprop.html

https://www.qoosky.io/techs/fc5cfc47cc

ReLU

https://www.anarchive-beta.com/entry/2020/07/31/180000

GELU(Gaussian Error Linear Units)

https://paperswithcode.com/method/gelu

https://alaaalatif.github.io/2019-04-11-gelu/

Softmax

https://qiita.com/okayu303/items/09efd10161d0a764833d

Dropout

mask(乱数)は変数の依存関係がないため, x' * mask.
forward で求まった mask を使う.

dropout:
  forward(x, ratio = 0.5):
    mask = random < ratio
    return x * mask

  backward(x'):
    return x' * mask

LayerNorm

https://liorsinai.github.io/mathematics/2022/05/18/layernorm.html

https://gist.github.com/domarps/8e390411940a6c3b712cdaf95f009040

Flash attention(Fused attention)

実際には Attention の各要素で backward 計算すると無駄が多いので, kernel をまとめて一つにする.

jax, triton あたりの実装が参考になるか

https://github.com/google/jax/blob/main/jax/experimental/pallas/ops/attention.py

https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html

あとは llm.c に入りそうなこちら

https://github.com/karpathy/llm.c/pull/60

その他実装

llm.c llama.cpp が参考になるでしょう.

Discussion