Pytorch Image Models (timm) with fastai
timm with fastaiのOverviewを読みます。
様々なモデルが簡単に利用できます。
Pytorch Image Models (timm)
'timm' は Ross Wightman によって作成されたディープラーニングライブラリで、SOTA コンピュータビジョンモデル、レイヤー、ユーティリティ、オプティマイザ、スケジューラ、データローダ、拡張、および ImageNet トレーニング結果を再現する機能を備えたトレーニング/検証スクリプトのコレクションです。
Install
pip install timm
または
git clone https://github.com/rwightman/pytorch-image-models
cd pytorch-image-models && pip install -e .
How to use
Create a model
import timm
import torch
model = timm.create_model('resnet34')
x = torch.randn(1, 3, 224, 224)
model(x).shape
torch.Size([1, 1000])
timmを使用してモデルを作成するのはとても簡単です。create_model
関数は、timmライブラリにある300を超えるモデルを作成するために使用できるファクトリメソッドです。
事前学習済モデルを作成するには、pretrained=True
を渡します。
pretrained_resnet_34 = timm.create_model('resnet34', pretrained=True)
Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet34-43635321.pth" to /home/tmabraham/.cache/torch/hub/checkpoints/resnet34-43635321.pth
モデルのクラス数を変更するには、num_classes=<number_of_classes>
を渡すだけです。
import timm
import torch
model = timm.create_model('resnet34', num_classes=10)
x = torch.randn(1, 3, 224, 224)
model(x).shape
torch.Size([1, 10])
List Models with Pretrained Weights
timm.list_models()
で使用可能なモデルの完全なリストを返します。事前学習済モデルの完全なリストを確認するには、list_models
に pretrained=True
を渡します。
avail_pretrained_models = timm.list_models(pretrained=True)
len(avail_pretrained_models), avail_pretrained_models[:5]
(592,
['adv_inception_v3',
'bat_resnext26ts',
'beit_base_patch16_224',
'beit_base_patch16_224_in22k',
'beit_base_patch16_384'])
現在、timmでは事前学習済の重みを持つ合計271のモデルが利用可能です!
Search for model architectures by Wildcard
以下のようにワイルドカードを使用してモデルアーキテクチャを検索することもできます。
all_densenet_models = timm.list_models('*densenet*')
all_densenet_models
['densenet121',
'densenet121d',
'densenet161',
'densenet169',
'densenet201',
'densenet264',
'densenet264d_iabn',
'densenetblur121d',
'tv_densenet121']
Fine-tune timm model in fastai
fastaiは、timm のfine-tuningをサポートしています。
from fastai.vision.all import *
path = untar_data(URLs.PETS)/'images'
dls = ImageDataLoaders.from_name_func(
path, get_image_files(path), valid_pct=0.2,
label_func=lambda x: x[0].isupper(), item_tfms=Resize(224))
# if a string is passed into the model argument, it will now use timm (if it is installed)
learn = vision_learner(dls, 'vit_tiny_patch16_224', metrics=error_rate)
learn.fine_tune(1)
epoch | train_loss | valid_loss | error_rate | time |
---|---|---|---|---|
0 | 0.201583 | 0.024980 | 0.006766 | 00:08 |
epoch | train_loss | valid_loss | error_rate | time |
---|---|---|---|---|
0 | 0.040622 | 0.024036 | 0.005413 | 00:10 |
補足
上記だけだと、学習しただけなので、推論と保存、読込を補足しておきます。
推論
fine-tuningしたモデルで推論するには以下のようになります。
- ひとまず学習画像を利用します
trains = get_image_files(path)
print(trains)
[Path('/root/.fastai/data/oxford-iiit-pet/images/boxer_30.jpg'), Path('/root/.fastai/data/oxford-iiit-pet/images/american_pit_bull_terrier_125.jpg'), Path('/root/.fastai/data/oxford-iiit-pet/images/boxer_2.jpg'), ...]
- 画像を表示
from PIL import Image
Image.open(trains[0])
- 推論
learn.predict(trains[0])
('False', tensor(0), tensor([1.0000e+00, 1.2900e-06]))
モデルの保存と読込
保存
learn.save('myModel')
もしくは
learn.export('myModel.pkl')
読込
learn2 = vision_learner(dls, 'vit_tiny_patch16_224', metrics=error_rate)
learn2.load('/content/myModel')
learn2.predict(trains[0])
('False', tensor(0), tensor([1.0000e+00, 1.2900e-06]))
learn3 = load_learner('/content/myModel.pkl')
learn3.predict(trains[0])
('False', tensor(0), tensor([1.0000e+00, 1.2900e-06]))
モデルの保存と読込と推論を補足(2023-06-04追記)
保存
torch.save(learn.model.state_dict(), 'myModel.pth')
読込
from fastai.vision.all import *
model, cfg = create_timm_model('vit_tiny_patch16_224', n_out=2, pretrained=False)
state = torch.load('myModel.pth')
model.load_state_dict(state)
model.eval()
推論
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
transform = create_transform(**resolve_data_config(cfg, model=model))
model(transform(image).unsqueeze(0))
Discussion