import matplotlib.pyplot as plt from PIL import Image from transformers import DetrImageProcessor, DetrForObjectDetection import torch # colors for visualization COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125], [0.494, 0.184, 0.556], [0.466, 0.674, 0.188]] import io processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50") model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50") def fig2img(fig): buf = io.BytesIO() fig.savefig(buf) buf.seek(0) img = Image.open(buf) return img def plot_results(image, results): plt.figure(figsize=(16, 10)) plt.imshow(image) ax = plt.gca() colors = COLORS * 100 for box, label, prob, color in zip(results["boxes"], results["labels"], results["scores"], colors): xmin, xmax, ymin, ymax = box[0].item(), box[2].item(), box[1].item(), box[3].item() ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color=color, linewidth=3)) text = f'{model.config.id2label[label.item()]}: {prob:0.2f}' ax.text(xmin, ymin, text, fontsize=15, bbox=dict(facecolor='yellow', alpha=0.5)) ax.axis("off") return fig2img(plt.gcf()) def predict(input_img): inputs = processor(images=input_img, return_tensors="pt") outputs = model(**inputs) target_sizes = torch.tensor([input_img.size[::-1]]) results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0] return plot_results(input_img, results) import gradio as gr demo = gr.Interface(fn=predict, inputs=gr.Image(type="pil"), outputs="image") demo.launch()