🗂
LLM(GPT)の自前 training 実装のための Backward のメモ
背景
富岳で C++ LLM pretrain したい...
DePIN で C++ LLM pretrain したい...
pytorch とかめんどすぎ(特にちょっと特殊なことをしたいときとか)
自前実装でコンパクトにしつつ, 大規模トレーニング対応のための hack を入れたりしたい.
まずは backward 計算です.
追記: 俺たちの karpathy 先生が llm.c でやってくれました! llm.c 見ればよいでしょう. https://github.com/karpathy/llm.c
Backward 計算
Embedding
ぺろっと解説してくれているのはあまりない...?
Back propagation in an embedding layer
forward は実質 look up テーブルである.
weight = embedding[index]
backward 計算は, index を one-hot vector として扱うことで微分計算可能にする.
CPU 実装
および llm.c 参照.
Linear, Matmul
これは機械学習の教科書でよく解説されているであろう.
ReLU
GELU(Gaussian Error Linear Units)
Softmax
Dropout
mask(乱数)は変数の依存関係がないため, x' * mask
.
forward で求まった mask を使う.
dropout:
forward(x, ratio = 0.5):
mask = random < ratio
return x * mask
backward(x'):
return x' * mask
LayerNorm
Flash attention(Fused attention)
実際には Attention の各要素で backward 計算すると無駄が多いので, kernel をまとめて一つにする.
jax, triton あたりの実装が参考になるか
あとは llm.c に入りそうなこちら
その他実装
llm.c llama.cpp が参考になるでしょう.
Discussion