piyushgrover commited on
Commit
6917a0d
·
1 Parent(s): 0f1766c

added code files

Browse files
Files changed (4) hide show
  1. README.md +7 -10
  2. app.py +145 -0
  3. requirements.txt +8 -0
  4. utils.py +67 -0
README.md CHANGED
@@ -1,13 +1,10 @@
1
  ---
2
- title: CLIPPhotoSearchEngine
3
- emoji: 🦀
4
- colorFrom: red
5
- colorTo: purple
6
- sdk: gradio
7
- sdk_version: 3.46.1
8
- app_file: app.py
9
- pinned: false
10
  license: mit
 
 
 
 
 
11
  ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
 
 
 
 
 
 
 
 
2
  license: mit
3
+ title: YoloV3
4
+ sdk: gradio
5
+ colorFrom: yellow
6
+ colorTo: green
7
+ pinned: true
8
  ---
9
+ # yolov3
10
+ S13 ERA V1
app.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import clip
4
+ import gradio as gr
5
+ from utils import *
6
+ import os
7
+
8
+ # Load the open CLIP model
9
+ model, preprocess = clip.load("ViT-B/32", device=device)
10
+ from pathlib import Path
11
+
12
+ # Download from Github Releases
13
+ if not Path('unsplash-dataset/photo_ids.csv').exists():
14
+ os.system('''wget https://github.com/haltakov/natural-language-image-search/releases/download/1.0.0/photo_ids.csv -O unsplash-dataset/photo_ids.csv''')
15
+
16
+ if not Path('unsplash-dataset/features.npy').exists():
17
+ os.system('''wget https://github.com/haltakov/natural-language-image-search/releases/download/1.0.0/features.npy - O unsplash-dataset/features.npy''')
18
+
19
+
20
+ # Load the photo IDs
21
+ photo_ids = pd.read_csv("unsplash-dataset/photo_ids.csv")
22
+ photo_ids = list(photo_ids['photo_id'])
23
+
24
+ # Load the features vectors
25
+ photo_features = np.load("unsplash-dataset/features.npy")
26
+
27
+ # Convert features to Tensors: Float32 on CPU and Float16 on GPU
28
+ if device == "cpu":
29
+ photo_features = torch.from_numpy(photo_features).float().to(device)
30
+ else:
31
+ photo_features = torch.from_numpy(photo_features).to(device)
32
+
33
+ # Print some statistics
34
+ print(f"Photos loaded: {len(photo_ids)}")
35
+
36
+
37
+ def search_by_text_and_photo(query_text, query_img, query_photo_id=None, photo_weight=0.5):
38
+ # Encode the search query
39
+ if not query_text and not query_photo_id:
40
+ return []
41
+
42
+ text_features = encode_search_query(model, query_text)
43
+
44
+ if query_photo_id:
45
+ # Find the feature vector for the specified photo ID
46
+ query_photo_index = photo_ids.index(query_photo_id)
47
+ query_photo_features = photo_features[query_photo_index]
48
+
49
+ # Combine the test and photo queries and normalize again
50
+ search_features = text_features + query_photo_features * photo_weight
51
+ search_features /= search_features.norm(dim=-1, keepdim=True)
52
+
53
+ # Find the best match
54
+ best_photo_ids = find_best_matches(search_features, photo_features, photo_ids, 10)
55
+
56
+ elif query_img:
57
+ query_photo_features = model.encode_image(query_img)
58
+ query_photo_features = query_photo_features / query_photo_features.norm(dim=1, keepdim=True)
59
+
60
+ # Combine the test and photo queries and normalize again
61
+ search_features = text_features + query_photo_features * photo_weight
62
+ search_features /= search_features.norm(dim=-1, keepdim=True)
63
+
64
+ # Find the best match
65
+ best_photo_ids = find_best_matches(search_features, photo_features, photo_ids, 10)
66
+ else:
67
+ # Display the results
68
+ print("Test search result")
69
+ best_photo_ids = search_unslash(query_text, photo_features, photo_ids, 10)
70
+
71
+ return best_photo_ids
72
+
73
+
74
+ with gr.Blocks() as app:
75
+ with gr.Row():
76
+ gr.Markdown(
77
+ """
78
+ # CLIP Image Search Engine!
79
+ ### Enter search query or/and input image to find the similar images from the database -
80
+ """)
81
+
82
+ with gr.Row(visible=True):
83
+ with gr.Column():
84
+ with gr.Row():
85
+ search_text = gr.Textbox(value='', placeholder='Search..', label='Enter Your Query')
86
+
87
+ with gr.Row():
88
+ submit_btn = gr.Button("Submit", variant='primary')
89
+ clear_btn = gr.ClearButton()
90
+
91
+ with gr.Column():
92
+ search_image = gr.Image(label='Upload Image or Select from results')
93
+
94
+ with gr.Row(visible=True):
95
+ output_images = gr.Gallery(allow_preview=False, label='Results.. ', info='',
96
+ value=[], columns=5, rows=2)
97
+
98
+ output_image_ids = gr.State([])
99
+
100
+
101
+ def clear_data():
102
+ return {
103
+ search_image: None,
104
+ output_images: None,
105
+ search_text: None
106
+ }
107
+
108
+
109
+ clear_btn.click(clear_data, None, [search_image, output_images, search_text])
110
+
111
+
112
+ def on_select(evt: gr.SelectData, output_image_ids):
113
+ return {
114
+ search_image: f"https://unsplash.com/photos/{output_image_ids[evt.index]}/download?w=100"
115
+ }
116
+
117
+
118
+ output_images.select(on_select, output_image_ids, search_image)
119
+
120
+
121
+ def func_search(query, img):
122
+ best_photo_ids = search_by_text_and_photo(query, img)
123
+ img_urls = []
124
+ for p_id in best_photo_ids:
125
+ url = f"https://unsplash.com/photos/{p_id}/download?w=100"
126
+ img_urls.append(url)
127
+
128
+ valid_images = filter_invalid_urls(img_urls, best_photo_ids)
129
+
130
+ return {
131
+ output_image_ids: valid_images['image_ids'],
132
+ output_images: valid_images['image_urls']
133
+ }
134
+
135
+
136
+ submit_btn.click(
137
+ func_search,
138
+ [search_text, search_image],
139
+ [output_images, output_image_ids]
140
+ )
141
+
142
+ '''
143
+ Launch the app
144
+ '''
145
+ app.launch()
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ pillow
4
+ gradio
5
+ numpy
6
+ panda
7
+ grequests
8
+ git+https://github.com/openai/CLIP.git
utils.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from gevent import monkey
2
+ def stub(*args, **kwargs): # pylint: disable=unused-argument
3
+ pass
4
+ monkey.patch_all = stub
5
+ import grequests
6
+ import requests
7
+
8
+ import torch
9
+ import clip
10
+ device = "cuda" if torch.cuda.is_available() else "cpu"
11
+
12
+ def encode_search_query(model, search_query):
13
+ with torch.no_grad():
14
+ tokenized_query = clip.tokenize(search_query)
15
+ # print("tokenized_query: ", tokenized_query.shape)
16
+ # Encode and normalize the search query using CLIP
17
+ text_encoded = model.encode_text(tokenized_query.to(device))
18
+ text_encoded /= text_encoded.norm(dim=-1, keepdim=True)
19
+
20
+ # Retrieve the feature vector
21
+ # print("text_encoded: ", text_encoded.shape)
22
+ return text_encoded
23
+
24
+
25
+ def find_best_matches(text_features, photo_features, photo_ids, results_count=5):
26
+ # Compute the similarity between the search query and each photo using the Cosine similarity
27
+ # print("text_features: ", text_features.shape)
28
+ # print("photo_features: ", photo_features.shape)
29
+ similarities = (photo_features @ text_features.T).squeeze(1)
30
+
31
+ # Sort the photos by their similarity score
32
+ best_photo_idx = (-similarities).argsort()
33
+ # print("best_photo_idx: ", best_photo_idx.shape)
34
+ # print("best_photo_idx: ", best_photo_idx[:results_count])
35
+
36
+ result_list = [photo_ids[i] for i in best_photo_idx[:results_count]]
37
+ # print("result_list: ", len(result_list))
38
+ # Return the photo IDs of the best matches
39
+ return result_list
40
+
41
+
42
+ def search_unslash(search_query, photo_features, photo_ids, results_count=10):
43
+ # Encode the search query
44
+ text_features = encode_search_query(search_query)
45
+
46
+ # Find the best matches
47
+ best_photo_ids = find_best_matches(text_features, photo_features, photo_ids, results_count)
48
+
49
+ return best_photo_ids
50
+
51
+
52
+
53
+ def filter_invalid_urls(urls, photo_ids):
54
+ rs = (grequests.get(u) for u in urls)
55
+ results = grequests.map(rs)
56
+
57
+ valid_image_ids = []
58
+ valid_image_urls = []
59
+ for i, res in enumerate(results):
60
+ if res and res.status_code == 200:
61
+ valid_image_urls.append(urls[i])
62
+ valid_image_ids.append(photo_ids[i])
63
+
64
+ return dict(
65
+ image_ids=valid_image_ids,
66
+ image_urls=valid_image_urls
67
+ )