Spaces:
Sleeping
Sleeping
Commit
·
6917a0d
1
Parent(s):
0f1766c
added code files
Browse files
README.md
CHANGED
@@ -1,13 +1,10 @@
|
|
1 |
---
|
2 |
-
title: CLIPPhotoSearchEngine
|
3 |
-
emoji: 🦀
|
4 |
-
colorFrom: red
|
5 |
-
colorTo: purple
|
6 |
-
sdk: gradio
|
7 |
-
sdk_version: 3.46.1
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
license: mit
|
|
|
|
|
|
|
|
|
|
|
11 |
---
|
12 |
-
|
13 |
-
|
|
|
1 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
license: mit
|
3 |
+
title: YoloV3
|
4 |
+
sdk: gradio
|
5 |
+
colorFrom: yellow
|
6 |
+
colorTo: green
|
7 |
+
pinned: true
|
8 |
---
|
9 |
+
# yolov3
|
10 |
+
S13 ERA V1
|
app.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import numpy as np
|
3 |
+
import clip
|
4 |
+
import gradio as gr
|
5 |
+
from utils import *
|
6 |
+
import os
|
7 |
+
|
8 |
+
# Load the open CLIP model
|
9 |
+
model, preprocess = clip.load("ViT-B/32", device=device)
|
10 |
+
from pathlib import Path
|
11 |
+
|
12 |
+
# Download from Github Releases
|
13 |
+
if not Path('unsplash-dataset/photo_ids.csv').exists():
|
14 |
+
os.system('''wget https://github.com/haltakov/natural-language-image-search/releases/download/1.0.0/photo_ids.csv -O unsplash-dataset/photo_ids.csv''')
|
15 |
+
|
16 |
+
if not Path('unsplash-dataset/features.npy').exists():
|
17 |
+
os.system('''wget https://github.com/haltakov/natural-language-image-search/releases/download/1.0.0/features.npy - O unsplash-dataset/features.npy''')
|
18 |
+
|
19 |
+
|
20 |
+
# Load the photo IDs
|
21 |
+
photo_ids = pd.read_csv("unsplash-dataset/photo_ids.csv")
|
22 |
+
photo_ids = list(photo_ids['photo_id'])
|
23 |
+
|
24 |
+
# Load the features vectors
|
25 |
+
photo_features = np.load("unsplash-dataset/features.npy")
|
26 |
+
|
27 |
+
# Convert features to Tensors: Float32 on CPU and Float16 on GPU
|
28 |
+
if device == "cpu":
|
29 |
+
photo_features = torch.from_numpy(photo_features).float().to(device)
|
30 |
+
else:
|
31 |
+
photo_features = torch.from_numpy(photo_features).to(device)
|
32 |
+
|
33 |
+
# Print some statistics
|
34 |
+
print(f"Photos loaded: {len(photo_ids)}")
|
35 |
+
|
36 |
+
|
37 |
+
def search_by_text_and_photo(query_text, query_img, query_photo_id=None, photo_weight=0.5):
|
38 |
+
# Encode the search query
|
39 |
+
if not query_text and not query_photo_id:
|
40 |
+
return []
|
41 |
+
|
42 |
+
text_features = encode_search_query(model, query_text)
|
43 |
+
|
44 |
+
if query_photo_id:
|
45 |
+
# Find the feature vector for the specified photo ID
|
46 |
+
query_photo_index = photo_ids.index(query_photo_id)
|
47 |
+
query_photo_features = photo_features[query_photo_index]
|
48 |
+
|
49 |
+
# Combine the test and photo queries and normalize again
|
50 |
+
search_features = text_features + query_photo_features * photo_weight
|
51 |
+
search_features /= search_features.norm(dim=-1, keepdim=True)
|
52 |
+
|
53 |
+
# Find the best match
|
54 |
+
best_photo_ids = find_best_matches(search_features, photo_features, photo_ids, 10)
|
55 |
+
|
56 |
+
elif query_img:
|
57 |
+
query_photo_features = model.encode_image(query_img)
|
58 |
+
query_photo_features = query_photo_features / query_photo_features.norm(dim=1, keepdim=True)
|
59 |
+
|
60 |
+
# Combine the test and photo queries and normalize again
|
61 |
+
search_features = text_features + query_photo_features * photo_weight
|
62 |
+
search_features /= search_features.norm(dim=-1, keepdim=True)
|
63 |
+
|
64 |
+
# Find the best match
|
65 |
+
best_photo_ids = find_best_matches(search_features, photo_features, photo_ids, 10)
|
66 |
+
else:
|
67 |
+
# Display the results
|
68 |
+
print("Test search result")
|
69 |
+
best_photo_ids = search_unslash(query_text, photo_features, photo_ids, 10)
|
70 |
+
|
71 |
+
return best_photo_ids
|
72 |
+
|
73 |
+
|
74 |
+
with gr.Blocks() as app:
|
75 |
+
with gr.Row():
|
76 |
+
gr.Markdown(
|
77 |
+
"""
|
78 |
+
# CLIP Image Search Engine!
|
79 |
+
### Enter search query or/and input image to find the similar images from the database -
|
80 |
+
""")
|
81 |
+
|
82 |
+
with gr.Row(visible=True):
|
83 |
+
with gr.Column():
|
84 |
+
with gr.Row():
|
85 |
+
search_text = gr.Textbox(value='', placeholder='Search..', label='Enter Your Query')
|
86 |
+
|
87 |
+
with gr.Row():
|
88 |
+
submit_btn = gr.Button("Submit", variant='primary')
|
89 |
+
clear_btn = gr.ClearButton()
|
90 |
+
|
91 |
+
with gr.Column():
|
92 |
+
search_image = gr.Image(label='Upload Image or Select from results')
|
93 |
+
|
94 |
+
with gr.Row(visible=True):
|
95 |
+
output_images = gr.Gallery(allow_preview=False, label='Results.. ', info='',
|
96 |
+
value=[], columns=5, rows=2)
|
97 |
+
|
98 |
+
output_image_ids = gr.State([])
|
99 |
+
|
100 |
+
|
101 |
+
def clear_data():
|
102 |
+
return {
|
103 |
+
search_image: None,
|
104 |
+
output_images: None,
|
105 |
+
search_text: None
|
106 |
+
}
|
107 |
+
|
108 |
+
|
109 |
+
clear_btn.click(clear_data, None, [search_image, output_images, search_text])
|
110 |
+
|
111 |
+
|
112 |
+
def on_select(evt: gr.SelectData, output_image_ids):
|
113 |
+
return {
|
114 |
+
search_image: f"https://unsplash.com/photos/{output_image_ids[evt.index]}/download?w=100"
|
115 |
+
}
|
116 |
+
|
117 |
+
|
118 |
+
output_images.select(on_select, output_image_ids, search_image)
|
119 |
+
|
120 |
+
|
121 |
+
def func_search(query, img):
|
122 |
+
best_photo_ids = search_by_text_and_photo(query, img)
|
123 |
+
img_urls = []
|
124 |
+
for p_id in best_photo_ids:
|
125 |
+
url = f"https://unsplash.com/photos/{p_id}/download?w=100"
|
126 |
+
img_urls.append(url)
|
127 |
+
|
128 |
+
valid_images = filter_invalid_urls(img_urls, best_photo_ids)
|
129 |
+
|
130 |
+
return {
|
131 |
+
output_image_ids: valid_images['image_ids'],
|
132 |
+
output_images: valid_images['image_urls']
|
133 |
+
}
|
134 |
+
|
135 |
+
|
136 |
+
submit_btn.click(
|
137 |
+
func_search,
|
138 |
+
[search_text, search_image],
|
139 |
+
[output_images, output_image_ids]
|
140 |
+
)
|
141 |
+
|
142 |
+
'''
|
143 |
+
Launch the app
|
144 |
+
'''
|
145 |
+
app.launch()
|
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
torchvision
|
3 |
+
pillow
|
4 |
+
gradio
|
5 |
+
numpy
|
6 |
+
panda
|
7 |
+
grequests
|
8 |
+
git+https://github.com/openai/CLIP.git
|
utils.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from gevent import monkey
|
2 |
+
def stub(*args, **kwargs): # pylint: disable=unused-argument
|
3 |
+
pass
|
4 |
+
monkey.patch_all = stub
|
5 |
+
import grequests
|
6 |
+
import requests
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import clip
|
10 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
11 |
+
|
12 |
+
def encode_search_query(model, search_query):
|
13 |
+
with torch.no_grad():
|
14 |
+
tokenized_query = clip.tokenize(search_query)
|
15 |
+
# print("tokenized_query: ", tokenized_query.shape)
|
16 |
+
# Encode and normalize the search query using CLIP
|
17 |
+
text_encoded = model.encode_text(tokenized_query.to(device))
|
18 |
+
text_encoded /= text_encoded.norm(dim=-1, keepdim=True)
|
19 |
+
|
20 |
+
# Retrieve the feature vector
|
21 |
+
# print("text_encoded: ", text_encoded.shape)
|
22 |
+
return text_encoded
|
23 |
+
|
24 |
+
|
25 |
+
def find_best_matches(text_features, photo_features, photo_ids, results_count=5):
|
26 |
+
# Compute the similarity between the search query and each photo using the Cosine similarity
|
27 |
+
# print("text_features: ", text_features.shape)
|
28 |
+
# print("photo_features: ", photo_features.shape)
|
29 |
+
similarities = (photo_features @ text_features.T).squeeze(1)
|
30 |
+
|
31 |
+
# Sort the photos by their similarity score
|
32 |
+
best_photo_idx = (-similarities).argsort()
|
33 |
+
# print("best_photo_idx: ", best_photo_idx.shape)
|
34 |
+
# print("best_photo_idx: ", best_photo_idx[:results_count])
|
35 |
+
|
36 |
+
result_list = [photo_ids[i] for i in best_photo_idx[:results_count]]
|
37 |
+
# print("result_list: ", len(result_list))
|
38 |
+
# Return the photo IDs of the best matches
|
39 |
+
return result_list
|
40 |
+
|
41 |
+
|
42 |
+
def search_unslash(search_query, photo_features, photo_ids, results_count=10):
|
43 |
+
# Encode the search query
|
44 |
+
text_features = encode_search_query(search_query)
|
45 |
+
|
46 |
+
# Find the best matches
|
47 |
+
best_photo_ids = find_best_matches(text_features, photo_features, photo_ids, results_count)
|
48 |
+
|
49 |
+
return best_photo_ids
|
50 |
+
|
51 |
+
|
52 |
+
|
53 |
+
def filter_invalid_urls(urls, photo_ids):
|
54 |
+
rs = (grequests.get(u) for u in urls)
|
55 |
+
results = grequests.map(rs)
|
56 |
+
|
57 |
+
valid_image_ids = []
|
58 |
+
valid_image_urls = []
|
59 |
+
for i, res in enumerate(results):
|
60 |
+
if res and res.status_code == 200:
|
61 |
+
valid_image_urls.append(urls[i])
|
62 |
+
valid_image_ids.append(photo_ids[i])
|
63 |
+
|
64 |
+
return dict(
|
65 |
+
image_ids=valid_image_ids,
|
66 |
+
image_urls=valid_image_urls
|
67 |
+
)
|