RayanRen commited on
Commit
4ba5a7f
·
1 Parent(s): f83217d

first commit

Browse files
Files changed (2) hide show
  1. app.py +48 -0
  2. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ from PIL import Image
3
+ from transformers import DetrImageProcessor, DetrForObjectDetection
4
+ import torch
5
+
6
+ # colors for visualization
7
+ COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125], [0.494, 0.184, 0.556], [0.466, 0.674, 0.188]]
8
+
9
+ import io
10
+
11
+ def fig2img(fig):
12
+ buf = io.BytesIO()
13
+ fig.savefig(buf)
14
+ buf.seek(0)
15
+ img = Image.open(buf)
16
+ return img
17
+
18
+ def plot_results(image, results):
19
+ plt.figure(figsize=(16, 10))
20
+ plt.imshow(image)
21
+ ax = plt.gca()
22
+ colors = COLORS * 100
23
+ for box, label, prob, color in zip(results["boxes"], results["labels"], results["scores"], colors):
24
+ xmin, xmax, ymin, ymax = box[0].item(), box[2].item(), box[1].item(), box[3].item()
25
+ ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
26
+ fill=False, color=color, linewidth=3))
27
+ text = f'{model.config.id2label[label.item()]}: {prob:0.2f}'
28
+ ax.text(xmin, ymin, text, fontsize=15,
29
+ bbox=dict(facecolor='yellow', alpha=0.5))
30
+ ax.axis("off")
31
+ return fig2img(plt.gcf())
32
+
33
+ def predict(input_img):
34
+ processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
35
+ model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
36
+ inputs = processor(images=input_img, return_tensors="pt")
37
+ outputs = model(**inputs)
38
+
39
+ target_sizes = torch.tensor([input_img.size[::-1]])
40
+ results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
41
+ return plot_results(input_img, results)
42
+
43
+ import gradio as gr
44
+
45
+ demo = gr.Interface(fn=predict,
46
+ inputs=gr.Image(type="pil"),
47
+ outputs="image")
48
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch==1.12.0
2
+ torchvision==0.13.0
3
+ gradio==3.1.4
4
+ transformers==4.25.1