| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from torch import nn |
| |
|
| | from timm.models import register_model |
| | from timm.models.vision_transformer import VisionTransformer, _create_vision_transformer, Mlp |
| |
|
| |
|
| | @register_model |
| | def vit_tiny_patch14_224(pretrained=False, **kwargs) -> VisionTransformer: |
| | """ ViT-Tiny (Vit-Ti/16) |
| | """ |
| | model_args = dict(patch_size=14, embed_dim=192, depth=12, num_heads=3) |
| | model = _create_vision_transformer('vit_tiny_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs)) |
| | return model |
| |
|
| |
|
| | @register_model |
| | def vit_small_patch14_224(pretrained=False, **kwargs) -> VisionTransformer: |
| | """ ViT-Small (ViT-S/16) |
| | """ |
| | model_args = dict(patch_size=14, embed_dim=384, depth=12, num_heads=6) |
| | model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs)) |
| | return model |
| |
|
| |
|
| | @register_model |
| | def vit_base_patch14_224(pretrained=False, **kwargs) -> VisionTransformer: |
| | """ ViT-Base (ViT-B/14) from original paper (https://arxiv.org/abs/2010.11929). |
| | ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. |
| | """ |
| | model_args = dict(patch_size=14, embed_dim=768, depth=12, num_heads=12) |
| | model = _create_vision_transformer('vit_base_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs)) |
| | return model |
| |
|
| |
|
| | @register_model |
| | def vit_huge_patch16_224(pretrained=False, **kwargs) -> VisionTransformer: |
| | """ ViT-Huge model (ViT-H/16) from original paper (https://arxiv.org/abs/2010.11929). |
| | """ |
| | model_args = dict(patch_size=16, embed_dim=1280, depth=32, num_heads=16) |
| | if pretrained: |
| | |
| | model = _create_vision_transformer('vit_huge_patch14_224', pretrained=True, **dict(model_args, **kwargs)) |
| | else: |
| | model = _create_vision_transformer('vit_huge_patch16_224', pretrained=False, **dict(model_args, **kwargs)) |
| | return model |
| |
|
| |
|
| | @register_model |
| | def vit_huge_patch16_224_mlpnorm(pretrained=False, **kwargs) -> VisionTransformer: |
| | """ ViT-Huge model (ViT-H/16) from original paper (https://arxiv.org/abs/2010.11929). |
| | """ |
| | model = vit_huge_patch16_224(pretrained=pretrained, **kwargs) |
| |
|
| | for m in model.modules(): |
| | if isinstance(m, Mlp) and not isinstance(m.norm, nn.LayerNorm): |
| | m.norm = nn.LayerNorm(m.fc1.out_features) |
| |
|
| | return model |
| |
|
| |
|
| | @register_model |
| | def vit_bigG_patch14_224(pretrained=False, **kwargs) -> VisionTransformer: |
| | model_args = dict(patch_size=14, embed_dim=1664, depth=48, num_heads=16, init_values=1e-6) |
| | model = _create_vision_transformer('vit_bigG_patch14', pretrained=False, **dict(model_args, **kwargs)) |
| | return model |
| |
|
| |
|