Open6

Torch.fx: Practical Program Capture and Transformation for Deep Learning in Python

lewisacidlewisacid

torch.fx

https://arxiv.org/abs/2112.08429

torch.fx: PyTorchでDLワークロードの計算グラフの書き換えを容易にする試みの1つ

https://pytorch.org/docs/stable/fx.html

主にsymbolic_trace、 中間表現、コード生成の3つの要素からなる

import torch
# Simple module for demonstration
class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.param = torch.nn.Parameter(torch.rand(3, 4))
        self.linear = torch.nn.Linear(4, 5)

    def forward(self, x):
        return self.linear(x + self.param).clamp(min=0.0, max=1.0)

module = MyModule()

from torch.fx import symbolic_trace
# Symbolic tracing frontend - captures the semantics of the module
symbolic_traced : torch.fx.GraphModule = symbolic_trace(module)

# High-level intermediate representation (IR) - Graph representation
print(symbolic_traced.graph)
"""
graph():
    %x : [#users=1] = placeholder[target=x]
    %param : [#users=1] = get_attr[target=param]
    %add : [#users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})
    %linear : [#users=1] = call_module[target=linear](args = (%add,), kwargs = {})
    %clamp : [#users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})
    return clamp
"""

# Code generation - valid Python code
print(symbolic_traced.code)
"""
def forward(self, x):
    param = self.param
    add = x + param;  x = param = None
    linear = self.linear(add);  add = None
    clamp = linear.clamp(min = 0.0, max = 1.0);  linear = None
    return clamp
"""

https://pytorch.org/docs/stable/fx.html#torch.fx.Node

torch.fxで使用されるIRのOpは以下の6つ

  1. placeholder: 関数の入力
  2. get_attr: モジュールからパラメータを取得する
  3. call_function: 値に対してfree functionを適用する
  4. call_module: 与えられた引数に対して、モジュールのforward()メソッドを適用する
  5. call_method: 値に対してメソッドを呼び出す
  6. output: トレースされた関数の出力がargs[0]属性で格納される。"return"ステートメントに対応

特徴:

  • Pythonで実装されている→使ったり、読んだり、カスタマイズしたりするのが簡単
  • Pythonコードを入力とし、Pythonコードを出力とする→Pythonエコシステムの恩恵(デバッガ、ランタイム)を受けられる

https://pytorch.org/docs/stable/fx.html#limitations-of-symbolic-tracing

制限:

  • 動的な制御フローをトレースできない
  • PyTorch以外の関数呼び出しが使えない

わからなかったところ:

  • torch.fxのIRやグラフ書き換え機能自体は、PyTorch 2.0の主要機能に含まれていない(legacyの機能だと思われる)??
  • 入れ物であるfx.Graphやfx.NodeのみPyTorch 2.0に利用されている??

ソースコード

https://github.com/pytorch/pytorch/blob/main/torch/fx/README.md

https://github.com/pytorch/pytorch/blob/main/torch/fx/node.py#L117-L627

https://github.com/pytorch/pytorch/blob/main/torch/fx/proxy.py#L344-L474

https://github.com/pytorch/pytorch/blob/main/torch/fx/graph.py#L674-L1521

https://github.com/pytorch/pytorch/blob/main/torch/fx/graph_module.py#L287-L784

symbolic_trace(m)の呼び出しは、Tracer().trace(m)と等価

https://github.com/pytorch/pytorch/blob/main/torch/fx/_symbolic_trace.py#L216-L839

https://pytorch.org/docs/stable/notes/extending.html#extending-torch

lewisacidlewisacid

PyTorch 2.0

https://pytorch.org/get-started/pytorch-2.0/

  • TorchDynamo:Python Frame Evaluation Hooks (cf. PEP 523) を使用してPyTorchプログラムを安全にキャプチャする
  • AOTAutograd:PyTorchのautogradエンジンをオーバーロードし、事前にbackwardトレースを生成するためのtracing autodiffとして機能させる
  • PrimTorch:PyTorchの2000以上の演算子を、開発者が完全なPyTorchバックエンドを構築するためにターゲットとすることができる250のプリミティブ演算子の閉じたセットに正規化する。これにより、PyTorchの機能やバックエンドを書く際の障壁を大幅に下げることができる
  • TorchInductor:複数のアクセラレータとバックエンドのための高速なコードを生成するディープラーニングコンパイラ。NVIDIAとAMDのGPUに対しては、OpenAI Tritonを主要なビルディングブロックとして使用している

  • 目的:高速なeager実行
  • eager実行の利点(デバッグの容易さ)を損なわず、Pythonコードの変更を要求することなく、実行を高速化する

TorchScript、FXトレースとの比較

https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html#comparison-to-torchscript-and-fx-tracing

torch.jit.trace

サンプル入力でメソッドを直接実行し、実行されたすべての演算子を順番に捕捉する

欠点:

  • 実際の制御フローパスしかトレースされないため、動的な制御フローをトレースできない
  • PyTorch以外の関数呼び出しの結果を定数として扱うため、(動的な制御フローを含む)関数呼び出しの結果が暗黙的にエラーになり得る

torch.jit.script

Pythonソースからプログラムを直接抽出し、コンパイルする

利点:

  • 動的な制御フローを扱える
  • サポートされていない機能を使用したことによるエラーを、報告することができる

欠点:

  • コンパイラスタックを必要とするため、実装が複雑化する
  • Pythonのすべての言語機能をサポートしない
    • torch.jit.scriptで使用可能なPythonのサブセットをTorchScriptと呼ぶ
    • 例えば型アノテーションをつける必要があるなど、Pythonコードの大幅な変更を必要とする
    • そのため、ユーザにとって使い勝手の良いものではない
  • PyTorch以外の関数呼び出しが使えない

torch.fx.symbolic_trace

torch.jit.scriptと似ているが、サンプル入力ではなく、抽象的な値でメソッドを実行し、トレースする(symbolic traceという)

利点:

  • Pythonで実装されており、修正や拡張が容易
  • キャプチャしたプログラムを6命令IRで表現しており、理解や解析が容易

欠点:

  • 動的な制御フローをトレースできない
  • PyTorch以外の関数呼び出しが使えない

TorchDynamo (Part of torch.compile)

Pythonのフレーム評価にフックして、Pythonのバイトコードからグラフを構築する(cf. PEP 523)

利点:

  • 動的な制御フローを含むグラフをキャプチャできる
  • PyTorch以外の関数呼び出しを使用できる
  • Pythonコードの修正を必要としない

TorchDynamo

TorchDynamoは、任意のPythonコードをFXグラフにJITコンパイルし、さらに最適化する役割を担う。TorchDynamoは、実行中のPythonバイトコードを解析し、PyTorch操作の呼び出しを検出することでFXグラフを抽出する。torch.compileのもう一つのコンポーネントであるTorchInductorが、FXグラフをさらに最適化カーネルにコンパイルする。

PrimTorch

https://dev-discuss.pytorch.org/t/where-do-the-2000-pytorch-operators-come-from-more-than-you-wanted-to-know/373

Untitled

TorchInductor

https://pytorch.org/get-started/pytorch-2.0/#torchinductor-fast-codegen-using-a-define-by-run-ir

TorchInductorは、Pythonicなdefine-by-runループレベルIRを使用して、PyTorchモデルをGPU上の生成TritonコードやCPU上のC++/OpenMPに自動的にマッピングする。TorchInductorのコアとなるループレベルIRには、わずか50個の演算子が含まれており、Pythonで実装されているため、容易にハッキングや拡張が可能である。

Next Step

↓fxグラフの入手方法、利用可能なIR、簡単な変換の実行方法について
[GoogleColab] PT2 Backend Integration

lewisacidlewisacid

LazyTensor: combining eager execution with domain-specific compilers

https://arxiv.org/abs/2102.13267

  • Lazy Tensor Core:アクセラレータとコンパイラのための新しい拡張ポイントの構築
  • 言語サブセット問題(あるドメインに特化した結果、元の言語機能の一部が使用できなくなる問題)に対処するために、イーガー実行の利点とドメイン固有コンパイラを組み合わせる新しいアプローチ
  • mainブランチに未マージ

https://github.com/pytorch/pytorch/blob/lazy_tensor_staging/lazy_tensor_core/README.md