🚄

PythonとNumbaで数値計算を高速化するときの知見

2021/03/27に公開

研究でPythonプログラムの高速化のためにNumbaを使用しました.
実装にあたって,いろいろエラーが出てつまずいたので,その知見をサンプルとして共有します.

私の場合ですが,ルンゲクッタ法を使う粒子群最適化法の計算が,約2000秒→約60秒と33倍の高速化になりました.

※注意※

  • Numbaの基本的な使い方は参考文献を参照してください.
  • すべての使用例では,nopythonモードであるnjitを使用しています.
  • Numbaは引数や戻り値の型を指定しなくても動作する場合がありますが,ここではすべて指定することを前提とします.

1. 環境

$ sw_vers
ProductName:	Mac OS X
ProductVersion:	10.15.2
BuildVersion:	19C57

$ python -V
Python 3.8.5

$ pip freeze
numba==0.51.0
numpy==1.19.1

Numbaは以下のコマンドでインストール出来ます.

$ pip install numba

2. 使用例

2-1. 関数内でnp.emptyを使用する

numbaを使用する場合,関数内でnp.emptyを使用するとエラーが出ることがあります.
その場合,以下のように型を指定するとうまく動作しました.

main.py
import numpy as np
from numba import njit

@njit("f8[:,:]()")
def func():
    x = np.empty((1, 2), dtype=np.float64)
    return x

print(func())

2-2. 複数の戻り値を返す

複数の戻り値を返す場合,Tuple((i8, i8))のように書きます.
カッコが2重になっていることに注意が必要です.

main.py
import numpy as np
from numba import njit

@njit("Tuple((i8, i8))(i8, i8)")
def func(x, y):
    return x, y

print(func(1, 2))

2-3. 多次元のリストを扱う

Numbaで多次元のリストを扱う場合,f8[:,:]のように書きます.
2次元だからコロンが2つというわけではなくて,何次元でも2つでいいようです.

main.py
import numpy as np
from numba import njit

@njit("f8[:,:](f8[:,:])")
def func(x):
    return x ** 2

x = np.random.rand(5, 5)
print(func(x))

3. 終わり

Pythonの高速化はCythonやJuliaなど色々な方法がありますが,デコレーターを書くだけのNumbaによる方法は一番簡単なものだと思います.

クラスやジェネレーターが使えないなどの制約はありますが,ボトルネックを局所的に高速化するやり方であれば,比較的容易に実装できるものだと感じました.

4. 参考文献

Discussion