PytorchのモデルをTorchvistaで可視化する
TorchvistaはPytorchのモデル構造と入力・出力Tensorの形状を視覚的に可視化するパッケージです。jupyter やgoogle colabといったのWebベースのNotebook環境で利用できます
モデルが大規模・複雑になると、コードを読むだけでは構造やTensorの流れを把握するのが難しくなります。Torchvista を使えば、視覚的に構造を確認でき、各ブランチやレイヤーの処理、モジュール内の構成も直感的に理解しやすくなります。
便利な機能
-
インタラクティブ
ズーム・ドラッグ・展開・折りたたみを使用して、確認したいレベルでの構造を確認することができます。 -
複雑なモジュール構造の可視化
例えば、分岐構造を持つInceptionモデルや、複数のAttention Blockを内包するような複雑なモデルでも、モジュール構造をきれいに可視化できます。 -
エラーデバッキング
スクラッチで実装しているとき、様々なエラーやバグに悩まされます。Torchvistaは問題が起きているとこまでの可視化を可能とするためtensorの形状が間違っているなどのデバッグをしやすくなります。
Demo Gallery
まずは作者が公開しているDemo page(英語)で実際どのように構造が可視化されているのか確認してみよう。インタラクティブな機能も試すことができます。
シンプルなモデルからResNetなどの有名なモデル、XLNetBaseCasedといった大規模モデルがサンプルとして提供されています。後に解説する各引数による描画の違いも確認することができます。
このモデル構造の可視化が気になったら、実際に使ってみましょう。
使い方
まず、torchvistaをインストールします。
pip install torchvista
次に構造を可視化したいモデルを定義する。Pytorchを使用して定義されたモデルであれば、Pytorchで提供されている訓練済みモデルでもスクラッチ実装されたものでも構いません。
ここでは、シンプルなモデルを定義してテストしていきます。
import torch
import torch.nn as nn
class TestModel(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
# 3x3 conv
self.seq = nn.Sequential(
nn.BatchNorm2d(in_ch),
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
nn.ReLU(),
)
def forward(self, X):
out = self.seq(X)
return out
model = TestModel(192, 128)
x = torch.rand(1, 192, 28, 28)
次に、torchvistaを使用てモデルの可視化していきます。
trace_model(<model(torch.nn.Module)>, <input>)
にモデルと入力tensorを与えます。
from torchvista import trace_model
trace_model(model, x)
このようにわずか1行のコードでモデルの構造を可視化することができます。
エラーデバッキング
エラーが発生している場合、可視化グラフ上のセルが赤く表示されます。この例の場合、各ブランチを結合するときに問題が発生しています。tensorの形状を見てみると、左から3つの目のブランチのConvレイヤーの後、tensorのサイズが一致していません。エラーメッセージも表示されるので見比べてどこに問題が起こっているのか簡単に把握できます。
画像を出力
trace_model()
にgenerate_image=True
の引数を渡すことで静的な画像を別タブで見るためのボタンが表示されます(デフォルトはFalse
)。
また、height
で画像のサイズを指定することもできます(デフォルトでは800)。
trace_model(model, x, generate_image=True)
個人的な感想
シンプルなコードで複雑なモデルの全体像を把握するのに役に立つと感じます。残念な点としては、モジュールを展開したり折りたたんだりするたびに、勝手にリサイズしてしまう。そのため、大規模なモデルだと展開・折りたたみするたびに、何度も拡大・縮小・移動をしないといけないのは、少し手間に感じました。
まとめ
複雑なモデルの全体像をコードだけで追うのは大変なことがあると思います。Torchvistaを使用することによって、視覚的に構造を確認することができ、各ブランチ・レイヤーの処理、input/outputの形状・モジュール内の構造が把握しやすくなると思います。手軽に利用できるので、是非試してみてください。
Reference
GitHub: Torchvista
Demo: Demo page
Discussion