Ahsen Khaliq
commited on
Commit
·
0c301a6
1
Parent(s):
cdfa28a
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import timm
|
5 |
+
import torchvision
|
6 |
+
import torchvision.transforms as T
|
7 |
+
|
8 |
+
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
9 |
+
import gradio as gr
|
10 |
+
|
11 |
+
torch.set_grad_enabled(False);
|
12 |
+
|
13 |
+
with open("imagenet_classes.txt", "r") as f:
|
14 |
+
imagenet_categories = [s.strip() for s in f.readlines()]
|
15 |
+
|
16 |
+
transform = T.Compose([
|
17 |
+
T.Resize(256, interpolation=3),
|
18 |
+
T.CenterCrop(224),
|
19 |
+
T.ToTensor(),
|
20 |
+
T.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
|
21 |
+
])
|
22 |
+
|
23 |
+
model = torch.hub.load('facebookresearch/deit:main', 'deit_base_patch16_224', pretrained=True)
|
24 |
+
|
25 |
+
def detr(im):
|
26 |
+
img = transform(im).unsqueeze(0)
|
27 |
+
|
28 |
+
# compute the predictions
|
29 |
+
out = model(img)
|
30 |
+
|
31 |
+
# and convert them into probabilities
|
32 |
+
scores = torch.nn.functional.softmax(out, dim=-1)[0]
|
33 |
+
|
34 |
+
# finally get the index of the prediction with highest score
|
35 |
+
topk_scores, topk_label = torch.topk(scores, k=5, dim=-1)
|
36 |
+
|
37 |
+
|
38 |
+
d = {}
|
39 |
+
for i in range(5):
|
40 |
+
pred_name = imagenet_categories[topk_label[i]]
|
41 |
+
pred_name = f"{pred_name:<25}"
|
42 |
+
score = topk_scores[i]
|
43 |
+
score = f"{score:.3f}"
|
44 |
+
d[pred_name] = score
|
45 |
+
return d
|
46 |
+
|
47 |
+
inputs = gr.inputs.Image(type='pil', label="Original Image")
|
48 |
+
outputs = gr.outputs.Label(type="confidences",num_top_classes=5)
|
49 |
+
|
50 |
+
title = "Deit"
|
51 |
+
description = "demo for Facebook DeiT: Data-efficient Image Transformers. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."
|
52 |
+
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2012.12877'>Training data-efficient image transformers & distillation through attention</a> | <a href='https://github.com/facebookresearch/deit'>Github Repo</a></p>"
|
53 |
+
|
54 |
+
examples = [
|
55 |
+
['deer.jpg'],
|
56 |
+
['cat.jpg']
|
57 |
+
]
|
58 |
+
|
59 |
+
gr.Interface(detr, inputs, outputs, title=title, description=description, article=article, examples=examples).launch()
|