RushabhShah122000's picture
Create app.py
46d2a02 verified
import os
import numpy as np
from PIL import Image
from torchvision import transforms, models
import torch
import torch.nn.functional as F
import streamlit as st
import pickle
from sklearn.neighbors import NearestNeighbors
import faiss
# Set up the image transformation
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Data augmentation transforms
augment_transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(20),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
transforms.RandomResizedCrop(224, scale=(0.8, 1.0), ratio=(0.75, 1.33)),
])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@st.cache_resource
def load_model():
model = models.efficientnet_b0(pretrained=True)
model.classifier = torch.nn.Identity() # Remove the final classification layer
model = model.to(device)
model.eval()
return model
model = load_model()
def extract_features(img):
img_t = transform(img)
batch_t = torch.unsqueeze(img_t, 0).to(device)
with torch.no_grad():
features = model(batch_t)
features = F.normalize(features, p=2, dim=1)
return features.cpu().squeeze().numpy()
def generate_augmented_images(img, num_augmented=5):
augmented_images = []
for _ in range(num_augmented):
augmented = augment_transform(img)
augmented_images.append(augmented)
return augmented_images
# def load_and_index_images(root_dir): #without adding data augmented images
# image_paths = []
# features_list = []
# categories = []
# for category in os.listdir(root_dir):
# category_path = os.path.join(root_dir, category)
# if os.path.isdir(category_path):
# for img_name in os.listdir(category_path):
# img_path = os.path.join(category_path, img_name)
# img = Image.open(img_path).convert('RGB')
# features = extract_features(img)
# image_paths.append(img_path)
# features_list.append(features)
# categories.append(category)
# features_array = np.array(features_list).astype('float32')
# d = features_array.shape[1] # dimension
# index = faiss.IndexFlatIP(d) # use inner product (cosine similarity on normalized vectors)
# index.add(features_array)
# return index, image_paths, categories
def load_and_index_images(root_dir):
image_paths = []
features_list = []
categories = []
for category in os.listdir(root_dir):
category_path = os.path.join(root_dir, category)
if os.path.isdir(category_path):
for img_name in os.listdir(category_path):
img_path = os.path.join(category_path, img_name)
img = Image.open(img_path).convert('RGB')
# Generate augmented images
augmented_images = generate_augmented_images(img)
features = extract_features(img)
image_paths.append(img_path)
features_list.append(features)
categories.append(category)
for aug_img in augmented_images:
aug_features = extract_features(aug_img)
features_list.append(aug_features)
image_paths.append(img_path) # Use original path for augmented images
categories.append(category)
features_array = np.array(features_list).astype('float32')
d = features_array.shape[1] # dimension
index = faiss.IndexFlatIP(d) # use inner product (cosine similarity on normalized vectors)
index.add(features_array)
return index, image_paths, categories
def save_index_and_metadata(nn, image_paths, categories, index_file, metadata_file):
with open(index_file, 'wb') as f:
pickle.dump(nn, f)
with open(metadata_file, 'wb') as f:
pickle.dump((image_paths, categories), f)
def load_index_and_metadata(index_file, metadata_file):
with open(index_file, 'rb') as f:
nn = pickle.load(f)
with open(metadata_file, 'rb') as f:
image_paths, categories = pickle.load(f)
return nn, image_paths, categories
def search_similar_images(index, image_paths, categories, query_features, k=20):
query_features = query_features.reshape(1, -1).astype('float32')
similarities, indices = index.search(query_features, k)
similar_images = [image_paths[i] for i in indices[0]]
similarity_scores = similarities[0]
similar_categories = [categories[i] for i in indices[0]]
return similar_images, similarity_scores, similar_categories
def index_files_exist(index_file, metadata_file):
return os.path.exists(index_file) and os.path.exists(metadata_file)
def main():
st.title("Image Classification and Similarity Search")
index_file = "faiss-d2-nn_index.pkl"
metadata_file = "faiss-d2-image_metadata.pkl"
if not index_files_exist(index_file, metadata_file):
st.warning("Index files not found. Creating new index...")
root_dir = "Dataset2" # Replace with your dataset path
index, image_paths, categories = load_and_index_images(root_dir)
save_index_and_metadata(index, image_paths, categories, index_file, metadata_file)
st.success("Index created and saved successfully!")
else:
index, image_paths, categories = load_index_and_metadata(index_file, metadata_file)
st.success("Index loaded successfully!")
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
image = Image.open(uploaded_file).convert('RGB')
query_features = extract_features(image)
# Search for similar images
similar_images, similarities, similar_categories = search_similar_images(index, image_paths, categories, query_features, k=50)
# Get the predicted class (most common category among top 5 similar images)
predicted_class = max(set(similar_categories[:5]), key=similar_categories[:5].count)
# Display query and matched image
col1, col2 = st.columns(2)
with col1:
st.subheader("Query Image")
st.image(image, caption="Uploaded Image", use_column_width=True)
st.write(f"Image ID: {uploaded_file.name}")
with col2:
if similar_images:
st.subheader("Matched Image")
matched_image_path = similar_images[0]
st.image(Image.open(matched_image_path),
caption=f"Matched Image (Similarity: {similarities[0]:.2f})",
use_column_width=True)
st.write(f"Image ID: {os.path.basename(matched_image_path)}")
else:
st.write("No matched image found")
st.subheader(f"Product Category: {predicted_class}")
similarity_threshold = st.slider("Similarity threshold", min_value=0.0, max_value=1.0, value=0.5, step=0.05)
# Filter results based on similarity threshold and predicted class, and remove duplicates
query_file_name = uploaded_file.name
seen_file_names = set([query_file_name]) # Add query image to seen set
filtered_results = []
for img, sim, cat in zip(similar_images[1:], similarities[1:], similar_categories[1:]): # Start from index 1
file_name = os.path.basename(img)
if sim >= similarity_threshold and cat == predicted_class and file_name not in seen_file_names:
filtered_results.append((img, sim))
seen_file_names.add(file_name)
# Rest of the code remains the same
if filtered_results:
max_images = len(filtered_results)
num_display = st.slider("Number of similar images to display", min_value=0, max_value=max_images, value=min(20, max_images))
st.subheader("Similar Images")
st.info(f"Displaying {num_display} out of {max_images} unique similar images found for the uploaded query image.")
# Create a grid for displaying similar images
num_cols = 5
num_rows = (num_display + num_cols - 1) // num_cols
for row in range(num_rows):
cols = st.columns(num_cols)
for col in range(num_cols):
idx = row * num_cols + col
if idx < num_display:
img_path, sim = filtered_results[idx]
with cols[col]:
st.image(Image.open(img_path), use_column_width=True)
st.write(f"Similarity: {sim:.2f}")
st.write(f"Image ID: {os.path.basename(img_path)}")
else:
st.info("No similar images found above the similarity threshold in the predicted class.")
if __name__ == "__main__":
main()