|
import gradio as gr |
|
from transformers import CLIPProcessor, CLIPModel |
|
from PIL import Image |
|
import torch |
|
import pickle |
|
from pathlib import Path |
|
import os |
|
import spaces |
|
|
|
|
|
model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14") |
|
processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") |
|
model.eval() |
|
|
|
DATASET_DIR = Path("dataset") |
|
CACHE_FILE = "cache.pkl" |
|
|
|
|
|
IMAGE_EXTENSIONS = ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.gif", "*.webp", "*.tiff", "*.tif"] |
|
|
|
def get_all_image_files(): |
|
"""Get all image files from dataset directory""" |
|
image_files = [] |
|
for ext in IMAGE_EXTENSIONS: |
|
image_files.extend(DATASET_DIR.glob(ext)) |
|
image_files.extend(DATASET_DIR.glob(ext.upper())) |
|
return image_files |
|
|
|
def get_embedding(image: Image.Image, device="cpu"): |
|
|
|
inputs = processor(images=image, return_tensors="pt").to(device) |
|
model_device = model.to(device) |
|
with torch.no_grad(): |
|
emb = model_device.get_image_features(**inputs) |
|
|
|
emb = emb / emb.norm(p=2, dim=-1, keepdim=True) |
|
return emb |
|
|
|
@spaces.GPU |
|
def get_reference_embeddings(): |
|
|
|
current_image_files = get_all_image_files() |
|
current_images = set(img_path.name for img_path in current_image_files) |
|
|
|
|
|
cached_embeddings = {} |
|
if os.path.exists(CACHE_FILE): |
|
with open(CACHE_FILE, "rb") as f: |
|
cached_embeddings = pickle.load(f) |
|
|
|
|
|
cached_images = set(cached_embeddings.keys()) |
|
|
|
|
|
if current_images != cached_images: |
|
print(f"Cache outdated. Current: {len(current_images)}, Cached: {len(cached_images)}") |
|
embeddings = {} |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
for img_path in current_image_files: |
|
print(f"Processing {img_path.name}...") |
|
try: |
|
img = Image.open(img_path).convert("RGB") |
|
emb = get_embedding(img, device=device) |
|
embeddings[img_path.name] = emb.cpu() |
|
except Exception as e: |
|
print(f"Error processing {img_path.name}: {e}") |
|
continue |
|
|
|
|
|
with open(CACHE_FILE, "wb") as f: |
|
pickle.dump(embeddings, f) |
|
print(f"Cache updated with {len(embeddings)} images") |
|
return embeddings |
|
else: |
|
print(f"Using cached embeddings for {len(cached_embeddings)} images") |
|
return cached_embeddings |
|
|
|
reference_embeddings = get_reference_embeddings() |
|
|
|
@spaces.GPU |
|
def search_similar(query_img): |
|
|
|
global reference_embeddings |
|
reference_embeddings = get_reference_embeddings() |
|
|
|
query_emb = get_embedding(query_img, device="cuda") |
|
results = [] |
|
|
|
for name, ref_emb in reference_embeddings.items(): |
|
|
|
ref_emb_gpu = ref_emb.to("cuda") |
|
|
|
sim = torch.nn.functional.cosine_similarity(query_emb, ref_emb_gpu, dim=1).item() |
|
results.append((name, sim)) |
|
|
|
results.sort(key=lambda x: x[1], reverse=True) |
|
|
|
|
|
SIMILARITY_THRESHOLD = 0.2 |
|
filtered_results = [(name, score) for name, score in results if score > SIMILARITY_THRESHOLD] |
|
|
|
if not filtered_results: |
|
return [("No similar images found", "No matches above similarity threshold")] |
|
|
|
|
|
return [(f"dataset/{name}", f"Score: {score:.4f}") for name, score in filtered_results[:5]] |
|
|
|
@spaces.GPU |
|
def add_image(name: str, image): |
|
if not name.strip(): |
|
return "Please provide a valid image name." |
|
|
|
|
|
path = DATASET_DIR / f"{name}.png" |
|
image.save(path, "PNG") |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
emb = get_embedding(image, device=device) |
|
|
|
|
|
reference_embeddings[f"{name}.png"] = emb.cpu() |
|
|
|
with open(CACHE_FILE, "wb") as f: |
|
pickle.dump(reference_embeddings, f) |
|
|
|
return f"Image '{name}' added to dataset. Total images: {len(reference_embeddings)}" |
|
|
|
search_interface = gr.Interface(fn=search_similar, |
|
inputs=gr.Image(type="pil", label="Query Image"), |
|
outputs=gr.Gallery(label="Top Matches", columns=5), |
|
allow_flagging="never") |
|
|
|
add_interface = gr.Interface(fn=add_image, |
|
inputs=[gr.Text(label="Image Name"), gr.Image(type="pil", label="Product Image")], |
|
outputs="text", |
|
allow_flagging="never") |
|
|
|
demo = gr.TabbedInterface([search_interface, add_interface], tab_names=["Search", "Add Product"]) |
|
demo.launch(mcp_server=True) |