|
import os |
|
import numpy as np |
|
from PIL import Image |
|
from torchvision import transforms, models |
|
import torch |
|
import torch.nn.functional as F |
|
import streamlit as st |
|
import pickle |
|
from sklearn.neighbors import NearestNeighbors |
|
import faiss |
|
|
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
]) |
|
|
|
|
|
augment_transform = transforms.Compose([ |
|
transforms.RandomHorizontalFlip(), |
|
transforms.RandomRotation(20), |
|
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), |
|
transforms.RandomResizedCrop(224, scale=(0.8, 1.0), ratio=(0.75, 1.33)), |
|
]) |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
@st.cache_resource |
|
def load_model(): |
|
model = models.efficientnet_b0(pretrained=True) |
|
model.classifier = torch.nn.Identity() |
|
model = model.to(device) |
|
model.eval() |
|
return model |
|
|
|
model = load_model() |
|
|
|
def extract_features(img): |
|
img_t = transform(img) |
|
batch_t = torch.unsqueeze(img_t, 0).to(device) |
|
with torch.no_grad(): |
|
features = model(batch_t) |
|
features = F.normalize(features, p=2, dim=1) |
|
return features.cpu().squeeze().numpy() |
|
|
|
def generate_augmented_images(img, num_augmented=5): |
|
augmented_images = [] |
|
for _ in range(num_augmented): |
|
augmented = augment_transform(img) |
|
augmented_images.append(augmented) |
|
return augmented_images |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_and_index_images(root_dir): |
|
image_paths = [] |
|
features_list = [] |
|
categories = [] |
|
for category in os.listdir(root_dir): |
|
category_path = os.path.join(root_dir, category) |
|
if os.path.isdir(category_path): |
|
for img_name in os.listdir(category_path): |
|
img_path = os.path.join(category_path, img_name) |
|
img = Image.open(img_path).convert('RGB') |
|
|
|
|
|
augmented_images = generate_augmented_images(img) |
|
|
|
features = extract_features(img) |
|
image_paths.append(img_path) |
|
features_list.append(features) |
|
categories.append(category) |
|
|
|
for aug_img in augmented_images: |
|
aug_features = extract_features(aug_img) |
|
features_list.append(aug_features) |
|
image_paths.append(img_path) |
|
categories.append(category) |
|
|
|
features_array = np.array(features_list).astype('float32') |
|
|
|
d = features_array.shape[1] |
|
index = faiss.IndexFlatIP(d) |
|
index.add(features_array) |
|
|
|
return index, image_paths, categories |
|
|
|
def save_index_and_metadata(nn, image_paths, categories, index_file, metadata_file): |
|
with open(index_file, 'wb') as f: |
|
pickle.dump(nn, f) |
|
with open(metadata_file, 'wb') as f: |
|
pickle.dump((image_paths, categories), f) |
|
|
|
def load_index_and_metadata(index_file, metadata_file): |
|
with open(index_file, 'rb') as f: |
|
nn = pickle.load(f) |
|
with open(metadata_file, 'rb') as f: |
|
image_paths, categories = pickle.load(f) |
|
return nn, image_paths, categories |
|
|
|
def search_similar_images(index, image_paths, categories, query_features, k=20): |
|
query_features = query_features.reshape(1, -1).astype('float32') |
|
similarities, indices = index.search(query_features, k) |
|
|
|
similar_images = [image_paths[i] for i in indices[0]] |
|
similarity_scores = similarities[0] |
|
similar_categories = [categories[i] for i in indices[0]] |
|
|
|
return similar_images, similarity_scores, similar_categories |
|
|
|
def index_files_exist(index_file, metadata_file): |
|
return os.path.exists(index_file) and os.path.exists(metadata_file) |
|
|
|
def main(): |
|
st.title("Image Classification and Similarity Search") |
|
|
|
index_file = "faiss-d2-nn_index.pkl" |
|
metadata_file = "faiss-d2-image_metadata.pkl" |
|
|
|
if not index_files_exist(index_file, metadata_file): |
|
st.warning("Index files not found. Creating new index...") |
|
root_dir = "Dataset2" |
|
index, image_paths, categories = load_and_index_images(root_dir) |
|
save_index_and_metadata(index, image_paths, categories, index_file, metadata_file) |
|
st.success("Index created and saved successfully!") |
|
else: |
|
index, image_paths, categories = load_index_and_metadata(index_file, metadata_file) |
|
st.success("Index loaded successfully!") |
|
|
|
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) |
|
|
|
if uploaded_file is not None: |
|
image = Image.open(uploaded_file).convert('RGB') |
|
query_features = extract_features(image) |
|
|
|
|
|
similar_images, similarities, similar_categories = search_similar_images(index, image_paths, categories, query_features, k=50) |
|
|
|
|
|
predicted_class = max(set(similar_categories[:5]), key=similar_categories[:5].count) |
|
|
|
|
|
col1, col2 = st.columns(2) |
|
with col1: |
|
st.subheader("Query Image") |
|
st.image(image, caption="Uploaded Image", use_column_width=True) |
|
st.write(f"Image ID: {uploaded_file.name}") |
|
with col2: |
|
if similar_images: |
|
st.subheader("Matched Image") |
|
matched_image_path = similar_images[0] |
|
st.image(Image.open(matched_image_path), |
|
caption=f"Matched Image (Similarity: {similarities[0]:.2f})", |
|
use_column_width=True) |
|
st.write(f"Image ID: {os.path.basename(matched_image_path)}") |
|
else: |
|
st.write("No matched image found") |
|
|
|
st.subheader(f"Product Category: {predicted_class}") |
|
|
|
similarity_threshold = st.slider("Similarity threshold", min_value=0.0, max_value=1.0, value=0.5, step=0.05) |
|
|
|
|
|
query_file_name = uploaded_file.name |
|
seen_file_names = set([query_file_name]) |
|
filtered_results = [] |
|
for img, sim, cat in zip(similar_images[1:], similarities[1:], similar_categories[1:]): |
|
file_name = os.path.basename(img) |
|
if sim >= similarity_threshold and cat == predicted_class and file_name not in seen_file_names: |
|
filtered_results.append((img, sim)) |
|
seen_file_names.add(file_name) |
|
|
|
|
|
if filtered_results: |
|
max_images = len(filtered_results) |
|
num_display = st.slider("Number of similar images to display", min_value=0, max_value=max_images, value=min(20, max_images)) |
|
|
|
st.subheader("Similar Images") |
|
st.info(f"Displaying {num_display} out of {max_images} unique similar images found for the uploaded query image.") |
|
|
|
|
|
num_cols = 5 |
|
num_rows = (num_display + num_cols - 1) // num_cols |
|
|
|
for row in range(num_rows): |
|
cols = st.columns(num_cols) |
|
for col in range(num_cols): |
|
idx = row * num_cols + col |
|
if idx < num_display: |
|
img_path, sim = filtered_results[idx] |
|
with cols[col]: |
|
st.image(Image.open(img_path), use_column_width=True) |
|
st.write(f"Similarity: {sim:.2f}") |
|
st.write(f"Image ID: {os.path.basename(img_path)}") |
|
|
|
else: |
|
st.info("No similar images found above the similarity threshold in the predicted class.") |
|
|
|
if __name__ == "__main__": |
|
main() |