File size: 1,750 Bytes
4ba5a7f
 
 
 
 
 
 
 
 
 
f11e3e8
 
 
4ba5a7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import matplotlib.pyplot as plt
from PIL import Image
from transformers import DetrImageProcessor, DetrForObjectDetection
import torch

# colors for visualization
COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125], [0.494, 0.184, 0.556], [0.466, 0.674, 0.188]]

import io

processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")

def fig2img(fig):
    buf = io.BytesIO()
    fig.savefig(buf)
    buf.seek(0)
    img = Image.open(buf)
    return img

def plot_results(image, results):
    plt.figure(figsize=(16, 10))
    plt.imshow(image)
    ax = plt.gca()
    colors = COLORS * 100
    for box, label, prob, color in zip(results["boxes"], results["labels"], results["scores"], colors):
        xmin, xmax, ymin, ymax = box[0].item(), box[2].item(), box[1].item(), box[3].item()
        ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
                                   fill=False, color=color, linewidth=3))
        text = f'{model.config.id2label[label.item()]}: {prob:0.2f}'
        ax.text(xmin, ymin, text, fontsize=15,
                bbox=dict(facecolor='yellow', alpha=0.5))
    ax.axis("off")
    return fig2img(plt.gcf())

def predict(input_img):
    inputs = processor(images=input_img, return_tensors="pt")
    outputs = model(**inputs)

    target_sizes = torch.tensor([input_img.size[::-1]])
    results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
    return plot_results(input_img, results)

import gradio as gr

demo = gr.Interface(fn=predict,
                    inputs=gr.Image(type="pil"),
                    outputs="image")
demo.launch()