File size: 5,254 Bytes
fb6458d e6e631c fb6458d e6e631c 43a18b1 e6e631c fb6458d e6e631c c0e2011 e6e631c c0e2011 e6e631c a4d053b e6e631c e1286f2 c0e2011 e6e631c e1286f2 c0e2011 e1286f2 e6e631c e1286f2 a4d053b e1286f2 a4d053b e1286f2 c0e2011 e1286f2 c0e2011 e1286f2 e6e631c e1286f2 e6e631c a4d053b e6e631c a4d053b e6e631c a4d053b e6e631c e1286f2 e6e631c 30bbdee e6e631c e1286f2 c0e2011 a4d053b e1286f2 c0e2011 a4d053b e6e631c e1286f2 e6e631c a4d053b e6e631c c0e2011 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
import gradio as gr
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import torch
import pickle
from pathlib import Path
import os
import spaces
# Load model/processor
model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
model.eval()
DATASET_DIR = Path("dataset")
CACHE_FILE = "cache.pkl"
# Define supported image formats
IMAGE_EXTENSIONS = ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.gif", "*.webp", "*.tiff", "*.tif"]
def get_all_image_files():
"""Get all image files from dataset directory"""
image_files = []
for ext in IMAGE_EXTENSIONS:
image_files.extend(DATASET_DIR.glob(ext))
image_files.extend(DATASET_DIR.glob(ext.upper())) # Also check uppercase
return image_files
def get_embedding(image: Image.Image, device="cpu"):
# Use CLIP's built-in preprocessing
inputs = processor(images=image, return_tensors="pt").to(device)
model_device = model.to(device)
with torch.no_grad():
emb = model_device.get_image_features(**inputs)
# L2 normalize the embeddings
emb = emb / emb.norm(p=2, dim=-1, keepdim=True)
return emb
@spaces.GPU
def get_reference_embeddings():
# Get all current image files
current_image_files = get_all_image_files()
current_images = set(img_path.name for img_path in current_image_files)
# Load existing cache if it exists
cached_embeddings = {}
if os.path.exists(CACHE_FILE):
with open(CACHE_FILE, "rb") as f:
cached_embeddings = pickle.load(f)
# Check if cache is up to date
cached_images = set(cached_embeddings.keys())
# If cache is missing images or has extra images, rebuild
if current_images != cached_images:
print(f"Cache outdated. Current: {len(current_images)}, Cached: {len(cached_images)}")
embeddings = {}
device = "cuda" if torch.cuda.is_available() else "cpu"
for img_path in current_image_files:
print(f"Processing {img_path.name}...")
try:
img = Image.open(img_path).convert("RGB")
emb = get_embedding(img, device=device)
embeddings[img_path.name] = emb.cpu()
except Exception as e:
print(f"Error processing {img_path.name}: {e}")
continue
# Save updated cache
with open(CACHE_FILE, "wb") as f:
pickle.dump(embeddings, f)
print(f"Cache updated with {len(embeddings)} images")
return embeddings
else:
print(f"Using cached embeddings for {len(cached_embeddings)} images")
return cached_embeddings
reference_embeddings = get_reference_embeddings()
@spaces.GPU
def search_similar(query_img):
# Refresh embeddings to catch any new images
global reference_embeddings
reference_embeddings = get_reference_embeddings()
query_emb = get_embedding(query_img, device="cuda")
results = []
for name, ref_emb in reference_embeddings.items():
# Move reference embedding to same device as query
ref_emb_gpu = ref_emb.to("cuda")
# Compute cosine similarity
sim = torch.nn.functional.cosine_similarity(query_emb, ref_emb_gpu, dim=1).item()
results.append((name, sim))
results.sort(key=lambda x: x[1], reverse=True)
# Filter out low similarity results (adjust threshold as needed)
SIMILARITY_THRESHOLD = 0.2 # Only show results above 20% similarity
filtered_results = [(name, score) for name, score in results if score > SIMILARITY_THRESHOLD]
if not filtered_results:
return [("No similar images found", "No matches above similarity threshold")]
# Return top 5 results
return [(f"dataset/{name}", f"Score: {score:.4f}") for name, score in filtered_results[:5]]
@spaces.GPU
def add_image(name: str, image):
if not name.strip():
return "Please provide a valid image name."
# Save as PNG to preserve quality for all input formats
path = DATASET_DIR / f"{name}.png"
image.save(path, "PNG")
# Use GPU for consistency if available
device = "cuda" if torch.cuda.is_available() else "cpu"
emb = get_embedding(image, device=device)
# Add to current embeddings and save cache
reference_embeddings[f"{name}.png"] = emb.cpu()
with open(CACHE_FILE, "wb") as f:
pickle.dump(reference_embeddings, f)
return f"Image '{name}' added to dataset. Total images: {len(reference_embeddings)}"
search_interface = gr.Interface(fn=search_similar,
inputs=gr.Image(type="pil", label="Query Image"),
outputs=gr.Gallery(label="Top Matches", columns=5),
allow_flagging="never")
add_interface = gr.Interface(fn=add_image,
inputs=[gr.Text(label="Image Name"), gr.Image(type="pil", label="Product Image")],
outputs="text",
allow_flagging="never")
demo = gr.TabbedInterface([search_interface, add_interface], tab_names=["Search", "Add Product"])
demo.launch(mcp_server=True) |