RayanRen's picture
Update app.py
f11e3e8
import matplotlib.pyplot as plt
from PIL import Image
from transformers import DetrImageProcessor, DetrForObjectDetection
import torch
# colors for visualization
COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125], [0.494, 0.184, 0.556], [0.466, 0.674, 0.188]]
import io
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
def fig2img(fig):
buf = io.BytesIO()
fig.savefig(buf)
buf.seek(0)
img = Image.open(buf)
return img
def plot_results(image, results):
plt.figure(figsize=(16, 10))
plt.imshow(image)
ax = plt.gca()
colors = COLORS * 100
for box, label, prob, color in zip(results["boxes"], results["labels"], results["scores"], colors):
xmin, xmax, ymin, ymax = box[0].item(), box[2].item(), box[1].item(), box[3].item()
ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
fill=False, color=color, linewidth=3))
text = f'{model.config.id2label[label.item()]}: {prob:0.2f}'
ax.text(xmin, ymin, text, fontsize=15,
bbox=dict(facecolor='yellow', alpha=0.5))
ax.axis("off")
return fig2img(plt.gcf())
def predict(input_img):
inputs = processor(images=input_img, return_tensors="pt")
outputs = model(**inputs)
target_sizes = torch.tensor([input_img.size[::-1]])
results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
return plot_results(input_img, results)
import gradio as gr
demo = gr.Interface(fn=predict,
inputs=gr.Image(type="pil"),
outputs="image")
demo.launch()