qubvel-hf's picture
Update README.md
b6e5a45 verified
---
pipeline_tag: image-segmentation
---
<!---
Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
# Instance Segmentation Example
Content:
- [PyTorch Version with Accelerate](#pytorch-version-with-accelerate)
- [Reload and Perform Inference](#reload-and-perform-inference)
## PyTorch Version with Accelerate
This model is based on the script [`run_instance_segmentation_no_trainer.py`](https://github.com/huggingface/transformers/blob/main/examples/pytorch/instance-segmentation/run_instance_segmentation_no_trainer.py).
The script uses [🤗 Accelerate](https://github.com/huggingface/accelerate) to write your own training loop in PyTorch and run it on various environments, including CPU, multi-CPU, GPU, multi-GPU, and TPU, with support for mixed precision.
First, configure the environment:
```bash
accelerate config
```
Answer the questions regarding your training environment. Then, run:
```bash
accelerate test
```
This command ensures everything is ready for training. Finally, launch training with:
```bash
accelerate launch run_instance_segmentation_no_trainer.py \
--model_name_or_path facebook/mask2former-swin-tiny-coco-instance \
--output_dir finetune-instance-segmentation-ade20k-mini-mask2former-no-trainer \
--dataset_name qubvel-hf/ade20k-mini \
--do_reduce_labels \
--image_height 256 \
--image_width 256 \
--num_train_epochs 40 \
--learning_rate 1e-5 \
--lr_scheduler_type constant \
--per_device_train_batch_size 8 \
--gradient_accumulation_steps 2 \
--dataloader_num_workers 8 \
--push_to_hub
```
## Reload and Perform Inference
You can easily load this trained model and perform inference as follows:
```python
import torch
import requests
import matplotlib.pyplot as plt
from PIL import Image
from transformers import Mask2FormerForUniversalSegmentation, Mask2FormerImageProcessor
# Load image
image = Image.open(requests.get("http://farm4.staticflickr.com/3017/3071497290_31f0393363_z.jpg", stream=True).raw)
# Load model and image processor
device = "cuda"
checkpoint = "qubvel-hf/finetune-instance-segmentation-ade20k-mini-mask2former-no-trainer"
model = Mask2FormerForUniversalSegmentation.from_pretrained(checkpoint, device_map=device)
image_processor = Mask2FormerImageProcessor.from_pretrained(checkpoint)
# Run inference on image
inputs = image_processor(images=[image], return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
# Post-process outputs
outputs = image_processor.post_process_instance_segmentation(outputs, target_sizes=[image.size[::-1]])
print("Mask shape: ", outputs[0]["segmentation"].shape)
print("Mask values: ", outputs[0]["segmentation"].unique())
for segment in outputs[0]["segments_info"]:
print("Segment: ", segment)
```
```
Mask shape: torch.Size([427, 640])
Mask values: tensor([-1., 0., 1., 2., 3., 4., 5., 6.])
Segment: {'id': 0, 'label_id': 0, 'was_fused': False, 'score': 0.946127}
Segment: {'id': 1, 'label_id': 1, 'was_fused': False, 'score': 0.961582}
Segment: {'id': 2, 'label_id': 1, 'was_fused': False, 'score': 0.968367}
Segment: {'id': 3, 'label_id': 1, 'was_fused': False, 'score': 0.819527}
Segment: {'id': 4, 'label_id': 1, 'was_fused': False, 'score': 0.655761}
Segment: {'id': 5, 'label_id': 1, 'was_fused': False, 'score': 0.531299}
Segment: {'id': 6, 'label_id': 1, 'was_fused': False, 'score': 0.929477}
```
Use the following code to visualize the results:
```python
import numpy as np
import matplotlib.pyplot as plt
segmentation = outputs[0]["segmentation"].numpy()
plt.figure(figsize=(10, 10))
plt.subplot(1, 2, 1)
plt.imshow(np.array(image))
plt.axis("off")
plt.subplot(1, 2, 2)
plt.imshow(segmentation)
plt.axis("off")
plt.show()
```
![Result](https://i.imgur.com/rZmaRjD.png)