Ahsen Khaliq commited on
Commit
0c301a6
·
1 Parent(s): cdfa28a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -0
app.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+
3
+ import torch
4
+ import timm
5
+ import torchvision
6
+ import torchvision.transforms as T
7
+
8
+ from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
9
+ import gradio as gr
10
+
11
+ torch.set_grad_enabled(False);
12
+
13
+ with open("imagenet_classes.txt", "r") as f:
14
+ imagenet_categories = [s.strip() for s in f.readlines()]
15
+
16
+ transform = T.Compose([
17
+ T.Resize(256, interpolation=3),
18
+ T.CenterCrop(224),
19
+ T.ToTensor(),
20
+ T.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
21
+ ])
22
+
23
+ model = torch.hub.load('facebookresearch/deit:main', 'deit_base_patch16_224', pretrained=True)
24
+
25
+ def detr(im):
26
+ img = transform(im).unsqueeze(0)
27
+
28
+ # compute the predictions
29
+ out = model(img)
30
+
31
+ # and convert them into probabilities
32
+ scores = torch.nn.functional.softmax(out, dim=-1)[0]
33
+
34
+ # finally get the index of the prediction with highest score
35
+ topk_scores, topk_label = torch.topk(scores, k=5, dim=-1)
36
+
37
+
38
+ d = {}
39
+ for i in range(5):
40
+ pred_name = imagenet_categories[topk_label[i]]
41
+ pred_name = f"{pred_name:<25}"
42
+ score = topk_scores[i]
43
+ score = f"{score:.3f}"
44
+ d[pred_name] = score
45
+ return d
46
+
47
+ inputs = gr.inputs.Image(type='pil', label="Original Image")
48
+ outputs = gr.outputs.Label(type="confidences",num_top_classes=5)
49
+
50
+ title = "Deit"
51
+ description = "demo for Facebook DeiT: Data-efficient Image Transformers. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."
52
+ article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2012.12877'>Training data-efficient image transformers & distillation through attention</a> | <a href='https://github.com/facebookresearch/deit'>Github Repo</a></p>"
53
+
54
+ examples = [
55
+ ['deer.jpg'],
56
+ ['cat.jpg']
57
+ ]
58
+
59
+ gr.Interface(detr, inputs, outputs, title=title, description=description, article=article, examples=examples).launch()