hassonofer's picture
Update README.md
cc46601 verified
---
tags:
- image-classification
- birder
- pytorch
library_name: birder
license: apache-2.0
---
# Model Card for swin_transformer_v2_s_intermediate-arabian-peninsula
A Swin Transformer v2 image classification model. The model follows a two-stage training process: first undergoing intermediate training on a large-scale dataset containing diverse bird species from around the world, then fine-tuned specifically on the `arabian-peninsula` dataset (all the relevant bird species found in the Arabian peninsula inc. rarities).
The species list is derived from data available at <https://avibase.bsc-eoc.org/checklist.jsp?region=ARA>.
## Model Details
- **Model Type:** Image classification and detection backbone
- **Model Stats:**
- Params (M): 49.5
- Input image size: 384 x 384
- **Dataset:** arabian-peninsula (735 classes)
- Intermediate training involved ~4500 species from asia, europe and eastern africa
- **Papers:**
- Swin Transformer V2: Scaling Up Capacity and Resolution: <https://arxiv.org/abs/2111.09883>
## Model Usage
### Image Classification
```python
import birder
from birder.inference.classification import infer_image
(net, model_info) = birder.load_pretrained_model("swin_transformer_v2_s_intermediate-arabian-peninsula", inference=True)
# Get the image size the model was trained on
size = birder.get_size_from_signature(model_info.signature)
# Create an inference transform
transform = birder.classification_transform(size, model_info.rgb_stats)
image = "path/to/image.jpeg" # or a PIL image, must be loaded in RGB format
(out, _) = infer_image(net, image, transform)
# out is a NumPy array with shape of (1, 735), representing class probabilities.
```
### Image Embeddings
```python
import birder
from birder.inference.classification import infer_image
(net, model_info) = birder.load_pretrained_model("swin_transformer_v2_s_intermediate-arabian-peninsula", inference=True)
# Get the image size the model was trained on
size = birder.get_size_from_signature(model_info.signature)
# Create an inference transform
transform = birder.classification_transform(size, model_info.rgb_stats)
image = "path/to/image.jpeg" # or a PIL image
(out, embedding) = infer_image(net, image, transform, return_embedding=True)
# embedding is a NumPy array with shape of (1, 768)
```
### Detection Feature Map
```python
from PIL import Image
import birder
(net, model_info) = birder.load_pretrained_model("swin_transformer_v2_s_intermediate-arabian-peninsula", inference=True)
# Get the image size the model was trained on
size = birder.get_size_from_signature(model_info.signature)
# Create an inference transform
transform = birder.classification_transform(size, model_info.rgb_stats)
image = Image.open("path/to/image.jpeg")
features = net.detection_features(transform(image).unsqueeze(0))
# features is a dict (stage name -> torch.Tensor)
print([(k, v.size()) for k, v in features.items()])
# Output example:
# [('stage1', torch.Size([1, 96, 96, 96])),
# ('stage2', torch.Size([1, 192, 48, 48])),
# ('stage3', torch.Size([1, 384, 24, 24])),
# ('stage4', torch.Size([1, 768, 12, 12]))]
```
## Citation
```bibtex
@misc{liu2022swintransformerv2scaling,
title={Swin Transformer V2: Scaling Up Capacity and Resolution},
author={Ze Liu and Han Hu and Yutong Lin and Zhuliang Yao and Zhenda Xie and Yixuan Wei and Jia Ning and Yue Cao and Zheng Zhang and Li Dong and Furu Wei and Baining Guo},
year={2022},
eprint={2111.09883},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2111.09883},
}
```