🧩

【timm】Easy timm explained

2024/04/30に公開

1. What's timm

timm is a deep-learning library created by Ross Wightman and is a collection of SOTA computer vision models, layers, ulitities, optimizers, schedulers, data-loaders, augmentation and also training/validating scripts with ability to reproduce ImageNet training result.

Quote: timmdocs

By the way, origin of name timm is Py"T"orch "Im"age "M"odels →"TImM".

2. Install

!pip install timm

Or for editable:

git clone https://github.com/rwightman/pytorch-image-models
cd pytorch-image-models && pip install -e .

3. How to use

3.1 Check available model

timm.list_models() returns a complete list of available models in timm.
・Basic

import timm

avail_pretrained_models = timm.list_models(pretrained=True)
display(len(avail_pretrained_models), avail_pretrained_models[:5])
# (1329, ※number of available models in timm
#  ['bat_resnext26ts.ch_in1k',
#   'beit_base_patch16_224.in22k_ft_in22k',
#   'beit_base_patch16_224.in22k_ft_in22k_in1k',
#   'beit_base_patch16_384.in22k_ft_in22k_in1k',
#   'beit_large_patch16_224.in22k_ft_in22k'])

・Conditioned

import timm

all_densenet_models = timm.list_models('*densenet*')
display(all_densenet_models)
all_densenet_models = timm.list_models('*densenet*1k*', pretrained=True)
display(all_densenet_models)

# ['densenet121',
#  'densenet161',
#  'densenet169',
#  'densenet201',
#  'densenet264d',
#  'densenetblur121d']

#  pretrained
# ['densenet121.ra_in1k',
#  'densenet121.tv_in1k',
#  'densenet161.tv_in1k',
#  'densenet169.tv_in1k',
#  'densenet201.tv_in1k',
#  'densenetblur121d.ra_in1k']

3.2 Create Model

・Basic

import timm 
import torch

model = timm.create_model('resnet34')
x     = torch.randn(1, 3, 224, 224)
display(model(x).shape)
# torch.Size([1, 1000])

・Pretrained

pretrained_resnet_34 = timm.create_model('resnet34', pretrained=True)

・Custom number of classes

import timm 
import torch

model = timm.create_model('resnet34', num_classes=10)
x     = torch.randn(1, 3, 224, 224)
display(model(x).shape)
# torch.Size([1, 10])

3.3 Check Model

Here, we can check the architecture of model.

# Print a summary of the model
print(model)

# Or iterate through the named modules to understand the structure
if False:
    for name, module in model.named_modules():
        print(name, module)

# Omit output

3.4 Modify the Model's Layer(s) (If you needed)

If you need, modify teh model's layer in here.

# Replace the final fully connected layer and adding a new dropout layer before the final fully connected layer
model.fc = torch.nn.Sequential(
    torch.nn.Dropout(0.5),
    torch.nn.Linear(in_features=2048, out_features=10) # Adjust to your desired output size
)

3.5 Test the Modified Model

Testing the model behavior.

# Find the expected input size with below or checking model's doc.
if False
    if hasattr(model, 'default_cfg'):
        input_size = model.default_cfg['input_size']  # Often a tuple like (3, 224, 224)
        print(input_size)

# Dummy data for testing
dummy_input = torch.randn(1, 3, 224, 224)  # Change the shape according to your model's input requirements
outputs = model(dummy_input)

print("Output:", outputs)

3.6 Fine-Tuning or Re-Training

By here, defining model was complated. From here, only doing fine-tuning or re-training as you like.

・Example

# Example of fine-tuning the adjusted model on a custom dataset
# Use DataLoader and loss functions for training
from torch.utils.data import DataLoader
from torch.nn import CrossEntropyLoss
from torch.optim import Adam

# Dummy DataLoader and training loop for illustration
train_loader = DataLoader(YourCustomDataset(), batch_size=32, shuffle=True)
optimizer = Adam(model.parameters(), lr=0.001)
criterion = CrossEntropyLoss()

# Training loop
for epoch in range(10):  # Adjust the number of epochs as needed
    for inputs, targets in train_loader:
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

torch.save(model, 'entire_resnet34.pth')  # Saves the entire model

Reference

fast.ai timmdocs

Discussion