🐘

【PyTorch】How to get the output shape of timm models

2024/08/17に公開

1. Get the all feature map(contains output) of timm model shape

I'll introduce the way to check the output of timm models whole frow briefly.
Sometimes, we have to check the size of the model's output because it is needed as input for the pooling layer after the model(or dense layer that adjusts the output shape) or check the model's behavior of inside.

2. Practical code

Do like this:
・Check Code

from urllib.request import urlopen
from PIL import Image
import timm

img = Image.open(urlopen(
    'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
))

model = timm.create_model(
    'resnet18.a1_in1k',
    # "efficientnet_b0.ra_in1k",
    pretrained=True,
    features_only=True,
)
model = model.eval()

# get model specific transforms (normalization, resize)
data_config = timm.data.resolve_model_data_config(model)
transforms = timm.data.create_transform(**data_config, is_training=False)

output = model(transforms(img).unsqueeze(0))  # unsqueeze single image into batch of 1

for o in output:
    # print shape of each feature map in output
    # e.g.:
    #  torch.Size([1, 64, 112, 112]) # output shape of first layer block
    #  torch.Size([1, 64, 56, 56])
    #  torch.Size([1, 128, 28, 28])
    #  torch.Size([1, 256, 14, 14])
    #  torch.Size([1, 512, 7, 7]) # final shape of first layer block

    print(o.shape)

・resnet18.a1_in1k outputs

#  torch.Size([1, 64, 112, 112]) # output shape of first layer block
#  torch.Size([1, 64, 56, 56])
#  torch.Size([1, 128, 28, 28])
#  torch.Size([1, 256, 14, 14])
#  torch.Size([1, 512, 7, 7]) # final shape of first layer block

This result is a pattern of features_only=True, this config means that returns all of the output of each layer block. If it is false, model returns only one output, for example, resnet18 returns the result that shape is ([1000]) because model contains the last pooling and dense layer for 1000 classes.

3. Summary

This time, I explained about how to check the shape of output of timm models. Please try it.

Discussion