🚄
PythonとNumbaで数値計算を高速化するときの知見
研究でPythonプログラムの高速化のためにNumbaを使用しました.
実装にあたって,いろいろエラーが出てつまずいたので,その知見をサンプルとして共有します.
私の場合ですが,ルンゲクッタ法を使う粒子群最適化法の計算が,約2000秒→約60秒と33倍の高速化になりました.
※注意※
- Numbaの基本的な使い方は参考文献を参照してください.
- すべての使用例では,nopythonモードである
njit
を使用しています. - Numbaは引数や戻り値の型を指定しなくても動作する場合がありますが,ここではすべて指定することを前提とします.
1. 環境
$ sw_vers
ProductName: Mac OS X
ProductVersion: 10.15.2
BuildVersion: 19C57
$ python -V
Python 3.8.5
$ pip freeze
numba==0.51.0
numpy==1.19.1
Numbaは以下のコマンドでインストール出来ます.
$ pip install numba
2. 使用例
np.empty
を使用する
2-1. 関数内でnumbaを使用する場合,関数内でnp.empty
を使用するとエラーが出ることがあります.
その場合,以下のように型を指定するとうまく動作しました.
main.py
import numpy as np
from numba import njit
@njit("f8[:,:]()")
def func():
x = np.empty((1, 2), dtype=np.float64)
return x
print(func())
2-2. 複数の戻り値を返す
複数の戻り値を返す場合,Tuple((i8, i8))
のように書きます.
カッコが2重になっていることに注意が必要です.
main.py
import numpy as np
from numba import njit
@njit("Tuple((i8, i8))(i8, i8)")
def func(x, y):
return x, y
print(func(1, 2))
2-3. 多次元のリストを扱う
Numbaで多次元のリストを扱う場合,f8[:,:]
のように書きます.
2次元だからコロンが2つというわけではなくて,何次元でも2つでいいようです.
main.py
import numpy as np
from numba import njit
@njit("f8[:,:](f8[:,:])")
def func(x):
return x ** 2
x = np.random.rand(5, 5)
print(func(x))
3. 終わり
Pythonの高速化はCythonやJuliaなど色々な方法がありますが,デコレーターを書くだけのNumbaによる方法は一番簡単なものだと思います.
クラスやジェネレーターが使えないなどの制約はありますが,ボトルネックを局所的に高速化するやり方であれば,比較的容易に実装できるものだと感じました.
Discussion