import timm
import torch
import pprint
# from torchviz import make_dot, make_dot_from_trace
timmとは#
`timm` is a deep-learning library created by Ross Wightman and is a collection of SOTA computer vision models, layers, utilities, optimizers, schedulers, data-loaders, augmentations and also training/validating scripts with ability to reproduce ImageNet training results.
(参考)
https://huggingface.co/docs/timm/quickstart
いくつモデルががるのか#
aveil_models = timm.list_models()
avail_pretrained_models = timm.list_models(pretrained=True)
print(f'モデル数: {len(aveil_models)}')
print(f'モデル数(pretrained): {len(avail_pretrained_models)}')
モデル数: 964
モデル数(pretrained): 770
どんなモデルがあるのか#
timm.list_models()[:10]
['adv_inception_v3',
'bat_resnext26ts',
'beit_base_patch16_224',
'beit_base_patch16_224_in22k',
'beit_base_patch16_384',
'beit_large_patch16_224',
'beit_large_patch16_224_in22k',
'beit_large_patch16_384',
'beit_large_patch16_512',
'beitv2_base_patch16_224']
# 検索も可能
timm.list_models("*efficientnet*", pretrained=True)
['efficientnet_b0',
'efficientnet_b1',
'efficientnet_b1_pruned',
'efficientnet_b2',
'efficientnet_b2_pruned',
'efficientnet_b3',
'efficientnet_b3_pruned',
'efficientnet_b4',
'efficientnet_el',
'efficientnet_el_pruned',
'efficientnet_em',
'efficientnet_es',
'efficientnet_es_pruned',
'efficientnet_lite0',
'efficientnetv2_rw_m',
'efficientnetv2_rw_s',
'efficientnetv2_rw_t',
'gc_efficientnetv2_rw_t',
'tf_efficientnet_b0',
'tf_efficientnet_b0_ap',
'tf_efficientnet_b0_ns',
'tf_efficientnet_b1',
'tf_efficientnet_b1_ap',
'tf_efficientnet_b1_ns',
'tf_efficientnet_b2',
'tf_efficientnet_b2_ap',
'tf_efficientnet_b2_ns',
'tf_efficientnet_b3',
'tf_efficientnet_b3_ap',
'tf_efficientnet_b3_ns',
'tf_efficientnet_b4',
'tf_efficientnet_b4_ap',
'tf_efficientnet_b4_ns',
'tf_efficientnet_b5',
'tf_efficientnet_b5_ap',
'tf_efficientnet_b5_ns',
'tf_efficientnet_b6',
'tf_efficientnet_b6_ap',
'tf_efficientnet_b6_ns',
'tf_efficientnet_b7',
'tf_efficientnet_b7_ap',
'tf_efficientnet_b7_ns',
'tf_efficientnet_b8',
'tf_efficientnet_b8_ap',
'tf_efficientnet_cc_b0_4e',
'tf_efficientnet_cc_b0_8e',
'tf_efficientnet_cc_b1_8e',
'tf_efficientnet_el',
'tf_efficientnet_em',
'tf_efficientnet_es',
'tf_efficientnet_l2_ns',
'tf_efficientnet_l2_ns_475',
'tf_efficientnet_lite0',
'tf_efficientnet_lite1',
'tf_efficientnet_lite2',
'tf_efficientnet_lite3',
'tf_efficientnet_lite4',
'tf_efficientnetv2_b0',
'tf_efficientnetv2_b1',
'tf_efficientnetv2_b2',
'tf_efficientnetv2_b3',
'tf_efficientnetv2_l',
'tf_efficientnetv2_l_in21ft1k',
'tf_efficientnetv2_l_in21k',
'tf_efficientnetv2_m',
'tf_efficientnetv2_m_in21ft1k',
'tf_efficientnetv2_m_in21k',
'tf_efficientnetv2_s',
'tf_efficientnetv2_s_in21ft1k',
'tf_efficientnetv2_s_in21k',
'tf_efficientnetv2_xl_in21ft1k',
'tf_efficientnetv2_xl_in21k']
# 検索も可能
timm.list_models("*YOLO*", pretrained=True)
# 検索も可能
timm.list_models("*mobilenet*", pretrained=True)
['mobilenetv2_050',
'mobilenetv2_100',
'mobilenetv2_110d',
'mobilenetv2_120d',
'mobilenetv2_140',
'mobilenetv3_large_100',
'mobilenetv3_large_100_miil',
'mobilenetv3_large_100_miil_in21k',
'mobilenetv3_rw',
'mobilenetv3_small_050',
'mobilenetv3_small_075',
'mobilenetv3_small_100',
'tf_mobilenetv3_large_075',
'tf_mobilenetv3_large_100',
'tf_mobilenetv3_large_minimal_100',
'tf_mobilenetv3_small_075',
'tf_mobilenetv3_small_100',
'tf_mobilenetv3_small_minimal_100']
それぞれのモデルのサイズを計測してみる#
参考: https://discuss.pytorch.org/t/finding-model-size/130275
def calc_model_size(model):
param_size = 0
for param in model.parameters():
param_size += param.nelement() * param.element_size()
buffer_size = 0
for buffer in model.buffers():
buffer_size += buffer.nelement() * buffer.element_size()
size_all_mb = (param_size + buffer_size) / 1024**2
return size_all_mb
for model_name in timm.list_models("*mobilenet*", pretrained=True):
model = timm.create_model(model_name)
model_size = calc_model_size(model)
print(f"{model_name:40}: {model_size:6.1f}MB")
mobilenetv2_050 : 7.6MB
mobilenetv2_100 : 13.5MB
mobilenetv2_110d : 17.4MB
mobilenetv2_120d : 22.5MB
mobilenetv2_140 : 23.5MB
mobilenetv3_large_100 : 21.0MB
mobilenetv3_large_100_miil : 21.0MB
mobilenetv3_large_100_miil_in21k : 71.0MB
mobilenetv3_rw : 21.0MB
mobilenetv3_small_050 : 6.1MB
mobilenetv3_small_075 : 7.8MB
mobilenetv3_small_100 : 9.7MB
tf_mobilenetv3_large_075 : 15.3MB
tf_mobilenetv3_large_100 : 21.0MB
tf_mobilenetv3_large_minimal_100 : 15.1MB
tf_mobilenetv3_small_075 : 7.8MB
tf_mobilenetv3_small_100 : 9.7MB
tf_mobilenetv3_small_minimal_100 : 7.8MB
for model_name in timm.list_models("*efficientnet*", pretrained=True):
model = timm.create_model(model_name)
model_size = calc_model_size(model)
print(f"{model_name:40}: {model_size:6.1f}MB")
efficientnet_b0 : 20.3MB
efficientnet_b1 : 30.0MB
efficientnet_b1_pruned : 24.3MB
efficientnet_b2 : 35.0MB
efficientnet_b2_pruned : 31.9MB
efficientnet_b3 : 47.0MB
efficientnet_b3_pruned : 37.8MB
efficientnet_b4 : 74.3MB
efficientnet_el : 40.7MB
efficientnet_el_pruned : 40.7MB
efficientnet_em : 26.6MB
efficientnet_es : 20.9MB
efficientnet_es_pruned : 20.9MB
efficientnet_lite0 : 17.9MB
efficientnetv2_rw_m : 204.3MB
efficientnetv2_rw_s : 91.9MB
efficientnetv2_rw_t : 52.5MB
gc_efficientnetv2_rw_t : 52.6MB
tf_efficientnet_b0 : 20.3MB
tf_efficientnet_b0_ap : 20.3MB
tf_efficientnet_b0_ns : 20.3MB
tf_efficientnet_b1 : 30.0MB
tf_efficientnet_b1_ap : 30.0MB
tf_efficientnet_b1_ns : 30.0MB
tf_efficientnet_b2 : 35.0MB
tf_efficientnet_b2_ap : 35.0MB
tf_efficientnet_b2_ns : 35.0MB
tf_efficientnet_b3 : 47.0MB
tf_efficientnet_b3_ap : 47.0MB
tf_efficientnet_b3_ns : 47.0MB
tf_efficientnet_b4 : 74.3MB
tf_efficientnet_b4_ap : 74.3MB
tf_efficientnet_b4_ns : 74.3MB
tf_efficientnet_b5 : 116.6MB
tf_efficientnet_b5_ap : 116.6MB
tf_efficientnet_b5_ns : 116.6MB
tf_efficientnet_b6 : 165.0MB
tf_efficientnet_b6_ap : 165.0MB
tf_efficientnet_b6_ns : 165.0MB
tf_efficientnet_b7 : 254.3MB
tf_efficientnet_b7_ap : 254.3MB
tf_efficientnet_b7_ns : 254.3MB
tf_efficientnet_b8 : 334.9MB
tf_efficientnet_b8_ap : 334.9MB
tf_efficientnet_cc_b0_4e : 50.9MB
tf_efficientnet_cc_b0_8e : 91.8MB
tf_efficientnet_cc_b1_8e : 151.7MB
tf_efficientnet_el : 40.7MB
tf_efficientnet_em : 26.6MB
tf_efficientnet_es : 20.9MB
tf_efficientnet_l2_ns : 1836.4MB
tf_efficientnet_l2_ns_475 : 1836.4MB
tf_efficientnet_lite0 : 17.9MB
tf_efficientnet_lite1 : 20.9MB
tf_efficientnet_lite2 : 23.5MB
tf_efficientnet_lite3 : 31.6MB
tf_efficientnet_lite4 : 50.0MB
tf_efficientnetv2_b0 : 27.5MB
tf_efficientnetv2_b1 : 31.3MB
tf_efficientnetv2_b2 : 38.8MB
tf_efficientnetv2_b3 : 55.2MB
tf_efficientnetv2_l : 454.1MB
tf_efficientnetv2_l_in21ft1k : 454.1MB
tf_efficientnetv2_l_in21k : 555.9MB
tf_efficientnetv2_m : 207.6MB
tf_efficientnetv2_m_in21ft1k : 207.6MB
tf_efficientnetv2_m_in21k : 309.5MB
tf_efficientnetv2_s : 82.4MB
tf_efficientnetv2_s_in21ft1k : 82.4MB
tf_efficientnetv2_s_in21k : 184.3MB
tf_efficientnetv2_xl_in21ft1k : 796.9MB
tf_efficientnetv2_xl_in21k : 898.7MB
for model_name in timm.list_models("*resnet*", pretrained=True):
model = timm.create_model(model_name)
model_size = calc_model_size(model)
print(f"{model_name:40}: {model_size:6.1f}MB")
cspresnet50 : 82.6MB
eca_resnet33ts : 75.2MB
ecaresnet26t : 61.2MB
ecaresnet50d : 97.8MB
ecaresnet50d_pruned : 76.2MB
ecaresnet50t : 97.8MB
ecaresnet101d : 170.4MB
ecaresnet101d_pruned : 95.1MB
ecaresnet269d : 390.4MB
ecaresnetlight : 115.3MB
ens_adv_inception_resnet_v2 : 213.3MB
gcresnet33ts : 76.0MB
gcresnet50t : 99.0MB
gluon_resnet18_v1b : 44.6MB
gluon_resnet34_v1b : 83.2MB
gluon_resnet50_v1b : 97.7MB
gluon_resnet50_v1c : 97.8MB
gluon_resnet50_v1d : 97.8MB
gluon_resnet50_v1s : 98.2MB
gluon_resnet101_v1b : 170.3MB
gluon_resnet101_v1c : 170.4MB
gluon_resnet101_v1d : 170.4MB
gluon_resnet101_v1s : 170.8MB
gluon_resnet152_v1b : 230.2MB
gluon_resnet152_v1c : 230.3MB
gluon_resnet152_v1d : 230.3MB
gluon_resnet152_v1s : 230.7MB
inception_resnet_v2 : 213.3MB
lambda_resnet26rpt_256 : 44.1MB
lambda_resnet26t : 41.9MB
lambda_resnet50ts : 82.4MB
legacy_seresnet18 : 45.0MB
legacy_seresnet34 : 83.8MB
legacy_seresnet50 : 107.4MB
legacy_seresnet101 : 188.6MB
legacy_seresnet152 : 255.5MB
nf_resnet50 : 97.5MB
resnet10t : 20.8MB
resnet14t : 38.5MB
resnet18 : 44.6MB
resnet18d : 44.7MB
resnet26 : 61.1MB
resnet26d : 61.2MB
resnet26t : 61.2MB
resnet32ts : 68.7MB
resnet33ts : 75.2MB
resnet34 : 83.2MB
resnet34d : 83.3MB
resnet50 : 97.7MB
resnet50_gn : 97.5MB
resnet50d : 97.8MB
resnet51q : 136.5MB
resnet61q : 140.9MB
resnet101 : 170.3MB
resnet101d : 170.4MB
resnet152 : 230.2MB
resnet152d : 230.3MB
resnet200d : 247.5MB
resnetaa50 : 97.7MB
resnetblur50 : 97.7MB
resnetrs50 : 136.4MB
resnetrs101 : 243.1MB
resnetrs152 : 331.0MB
resnetrs200 : 356.2MB
resnetrs270 : 496.3MB
resnetrs350 : 626.6MB
resnetrs420 : 733.4MB
resnetv2_50 : 97.6MB
resnetv2_50d_evos : 97.6MB
resnetv2_50d_gn : 97.5MB
resnetv2_50x1_bit_distilled : 97.5MB
resnetv2_50x1_bitm : 97.5MB
resnetv2_50x1_bitm_in21k : 260.4MB
resnetv2_50x3_bitm : 829.0MB
resnetv2_50x3_bitm_in21k : 1317.6MB
resnetv2_101 : 170.3MB
resnetv2_101x1_bitm : 169.9MB
resnetv2_101x1_bitm_in21k : 332.8MB
resnetv2_101x3_bitm : 1479.9MB
resnetv2_101x3_bitm_in21k : 1968.4MB
resnetv2_152x2_bit_teacher : 901.5MB
resnetv2_152x2_bit_teacher_384 : 901.5MB
resnetv2_152x2_bitm : 901.5MB
resnetv2_152x2_bitm_in21k : 1227.3MB
resnetv2_152x4_bitm : 3572.6MB
resnetv2_152x4_bitm_in21k : 4224.0MB
seresnet33ts : 75.6MB
seresnet50 : 107.4MB
seresnet152d : 255.6MB
skresnet18 : 45.7MB
skresnet34 : 85.1MB
ssl_resnet18 : 44.6MB
ssl_resnet50 : 97.7MB
swsl_resnet18 : 44.6MB
swsl_resnet50 : 97.7MB
tresnet_l : 214.0MB
tresnet_l_448 : 214.0MB
tresnet_m : 120.0MB
tresnet_m_448 : 120.0MB
tresnet_m_miil_in21k : 199.9MB
tresnet_v2_l : 176.6MB
tresnet_xl : 299.8MB
tresnet_xl_448 : 299.8MB
tv_resnet34 : 83.2MB
tv_resnet50 : 97.7MB
tv_resnet101 : 170.3MB
tv_resnet152 : 230.2MB
wide_resnet50_2 : 263.0MB
wide_resnet101_2 : 484.6MB
モデルを作成する#
model = timm.create_model('efficientnet_b0', pretrained=True)
# model = timm.create_model('mobilenetv2_050')
output = model(torch.rand(1, 3, 224, 224))
print(output.shape)
Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b0_ra-3dd342df.pth" to /Users/akirakawai/.cache/torch/hub/checkpoints/efficientnet_b0_ra-3dd342df.pth
torch.Size([1, 1000])
入力や出力のCHを変更することも可能#
model = timm.create_model('efficientnet_b0', pretrained=True, in_chans=10)
output = model(torch.rand(1, 10, 224, 224))
print(output.shape)
torch.Size([1, 1000])
model = timm.create_model('efficientnet_b0', pretrained=True, num_classes=10)
output = model(torch.rand(1, 3, 224, 224))
print(output.shape)
torch.Size([1, 10])
model = timm.create_model('efficientnet_b0', pretrained=True, in_chans=5, num_classes=5)
output = model(torch.rand(1, 5, 224, 224))
print(output.shape)
torch.Size([1, 5])
weightの読み込み方法の変更#
from timm.models.resnet import ResNet, BasicBlock, default_cfgs
from timm.models.helpers import load_pretrained
from copy import deepcopy
resnet34_default_cfg = default_cfgs['resnet34']
resnet34 = ResNet(BasicBlock, layers=[3, 4, 6, 3], in_chans=1)
resnet34.default_cfg = deepcopy(resnet34_default_cfg)
resnet34.conv1
Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)