RushabhShah122000's picture
Update app.py
f5e7a4a verified
import os
import numpy as np
from PIL import Image
from torchvision import transforms, models
import torch
import torch.nn.functional as F
import pickle
import faiss
import gradio as gr
import time
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
# Set up the image transformation
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Data augmentation transforms
augment_transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(20),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
transforms.RandomResizedCrop(224, scale=(0.8, 1.0), ratio=(0.75, 1.33)),
])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def load_model():
model = models.efficientnet_b0(pretrained=True)
model.classifier = torch.nn.Identity() # Remove the final classification layer
model = model.to(device)
model.eval()
return model
model = load_model()
def process_image(image, similarity_threshold):
similarity_threshold = similarity_threshold / 100.0
query_features = extract_features(image)
similar_images, similarities, similar_categories = search_similar_images(
index, image_paths, categories, query_features, k=50
)
matched_images = []
filtered_similarities = []
filtered_file_names = []
for img_path, sim in zip(similar_images, similarities):
if sim >= similarity_threshold:
img = Image.open(img_path)
matched_images.append(img)
filtered_similarities.append(sim)
filtered_file_names.append(os.path.basename(img_path))
predicted_class = max(set(similar_categories[:5]), key=similar_categories[:5].count)
return matched_images, filtered_similarities, filtered_file_names, predicted_class
def update_max_images(similarity_threshold):
similarity_threshold = similarity_threshold / 100.0
count = 0
for features in index.reconstruct_n(0, index.ntotal):
similarity = np.dot(features, features)
if similarity >= similarity_threshold:
count += 1
max_images = min(count, 50)
return max_images
def extract_features(img):
img_t = transform(img)
batch_t = torch.unsqueeze(img_t, 0).to(device)
with torch.no_grad():
features = model(batch_t)
features = F.normalize(features, p=2, dim=1)
return features.cpu().squeeze().numpy()
def generate_augmented_images(img, num_augmented=5):
augmented_images = []
for _ in range(num_augmented):
augmented = augment_transform(img)
augmented_images.append(augmented)
return augmented_images
# def load_and_index_images(root_dir): #without adding data augmented images
# image_paths = []
# features_list = []
# categories = []
# for category in os.listdir(root_dir):
# category_path = os.path.join(root_dir, category)
# if os.path.isdir(category_path):
# for img_name in os.listdir(category_path):
# img_path = os.path.join(category_path, img_name)
# img = Image.open(img_path).convert('RGB')
# features = extract_features(img)
# image_paths.append(img_path)
# features_list.append(features)
# categories.append(category)
# features_array = np.array(features_list).astype('float32')
# d = features_array.shape[1] # dimension
# index = faiss.IndexFlatIP(d) # use inner product (cosine similarity on normalized vectors)
# index.add(features_array)
# return index, image_paths, categories
def load_and_index_images(root_dir):
image_paths = []
features_list = []
categories = []
for category in os.listdir(root_dir):
category_path = os.path.join(root_dir, category)
if os.path.isdir(category_path):
for img_name in os.listdir(category_path):
img_path = os.path.join(category_path, img_name)
img = Image.open(img_path).convert('RGB')
# Generate augmented images
augmented_images = generate_augmented_images(img)
features = extract_features(img)
image_paths.append(img_path)
features_list.append(features)
categories.append(category)
for aug_img in augmented_images:
aug_features = extract_features(aug_img)
features_list.append(aug_features)
image_paths.append(img_path) # Use original path for augmented images
categories.append(category)
features_array = np.array(features_list).astype('float32')
d = features_array.shape[1] # dimension
index = faiss.IndexFlatIP(d) # use inner product (cosine similarity on normalized vectors)
index.add(features_array)
return index, image_paths, categories
def save_index_and_metadata(nn, image_paths, categories, index_file, metadata_file):
with open(index_file, 'wb') as f:
pickle.dump(nn, f)
with open(metadata_file, 'wb') as f:
pickle.dump((image_paths, categories), f)
def load_index_and_metadata(index_file, metadata_file):
with open(index_file, 'rb') as f:
nn = pickle.load(f)
with open(metadata_file, 'rb') as f:
image_paths, categories = pickle.load(f)
return nn, image_paths, categories
def search_similar_images(index, image_paths, categories, query_features, k=20):
query_features = query_features.reshape(1, -1).astype('float32')
similarities, indices = index.search(query_features, k)
similar_images = [image_paths[i] for i in indices[0]]
similarity_scores = similarities[0]
similar_categories = [categories[i] for i in indices[0]]
return similar_images, similarity_scores, similar_categories
def index_files_exist(index_file, metadata_file):
return os.path.exists(index_file) and os.path.exists(metadata_file)
def search_and_display(image, similarity_threshold, num_display):
query_features = extract_features(image)
similar_images, similarities, similar_categories = search_similar_images(
index, image_paths, categories, query_features, k=50
)
predicted_class = max(set(similar_categories[:5]), key=similar_categories[:5].count)
query_file_name = "query_image.jpg"
seen_file_names = set([query_file_name])
filtered_results = []
for img, sim, cat in zip(similar_images[1:], similarities[1:], similar_categories[1:]):
file_name = os.path.basename(img)
if sim >= similarity_threshold and cat == predicted_class and file_name not in seen_file_names:
filtered_results.append((img, sim))
seen_file_names.add(file_name)
output_images = []
output_labels = []
output_images.append(image)
output_labels.append(f"Query Image\nPredicted Category: {predicted_class}")
if similar_images:
matched_image_path = similar_images[0]
matched_image = Image.open(matched_image_path)
output_images.append(matched_image)
output_labels.append(f"Matched Image\nSimilarity: {similarities[0]:.2f}\nImage ID: {os.path.basename(matched_image_path)}")
for i, (img_path, sim) in enumerate(filtered_results[:num_display]):
img = Image.open(img_path)
output_images.append(img)
output_labels.append(f"Similarity: {sim:.2f}%\nFile Name: {os.path.basename(img_path)}")
return output_images, output_labels, f"Product Category: {predicted_class}"
# Load index and metadata
index_file = "faiss-d2-nn_index.pkl"
metadata_file = "faiss-d2-image_metadata.pkl"
if not os.path.exists(index_file) or not os.path.exists(metadata_file):
root_dir = "Dataset2"
index, image_paths, categories = load_and_index_images(root_dir)
save_index_and_metadata(index, image_paths, categories, index_file, metadata_file)
else:
index, image_paths, categories = load_index_and_metadata(index_file, metadata_file)
print(f"Index size: {index.ntotal}")
print(f"Number of image paths: {len(image_paths)}")
print(f"Number of categories: {len(set(categories))}")
# Define Gradio interface
def gradio_interface(image, similarity_threshold, num_display):
matched_images, similarities, file_names, predicted_category = process_image(image, similarity_threshold)
print(f"Debug: Number of matched images: {len(matched_images)}")
print(f"Debug: Similarity scores: {similarities}")
perfect_match = None
similar_products = []
seen_file_names = set()
highest_similarity = 0
highest_similarity_match = None
for img, sim, name in zip(matched_images, similarities, file_names):
print(f"Debug: Processing image {name} with similarity {sim}")
if sim > highest_similarity:
highest_similarity = sim
highest_similarity_match = (img, f"Similarity: {sim*100:.2f}%\nProduct name: {name}")
if sim >= 0.99 and perfect_match is None:
perfect_match = (img, f"Similarity: {sim*100:.2f}%\nProduct name: {name}")
seen_file_names.add(name)
print(f"Debug: Near-perfect match found: {name}")
elif name not in seen_file_names:
similar_products.append((img, f"{sim*100:.2f}% - {name}"))
seen_file_names.add(name)
if perfect_match is None:
perfect_match = highest_similarity_match
print(f"Debug: Using highest similarity match: {highest_similarity}")
return (
f"{predicted_category}",
perfect_match[0],
perfect_match[1],
similar_products[:num_display]
)
class ImageSearchState:
def __init__(self):
self.matched_images = None
self.similarities = None
self.file_names = None
self.predicted_category = None
self.filtered_products = None
state = ImageSearchState()
def process_uploaded_image(image):
if image is None:
return None, None, None, None, gr.Slider(minimum=1, maximum=50, value=10, step=1, label="Set value to display total images")
state.matched_images, state.similarities, state.file_names, state.predicted_category = process_image(image, 0) # Use 0 to get all matches
max_images = len(state.matched_images)
updated_num_images_slider = gr.Slider(minimum=1, maximum=max_images, value=min(10, max_images), step=1, label=f"Set value to display total images (max: {max_images})")
return update_results(50, 10) # Default values
def update_results(similarity_threshold, num_display):
if state.matched_images is None:
return None, None, None, None, gr.Slider(minimum=1, maximum=50, value=10, step=1, label="Set value to display total images")
perfect_match = None
similar_products = []
seen_file_names = set()
highest_similarity = 0
highest_similarity_match = None
similarity_threshold = similarity_threshold / 100.0 # Convert to 0-1 range
for img, sim, name in zip(state.matched_images, state.similarities, state.file_names):
if sim > highest_similarity:
highest_similarity = sim
highest_similarity_match = (img, f"Similarity: {sim*100:.2f}%\nProduct name: {name}")
if sim >= 0.99 and perfect_match is None:
perfect_match = (img, f"Similarity: {sim*100:.2f}%\nProduct name: {name}")
seen_file_names.add(name)
elif sim >= similarity_threshold and name not in seen_file_names:
similar_products.append((img, f"{sim*100:.2f}% - {name}"))
seen_file_names.add(name)
if perfect_match is None:
perfect_match = highest_similarity_match
state.filtered_products = similar_products
max_images = len(similar_products)
updated_num_images_slider = gr.Slider(minimum=1, maximum=max_images, value=min(num_display, max_images), step=1, label=f"Set value to display total images (max: {max_images})")
return (
f"{state.predicted_category}",
perfect_match[0] if perfect_match else None,
perfect_match[1] if perfect_match else "No perfect match found",
similar_products[:num_display],
updated_num_images_slider
)
def update_display(num_display):
time.sleep(1) # 1 second delay
return state.filtered_products[:num_display]
with gr.Blocks() as demo:
gr.Markdown("# Product Image Search")
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(type="pil", label="Upload Product Image")
with gr.Column(scale=1):
similarity_slider = gr.Slider(minimum=0, maximum=100, value=50, step=1, label="Set percentage to get similar images")
num_images_slider = gr.Slider(minimum=1, maximum=50, value=10, step=1, label="Set value to display total images")
gr.Markdown("## Product Category")
with gr.Row():
category_output = gr.Textbox(label="Detected Category", placeholder="Detected product category will appear here.")
gr.Markdown("## 100% Match Result")
with gr.Row():
with gr.Column(scale=1):
perfect_match_image = gr.Image(label="100% Matched Image", show_label=False)
with gr.Column(scale=1):
perfect_match_info = gr.Textbox(label="Match Information", placeholder="Details of the 100% match will appear here.")
gr.Markdown("## Similar Images")
similar_products_gallery = gr.Gallery(label="Similar Products", show_label=False, columns=5, rows=None, height="auto", object_fit="contain")
image_input.change(
fn=process_uploaded_image,
inputs=[image_input],
outputs=[category_output, perfect_match_image, perfect_match_info, similar_products_gallery, num_images_slider]
)
similarity_slider.change(
fn=update_results,
inputs=[similarity_slider, num_images_slider],
outputs=[category_output, perfect_match_image, perfect_match_info, similar_products_gallery, num_images_slider]
)
num_images_slider.release(
fn=update_display,
inputs=[num_images_slider],
outputs=[similar_products_gallery]
)
demo.launch()