ArtificialCoder02 commited on
Commit
6892c1c
·
1 Parent(s): 9e0d2af

Uploaded Version 0.1 of App and Models

Browse files
Files changed (1) hide show
  1. app.py +177 -0
app.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import time
4
+ import random
5
+ import torch
6
+ import torchvision.transforms as transforms
7
+ import gradio as gr
8
+ import matplotlib.pyplot as plt
9
+
10
+ from models import get_model
11
+ from dotmap import DotMap
12
+ from PIL import Image
13
+
14
+ #os.environ['TERM'] = 'linux'
15
+ #os.environ['TERMINFO'] = '/etc/terminfo'
16
+
17
+ # args
18
+ args = DotMap()
19
+ args.deploy = 'vanilla'
20
+ args.arch = 'dino_small_patch16'
21
+ args.no_pretrain = True
22
+ args.resume = 'https://huggingface.co/spaces/ArtificialCoder02/pmf_zeroshot/blob/main/best.pth'
23
+ args.api_key = 'AIzaSyAFkOGnXhy-2ZB0imDvNNqf2rHb98vR_qY'
24
+ args.cx = '06d75168141bc47f1'
25
+
26
+
27
+ # model
28
+ device = 'cpu' #torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
+ model = get_model(args)
30
+ model.to(device)
31
+ checkpoint = torch.hub.load_state_dict_from_url(args.resume, map_location='cpu')
32
+ model.load_state_dict(checkpoint['model'], strict=True)
33
+
34
+
35
+ # image transforms
36
+ def test_transform():
37
+ def _convert_image_to_rgb(im):
38
+ return im.convert('RGB')
39
+
40
+ return transforms.Compose([
41
+ transforms.Resize(256),
42
+ transforms.CenterCrop(224),
43
+ _convert_image_to_rgb,
44
+ transforms.ToTensor(),
45
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
46
+ std=[0.229, 0.224, 0.225]),
47
+ ])
48
+
49
+ preprocess = test_transform()
50
+
51
+ @torch.no_grad()
52
+ def denormalize(x, mean, std):
53
+ # 3, H, W
54
+ t = x.clone()
55
+ t.mul_(std).add_(mean)
56
+ return torch.clamp(t, 0, 1)
57
+
58
+
59
+ # Google image search
60
+ from google_images_search import GoogleImagesSearch
61
+
62
+ class MyGIS(GoogleImagesSearch):
63
+ def __enter__(self):
64
+ return self
65
+ def __exit__(self, exc_type, exc_val, exc_tb):
66
+ return
67
+
68
+ # define search params
69
+ # option for commonly used search param are shown below for easy reference.
70
+ # For param marked with '##':
71
+ # - Multiselect is currently not feasible. Choose ONE option only
72
+ # - This param can also be omitted from _search_params if you do not wish to define any value
73
+ _search_params = {
74
+ 'q': '...',
75
+ 'num': 10,
76
+ 'fileType': 'png', #'jpg|gif|png',
77
+ 'rights': 'cc_publicdomain', #'cc_publicdomain|cc_attribute|cc_sharealike|cc_noncommercial|cc_nonderived',
78
+ #'safe': 'active|high|medium|off|safeUndefined', ##
79
+ 'imgType': 'photo', #'clipart|face|lineart|stock|photo|animated|imgTypeUndefined', ##
80
+ #'imgSize': 'huge|icon|large|medium|small|xlarge|xxlarge|imgSizeUndefined', ##
81
+ #'imgDominantColor': 'black|blue|brown|gray|green|orange|pink|purple|red|teal|white|yellow|imgDominantColorUndefined', ##
82
+ 'imgColorType': 'color', #'color|gray|mono|trans|imgColorTypeUndefined' ##
83
+ }
84
+
85
+
86
+ # Gradio UI
87
+ def inference(query, labels, n_supp=10,
88
+ file_type='png', rights='cc_publicdomain',
89
+ image_type='photo', color_type='color'):
90
+ '''
91
+ query: PIL image
92
+ labels: list of class names
93
+ '''
94
+ labels = labels.split(',')
95
+ n_supp = int(n_supp)
96
+
97
+ _search_params['num'] = n_supp
98
+ _search_params['fileType'] = file_type
99
+ _search_params['rights'] = rights
100
+ _search_params['imgType'] = image_type
101
+ _search_params['imgColorType'] = color_type
102
+
103
+ fig, axs = plt.subplots(len(labels), n_supp, figsize=(n_supp*4, len(labels)*4))
104
+
105
+ with torch.no_grad():
106
+ # query image
107
+ query = preprocess(query).unsqueeze(0).unsqueeze(0).to(device) # (1, 1, 3, H, W)
108
+
109
+ supp_x = []
110
+ supp_y = []
111
+
112
+ # search support images
113
+ for idx, y in enumerate(labels):
114
+ gis = GoogleImagesSearch(args.api_key, args.cx)
115
+ _search_params['q'] = y
116
+ gis.search(search_params=_search_params, custom_image_name='my_image')
117
+ gis._custom_image_name = 'my_image' # fix: image name sometimes too long
118
+
119
+ for j, x in enumerate(gis.results()):
120
+ x.download('./')
121
+ x_im = Image.open(x.path)
122
+
123
+ # vis
124
+ axs[idx, j].imshow(x_im)
125
+ axs[idx, j].set_title(f'{y}{j}:{x.url}')
126
+ axs[idx, j].axis('off')
127
+
128
+ x_im = preprocess(x_im) # (3, H, W)
129
+ supp_x.append(x_im)
130
+ supp_y.append(idx)
131
+
132
+ print('Searching for support images is done.')
133
+
134
+ supp_x = torch.stack(supp_x, dim=0).unsqueeze(0).to(device) # (1, n_supp*n_labels, 3, H, W)
135
+ supp_y = torch.tensor(supp_y).long().unsqueeze(0).to(device) # (1, n_supp*n_labels)
136
+
137
+ with torch.cuda.amp.autocast(True):
138
+ output = model(supp_x, supp_y, query) # (1, 1, n_labels)
139
+
140
+ probs = output.softmax(dim=-1).detach().cpu().numpy()
141
+
142
+ return {k: float(v) for k, v in zip(labels, probs[0, 0])}, fig
143
+
144
+
145
+ # DEBUG
146
+ ##query = Image.open('../labrador-puppy.jpg')
147
+ #query = Image.open('/Users/hushell/Documents/Dan_tr.png')
148
+ ##labels = 'dog, cat'
149
+ #labels = 'girl, sussie'
150
+ #output = inference(query, labels, n_supp=2)
151
+ #print(output)
152
+
153
+
154
+ title = "P>M>F few-shot learning pipeline with Google Image Search (GIS)"
155
+ description = "Short description: We take a ViT-small backbone, which is pre-trained with DINO, and meta-trained on Meta-Dataset; for few-shot classification, we use a ProtoNet classifier. The demo can be viewed as zero-shot since the support set is built by searching images from Google. Note that you may need to play with GIS parameters to get good support examples. Besides, GIS is not very stable as search requests may fail for many reasons (e.g., number of requests reaches the limit of the day)."
156
+ article = "<p style='text-align: center'><a href='http://arxiv.org/abs/2204.07305' target='_blank'>Arxiv</a></p>"
157
+
158
+
159
+ gr.Interface(fn=inference,
160
+ inputs=[
161
+ gr.inputs.Image(label="Image to classify", type="pil"),
162
+ gr.inputs.Textbox(lines=1, label="Class hypotheses:", placeholder="Enter class names separated by ','",),
163
+ gr.inputs.Slider(minimum=2, maximum=10, step=1, label="GIS: Number of support examples per class"),
164
+ gr.inputs.Dropdown(['png', 'jpg'], default='png', label='GIS: Image file type'),
165
+ gr.inputs.Dropdown(['cc_publicdomain', 'cc_attribute', 'cc_sharealike', 'cc_noncommercial', 'cc_nonderived'], default='cc_publicdomain', label='GIS: Copy rights'),
166
+ gr.inputs.Dropdown(['clipart', 'face', 'lineart', 'stock', 'photo', 'animated', 'imgTypeUndefined'], default='photo', label='GIS: Image type'),
167
+ gr.inputs.Dropdown(['color', 'gray', 'mono', 'trans', 'imgColorTypeUndefined'], default='color', label='GIS: Image color type'),
168
+ ],
169
+ theme="grass",
170
+ outputs=[
171
+ gr.outputs.Label(label="Predicted class probabilities"),
172
+ gr.outputs.Image(type='plot', label="Support examples from Google image search"),
173
+ ],
174
+ title=title,
175
+ description=description,
176
+ article=article,
177
+ ).launch(debug=True)