import gradio as gr from transformers import CLIPProcessor, CLIPModel from PIL import Image import torch import pickle from pathlib import Path import os import spaces # Load model/processor model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14") processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") model.eval() DATASET_DIR = Path("dataset") CACHE_FILE = "cache.pkl" # Define supported image formats IMAGE_EXTENSIONS = ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.gif", "*.webp", "*.tiff", "*.tif"] def get_all_image_files(): """Get all image files from dataset directory""" image_files = [] for ext in IMAGE_EXTENSIONS: image_files.extend(DATASET_DIR.glob(ext)) image_files.extend(DATASET_DIR.glob(ext.upper())) # Also check uppercase return image_files def get_embedding(image: Image.Image, device="cpu"): # Use CLIP's built-in preprocessing inputs = processor(images=image, return_tensors="pt").to(device) model_device = model.to(device) with torch.no_grad(): emb = model_device.get_image_features(**inputs) # L2 normalize the embeddings emb = emb / emb.norm(p=2, dim=-1, keepdim=True) return emb @spaces.GPU def get_reference_embeddings(): # Get all current image files current_image_files = get_all_image_files() current_images = set(img_path.name for img_path in current_image_files) # Load existing cache if it exists cached_embeddings = {} if os.path.exists(CACHE_FILE): with open(CACHE_FILE, "rb") as f: cached_embeddings = pickle.load(f) # Check if cache is up to date cached_images = set(cached_embeddings.keys()) # If cache is missing images or has extra images, rebuild if current_images != cached_images: print(f"Cache outdated. Current: {len(current_images)}, Cached: {len(cached_images)}") embeddings = {} device = "cuda" if torch.cuda.is_available() else "cpu" for img_path in current_image_files: print(f"Processing {img_path.name}...") try: img = Image.open(img_path).convert("RGB") emb = get_embedding(img, device=device) embeddings[img_path.name] = emb.cpu() except Exception as e: print(f"Error processing {img_path.name}: {e}") continue # Save updated cache with open(CACHE_FILE, "wb") as f: pickle.dump(embeddings, f) print(f"Cache updated with {len(embeddings)} images") return embeddings else: print(f"Using cached embeddings for {len(cached_embeddings)} images") return cached_embeddings reference_embeddings = get_reference_embeddings() @spaces.GPU def search_similar(query_img): # Refresh embeddings to catch any new images global reference_embeddings reference_embeddings = get_reference_embeddings() query_emb = get_embedding(query_img, device="cuda") results = [] for name, ref_emb in reference_embeddings.items(): # Move reference embedding to same device as query ref_emb_gpu = ref_emb.to("cuda") # Compute cosine similarity sim = torch.nn.functional.cosine_similarity(query_emb, ref_emb_gpu, dim=1).item() results.append((name, sim)) results.sort(key=lambda x: x[1], reverse=True) # Filter out low similarity results (adjust threshold as needed) SIMILARITY_THRESHOLD = 0.2 # Only show results above 20% similarity filtered_results = [(name, score) for name, score in results if score > SIMILARITY_THRESHOLD] if not filtered_results: return [("No similar images found", "No matches above similarity threshold")] # Return top 5 results return [(f"dataset/{name}", f"Score: {score:.4f}") for name, score in filtered_results[:5]] @spaces.GPU def add_image(name: str, image): if not name.strip(): return "Please provide a valid image name." # Save as PNG to preserve quality for all input formats path = DATASET_DIR / f"{name}.png" image.save(path, "PNG") # Use GPU for consistency if available device = "cuda" if torch.cuda.is_available() else "cpu" emb = get_embedding(image, device=device) # Add to current embeddings and save cache reference_embeddings[f"{name}.png"] = emb.cpu() with open(CACHE_FILE, "wb") as f: pickle.dump(reference_embeddings, f) return f"Image '{name}' added to dataset. Total images: {len(reference_embeddings)}" search_interface = gr.Interface(fn=search_similar, inputs=gr.Image(type="pil", label="Query Image"), outputs=gr.Gallery(label="Top Matches", columns=5), allow_flagging="never") add_interface = gr.Interface(fn=add_image, inputs=[gr.Text(label="Image Name"), gr.Image(type="pil", label="Product Image")], outputs="text", allow_flagging="never") demo = gr.TabbedInterface([search_interface, add_interface], tab_names=["Search", "Add Product"]) demo.launch(mcp_server=True)