Spaces:
Build error
Build error
import matplotlib.pyplot as plt | |
from PIL import Image | |
from transformers import DetrImageProcessor, DetrForObjectDetection | |
import torch | |
# colors for visualization | |
COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125], [0.494, 0.184, 0.556], [0.466, 0.674, 0.188]] | |
import io | |
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50") | |
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50") | |
def fig2img(fig): | |
buf = io.BytesIO() | |
fig.savefig(buf) | |
buf.seek(0) | |
img = Image.open(buf) | |
return img | |
def plot_results(image, results): | |
plt.figure(figsize=(16, 10)) | |
plt.imshow(image) | |
ax = plt.gca() | |
colors = COLORS * 100 | |
for box, label, prob, color in zip(results["boxes"], results["labels"], results["scores"], colors): | |
xmin, xmax, ymin, ymax = box[0].item(), box[2].item(), box[1].item(), box[3].item() | |
ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, | |
fill=False, color=color, linewidth=3)) | |
text = f'{model.config.id2label[label.item()]}: {prob:0.2f}' | |
ax.text(xmin, ymin, text, fontsize=15, | |
bbox=dict(facecolor='yellow', alpha=0.5)) | |
ax.axis("off") | |
return fig2img(plt.gcf()) | |
def predict(input_img): | |
inputs = processor(images=input_img, return_tensors="pt") | |
outputs = model(**inputs) | |
target_sizes = torch.tensor([input_img.size[::-1]]) | |
results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0] | |
return plot_results(input_img, results) | |
import gradio as gr | |
demo = gr.Interface(fn=predict, | |
inputs=gr.Image(type="pil"), | |
outputs="image") | |
demo.launch() |