ImgSearch / app.py
AkinyemiAra's picture
Update app.py
c0e2011 verified
raw
history blame
5.25 kB
import gradio as gr
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import torch
import pickle
from pathlib import Path
import os
import spaces
# Load model/processor
model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
model.eval()
DATASET_DIR = Path("dataset")
CACHE_FILE = "cache.pkl"
# Define supported image formats
IMAGE_EXTENSIONS = ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.gif", "*.webp", "*.tiff", "*.tif"]
def get_all_image_files():
"""Get all image files from dataset directory"""
image_files = []
for ext in IMAGE_EXTENSIONS:
image_files.extend(DATASET_DIR.glob(ext))
image_files.extend(DATASET_DIR.glob(ext.upper())) # Also check uppercase
return image_files
def get_embedding(image: Image.Image, device="cpu"):
# Use CLIP's built-in preprocessing
inputs = processor(images=image, return_tensors="pt").to(device)
model_device = model.to(device)
with torch.no_grad():
emb = model_device.get_image_features(**inputs)
# L2 normalize the embeddings
emb = emb / emb.norm(p=2, dim=-1, keepdim=True)
return emb
@spaces.GPU
def get_reference_embeddings():
# Get all current image files
current_image_files = get_all_image_files()
current_images = set(img_path.name for img_path in current_image_files)
# Load existing cache if it exists
cached_embeddings = {}
if os.path.exists(CACHE_FILE):
with open(CACHE_FILE, "rb") as f:
cached_embeddings = pickle.load(f)
# Check if cache is up to date
cached_images = set(cached_embeddings.keys())
# If cache is missing images or has extra images, rebuild
if current_images != cached_images:
print(f"Cache outdated. Current: {len(current_images)}, Cached: {len(cached_images)}")
embeddings = {}
device = "cuda" if torch.cuda.is_available() else "cpu"
for img_path in current_image_files:
print(f"Processing {img_path.name}...")
try:
img = Image.open(img_path).convert("RGB")
emb = get_embedding(img, device=device)
embeddings[img_path.name] = emb.cpu()
except Exception as e:
print(f"Error processing {img_path.name}: {e}")
continue
# Save updated cache
with open(CACHE_FILE, "wb") as f:
pickle.dump(embeddings, f)
print(f"Cache updated with {len(embeddings)} images")
return embeddings
else:
print(f"Using cached embeddings for {len(cached_embeddings)} images")
return cached_embeddings
reference_embeddings = get_reference_embeddings()
@spaces.GPU
def search_similar(query_img):
# Refresh embeddings to catch any new images
global reference_embeddings
reference_embeddings = get_reference_embeddings()
query_emb = get_embedding(query_img, device="cuda")
results = []
for name, ref_emb in reference_embeddings.items():
# Move reference embedding to same device as query
ref_emb_gpu = ref_emb.to("cuda")
# Compute cosine similarity
sim = torch.nn.functional.cosine_similarity(query_emb, ref_emb_gpu, dim=1).item()
results.append((name, sim))
results.sort(key=lambda x: x[1], reverse=True)
# Filter out low similarity results (adjust threshold as needed)
SIMILARITY_THRESHOLD = 0.2 # Only show results above 20% similarity
filtered_results = [(name, score) for name, score in results if score > SIMILARITY_THRESHOLD]
if not filtered_results:
return [("No similar images found", "No matches above similarity threshold")]
# Return top 5 results
return [(f"dataset/{name}", f"Score: {score:.4f}") for name, score in filtered_results[:5]]
@spaces.GPU
def add_image(name: str, image):
if not name.strip():
return "Please provide a valid image name."
# Save as PNG to preserve quality for all input formats
path = DATASET_DIR / f"{name}.png"
image.save(path, "PNG")
# Use GPU for consistency if available
device = "cuda" if torch.cuda.is_available() else "cpu"
emb = get_embedding(image, device=device)
# Add to current embeddings and save cache
reference_embeddings[f"{name}.png"] = emb.cpu()
with open(CACHE_FILE, "wb") as f:
pickle.dump(reference_embeddings, f)
return f"Image '{name}' added to dataset. Total images: {len(reference_embeddings)}"
search_interface = gr.Interface(fn=search_similar,
inputs=gr.Image(type="pil", label="Query Image"),
outputs=gr.Gallery(label="Top Matches", columns=5),
allow_flagging="never")
add_interface = gr.Interface(fn=add_image,
inputs=[gr.Text(label="Image Name"), gr.Image(type="pil", label="Product Image")],
outputs="text",
allow_flagging="never")
demo = gr.TabbedInterface([search_interface, add_interface], tab_names=["Search", "Add Product"])
demo.launch(mcp_server=True)