Spaces:
Runtime error
Runtime error
import os | |
import numpy as np | |
from PIL import Image | |
from torchvision import transforms, models | |
import torch | |
import torch.nn.functional as F | |
import pickle | |
import faiss | |
import gradio as gr | |
import time | |
import os | |
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' | |
# Set up the image transformation | |
transform = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
]) | |
# Data augmentation transforms | |
augment_transform = transforms.Compose([ | |
transforms.RandomHorizontalFlip(), | |
transforms.RandomRotation(20), | |
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), | |
transforms.RandomResizedCrop(224, scale=(0.8, 1.0), ratio=(0.75, 1.33)), | |
]) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
def load_model(): | |
model = models.efficientnet_b0(pretrained=True) | |
model.classifier = torch.nn.Identity() # Remove the final classification layer | |
model = model.to(device) | |
model.eval() | |
return model | |
model = load_model() | |
def process_image(image, similarity_threshold): | |
similarity_threshold = similarity_threshold / 100.0 | |
query_features = extract_features(image) | |
similar_images, similarities, similar_categories = search_similar_images( | |
index, image_paths, categories, query_features, k=50 | |
) | |
matched_images = [] | |
filtered_similarities = [] | |
filtered_file_names = [] | |
for img_path, sim in zip(similar_images, similarities): | |
if sim >= similarity_threshold: | |
img = Image.open(img_path) | |
matched_images.append(img) | |
filtered_similarities.append(sim) | |
filtered_file_names.append(os.path.basename(img_path)) | |
predicted_class = max(set(similar_categories[:5]), key=similar_categories[:5].count) | |
return matched_images, filtered_similarities, filtered_file_names, predicted_class | |
def update_max_images(similarity_threshold): | |
similarity_threshold = similarity_threshold / 100.0 | |
count = 0 | |
for features in index.reconstruct_n(0, index.ntotal): | |
similarity = np.dot(features, features) | |
if similarity >= similarity_threshold: | |
count += 1 | |
max_images = min(count, 50) | |
return max_images | |
def extract_features(img): | |
img_t = transform(img) | |
batch_t = torch.unsqueeze(img_t, 0).to(device) | |
with torch.no_grad(): | |
features = model(batch_t) | |
features = F.normalize(features, p=2, dim=1) | |
return features.cpu().squeeze().numpy() | |
def generate_augmented_images(img, num_augmented=5): | |
augmented_images = [] | |
for _ in range(num_augmented): | |
augmented = augment_transform(img) | |
augmented_images.append(augmented) | |
return augmented_images | |
# def load_and_index_images(root_dir): #without adding data augmented images | |
# image_paths = [] | |
# features_list = [] | |
# categories = [] | |
# for category in os.listdir(root_dir): | |
# category_path = os.path.join(root_dir, category) | |
# if os.path.isdir(category_path): | |
# for img_name in os.listdir(category_path): | |
# img_path = os.path.join(category_path, img_name) | |
# img = Image.open(img_path).convert('RGB') | |
# features = extract_features(img) | |
# image_paths.append(img_path) | |
# features_list.append(features) | |
# categories.append(category) | |
# features_array = np.array(features_list).astype('float32') | |
# d = features_array.shape[1] # dimension | |
# index = faiss.IndexFlatIP(d) # use inner product (cosine similarity on normalized vectors) | |
# index.add(features_array) | |
# return index, image_paths, categories | |
def load_and_index_images(root_dir): | |
image_paths = [] | |
features_list = [] | |
categories = [] | |
for category in os.listdir(root_dir): | |
category_path = os.path.join(root_dir, category) | |
if os.path.isdir(category_path): | |
for img_name in os.listdir(category_path): | |
img_path = os.path.join(category_path, img_name) | |
img = Image.open(img_path).convert('RGB') | |
# Generate augmented images | |
augmented_images = generate_augmented_images(img) | |
features = extract_features(img) | |
image_paths.append(img_path) | |
features_list.append(features) | |
categories.append(category) | |
for aug_img in augmented_images: | |
aug_features = extract_features(aug_img) | |
features_list.append(aug_features) | |
image_paths.append(img_path) # Use original path for augmented images | |
categories.append(category) | |
features_array = np.array(features_list).astype('float32') | |
d = features_array.shape[1] # dimension | |
index = faiss.IndexFlatIP(d) # use inner product (cosine similarity on normalized vectors) | |
index.add(features_array) | |
return index, image_paths, categories | |
def save_index_and_metadata(nn, image_paths, categories, index_file, metadata_file): | |
with open(index_file, 'wb') as f: | |
pickle.dump(nn, f) | |
with open(metadata_file, 'wb') as f: | |
pickle.dump((image_paths, categories), f) | |
def load_index_and_metadata(index_file, metadata_file): | |
with open(index_file, 'rb') as f: | |
nn = pickle.load(f) | |
with open(metadata_file, 'rb') as f: | |
image_paths, categories = pickle.load(f) | |
return nn, image_paths, categories | |
def search_similar_images(index, image_paths, categories, query_features, k=20): | |
query_features = query_features.reshape(1, -1).astype('float32') | |
similarities, indices = index.search(query_features, k) | |
similar_images = [image_paths[i] for i in indices[0]] | |
similarity_scores = similarities[0] | |
similar_categories = [categories[i] for i in indices[0]] | |
return similar_images, similarity_scores, similar_categories | |
def index_files_exist(index_file, metadata_file): | |
return os.path.exists(index_file) and os.path.exists(metadata_file) | |
def search_and_display(image, similarity_threshold, num_display): | |
query_features = extract_features(image) | |
similar_images, similarities, similar_categories = search_similar_images( | |
index, image_paths, categories, query_features, k=50 | |
) | |
predicted_class = max(set(similar_categories[:5]), key=similar_categories[:5].count) | |
query_file_name = "query_image.jpg" | |
seen_file_names = set([query_file_name]) | |
filtered_results = [] | |
for img, sim, cat in zip(similar_images[1:], similarities[1:], similar_categories[1:]): | |
file_name = os.path.basename(img) | |
if sim >= similarity_threshold and cat == predicted_class and file_name not in seen_file_names: | |
filtered_results.append((img, sim)) | |
seen_file_names.add(file_name) | |
output_images = [] | |
output_labels = [] | |
output_images.append(image) | |
output_labels.append(f"Query Image\nPredicted Category: {predicted_class}") | |
if similar_images: | |
matched_image_path = similar_images[0] | |
matched_image = Image.open(matched_image_path) | |
output_images.append(matched_image) | |
output_labels.append(f"Matched Image\nSimilarity: {similarities[0]:.2f}\nImage ID: {os.path.basename(matched_image_path)}") | |
for i, (img_path, sim) in enumerate(filtered_results[:num_display]): | |
img = Image.open(img_path) | |
output_images.append(img) | |
output_labels.append(f"Similarity: {sim:.2f}%\nFile Name: {os.path.basename(img_path)}") | |
return output_images, output_labels, f"Product Category: {predicted_class}" | |
# Load index and metadata | |
index_file = "faiss-d2-nn_index.pkl" | |
metadata_file = "faiss-d2-image_metadata.pkl" | |
if not os.path.exists(index_file) or not os.path.exists(metadata_file): | |
root_dir = "Dataset2" | |
index, image_paths, categories = load_and_index_images(root_dir) | |
save_index_and_metadata(index, image_paths, categories, index_file, metadata_file) | |
else: | |
index, image_paths, categories = load_index_and_metadata(index_file, metadata_file) | |
print(f"Index size: {index.ntotal}") | |
print(f"Number of image paths: {len(image_paths)}") | |
print(f"Number of categories: {len(set(categories))}") | |
# Define Gradio interface | |
def gradio_interface(image, similarity_threshold, num_display): | |
matched_images, similarities, file_names, predicted_category = process_image(image, similarity_threshold) | |
print(f"Debug: Number of matched images: {len(matched_images)}") | |
print(f"Debug: Similarity scores: {similarities}") | |
perfect_match = None | |
similar_products = [] | |
seen_file_names = set() | |
highest_similarity = 0 | |
highest_similarity_match = None | |
for img, sim, name in zip(matched_images, similarities, file_names): | |
print(f"Debug: Processing image {name} with similarity {sim}") | |
if sim > highest_similarity: | |
highest_similarity = sim | |
highest_similarity_match = (img, f"Similarity: {sim*100:.2f}%\nProduct name: {name}") | |
if sim >= 0.99 and perfect_match is None: | |
perfect_match = (img, f"Similarity: {sim*100:.2f}%\nProduct name: {name}") | |
seen_file_names.add(name) | |
print(f"Debug: Near-perfect match found: {name}") | |
elif name not in seen_file_names: | |
similar_products.append((img, f"{sim*100:.2f}% - {name}")) | |
seen_file_names.add(name) | |
if perfect_match is None: | |
perfect_match = highest_similarity_match | |
print(f"Debug: Using highest similarity match: {highest_similarity}") | |
return ( | |
f"{predicted_category}", | |
perfect_match[0], | |
perfect_match[1], | |
similar_products[:num_display] | |
) | |
class ImageSearchState: | |
def __init__(self): | |
self.matched_images = None | |
self.similarities = None | |
self.file_names = None | |
self.predicted_category = None | |
self.filtered_products = None | |
state = ImageSearchState() | |
def process_uploaded_image(image): | |
if image is None: | |
return None, None, None, None, gr.Slider(minimum=1, maximum=50, value=10, step=1, label="Set value to display total images") | |
state.matched_images, state.similarities, state.file_names, state.predicted_category = process_image(image, 0) # Use 0 to get all matches | |
max_images = len(state.matched_images) | |
updated_num_images_slider = gr.Slider(minimum=1, maximum=max_images, value=min(10, max_images), step=1, label=f"Set value to display total images (max: {max_images})") | |
return update_results(50, 10) # Default values | |
def update_results(similarity_threshold, num_display): | |
if state.matched_images is None: | |
return None, None, None, None, gr.Slider(minimum=1, maximum=50, value=10, step=1, label="Set value to display total images") | |
perfect_match = None | |
similar_products = [] | |
seen_file_names = set() | |
highest_similarity = 0 | |
highest_similarity_match = None | |
similarity_threshold = similarity_threshold / 100.0 # Convert to 0-1 range | |
for img, sim, name in zip(state.matched_images, state.similarities, state.file_names): | |
if sim > highest_similarity: | |
highest_similarity = sim | |
highest_similarity_match = (img, f"Similarity: {sim*100:.2f}%\nProduct name: {name}") | |
if sim >= 0.99 and perfect_match is None: | |
perfect_match = (img, f"Similarity: {sim*100:.2f}%\nProduct name: {name}") | |
seen_file_names.add(name) | |
elif sim >= similarity_threshold and name not in seen_file_names: | |
similar_products.append((img, f"{sim*100:.2f}% - {name}")) | |
seen_file_names.add(name) | |
if perfect_match is None: | |
perfect_match = highest_similarity_match | |
state.filtered_products = similar_products | |
max_images = len(similar_products) | |
updated_num_images_slider = gr.Slider(minimum=1, maximum=max_images, value=min(num_display, max_images), step=1, label=f"Set value to display total images (max: {max_images})") | |
return ( | |
f"{state.predicted_category}", | |
perfect_match[0] if perfect_match else None, | |
perfect_match[1] if perfect_match else "No perfect match found", | |
similar_products[:num_display], | |
updated_num_images_slider | |
) | |
def update_display(num_display): | |
time.sleep(1) # 1 second delay | |
return state.filtered_products[:num_display] | |
with gr.Blocks() as demo: | |
gr.Markdown("# Product Image Search") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
image_input = gr.Image(type="pil", label="Upload Product Image") | |
with gr.Column(scale=1): | |
similarity_slider = gr.Slider(minimum=0, maximum=100, value=50, step=1, label="Set percentage to get similar images") | |
num_images_slider = gr.Slider(minimum=1, maximum=50, value=10, step=1, label="Set value to display total images") | |
gr.Markdown("## Product Category") | |
with gr.Row(): | |
category_output = gr.Textbox(label="Detected Category", placeholder="Detected product category will appear here.") | |
gr.Markdown("## 100% Match Result") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
perfect_match_image = gr.Image(label="100% Matched Image", show_label=False) | |
with gr.Column(scale=1): | |
perfect_match_info = gr.Textbox(label="Match Information", placeholder="Details of the 100% match will appear here.") | |
gr.Markdown("## Similar Images") | |
similar_products_gallery = gr.Gallery(label="Similar Products", show_label=False, columns=5, rows=None, height="auto", object_fit="contain") | |
image_input.change( | |
fn=process_uploaded_image, | |
inputs=[image_input], | |
outputs=[category_output, perfect_match_image, perfect_match_info, similar_products_gallery, num_images_slider] | |
) | |
similarity_slider.change( | |
fn=update_results, | |
inputs=[similarity_slider, num_images_slider], | |
outputs=[category_output, perfect_match_image, perfect_match_info, similar_products_gallery, num_images_slider] | |
) | |
num_images_slider.release( | |
fn=update_display, | |
inputs=[num_images_slider], | |
outputs=[similar_products_gallery] | |
) | |
demo.launch() |