🐬

【PyTorch】How to visualize ML model

2024/06/01に公開

I'll explain how to visualize pytorch model in this article.

1. print(model)

This is a most simple way to visualize the layers used. It behave that showing layer respect to order defined in init(), regardless how the model works actually.

・print(model)

!pip install torchviz
import torch
import torch.nn as nn
from torchviz import make_dot

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.fc1 = nn.Linear(16 * 28 * 28, 128)
        self.fc2 = nn.Linear(128, 10)
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU()
        

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

model = MyModel()

# Create a sample input
x = torch.randn(1, 3, 28, 28)

# Generate the computational graph
y = model(x)

print(model)

・Layers

MyModel(
  (fc1): Linear(in_features=12544, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=10, bias=True)
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu): ReLU()
)

This is very useful for simple visualization of layers used, but can't understand the architecture.

2. torchview and torchviz

We can see the model architecture and connection with a graph, I think it is also useful for understaing model's behaviour.

・Architecture

!pip install torchview
import torchvision
from torchview import draw_graph

# show graph
model_graph = draw_graph(model, input_size=x.shape, expand_nested=True)
display(model_graph.visual_graph) 

# save graph
model_graph.visual_graph.render(filename='model_graph', format='png')

If you wanna see the computional graph(drawing flow of model parameters), use this code.
・Computional graph

from torchviz import make_dot
from IPython.display import Image, display

# Assuming 'model' and 'y' are defined
dot = make_dot(y, params=dict(model.named_parameters()))

# Save to a PNG file
dot.format = 'png'
dot.render('model_graph', format="png")

# Display in Jupyter notebook
display(Image(filename='model_graph.png'))

I often use draw_graph in torchview to check the model architecture, it helps me on a daily basis.

Summary

This time, I explained how to visualize the pytorch model. There are more method to achieve it, so look for another one if you need.

Discussion