🧩
【timm】Easy timm explained
1. What's timm
timm
is a deep-learning library created by Ross Wightman
and is a collection of SOTA computer vision models, layers, ulitities, optimizers, schedulers, data-loaders, augmentation and also training/validating scripts with ability to reproduce ImageNet training result.
Quote: timmdocs
By the way, origin of name timm
is Py"T"orch "Im"age "M"odels →"TImM".
2. Install
!pip install timm
Or for editable:
git clone https://github.com/rwightman/pytorch-image-models
cd pytorch-image-models && pip install -e .
3. How to use
3.1 Check available model
timm.list_models()
returns a complete list of available models in timm
.
・Basic
import timm
avail_pretrained_models = timm.list_models(pretrained=True)
display(len(avail_pretrained_models), avail_pretrained_models[:5])
# (1329, ※number of available models in timm
# ['bat_resnext26ts.ch_in1k',
# 'beit_base_patch16_224.in22k_ft_in22k',
# 'beit_base_patch16_224.in22k_ft_in22k_in1k',
# 'beit_base_patch16_384.in22k_ft_in22k_in1k',
# 'beit_large_patch16_224.in22k_ft_in22k'])
・Conditioned
import timm
all_densenet_models = timm.list_models('*densenet*')
display(all_densenet_models)
all_densenet_models = timm.list_models('*densenet*1k*', pretrained=True)
display(all_densenet_models)
# ['densenet121',
# 'densenet161',
# 'densenet169',
# 'densenet201',
# 'densenet264d',
# 'densenetblur121d']
# pretrained
# ['densenet121.ra_in1k',
# 'densenet121.tv_in1k',
# 'densenet161.tv_in1k',
# 'densenet169.tv_in1k',
# 'densenet201.tv_in1k',
# 'densenetblur121d.ra_in1k']
3.2 Create Model
・Basic
import timm
import torch
model = timm.create_model('resnet34')
x = torch.randn(1, 3, 224, 224)
display(model(x).shape)
# torch.Size([1, 1000])
・Pretrained
pretrained_resnet_34 = timm.create_model('resnet34', pretrained=True)
・Custom number of classes
import timm
import torch
model = timm.create_model('resnet34', num_classes=10)
x = torch.randn(1, 3, 224, 224)
display(model(x).shape)
# torch.Size([1, 10])
3.3 Check Model
Here, we can check the architecture of model.
# Print a summary of the model
print(model)
# Or iterate through the named modules to understand the structure
if False:
for name, module in model.named_modules():
print(name, module)
# Omit output
3.4 Modify the Model's Layer(s) (If you needed)
If you need, modify teh model's layer in here.
# Replace the final fully connected layer and adding a new dropout layer before the final fully connected layer
model.fc = torch.nn.Sequential(
torch.nn.Dropout(0.5),
torch.nn.Linear(in_features=2048, out_features=10) # Adjust to your desired output size
)
3.5 Test the Modified Model
Testing the model behavior.
# Find the expected input size with below or checking model's doc.
if False
if hasattr(model, 'default_cfg'):
input_size = model.default_cfg['input_size'] # Often a tuple like (3, 224, 224)
print(input_size)
# Dummy data for testing
dummy_input = torch.randn(1, 3, 224, 224) # Change the shape according to your model's input requirements
outputs = model(dummy_input)
print("Output:", outputs)
3.6 Fine-Tuning or Re-Training
By here, defining model was complated. From here, only doing fine-tuning or re-training as you like.
・Example
# Example of fine-tuning the adjusted model on a custom dataset
# Use DataLoader and loss functions for training
from torch.utils.data import DataLoader
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
# Dummy DataLoader and training loop for illustration
train_loader = DataLoader(YourCustomDataset(), batch_size=32, shuffle=True)
optimizer = Adam(model.parameters(), lr=0.001)
criterion = CrossEntropyLoss()
# Training loop
for epoch in range(10): # Adjust the number of epochs as needed
for inputs, targets in train_loader:
outputs = model(inputs)
loss = criterion(outputs, targets)
optimizer.zero_grad()
loss.backward()
optimizer.step()
torch.save(model, 'entire_resnet34.pth') # Saves the entire model
Reference
fast.ai timmdocs
Discussion