einsum!を作る
この記事は数値計算Advent Calendar 2022の1日目の記事です。
前回の記事では既存実装としてNumPyにおけるnumpy.einsum
の仕様を見ていきましたが、今回はRustのndarray crate向けにeinsumを作っていきます。完成形は次のようになります:
use ndarray::array;
use einsum_derive::einsum;
let a = array![
[1.0, 2.0],
[3.0, 4.0]
];
let b = array![
[1.0, 2.0],
[3.0, 4.0]
];
let c = einsum!("ij,jk->ik", a, b);
assert_eq!(c, array![
[6.0, 8.0],
[12.0, 16.0]
]);
このeinsum_derive
crateは下記のリポジトリで開発されています:
なお現在の実装では前回説明した省略記号...
を含むeinsumはサポート出来ていません。またBLAS演算に置き換える操作もまだ実装されておらず、素朴なループによる実装を生成しています。
全体像は次の通りです:
einsum入門
実装の解説を始める前にeinsumの機能について少し見てみましょう。
理論物理等の文脈に置いて、アインシュタインの縮約記法というのは複数のテンソルを取る演算で和の記号
を和の記号を省略して
を単に
上記の様な代表的な演算に付いては線形代数ライブラリ側にも対応した関数が用意されますが、このように複数のテンソルを受け取ってある添字に対して和を取る関数は非常にたくさん存在します。ではこれらの関数は自動的に作れないのでしょうか? そこでアインシュタインの縮約記法における和の補完規則の方に着目します。つまり
だと思って、この関数をi,i->
と書くことにします。添字はアルファベット一文字であれば何でもよく、i,i->
とj,j->
は同じ関数を表すことにします。同じように3つの行列を引数に行列-行列-行列積を計算する関数
をij,jk,kl->il
と書きます。このように縮約記法でかける操作を、その添字のパターンを記述することで文字列として関数を指定できます。この文字列を解釈して追加の引数としてもらったテンソルに対して実行する処理系の事をeinsumと呼びます。
手続きマクロによるコード生成
Rustには標準で手続きマクロ(procedural macro, proc-macroとよく呼ばれる)と呼ばれる、Rustのコードを生成するコードをRustで記述できる機能が存在します。上の例で言えば einsum!("ij,jk->ik", a, b)
の部分が手続きマクロの呼び出しに対応していて、これにより"ij,jk->ik", a, b
を入力がRustのコードの構文木(というかトークン列)を出力とする関数
#[proc_macro]
pub fn einsum(input: TokenStream) -> TokenStream { ... }
に渡されて実行され、この実行結果のトークン列
{
fn ij_jk__ik<T, S0, S1>(
arg0: ndarray::ArrayBase<S0, ndarray::Ix2>,
arg1: ndarray::ArrayBase<S1, ndarray::Ix2>,
) -> ndarray::Array<T, ndarray::Ix2>
where
T: ndarray::LinalgScalar,
S0: ndarray::Data<Elem = T>,
S1: ndarray::Data<Elem = T>,
{
let (n_i, n_j) = arg0.dim();
let (_, n_k) = arg1.dim();
{
let (n_0, n_1) = arg0.dim();
assert_eq!(n_0, n_i);
assert_eq!(n_1, n_j);
}
{
let (n_0, n_1) = arg1.dim();
assert_eq!(n_0, n_j);
assert_eq!(n_1, n_k);
}
let mut out0 = ndarray::Array::zeros((n_i, n_k));
for i in 0..n_i {
for k in 0..n_k {
for j in 0..n_j {
out0[(i, k)] = arg0[(i, j)] * arg1[(j, k)];
}
}
}
out0
}
let arg0 = a;
let arg1 = b;
let out0 = ij_jk__ik(arg0, arg1);
out0
}
がeinsum!
の呼び出し部分に置換されて本来のコンパイルが行われます。ユーザーはこのようなコード生成が行われている事に全く気づかないまま、特別なコード生成の為の設定を記述すること無く、通常のprintln!
等のマクロの様に使うことが出来ます。
今回は手続きマクロを用いてeinsumを実装するため、実行時の情報であるテンソルの形状とstrideの情報が得られません。可能であればテンソルの形状とstrideの情報を持った上で計算する順序を決定する方が有利になり得ますが、この設計では全てのサイズが同一であると仮定してその最適化は初めから諦めます。
Subscriptのパース
Rustの代表的なパーサコンビネータライブラリであるnomを使ってパーサーを書きます。省略記号...
は今回サポートしていないと書きましたが、将来的にはサポート予定なのでパーサー部分には入っていて、後段の処理に行くところでエラーにしています。BNF-likeに書くと次のようになります:
ellipsis = ...
index = a | b | c | d | e | f | g | h | i | j | k | l | m
| n | o | p | q | r | s | t | u | v | w | x | y | z;
subscript = { index } [ ellipsis { index } ];
subscripts = subscript {, subscript} [ -> subscript ]
einsumの分解
例えば3つの
この処理はeinsumの記法でいうとij,jk,kl->il
と書けます。これは素朴に計算すると、各添字
の様に分けて中間の行列
テンソルの名前管理
これをeinsumの記法で書くとどうなるのでしょうか? 直感的にはij,jk,kl->il
をij,jk->ik
とik,kl->il
に分解したように見えますが、これだと1つ目のik
と2つ目のik
が対応している事が上手く表せていません。そこで元の数式の通り、ここにどのテンソルを使うのかを追加した記法を用意しましょう。まず分解前のeinsumを
ij,jk,kl->il | arg0,arg1,arg2->out0
のように|
で区切った右側にテンソルの名前を記述します。引数として取るテンソルはその順番に応じてarg{N}
、einsumが生成するテンソルはout{N}
のように名前を付けます。この表記を使うと上記の分解は次の様に書けます
ij,jk->ik | arg0,arg1->out1
ik,kl->il | out1,arg2->out0
処理は上から順番に実行されることにします。これで途中で生成されるout1
であり、それが1つ目のeinsumで生成された出力であって、2つ目の処理で第一引数として使われることが分かります。この様にテンソルに名前をつけた状態で並べたものをパスと呼ぶことにします。コード上では einsum_codgen::Pathが対応します。
この表記法は単に私の好みで、こうである必要は無いです。例えば元の数式に近づけて
arg0 @ ij, arg1 @ jk, arg2 @ kl -> out0 @ il
の様に書くことも出来ます。重要なのは分解を記述するにはテンソルにパスの中で有効な名前をつける必要があるという点です。
分解順序と計算量
この分解はいつ可能でしょうか? 行列積の場合は自明に分解出きることが分かりましたが、一般のユーザー入力に対してどのように分解するかをどう決めればいいのでしょう?
この問題を考えるため、上の例を少し書き換えてみましょう:
1つの
この分解は再帰的に行うことが出来ます。例えば行列-行列-ベクトル積ij,jk,kl,l->i
と書けますが、まず最後の行列-ベクトル積を行うことで
kl,l->k | arg2,arg3->out1
ij,jk,k->i | arg0,arg1,out1->out0
となります。さらに2つ目のeinsumを同じように最後の行列-ベクトル積を選んで分解することにより
kl,l->k | arg2,arg3->out1
jk,k->j | arg1,out1->out2
ij,j->i | arg0,out2->out0
のように行列-行列積を計算することない3段の行列-ベクトル積の形に分解され、計算量は
ij,jk->ik | arg0,arg1->out1
ik,kl,l->i | out1,arg2,arg3->out0
再び先頭の行列-行列積を選ぶと
ij,jk->ik | arg0,arg1->out1
ik,kl->il | out1,arg2->out2
il,l->i | out2,arg3->out0
のように分解され、これ以上分解できません。この形では計算量は
現在の実装(0.1.0)では全パターンの分解に対して計算量とメモリのオーダーを計算し、最小のPath
を計算するようになっています。
コード生成
最後に求まったPath
からRustのコードを生成します。上述した様にeinsumのsubscriptは関数を表すので、例えばij,jk->ik
は2つの2次元配列を受け取り2次元配列を返す関数に展開されます。Rustの関数の識別子として,
や->
は使えないので、この部分をエスケープして関数名にします
fn ij_jk__ik<T, S0, S1>(
arg0: ndarray::ArrayBase<S0, ndarray::Ix2>,
arg1: ndarray::ArrayBase<S1, ndarray::Ix2>,
) -> ndarray::Array<T, ndarray::Ix2>
where
T: ndarray::LinalgScalar,
S0: ndarray::Data<Elem = T>,
S1: ndarray::Data<Elem = T>,
{ ... }
この関数はeinsum!
マクロのユーザーからは直接見えてほしくないので、スコープを作ってその中で定義します。つまり
let c = einsum!("ij,jk->ik", a, b);
を次の様に展開します:
let c = {
// 必要な関数を定義する
// この関数はこのスコープの外からは見えない
fn ij_jk__ik<T, S0, S1>(
arg0: ndarray::ArrayBase<S0, ndarray::Ix2>,
arg1: ndarray::ArrayBase<S1, ndarray::Ix2>,
) -> ndarray::Array<T, ndarray::Ix2>
where
T: ndarray::LinalgScalar,
S0: ndarray::Data<Elem = T>,
S1: ndarray::Data<Elem = T>,
{ ... }
// マクロの引数の名前を整理
let arg0 = a;
let arg1 = b;
// ↑で作った関数を呼び出し
let out0 = ij_jk__ik(arg0, arg1);
// ブロックから値を返す
out0
};
複数のeinsumに分解されているときは対応する関数をまず定義して、上で議論したテンソルの名前を頼りに順番にそれらを呼び出していきます。
Roadmap
今後の開発予定は次の通りです。他に機能要望などあればGitHub issueまでどうぞ。
- BLAS演算への置き換え
- 現状は素朴なfor-loopで実装しているので計算量は正しいがBLASに比べると随分遅い
- 省略記号
...
とブロードキャスティングのサポート
Discussion