深層学習フレームワークの設計を考えてるメモ書き
リポジトリ
(cl-waffe2のドキュメント書こうと思ってたけど、設計の説明が難しかったので一旦日本語で書き下ろしてる)
・殴り書きですみません あんまり推敲してないです
・以下の文章は深層学習フレームワークでちょこっと何か作ったことある人をターゲットにしています。
cl-waffe2, programmable deep-learning framework.
⚠️ cl-waffe2 はまだコンセプトの段階です。 API は変更される可能性があります。
cl-waffe2 は深層学習モデルを構築するための微分可能な行列演算を提供する Common Lisp 製フレームワークです。
プロジェクトのゴールとコンセプトは以下の通りです
- 複数の小さなバックエンドから構成されるバックエンドから構成されるcl-waffeのAPI
- 全ての演算は遅延評価され、JITコンパイルされる
- defined-by-runに限りなく近いdefined-and-runスタイル
コンセプト
複数のバックエンドから構成される演算
cl-waffe2の全ての演算は、以下のような構造を取ります。
[抽象的な定義]
|
|-----------|-----------|
[CPUでの実装1] [CPUでの実装2] [CUDAでの実装1] ...
抽象的な定義とはdefnode
マクロによって宣言され、各実装はdefine-impl
マクロを用いて宣言されます。 一つのデバイスに対して複数の実装があっても構いません。(例えば:exp 関数に対して通常の実装と近似版の実装を用意することが可能です)
例として、加算演算!add
を実装することを考えてみましょう。
加算演算 AddNode
とは、与えられた二つの行列 A と B の和を求め、その結果を A に格納する演算のことを言います。
(defnode (AddNode (myself)
:where (A[~] B[~] -> A[~])
:documentation "A <- A + B"
:backward ((self dout dx dy)
(declare (ignore dx dy))
(values dout dout))
;; コンストラクタの処理はここに書かれます。
;; 初期化されたAddNodeクラスはmyselfとして渡されます。
)
ここで、
:where
行列の形状を宣言します。->の前が引数の行列 ->の後が演算後の行列の形状を指します。
~は全ての引数の行列の形状が同じであることを意味します。
:backward
は逆伝播の演算を定義します。この宣言はdefnode
内もしくはdefine-impl
内のどちらか片方で宣言すればOKです。
宣言したノードは(AddNode)
コンストラクタを用いて初期化することができますが、まだAddNode
に対しての実装が一つもないので、エラーが返ってきます。
(AddNode)
;; -> Couldn't find any implementation of AddNode for (CPUTENSOR LISPTENSOR).
そのため、define-impl
マクロを用いて演算の実態を定義します。
演算は、cl-waffe2/vm.generic-tensor:AbstractTensor
クラスを継承して宣言できるバックエンドに対して一つ定義できます。例えばcl-waffe2は標準で(2023/06/18日現在)
- LispTensor (ANSI Common Lisp 環境のみで動作する Portable な実装を提供)
- CPUTensor (SBCL依存だが OpenBLAS を用いて高速に動作する)
の二つのバックエンドを提供しています。新しくバックエンドを作りたい場合は
(defclass MyTensor (AbstractTensor) nil)
のように宣言してください。(参考に:https://github.com/hikettei/cl-waffe2/blob/master/source/backends/lisp/tensor.lisp)
例えば、cl-waffe2でLispTensorに対する実装は次のように定義されています。
(define-impl (AddNode :device LispTensor)
:forward ((self x y)
(let ((adder (matrix-add (dtype x))))
`(,@(call-with-view
#'(lambda (x-view
y-view)
`(funcall ,adder
(tensor-vec ,x)
(tensor-vec ,y)
,(offset-of x-view 0)
,(offset-of y-view 0)
,(size-of x-view 0)
,(stride-of x-view 0)
,(stride-of y-view 0)))
`(,x ,y))
,x))))
:forward
にはdefmacro
でマクロを定義するときと同じ要領で演算の展開式を書きます。(詳しくは後述)
なぜこのような周りくどい方法にするのかというと:
- 小さい行列/スカラー値に対して高速に動作させるため
- AddNodeであれば演算に必要な最小次元数(この場合要素ごとなので1)を定義しておいて、最適化して呼び出したい。
- 関数内部で実際に計算を行わなくても、例えば将来的にcl-waffe2からCのコードを生成させるみたいな振る舞いをさせたい
という目的があります。
(define-implの書き方はあまり綺麗じゃないので、もう少しなんとかしたいと考えている・・・)
これでAddNode
の宣言と実装ができたので、あとはこのようにして
(forward (AddNode) (randn `(10 10)) (randn `(10 10)))
{CPUTENSOR[float] :shape (10 10) :named ChainTMP9412
:vec-state [maybe-not-computed]
((-0.33475596 1.0127474 -0.060175765 ~ 1.4573603 -0.987001 -1.0165008)
(-0.045512 -0.17995936 0.23593931 ~ 0.8409552 2.6434622 -0.5789532)
...
(0.13282542 1.9386152 0.16213055 ~ 0.4363958 0.8294802 -0.1558509)
(1.1732875 -1.5769591 -1.2152125 ~ -0.2833903 -0.81108683 0.9846606))
:facet :input
:requires-grad NIL
:backward <Node: ADDNODE-CPUTENSOR (A[~] B[~] -> A[~])>}
(:vec-stateに注目, この時点ではまだ演算は実行されていないので注意。表示された行列は最初の引数Aである。)
ノードを構築していく、ある程度ノードが出来上がったらbuild
かproceed
関数でコンパイル/実行できる。
(proceed (AddNode) (randn `(10 10)) (randn `(10 10)))
;; proceed-time 関数は、コンパイル時間を除いた実行時間を測ることができる。
(proceed-time (AddNode) (randn `(10 10)) (randn `(10 10)))
Evaluation took:
0.000 seconds of real time
0.000014 seconds of total run time (0.000014 user, 0.000000 system)
100.00% CPU
30,512 processor cycles
0 bytes consed
{CPUTENSOR[float] :shape (10 10) :named ChainTMP9447
:vec-state [computed]
((-1.5820543 2.2804832 -0.5613338 ~ 1.1143546 -1.3096298 -1.3756635)
(-1.5208249 0.21621853 2.660368 ~ -1.032644 0.25917292 -1.9737494)
...
(2.2557664 2.4791012 -0.04298857 ~ -1.2520232 1.8216541 -2.818116)
(0.8615336 0.92017823 -0.25378937 ~ 0.9697968 -0.6300591 1.5660275))
:facet :input
:requires-grad NIL
:backward <Node: PROCEEDNODE-T (A[~] -> A[~])>}
バックエンドはwith-devices
を用いてシームレスに切り替えられる
(with-devices (LispTensor CPUTensor) ;; LispTensor -> CPUTensorの優先順位 LispTensorの実装がないならCPUTensorを使う
(!add (randn `(10 10)) (randn `(10 10))))
気が向いたら書く:
JITコンパイルとキャッシュの最適化
BroadcastingとView
proceed関数
Shaping APIs
References