import timm
import torch
import pprint
# from torchviz import make_dot, make_dot_from_trace

timmとは#

`timm` is a deep-learning library created by Ross Wightman and is a collection of SOTA computer vision models, layers, utilities, optimizers, schedulers, data-loaders, augmentations and also training/validating scripts with ability to reproduce ImageNet training results.

(参考)

https://timm.fast.ai/

https://huggingface.co/docs/timm/quickstart

いくつモデルががるのか#

aveil_models = timm.list_models()
avail_pretrained_models = timm.list_models(pretrained=True)

print(f'モデル数: {len(aveil_models)}')
print(f'モデル数(pretrained): {len(avail_pretrained_models)}')
モデル数: 964
モデル数(pretrained): 770

どんなモデルがあるのか#

timm.list_models()[:10]
['adv_inception_v3',
 'bat_resnext26ts',
 'beit_base_patch16_224',
 'beit_base_patch16_224_in22k',
 'beit_base_patch16_384',
 'beit_large_patch16_224',
 'beit_large_patch16_224_in22k',
 'beit_large_patch16_384',
 'beit_large_patch16_512',
 'beitv2_base_patch16_224']
# 検索も可能
timm.list_models("*efficientnet*", pretrained=True)
['efficientnet_b0',
 'efficientnet_b1',
 'efficientnet_b1_pruned',
 'efficientnet_b2',
 'efficientnet_b2_pruned',
 'efficientnet_b3',
 'efficientnet_b3_pruned',
 'efficientnet_b4',
 'efficientnet_el',
 'efficientnet_el_pruned',
 'efficientnet_em',
 'efficientnet_es',
 'efficientnet_es_pruned',
 'efficientnet_lite0',
 'efficientnetv2_rw_m',
 'efficientnetv2_rw_s',
 'efficientnetv2_rw_t',
 'gc_efficientnetv2_rw_t',
 'tf_efficientnet_b0',
 'tf_efficientnet_b0_ap',
 'tf_efficientnet_b0_ns',
 'tf_efficientnet_b1',
 'tf_efficientnet_b1_ap',
 'tf_efficientnet_b1_ns',
 'tf_efficientnet_b2',
 'tf_efficientnet_b2_ap',
 'tf_efficientnet_b2_ns',
 'tf_efficientnet_b3',
 'tf_efficientnet_b3_ap',
 'tf_efficientnet_b3_ns',
 'tf_efficientnet_b4',
 'tf_efficientnet_b4_ap',
 'tf_efficientnet_b4_ns',
 'tf_efficientnet_b5',
 'tf_efficientnet_b5_ap',
 'tf_efficientnet_b5_ns',
 'tf_efficientnet_b6',
 'tf_efficientnet_b6_ap',
 'tf_efficientnet_b6_ns',
 'tf_efficientnet_b7',
 'tf_efficientnet_b7_ap',
 'tf_efficientnet_b7_ns',
 'tf_efficientnet_b8',
 'tf_efficientnet_b8_ap',
 'tf_efficientnet_cc_b0_4e',
 'tf_efficientnet_cc_b0_8e',
 'tf_efficientnet_cc_b1_8e',
 'tf_efficientnet_el',
 'tf_efficientnet_em',
 'tf_efficientnet_es',
 'tf_efficientnet_l2_ns',
 'tf_efficientnet_l2_ns_475',
 'tf_efficientnet_lite0',
 'tf_efficientnet_lite1',
 'tf_efficientnet_lite2',
 'tf_efficientnet_lite3',
 'tf_efficientnet_lite4',
 'tf_efficientnetv2_b0',
 'tf_efficientnetv2_b1',
 'tf_efficientnetv2_b2',
 'tf_efficientnetv2_b3',
 'tf_efficientnetv2_l',
 'tf_efficientnetv2_l_in21ft1k',
 'tf_efficientnetv2_l_in21k',
 'tf_efficientnetv2_m',
 'tf_efficientnetv2_m_in21ft1k',
 'tf_efficientnetv2_m_in21k',
 'tf_efficientnetv2_s',
 'tf_efficientnetv2_s_in21ft1k',
 'tf_efficientnetv2_s_in21k',
 'tf_efficientnetv2_xl_in21ft1k',
 'tf_efficientnetv2_xl_in21k']
# 検索も可能
timm.list_models("*YOLO*", pretrained=True)
# 検索も可能
timm.list_models("*mobilenet*", pretrained=True)
['mobilenetv2_050',
 'mobilenetv2_100',
 'mobilenetv2_110d',
 'mobilenetv2_120d',
 'mobilenetv2_140',
 'mobilenetv3_large_100',
 'mobilenetv3_large_100_miil',
 'mobilenetv3_large_100_miil_in21k',
 'mobilenetv3_rw',
 'mobilenetv3_small_050',
 'mobilenetv3_small_075',
 'mobilenetv3_small_100',
 'tf_mobilenetv3_large_075',
 'tf_mobilenetv3_large_100',
 'tf_mobilenetv3_large_minimal_100',
 'tf_mobilenetv3_small_075',
 'tf_mobilenetv3_small_100',
 'tf_mobilenetv3_small_minimal_100']

それぞれのモデルのサイズを計測してみる#

参考: https://discuss.pytorch.org/t/finding-model-size/130275

def calc_model_size(model):
    param_size = 0
    for param in model.parameters():
        param_size += param.nelement() * param.element_size()
    buffer_size = 0
    for buffer in model.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()

    size_all_mb = (param_size + buffer_size) / 1024**2
    return size_all_mb
for model_name in timm.list_models("*mobilenet*", pretrained=True):
    model = timm.create_model(model_name)
    model_size = calc_model_size(model)
    print(f"{model_name:40}: {model_size:6.1f}MB")
mobilenetv2_050                         :    7.6MB
mobilenetv2_100                         :   13.5MB
mobilenetv2_110d                        :   17.4MB
mobilenetv2_120d                        :   22.5MB
mobilenetv2_140                         :   23.5MB
mobilenetv3_large_100                   :   21.0MB
mobilenetv3_large_100_miil              :   21.0MB
mobilenetv3_large_100_miil_in21k        :   71.0MB
mobilenetv3_rw                          :   21.0MB
mobilenetv3_small_050                   :    6.1MB
mobilenetv3_small_075                   :    7.8MB
mobilenetv3_small_100                   :    9.7MB
tf_mobilenetv3_large_075                :   15.3MB
tf_mobilenetv3_large_100                :   21.0MB
tf_mobilenetv3_large_minimal_100        :   15.1MB
tf_mobilenetv3_small_075                :    7.8MB
tf_mobilenetv3_small_100                :    9.7MB
tf_mobilenetv3_small_minimal_100        :    7.8MB
for model_name in timm.list_models("*efficientnet*", pretrained=True):
    model = timm.create_model(model_name)
    model_size = calc_model_size(model)
    print(f"{model_name:40}: {model_size:6.1f}MB")
efficientnet_b0                         :  20.3MB
efficientnet_b1                         :  30.0MB
efficientnet_b1_pruned                  :  24.3MB
efficientnet_b2                         :  35.0MB
efficientnet_b2_pruned                  :  31.9MB
efficientnet_b3                         :  47.0MB
efficientnet_b3_pruned                  :  37.8MB
efficientnet_b4                         :  74.3MB
efficientnet_el                         :  40.7MB
efficientnet_el_pruned                  :  40.7MB
efficientnet_em                         :  26.6MB
efficientnet_es                         :  20.9MB
efficientnet_es_pruned                  :  20.9MB
efficientnet_lite0                      :  17.9MB
efficientnetv2_rw_m                     : 204.3MB
efficientnetv2_rw_s                     :  91.9MB
efficientnetv2_rw_t                     :  52.5MB
gc_efficientnetv2_rw_t                  :  52.6MB
tf_efficientnet_b0                      :  20.3MB
tf_efficientnet_b0_ap                   :  20.3MB
tf_efficientnet_b0_ns                   :  20.3MB
tf_efficientnet_b1                      :  30.0MB
tf_efficientnet_b1_ap                   :  30.0MB
tf_efficientnet_b1_ns                   :  30.0MB
tf_efficientnet_b2                      :  35.0MB
tf_efficientnet_b2_ap                   :  35.0MB
tf_efficientnet_b2_ns                   :  35.0MB
tf_efficientnet_b3                      :  47.0MB
tf_efficientnet_b3_ap                   :  47.0MB
tf_efficientnet_b3_ns                   :  47.0MB
tf_efficientnet_b4                      :  74.3MB
tf_efficientnet_b4_ap                   :  74.3MB
tf_efficientnet_b4_ns                   :  74.3MB
tf_efficientnet_b5                      : 116.6MB
tf_efficientnet_b5_ap                   : 116.6MB
tf_efficientnet_b5_ns                   : 116.6MB
tf_efficientnet_b6                      : 165.0MB
tf_efficientnet_b6_ap                   : 165.0MB
tf_efficientnet_b6_ns                   : 165.0MB
tf_efficientnet_b7                      : 254.3MB
tf_efficientnet_b7_ap                   : 254.3MB
tf_efficientnet_b7_ns                   : 254.3MB
tf_efficientnet_b8                      : 334.9MB
tf_efficientnet_b8_ap                   : 334.9MB
tf_efficientnet_cc_b0_4e                :  50.9MB
tf_efficientnet_cc_b0_8e                :  91.8MB
tf_efficientnet_cc_b1_8e                : 151.7MB
tf_efficientnet_el                      :  40.7MB
tf_efficientnet_em                      :  26.6MB
tf_efficientnet_es                      :  20.9MB
tf_efficientnet_l2_ns                   : 1836.4MB
tf_efficientnet_l2_ns_475               : 1836.4MB
tf_efficientnet_lite0                   :  17.9MB
tf_efficientnet_lite1                   :  20.9MB
tf_efficientnet_lite2                   :  23.5MB
tf_efficientnet_lite3                   :  31.6MB
tf_efficientnet_lite4                   :  50.0MB
tf_efficientnetv2_b0                    :  27.5MB
tf_efficientnetv2_b1                    :  31.3MB
tf_efficientnetv2_b2                    :  38.8MB
tf_efficientnetv2_b3                    :  55.2MB
tf_efficientnetv2_l                     : 454.1MB
tf_efficientnetv2_l_in21ft1k            : 454.1MB
tf_efficientnetv2_l_in21k               : 555.9MB
tf_efficientnetv2_m                     : 207.6MB
tf_efficientnetv2_m_in21ft1k            : 207.6MB
tf_efficientnetv2_m_in21k               : 309.5MB
tf_efficientnetv2_s                     :  82.4MB
tf_efficientnetv2_s_in21ft1k            :  82.4MB
tf_efficientnetv2_s_in21k               : 184.3MB
tf_efficientnetv2_xl_in21ft1k           : 796.9MB
tf_efficientnetv2_xl_in21k              : 898.7MB
for model_name in timm.list_models("*resnet*", pretrained=True):
    model = timm.create_model(model_name)
    model_size = calc_model_size(model)
    print(f"{model_name:40}: {model_size:6.1f}MB")
cspresnet50                             :   82.6MB
eca_resnet33ts                          :   75.2MB
ecaresnet26t                            :   61.2MB
ecaresnet50d                            :   97.8MB
ecaresnet50d_pruned                     :   76.2MB
ecaresnet50t                            :   97.8MB
ecaresnet101d                           :  170.4MB
ecaresnet101d_pruned                    :   95.1MB
ecaresnet269d                           :  390.4MB
ecaresnetlight                          :  115.3MB
ens_adv_inception_resnet_v2             :  213.3MB
gcresnet33ts                            :   76.0MB
gcresnet50t                             :   99.0MB
gluon_resnet18_v1b                      :   44.6MB
gluon_resnet34_v1b                      :   83.2MB
gluon_resnet50_v1b                      :   97.7MB
gluon_resnet50_v1c                      :   97.8MB
gluon_resnet50_v1d                      :   97.8MB
gluon_resnet50_v1s                      :   98.2MB
gluon_resnet101_v1b                     :  170.3MB
gluon_resnet101_v1c                     :  170.4MB
gluon_resnet101_v1d                     :  170.4MB
gluon_resnet101_v1s                     :  170.8MB
gluon_resnet152_v1b                     :  230.2MB
gluon_resnet152_v1c                     :  230.3MB
gluon_resnet152_v1d                     :  230.3MB
gluon_resnet152_v1s                     :  230.7MB
inception_resnet_v2                     :  213.3MB
lambda_resnet26rpt_256                  :   44.1MB
lambda_resnet26t                        :   41.9MB
lambda_resnet50ts                       :   82.4MB
legacy_seresnet18                       :   45.0MB
legacy_seresnet34                       :   83.8MB
legacy_seresnet50                       :  107.4MB
legacy_seresnet101                      :  188.6MB
legacy_seresnet152                      :  255.5MB
nf_resnet50                             :   97.5MB
resnet10t                               :   20.8MB
resnet14t                               :   38.5MB
resnet18                                :   44.6MB
resnet18d                               :   44.7MB
resnet26                                :   61.1MB
resnet26d                               :   61.2MB
resnet26t                               :   61.2MB
resnet32ts                              :   68.7MB
resnet33ts                              :   75.2MB
resnet34                                :   83.2MB
resnet34d                               :   83.3MB
resnet50                                :   97.7MB
resnet50_gn                             :   97.5MB
resnet50d                               :   97.8MB
resnet51q                               :  136.5MB
resnet61q                               :  140.9MB
resnet101                               :  170.3MB
resnet101d                              :  170.4MB
resnet152                               :  230.2MB
resnet152d                              :  230.3MB
resnet200d                              :  247.5MB
resnetaa50                              :   97.7MB
resnetblur50                            :   97.7MB
resnetrs50                              :  136.4MB
resnetrs101                             :  243.1MB
resnetrs152                             :  331.0MB
resnetrs200                             :  356.2MB
resnetrs270                             :  496.3MB
resnetrs350                             :  626.6MB
resnetrs420                             :  733.4MB
resnetv2_50                             :   97.6MB
resnetv2_50d_evos                       :   97.6MB
resnetv2_50d_gn                         :   97.5MB
resnetv2_50x1_bit_distilled             :   97.5MB
resnetv2_50x1_bitm                      :   97.5MB
resnetv2_50x1_bitm_in21k                :  260.4MB
resnetv2_50x3_bitm                      :  829.0MB
resnetv2_50x3_bitm_in21k                : 1317.6MB
resnetv2_101                            :  170.3MB
resnetv2_101x1_bitm                     :  169.9MB
resnetv2_101x1_bitm_in21k               :  332.8MB
resnetv2_101x3_bitm                     : 1479.9MB
resnetv2_101x3_bitm_in21k               : 1968.4MB
resnetv2_152x2_bit_teacher              :  901.5MB
resnetv2_152x2_bit_teacher_384          :  901.5MB
resnetv2_152x2_bitm                     :  901.5MB
resnetv2_152x2_bitm_in21k               : 1227.3MB
resnetv2_152x4_bitm                     : 3572.6MB
resnetv2_152x4_bitm_in21k               : 4224.0MB
seresnet33ts                            :   75.6MB
seresnet50                              :  107.4MB
seresnet152d                            :  255.6MB
skresnet18                              :   45.7MB
skresnet34                              :   85.1MB
ssl_resnet18                            :   44.6MB
ssl_resnet50                            :   97.7MB
swsl_resnet18                           :   44.6MB
swsl_resnet50                           :   97.7MB
tresnet_l                               :  214.0MB
tresnet_l_448                           :  214.0MB
tresnet_m                               :  120.0MB
tresnet_m_448                           :  120.0MB
tresnet_m_miil_in21k                    :  199.9MB
tresnet_v2_l                            :  176.6MB
tresnet_xl                              :  299.8MB
tresnet_xl_448                          :  299.8MB
tv_resnet34                             :   83.2MB
tv_resnet50                             :   97.7MB
tv_resnet101                            :  170.3MB
tv_resnet152                            :  230.2MB
wide_resnet50_2                         :  263.0MB
wide_resnet101_2                        :  484.6MB

モデルを作成する#

model = timm.create_model('efficientnet_b0', pretrained=True)
# model = timm.create_model('mobilenetv2_050')
output = model(torch.rand(1, 3, 224, 224))
print(output.shape)
Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b0_ra-3dd342df.pth" to /Users/akirakawai/.cache/torch/hub/checkpoints/efficientnet_b0_ra-3dd342df.pth
torch.Size([1, 1000])

入力や出力のCHを変更することも可能#

model = timm.create_model('efficientnet_b0', pretrained=True, in_chans=10)
output = model(torch.rand(1, 10, 224, 224))
print(output.shape)
torch.Size([1, 1000])
model = timm.create_model('efficientnet_b0', pretrained=True, num_classes=10)
output = model(torch.rand(1, 3, 224, 224))
print(output.shape)
torch.Size([1, 10])
model = timm.create_model('efficientnet_b0', pretrained=True, in_chans=5, num_classes=5)
output = model(torch.rand(1, 5, 224, 224))
print(output.shape)
torch.Size([1, 5])

weightの読み込み方法の変更#

from timm.models.resnet import ResNet, BasicBlock, default_cfgs
from timm.models.helpers import load_pretrained
from copy import deepcopy
resnet34_default_cfg = default_cfgs['resnet34']
resnet34 = ResNet(BasicBlock, layers=[3, 4, 6, 3], in_chans=1)
resnet34.default_cfg = deepcopy(resnet34_default_cfg)

resnet34.conv1
Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)