import os import numpy as np from PIL import Image from torchvision import transforms, models import torch import torch.nn.functional as F import streamlit as st import pickle from sklearn.neighbors import NearestNeighbors import faiss # Set up the image transformation transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # Data augmentation transforms augment_transform = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.RandomRotation(20), transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), transforms.RandomResizedCrop(224, scale=(0.8, 1.0), ratio=(0.75, 1.33)), ]) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @st.cache_resource def load_model(): model = models.efficientnet_b0(pretrained=True) model.classifier = torch.nn.Identity() # Remove the final classification layer model = model.to(device) model.eval() return model model = load_model() def extract_features(img): img_t = transform(img) batch_t = torch.unsqueeze(img_t, 0).to(device) with torch.no_grad(): features = model(batch_t) features = F.normalize(features, p=2, dim=1) return features.cpu().squeeze().numpy() def generate_augmented_images(img, num_augmented=5): augmented_images = [] for _ in range(num_augmented): augmented = augment_transform(img) augmented_images.append(augmented) return augmented_images # def load_and_index_images(root_dir): #without adding data augmented images # image_paths = [] # features_list = [] # categories = [] # for category in os.listdir(root_dir): # category_path = os.path.join(root_dir, category) # if os.path.isdir(category_path): # for img_name in os.listdir(category_path): # img_path = os.path.join(category_path, img_name) # img = Image.open(img_path).convert('RGB') # features = extract_features(img) # image_paths.append(img_path) # features_list.append(features) # categories.append(category) # features_array = np.array(features_list).astype('float32') # d = features_array.shape[1] # dimension # index = faiss.IndexFlatIP(d) # use inner product (cosine similarity on normalized vectors) # index.add(features_array) # return index, image_paths, categories def load_and_index_images(root_dir): image_paths = [] features_list = [] categories = [] for category in os.listdir(root_dir): category_path = os.path.join(root_dir, category) if os.path.isdir(category_path): for img_name in os.listdir(category_path): img_path = os.path.join(category_path, img_name) img = Image.open(img_path).convert('RGB') # Generate augmented images augmented_images = generate_augmented_images(img) features = extract_features(img) image_paths.append(img_path) features_list.append(features) categories.append(category) for aug_img in augmented_images: aug_features = extract_features(aug_img) features_list.append(aug_features) image_paths.append(img_path) # Use original path for augmented images categories.append(category) features_array = np.array(features_list).astype('float32') d = features_array.shape[1] # dimension index = faiss.IndexFlatIP(d) # use inner product (cosine similarity on normalized vectors) index.add(features_array) return index, image_paths, categories def save_index_and_metadata(nn, image_paths, categories, index_file, metadata_file): with open(index_file, 'wb') as f: pickle.dump(nn, f) with open(metadata_file, 'wb') as f: pickle.dump((image_paths, categories), f) def load_index_and_metadata(index_file, metadata_file): with open(index_file, 'rb') as f: nn = pickle.load(f) with open(metadata_file, 'rb') as f: image_paths, categories = pickle.load(f) return nn, image_paths, categories def search_similar_images(index, image_paths, categories, query_features, k=20): query_features = query_features.reshape(1, -1).astype('float32') similarities, indices = index.search(query_features, k) similar_images = [image_paths[i] for i in indices[0]] similarity_scores = similarities[0] similar_categories = [categories[i] for i in indices[0]] return similar_images, similarity_scores, similar_categories def index_files_exist(index_file, metadata_file): return os.path.exists(index_file) and os.path.exists(metadata_file) def main(): st.title("Image Classification and Similarity Search") index_file = "faiss-d2-nn_index.pkl" metadata_file = "faiss-d2-image_metadata.pkl" if not index_files_exist(index_file, metadata_file): st.warning("Index files not found. Creating new index...") root_dir = "Dataset2" # Replace with your dataset path index, image_paths, categories = load_and_index_images(root_dir) save_index_and_metadata(index, image_paths, categories, index_file, metadata_file) st.success("Index created and saved successfully!") else: index, image_paths, categories = load_index_and_metadata(index_file, metadata_file) st.success("Index loaded successfully!") uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) if uploaded_file is not None: image = Image.open(uploaded_file).convert('RGB') query_features = extract_features(image) # Search for similar images similar_images, similarities, similar_categories = search_similar_images(index, image_paths, categories, query_features, k=50) # Get the predicted class (most common category among top 5 similar images) predicted_class = max(set(similar_categories[:5]), key=similar_categories[:5].count) # Display query and matched image col1, col2 = st.columns(2) with col1: st.subheader("Query Image") st.image(image, caption="Uploaded Image", use_column_width=True) st.write(f"Image ID: {uploaded_file.name}") with col2: if similar_images: st.subheader("Matched Image") matched_image_path = similar_images[0] st.image(Image.open(matched_image_path), caption=f"Matched Image (Similarity: {similarities[0]:.2f})", use_column_width=True) st.write(f"Image ID: {os.path.basename(matched_image_path)}") else: st.write("No matched image found") st.subheader(f"Product Category: {predicted_class}") similarity_threshold = st.slider("Similarity threshold", min_value=0.0, max_value=1.0, value=0.5, step=0.05) # Filter results based on similarity threshold and predicted class, and remove duplicates query_file_name = uploaded_file.name seen_file_names = set([query_file_name]) # Add query image to seen set filtered_results = [] for img, sim, cat in zip(similar_images[1:], similarities[1:], similar_categories[1:]): # Start from index 1 file_name = os.path.basename(img) if sim >= similarity_threshold and cat == predicted_class and file_name not in seen_file_names: filtered_results.append((img, sim)) seen_file_names.add(file_name) # Rest of the code remains the same if filtered_results: max_images = len(filtered_results) num_display = st.slider("Number of similar images to display", min_value=0, max_value=max_images, value=min(20, max_images)) st.subheader("Similar Images") st.info(f"Displaying {num_display} out of {max_images} unique similar images found for the uploaded query image.") # Create a grid for displaying similar images num_cols = 5 num_rows = (num_display + num_cols - 1) // num_cols for row in range(num_rows): cols = st.columns(num_cols) for col in range(num_cols): idx = row * num_cols + col if idx < num_display: img_path, sim = filtered_results[idx] with cols[col]: st.image(Image.open(img_path), use_column_width=True) st.write(f"Similarity: {sim:.2f}") st.write(f"Image ID: {os.path.basename(img_path)}") else: st.info("No similar images found above the similarity threshold in the predicted class.") if __name__ == "__main__": main()