piyushgrover's picture
Update app.py
be4bf7c verified
import pandas as pd
import numpy as np
import clip
import gradio as gr
from utils import *
import os
# Load the open CLIP model
model, preprocess = clip.load("ViT-B/32", device=device)
from pathlib import Path
# Download from Github Releases
if not Path('unsplash-dataset/photo_ids.csv').exists():
os.system('''wget https://github.com/haltakov/natural-language-image-search/releases/download/1.0.0/photo_ids.csv -O unsplash-dataset/photo_ids.csv''')
if not Path('unsplash-dataset/features.npy').exists():
os.system('''wget https://github.com/haltakov/natural-language-image-search/releases/download/1.0.0/features.npy - O unsplash-dataset/features.npy''')
# Load the photo IDs
photo_ids = pd.read_csv("unsplash-dataset/photo_ids.csv")
photo_ids = list(photo_ids['photo_id'])
# Load the features vectors
photo_features = np.load("unsplash-dataset/features.npy")
# Convert features to Tensors: Float32 on CPU and Float16 on GPU
if device == "cpu":
photo_features = torch.from_numpy(photo_features).float().to(device)
else:
photo_features = torch.from_numpy(photo_features).to(device)
# Print some statistics
print(f"Photos loaded: {len(photo_ids)}")
from PIL import Image
def encode_search_query(net, search_query):
with torch.no_grad():
tokenized_query = clip.tokenize(search_query)
# print("tokenized_query: ", tokenized_query.shape)
# Encode and normalize the search query using CLIP
text_encoded = net.encode_text(tokenized_query.to(device))
text_encoded /= text_encoded.norm(dim=-1, keepdim=True)
# Retrieve the feature vector
# print("text_encoded: ", text_encoded.shape)
return text_encoded
def find_best_matches(text_features, photo_features, photo_ids, results_count=5):
# Compute the similarity between the search query and each photo using the Cosine similarity
# print("text_features: ", text_features.shape)
# print("photo_features: ", photo_features.shape)
similarities = (photo_features @ text_features.T).squeeze(1)
# Sort the photos by their similarity score
best_photo_idx = (-similarities).argsort()
# print("best_photo_idx: ", best_photo_idx.shape)
# print("best_photo_idx: ", best_photo_idx[:results_count])
result_list = [photo_ids[i] for i in best_photo_idx[:results_count]]
# print("result_list: ", len(result_list))
# Return the photo IDs of the best matches
return result_list
def search_unslash(net, search_query, photo_features, photo_ids, results_count=10):
# Encode the search query
text_features = encode_search_query(net, search_query)
# Find the best matches
best_photo_ids = find_best_matches(text_features, photo_features, photo_ids, results_count)
return best_photo_ids
def search_by_text_and_photo(query_text, query_photo=None, query_photo_id=None, photo_weight=0.5):
# Encode the search query
if not query_text and query_photo is None and not query_photo_id:
return []
text_features = encode_search_query(model, query_text)
if query_photo_id:
# Find the feature vector for the specified photo ID
query_photo_index = photo_ids.index(query_photo_id)
query_photo_features = photo_features[query_photo_index]
# Combine the test and photo queries and normalize again
search_features = text_features + query_photo_features * photo_weight
search_features /= search_features.norm(dim=-1, keepdim=True)
# Find the best match
best_photo_ids = find_best_matches(search_features, photo_features, photo_ids, 10)
elif query_photo is not None:
query_photo = preprocess(query_photo)
query_photo = torch.tensor(query_photo).permute(2, 0, 1)
print(query_photo.shape)
query_photo_features = model.encode_image(query_photo)
query_photo_features = query_photo_features / query_photo_features.norm(dim=1, keepdim=True)
# Combine the test and photo queries and normalize again
search_features = text_features + query_photo_features * photo_weight
search_features /= search_features.norm(dim=-1, keepdim=True)
# Find the best match
best_photo_ids = find_best_matches(search_features, photo_features, photo_ids, 10)
else:
# Display the results
print("Result...")
best_photo_ids = search_unslash(model, query_text, photo_features, photo_ids, 10)
return best_photo_ids
def fn_query_on_load():
return "Dogs playing during sunset"
with gr.Blocks() as app:
with gr.Row():
gr.Markdown(
"""
# CLIP Image Search Engine!
### Enter search query or/and select image to find the similar images
""")
with gr.Row(visible=True):
with gr.Column():
with gr.Row():
search_text = gr.Textbox(value=fn_query_on_load, placeholder='Search..', label=None)
with gr.Row():
submit_btn = gr.Button("Submit", variant='primary')
clear_btn = gr.ClearButton()
with gr.Column(visible=True) as input_image_col:
search_image = gr.Image(label='Select from results', interactive=False)
search_image_id = gr.State(None)
with gr.Row(visible=True):
output_images = gr.Gallery(allow_preview=False, label='Results.. ',
value=[], columns=5, rows=2)
output_image_ids = gr.State([])
def clear_data():
return {
search_image: None,
output_images: None,
search_text: None,
search_image_id: None,
input_image_col: gr.update(visible=True)
}
clear_btn.click(clear_data, None, [search_image, output_images, search_text, search_image_id, input_image_col])
def on_select(evt: gr.SelectData, output_image_ids):
return {
search_image: f"https://unsplash.com/photos/{output_image_ids[evt.index]}/download?w=320",
search_image_id: output_image_ids[evt.index],
input_image_col: gr.update(visible=True)
}
output_images.select(on_select, output_image_ids, [search_image, search_image_id, input_image_col])
def func_search(query, img, img_id):
best_photo_ids = []
if img_id:
best_photo_ids = search_by_text_and_photo(query, query_photo_id=img_id)
elif img is not None:
img = Image.open(img)
best_photo_ids = search_by_text_and_photo(query, query_photo=img)
elif query:
best_photo_ids = search_by_text_and_photo(query)
if len(best_photo_ids) == 0:
print("Invalid Search Request")
return {
output_image_ids: [],
output_images: []
}
else:
img_urls = []
for p_id in best_photo_ids:
url = f"https://unsplash.com/photos/{p_id}/download?w=20"
img_urls.append(url)
valid_images = filter_invalid_urls(img_urls, best_photo_ids)
return {
output_image_ids: valid_images['image_ids'],
output_images: valid_images['image_urls']
}
submit_btn.click(
func_search,
[search_text, search_image, search_image_id],
[output_images, output_image_ids]
)
def on_upload(evt: gr.SelectData):
return {
search_image_id: None
}
search_image.upload(on_upload, None, search_image_id)
'''
Launch the app
'''
app.launch()