get_ipython().run_line_magic()について調べてみた。

2024/06/30に公開

はじめに

Pythonでjaxについて勉強している際に

    get_ipython().run_line_magic('timeit', f'-n {repeat} numpy_mod(size)')

このコードを目の当たりしましたが、この意味が全然分からなかったので、調べてみました。
今回使うコードは以下のようなものです。

import jax.numpy as jnp
from jax import jit
from functools import partial

# (size, size)の行列を作ってMod計算
@partial(jit, static_argnums=(0,))
def jax_jit_mod(size):
    x = jnp.arange(size, dtype=jnp.int32)
    mat = x[None, :] * x[:, None] # (size, size)
    return mat % 256

for i in range(4):
    size = 10**(i+1)
    repeat = 10**(4-i)
    print("size =", size, "repeat =", repeat)
    get_ipython().run_line_magic('timeit', f'-n {repeat} jax_jit_mod(size).block_until_ready()') # jitありJAX

コードの解説

全体の流れ

全体の流れとしては、異なるサイズの行列に対する計算速度を測定する際にjaxのjitを用いて高速化してみたものです。行列計算はxの外積を計算し、各要素を256で割るという処理をしています。

jitの説明

jitを使って関数をJITコンパイルします。

  • JIT(Just-In-Time)コンパイルは、プログラムの一部を実行時にネイティブマシンコードに変換します。これにより、通常のインタープリタ実行よりも高速に動作します。
  • 特に、数値計算や行列操作などの重い計算に対して、JITコンパイルによる最適化は大きなパフォーマンス向上をもたらします。

partial は、関数の一部の引数やキーワード引数を事前に設定して新しい関数を作成するためのものです。JAXのjitデコレータと一緒に使うことで、特定の引数を静的に指定できます。

@partial(jit, static_argnums=(0,)) の意味

@partial(jit, static_argnums=(0,)) は、以下のように解釈できます:

jit デコレータで関数をJITコンパイルする。
static_argnums=(0,) で、関数の最初の引数(0番目の引数)を静的引数として指定する。

get_ipython().run_line_magic('timeit', f'-n {repeat} jax_jit_mod(size).block_until_ready()') の説明

get_ipython().run_line_magic('timeit', f'-n {repeat} numpy_mod(size)') は、IPython環境でタイミング計測を行うためのコードです。具体的には、%timeit マジックコマンドを使って、関数の実行時間を測定します。以下に詳しく説明します。

IPython マジックコマンド

IPythonには、特定のタスクを簡略化するためのマジックコマンドが多数用意されています。その中でも %timeit は、コードの実行時間を測定するために使われます。

timeitマジックコマンド

%timeit は、指定したコードの実行時間を測定し、複数回の実行結果の平均を表示します。これにより、コードの性能を評価するのに役立ちます。

get_ipython().run_line_magic

get_ipython()は、現在のIPythonインタプリタインスタンスを取得するための関数です。IPythonは、Pythonの拡張シェルであり、Jupyter NotebookもIPythonをバックエンドとして使用しています。この関数を使用することで、IPythonの豊富な機能にアクセスすることができます。

run_line_magic()は、IPythonのマジックコマンドをプログラム内から実行するためのメソッドです。IPythonには多くの便利なマジックコマンドがあり、これらをプログラム的に呼び出すことができます。

使い方の例

get_ipython().run_line_magic(magic_name, line)
  • magic_name:実行したいマジックコマンドの名前
    • プログラムの途中でタイミング計測を行う(%timeitなど)
    • 環境変数や設定を一時的に変更する(%envなど)
    • 外部コマンドを実行する(%ls、%pwdなど)
  • line:マジックコマンドに渡す引数

f'-n {repeat} numpy_mod(size)'

この部分は、フォーマット文字列を使って、コマンドライン引数を動的に生成しています。

  • -n {repeat} は、timeit コマンドのオプションで、測定を指定した回数(repeat 回)繰り返すことを指定します。
  • numpy_mod(size) は、計測対象の関数呼び出しです。

例えば、repeat が10の場合、f'-n {repeat} numpy_mod(size)'-n 10 numpy_mod(size) という文字列に展開されます。

全体の解説

get_ipython().run_line_magic('timeit', f'-n {repeat} numpy_mod(size)') は、次の処理を行います:

  1. get_ipython() は、現在のIPythonインタプリタを取得します。
  2. run_line_magic('timeit', ...) は、IPythonの %timeit マジックコマンドをプログラムから実行します。
  3. f'-n {repeat} numpy_mod(size)' は、repeat 回数を指定して numpy_mod(size) 関数の実行時間を測定します。

注意点

  • %timeit マジックコマンドは、Jupyter Notebook や IPython シェルで動作します。通常の Python スクリプトでは動作しないので、注意が必要です。
  • get_ipython().run_line_magic を使うことで、IPythonのマジックコマンドをスクリプトから実行することができますが、コードの可読性を考慮して使うべきです。

Discussion