🗂

pytorchのdtype, shapeを型安全にするjaxtypingのすすめ

2024/09/07に公開

型安全なテンソル操作のためのjaxtyping入門

こんにちは、Yosematです。ディープラーニングや数値計算において、テンソル(多次元配列)は非常に重要なデータ構造です。テンソル操作の際に、しばしば「次元が合わない」「型が違う」といった問題に悩まされることがあります。そんなときに役立つのがjaxtypingです。jaxtypingは、テンソルの型と形状をチェックするライブラリで名前に反してJaxだけでなくPyTorchやNumpyユーザーにとっても便利なツールです。PyTorchユーザーとして一通り使ってみましたがモダンな機械学習研究開発に必須のツールと判断できたので共有します。

jaxtypingとは?

jaxtypingは、Python の型ヒント機能を活用しテンソルの型安全性と形状チェックを実現するためのライブラリです。これにより間違った型や形状のテンソルが関数に渡された際にエラーを早期に検知できます。エラーは静的解析では検知できませんが、ランタイムに検出することができます。PyTorchやNumpyのテンソルは多少次元が違っても雰囲気で計算するためのbroadcasting機能を持ちますが、そのせいで意図しないバグを有むことがよくあります。jaxtypingは研究開発を堅実で確かなものにするための最高のツールです。

インストール

jaxtypingを利用するためにはjaxtypingとランタイム型チェックツールを組み合わせてインストールする必要があります。ここではbeartypeを利用します。

pip install jaxtyping beartype

そして型をチェックしたいモジュールの__init__.pyに以下を記述します。

from beartype.claw import beartype_this_package
beartype_this_package()

これで下準備は完璧です。

使い方

jaxtypingからdtypeをimportするとともにspaceで区切ったstrでShapeを指定します。

from jaxtyping import Float32
import torch

def run_tensor_operation(
    tensor: Float32[torch.Tensor, "batch height width channels"]
) -> Float32[torch.Tensor, "batch height width channels"]:
    return tensor * 2

# うまく行くケース
valid_tensor = torch.rand(16, 64, 64, 3)
run_tensor_operation(valid_tensor)  # OK

# 実行時エラーになるケース
invalid_tensor = torch.rand(16, 64, 64)  # channels がない
# run_tensor_operation(invalid_tensor)  # 実行時にエラー

細かいところ

1次元とスカラー

次元が1の場合はDtype[Tensor, " B"], 次元が0(スカラ)の場合にはDtype[Tensor, ""]のように指定します。1次元の場合に必要になる最初のスペースに気をつけてください。

from jaxtyping import Int
import torch

def sum(x: Int[torch.Tensor, " B") -> Int[torch.Tensor, ""]:
    return x.sum()

F722 Lint Error

正常なアノテーションでもF722エラーでPyrightやMypyが不平を言い出します。これを防ぐためにF722エラーをignoreしてください。Ruffを使っている場合にはpyproject.tomlに以下を記述します。

pyproject.toml
[tool.ruff.lint]
ignore = [
    "F722"
]

jaxtyping のメリット

  1. 開発速度の向上
    常々いっていることですが開発工数の基本は正常系2割異常系8割です。テンソルの形状や型の不一致は、特に複雑なモデルやデータ前処理で頻発します。jaxtyping を使うことで、こうしたバグを開発の早い段階で発見できるため結果的に工数が大幅に削減できます。

  2. コードの可読性向上
    テンソルの型や形状を関数の引数や返り値に明示的に記載することで、他の開発者(あるいは未来の自分)がコードを読む際に、関数が何を期待しているかをすぐに理解できます。多くの場合助かるのはあなた自身です。

パフォーマンスへの影響

ランタイムにテンソル型をチェックするので軽微なパフォーマンス低下が見込まれます。とはいえ、ディープラーニングの計算ボトルネックにこの型チェックが含まれることはまずありえないでしょう。もし気になるのであればJAXTYPING_DISABLE=1を設定することでjaxtypingを無効にすることができます。

まとめ

jaxtypingは、テンソル操作において型安全性を高め早期にバグを発見するための非常に有効なツールです。特に複雑なテンソル操作が絡むディープラーニングのプロジェクトでは、その恩恵を強く感じることができるでしょう。

型安全なテンソル操作を求めている方にはぜひjaxtypingを導入してみることをおすすめします。

Discussion