📊

PyTorch2.1のcompile機能の実力を試す

2023/10/07に公開

はじめに

半年ぶりにPyTorchにアップデートが来ました。PyTorch2.0の段階ではなんだか不安定だったtorch.compile()による高速化機能がようやく安定したようなので、いったいどれくらいの高速化効果があるのか早速検証していこうと思います。

torch.compile()について

torch.compile()はPyTorch2.0の新機能にしておそらく最大の目玉機能です。

net=torch.compile(net,mode="default")

こんな風にコードを1行付け足すだけで、ネットワークの学習が高速化します。
また、コンパイルのモードについては"default"のほかに"reduce-overhead","max-autotune"の計3種類が存在します。なので今回はこちらの速度についても計測していきたいと思います(2.0の時はこのオプションをいじると学習が全く進まなくなっていました。おま環と言われればそれまでかもしれませんが…)。
ただし、現状はLinuxでしか動かせないopenAI tritonを最適化エンジンに採用しているためか、この機能はWindowsでは使用できません。よって、Windows使いの私は愛用のNGCイメージ(NVIDIAさん謹製のAI用Dockerイメージ)でいろいろと試していくことにします。

ちなみにお気づきの方もいらっしゃるかとは思いますが、使い方がTorchScriptのtorch.jit.trace()と全く同じです。TorchScriptは失敗だったんですね、まぁ全然速くなんなかったしそりゃあね。

インストール

たかがPyTorchのバージョン上げるくらいで解説もなにもないだろう…
と思っていたら、typing_extensionsのバージョン要件が4.0.0以上に上がっているというちょっとした罠がありました。(筆者はここでプチはまりしました。こういうのって普通勝手にあげてくれませんでしたっけ?)

というわけで

pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
pip install typing_extensions==4.1.1

この2つコマンドを打ってインストール完了です。

検証方法

ResNet50をfood-101データセットの101,000枚の画像で3エポック訓練します。あまりに小さいネットワーク(俗にいうトイネット)を使うと正常にベンチマークしたことになるか怪しいので、ちゃんとRes50を使いますし、画像もResNet論文と同じ方式で224*224にリサイズします。
少し迷いましたが、画像枚数を稼ぎたかったので10.1万枚の画像全てをトレーニングに突っ込みました(train,validに分割しても問題なくvalid-lossが下がることは別で確認しているのでご安心を)。
今回はバッチサイズは64、ampの半精度を有効にして測ってみることにしました。
なお、PyTorch標準のデータローダーは(たとえ最適化したとしても)かなり遅いことで有名で、これを使うとデータローダーをベンチマークしていることになってしまいかねません。よって、今回はNVIDIA DALIをデータローダーに採用します。
また、WSLを使う場合は、Windowsからマウントした共有ディレクトリにデータを置いてはいけません。Windowsからの変換処理が挟まって劇的に遅くなってしまいます。

結果

コンパイル方法 Epoch1 Epoch2 Epoch3
コンパイル不使用 139s 141s 141s
default 167s 129s 128s
reduce-overhead 150s 116s 117s
max-autotune 182s 121s 121s

所感

defaultで約1.1倍、reduce-overheadで約1.2倍の高速化といったところでしょうか。事前の謳い文句ではもう少し高速化するようなことが言われていたのでちょっと残念ではあります。
一方、max-autotuneは一番長時間コンパイルしている割に大して速くないという結果に
ただ、これはnightlyの頃から一貫してこんな感じの結果になっていたので、これはそういう仕様なんだと思います。

まとめ

勝手に期待を膨らませすぎていたせいか、所感のところではちょっとテンションが下がっています…が、冷静に考えて学習がほぼ無条件で高速化するというのは非常に魅力的な機能だと思います。だってお金を払わずにGPUのグレードが一段階上がっているようなものですからね
PyTorchを使われている方は、torch.compile()をぜひ一度は必ず試してみるべきでしょう。
mode引数のチューニングも忘れずに!

環境

Ryzen7 3700X & GeForce RTX 3090
Windows11 Pro & WSL
NGCイメージ「nvcr.io/nvidia/tensorflow:22.12-tf2-py3」
PyTorch 2.1.0

Discussion