RushabhShah122000 commited on
Commit
46d2a02
·
verified ·
1 Parent(s): e000d3b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +223 -0
app.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ from PIL import Image
4
+ from torchvision import transforms, models
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import streamlit as st
8
+ import pickle
9
+ from sklearn.neighbors import NearestNeighbors
10
+ import faiss
11
+
12
+ # Set up the image transformation
13
+ transform = transforms.Compose([
14
+ transforms.Resize((224, 224)),
15
+ transforms.ToTensor(),
16
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
17
+ ])
18
+
19
+ # Data augmentation transforms
20
+ augment_transform = transforms.Compose([
21
+ transforms.RandomHorizontalFlip(),
22
+ transforms.RandomRotation(20),
23
+ transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
24
+ transforms.RandomResizedCrop(224, scale=(0.8, 1.0), ratio=(0.75, 1.33)),
25
+ ])
26
+
27
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
+
29
+ @st.cache_resource
30
+ def load_model():
31
+ model = models.efficientnet_b0(pretrained=True)
32
+ model.classifier = torch.nn.Identity() # Remove the final classification layer
33
+ model = model.to(device)
34
+ model.eval()
35
+ return model
36
+
37
+ model = load_model()
38
+
39
+ def extract_features(img):
40
+ img_t = transform(img)
41
+ batch_t = torch.unsqueeze(img_t, 0).to(device)
42
+ with torch.no_grad():
43
+ features = model(batch_t)
44
+ features = F.normalize(features, p=2, dim=1)
45
+ return features.cpu().squeeze().numpy()
46
+
47
+ def generate_augmented_images(img, num_augmented=5):
48
+ augmented_images = []
49
+ for _ in range(num_augmented):
50
+ augmented = augment_transform(img)
51
+ augmented_images.append(augmented)
52
+ return augmented_images
53
+
54
+ # def load_and_index_images(root_dir): #without adding data augmented images
55
+ # image_paths = []
56
+ # features_list = []
57
+ # categories = []
58
+ # for category in os.listdir(root_dir):
59
+ # category_path = os.path.join(root_dir, category)
60
+ # if os.path.isdir(category_path):
61
+ # for img_name in os.listdir(category_path):
62
+ # img_path = os.path.join(category_path, img_name)
63
+ # img = Image.open(img_path).convert('RGB')
64
+
65
+ # features = extract_features(img)
66
+ # image_paths.append(img_path)
67
+ # features_list.append(features)
68
+ # categories.append(category)
69
+
70
+ # features_array = np.array(features_list).astype('float32')
71
+
72
+ # d = features_array.shape[1] # dimension
73
+ # index = faiss.IndexFlatIP(d) # use inner product (cosine similarity on normalized vectors)
74
+ # index.add(features_array)
75
+
76
+ # return index, image_paths, categories
77
+
78
+ def load_and_index_images(root_dir):
79
+ image_paths = []
80
+ features_list = []
81
+ categories = []
82
+ for category in os.listdir(root_dir):
83
+ category_path = os.path.join(root_dir, category)
84
+ if os.path.isdir(category_path):
85
+ for img_name in os.listdir(category_path):
86
+ img_path = os.path.join(category_path, img_name)
87
+ img = Image.open(img_path).convert('RGB')
88
+
89
+ # Generate augmented images
90
+ augmented_images = generate_augmented_images(img)
91
+
92
+ features = extract_features(img)
93
+ image_paths.append(img_path)
94
+ features_list.append(features)
95
+ categories.append(category)
96
+
97
+ for aug_img in augmented_images:
98
+ aug_features = extract_features(aug_img)
99
+ features_list.append(aug_features)
100
+ image_paths.append(img_path) # Use original path for augmented images
101
+ categories.append(category)
102
+
103
+ features_array = np.array(features_list).astype('float32')
104
+
105
+ d = features_array.shape[1] # dimension
106
+ index = faiss.IndexFlatIP(d) # use inner product (cosine similarity on normalized vectors)
107
+ index.add(features_array)
108
+
109
+ return index, image_paths, categories
110
+
111
+ def save_index_and_metadata(nn, image_paths, categories, index_file, metadata_file):
112
+ with open(index_file, 'wb') as f:
113
+ pickle.dump(nn, f)
114
+ with open(metadata_file, 'wb') as f:
115
+ pickle.dump((image_paths, categories), f)
116
+
117
+ def load_index_and_metadata(index_file, metadata_file):
118
+ with open(index_file, 'rb') as f:
119
+ nn = pickle.load(f)
120
+ with open(metadata_file, 'rb') as f:
121
+ image_paths, categories = pickle.load(f)
122
+ return nn, image_paths, categories
123
+
124
+ def search_similar_images(index, image_paths, categories, query_features, k=20):
125
+ query_features = query_features.reshape(1, -1).astype('float32')
126
+ similarities, indices = index.search(query_features, k)
127
+
128
+ similar_images = [image_paths[i] for i in indices[0]]
129
+ similarity_scores = similarities[0]
130
+ similar_categories = [categories[i] for i in indices[0]]
131
+
132
+ return similar_images, similarity_scores, similar_categories
133
+
134
+ def index_files_exist(index_file, metadata_file):
135
+ return os.path.exists(index_file) and os.path.exists(metadata_file)
136
+
137
+ def main():
138
+ st.title("Image Classification and Similarity Search")
139
+
140
+ index_file = "faiss-d2-nn_index.pkl"
141
+ metadata_file = "faiss-d2-image_metadata.pkl"
142
+
143
+ if not index_files_exist(index_file, metadata_file):
144
+ st.warning("Index files not found. Creating new index...")
145
+ root_dir = "Dataset2" # Replace with your dataset path
146
+ index, image_paths, categories = load_and_index_images(root_dir)
147
+ save_index_and_metadata(index, image_paths, categories, index_file, metadata_file)
148
+ st.success("Index created and saved successfully!")
149
+ else:
150
+ index, image_paths, categories = load_index_and_metadata(index_file, metadata_file)
151
+ st.success("Index loaded successfully!")
152
+
153
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
154
+
155
+ if uploaded_file is not None:
156
+ image = Image.open(uploaded_file).convert('RGB')
157
+ query_features = extract_features(image)
158
+
159
+ # Search for similar images
160
+ similar_images, similarities, similar_categories = search_similar_images(index, image_paths, categories, query_features, k=50)
161
+
162
+ # Get the predicted class (most common category among top 5 similar images)
163
+ predicted_class = max(set(similar_categories[:5]), key=similar_categories[:5].count)
164
+
165
+ # Display query and matched image
166
+ col1, col2 = st.columns(2)
167
+ with col1:
168
+ st.subheader("Query Image")
169
+ st.image(image, caption="Uploaded Image", use_column_width=True)
170
+ st.write(f"Image ID: {uploaded_file.name}")
171
+ with col2:
172
+ if similar_images:
173
+ st.subheader("Matched Image")
174
+ matched_image_path = similar_images[0]
175
+ st.image(Image.open(matched_image_path),
176
+ caption=f"Matched Image (Similarity: {similarities[0]:.2f})",
177
+ use_column_width=True)
178
+ st.write(f"Image ID: {os.path.basename(matched_image_path)}")
179
+ else:
180
+ st.write("No matched image found")
181
+
182
+ st.subheader(f"Product Category: {predicted_class}")
183
+
184
+ similarity_threshold = st.slider("Similarity threshold", min_value=0.0, max_value=1.0, value=0.5, step=0.05)
185
+
186
+ # Filter results based on similarity threshold and predicted class, and remove duplicates
187
+ query_file_name = uploaded_file.name
188
+ seen_file_names = set([query_file_name]) # Add query image to seen set
189
+ filtered_results = []
190
+ for img, sim, cat in zip(similar_images[1:], similarities[1:], similar_categories[1:]): # Start from index 1
191
+ file_name = os.path.basename(img)
192
+ if sim >= similarity_threshold and cat == predicted_class and file_name not in seen_file_names:
193
+ filtered_results.append((img, sim))
194
+ seen_file_names.add(file_name)
195
+
196
+ # Rest of the code remains the same
197
+ if filtered_results:
198
+ max_images = len(filtered_results)
199
+ num_display = st.slider("Number of similar images to display", min_value=0, max_value=max_images, value=min(20, max_images))
200
+
201
+ st.subheader("Similar Images")
202
+ st.info(f"Displaying {num_display} out of {max_images} unique similar images found for the uploaded query image.")
203
+
204
+ # Create a grid for displaying similar images
205
+ num_cols = 5
206
+ num_rows = (num_display + num_cols - 1) // num_cols
207
+
208
+ for row in range(num_rows):
209
+ cols = st.columns(num_cols)
210
+ for col in range(num_cols):
211
+ idx = row * num_cols + col
212
+ if idx < num_display:
213
+ img_path, sim = filtered_results[idx]
214
+ with cols[col]:
215
+ st.image(Image.open(img_path), use_column_width=True)
216
+ st.write(f"Similarity: {sim:.2f}")
217
+ st.write(f"Image ID: {os.path.basename(img_path)}")
218
+
219
+ else:
220
+ st.info("No similar images found above the similarity threshold in the predicted class.")
221
+
222
+ if __name__ == "__main__":
223
+ main()