🧮

深層学習モデルの推論ランタイムを0から作った話

2022/11/24に公開

はじめに

深層学習モデルを動作させるためのソフトウェアは数多くあります。
PyTorch や TensorFlow などのフレームワークはそれ自身がモデルを実行する機能を持っていますし、ONNX Runtime のようにモデルを動作させることに特化したソフトウェアも存在します。

これらのソフトウェアは大抵、Python などから簡単に扱うことができます。
しかしながら、それらがどのように動作しているのか疑問に思うことはないでしょうか。

この記事では、0 から深層学習モデルの推論ランタイム(長いので以下「深層学習ランタイム」)を作った過程で学んだことを、とりとめもなく紹介していきます。ほとんど、自分用のメモのようになってしまうかもしれません。
作ったものは以下のリポジトリにあります。
技術的にはかなり適当なことを書いてしまうかもしれません。

https://github.com/maekawatoshiki/altius

深層学習ランタイムは何をするのか

深層学習ランタイムは、深層学習モデルを動作させるためのソフトウェアです。
もう少し具体的に考えてみると、深層学習ランタイムは「モデル」と「入力データ」を受け取り、「出力データ」を返すソフトウェアだと言えます。(ここでは推論時のみを取り扱います。)

今は機械学習寄りの話をしているため、入力・出力データはテンソルとして表現されていることがほとんどです。
例えば、入力として shape が 1 \times 3 \times 224 \times 224 の画像、出力として shape が 1 \times 1000 のクラス分類、などが挙げられます。
(ここでは、テンソルは単に多次元配列だと捉えていただければ大丈夫です。流石に適当すぎるかもしれないですが。)

テンソルを受け取り、モデルの構造に沿って何らかの計算を行い、テンソルを返す。この一連の処理を行うのが深層学習ランタイムというわけですね。

深層学習モデルはどのように表現されるのか

深層学習モデルを動作させるためには、それがどのような構造をしているのか知る必要があります。
フレームワークやフォーマットごとに細かい部分は異なりますが、基本的に以下のような特徴を持ちます。

  • 有向非巡回グラフ(DAG)
    • ノード: テンソルに対する操作(e.g. Add, Relu, Conv)
    • エッジ: テンソルの値(= データの流れを表現している)
  • 入次数が 0 のノードに(入力)テンソルを渡すと、出次数が 0 のノードから(出力)テンソルが得られる
    • ノードは、テンソルを受け取り、それに対して何らかの操作(計算)を行い、テンソルを返す
    • あるノードからあるノードへと行き着くまでに、それぞれのノードの種類に対応する計算が行われる
    • 外から見ると、テンソルを与えたら、別のテンソルが得られるように見える
  • DAG であるため、トポロジカルソートするとノードを一列に並べられる
    • 計算機で動かしやすい

また、広く使われているフォーマットとして、ONNX が存在します。
ONNX 形式で表現されたモデルを可視化した例を以下に示します。(一番上の Input3 はノードではなくて入力テンソルの名前、一番下のPlus214_Output_0 は出力テンソルの名前、それ以外の四角形はノードです)

mnist

苦労話

ここまで色々と書いてきましたが、私が深層学習ランタイムを作り始めた頃はこれらのほぼ全てを知らない状態でした。

さらに言えば、深層学習自体はもちろん、NumPy などのライブラリが配列(テンソル)をどのように扱っているのか・配列に対する操作にはどのようなものがあるのか、すらも知りませんでした。

例えば以下のようなことを知らなかったはずです。(思い出せる限り列挙)

  • N 次元配列がどう実現されるのか
    • 何次元の配列だろうが、(基本的に)値はメモリ上に一列に並んでいることは知っていた
      • float A[H][W]; と定義された変数なら、A[y][x]A + (y * W + x) のアドレスが同じことなどは理解していた
    • ただ、これを N 次元に拡張したとき、Stride という概念が登場してきて、最初は混乱した
      • float A[N][C][H][W]; の Stride は例えば [C*H*W, H*W, W, 1] と表現できる
      • 便宜上、S=[C*H*W, H*W, W, 1] とおいて(S[0]=C*H*W, S[3]=1) 、
      • Stride を使えば、A[x][y][z][w] のアドレスが A+x*S[0]+y*S[1]+z*S[2]+w*[3] と表現できる(すごい!)
    • さらに、Layout といった概念も後から登場してきて、値がメモリ上に一列に並んでいることすらも幻想だったと気づいた
  • Broadcast がどう実現されるのか
    • Broadcast という考え方を用いると、異なる shape のテンソル間の計算が可能になる
      • e.g. shape がそれぞれ (8, 4), (8, ) のテンソルを足して、(8, 4) のテンソルを得られる
      • なんとなく何が起きているのかは理解できたが、じゃあ実際計算機上でどう実現されているのかわからなかった
    • Broadcast は一部の Stride を 0 にすることで実現される、と知った時はとても賢いなぁと感激した覚えがある
      • e.g. float A[8][4];, float B[8]; の Stride は(このままだと)それぞれ [4, 1], [1] になる
      • ここで、それぞれの Stride を [4, 1], [1, 0] だと解釈してあげる
      • すると、要素和は Output[y][x] = *(A+y*4+x*1) + *(B+y*1+x*0) だと書ける(!)
  • Reshape とは何なのか
    • shape を (1, 3, 20, 20) から (1, 1200) に変換できる、と言われてもピンとこなかった
    • (1, 10) から (1, 2, 5) のように次元が増える場合もある、と言われた時は頭が ??? で埋め尽くされた覚えがある
    • これは結局、Stride の値を書き換えているだけなんだと理解した
      • データの中身が変化することはない
  • im2col とは何か
    • 畳み込み演算(Convolution)を素直に実装すると何重ものループを使うことになる
      • そして(少なくとも CPU 上なら)遅い
    • そこで画像を上手く変換してあげることで、重みとの行列積だけで Convolution を実現するのが im2col
      • ただしメモリを消費するので気を付ける
  • 単純な行列積の実装は本当に遅い
    • 知識としては知っていたことだけれど、いざ手を動かしてみると実感する
    • OpenBLAS, Intel MKL, Blis は本当に速い
  • Protocol Buffers の扱い方
    • これは ONNX に関係する話
    • Rust や Python で、自在に Protocol Buffers を扱う力がついた気がする
  • Python
    • おそらく機械学習に入門していなければ、こんなに Python を書けるようにはならなかった気がする

さらに、以下のようなことでも苦労しました。

  • 深層学習ランタイムを自作する人は少ない
    • 畳み込み演算を手書きする、MNIST の推論モデルを手書きする、くらいならいくつか見つけることができました
    • もっと調べれば見つかるのかもしれませんが、既存のランタイムのようにさまざまなモデルを動作させている人がなかなか見つかりませんでした
    • 必然的に、有名で大きなソフトウェアを参考にするしかなくなります
  • モデルが大きい
    • 数十MB から 1GB ほどまで、さまざまなモデルが存在します
    • Git LFS で管理したり、Google Drive に置いたりしましたが、やっぱり扱いが面倒

出来上がったもの

読者の方がおそらく一番気になっているであろう、結局何が出来上がったのかを紹介します。

私は Rust が好きなので Rust でランタイムを実装し、さらに Python からも使えるように pyo3 でバインディングも作りました。Python バインディングはほとんど ONNX Runtime の API のサブセットとして振る舞うため、かなり簡単に使えます。

import onnxruntime as ort
# このくらいのコードなら、ort を altius_py に置き換えるだけで動きます
import altius_py # as ort

sess = ort.InferenceSession("model.onnx")
output = sess.run(None, input)

モデルのフォーマットとしては ONNX にのみ対応しています。PyTorch や TensorFlow などに対応してもよかったのですが、単純に面倒なのと、それらは ONNX として export できるので優先度が低くなりました。

また、現在は CPU 上での実行にのみ対応しています。
一応マルチスレッドにも対応しているため、モデルによっては ONNX Runtime と互角の速度で動作します。

動作を確認したモデルは以下の通りです。(ほとんどのモデルは onnx simplifier を通してから実行しました)
Computer vision から自然言語処理まで色々なモデルが動きます。

自分の環境で試してみたい方は、リポジトリ の README.md を参考にしてみてください。(ドキュメント不足なので、気軽に Discussion や Issue を立ててください…)

少しだけ実装の話を

  • DAG をトポロジカルソートして、その順で逐次実行
    • 同時に2つのノードが別スレッドで動いたりはしない
    • ソートの方法は適当; メモリ使用量を最小にするなどの工夫はなし
  • 簡単な Gelu fusion
    • Gelu という活性化関数は ONNX では直接サポートされていないため、いろんなノードの組み合わせで表現される
    • これらの散らばったノードをまとめあげて Gelu として認識・実行する最適化を実装
      • ほんのちょっとだけ速くなる
      • Transformer 系のモデルで活躍
  • 並列化
    • オペレータによっては内部実装が並列化されている
    • 要素単位の操作(e.g. Add, Relu, ...)など
  • サボった部分
    • 行列積
      • 自分では実装していない
      • 最終的に Blis を使うことにした(最も性能がよかったため)
        • macOS では accelerate も使えます。とても速い。
      • GEMM や Convolution で使っている
    • いくつかのメモリ操作
      • Rust の ndarray crate を使って、自分で実装していない部分がある
      • そのうち自分で書き直したい

終わりに

今までいろんなもの(OS、Web ブラウザ、JavaScript 処理系、その他言語処理系関連)を自作してきましたが、深層学習ランタイムはそれらとは少し毛色が違って刺激的でした。
(グラフ)コンパイラや HPC 寄りの知識を得るきっかけにもなりますし、画像処理や言語処理などのモデルが動くと普通に面白いです。(特に言語処理モデルが動くと、計算機と会話できるので楽しい)

深層学習ランタイム自作、流行ってほしいですね。

Discussion