File size: 5,254 Bytes
fb6458d
e6e631c
 
 
 
 
 
 
fb6458d
e6e631c
43a18b1
 
e6e631c
fb6458d
e6e631c
 
 
c0e2011
 
 
 
 
 
 
 
 
 
 
e6e631c
c0e2011
e6e631c
 
 
 
a4d053b
e6e631c
 
e1286f2
c0e2011
e6e631c
e1286f2
c0e2011
 
e1286f2
 
 
e6e631c
 
e1286f2
a4d053b
e1286f2
 
a4d053b
e1286f2
 
 
 
 
 
c0e2011
e1286f2
c0e2011
 
 
 
 
 
 
e1286f2
 
 
 
 
 
 
 
 
e6e631c
 
 
 
 
e1286f2
 
 
 
e6e631c
 
a4d053b
e6e631c
a4d053b
 
 
 
e6e631c
a4d053b
e6e631c
e1286f2
 
 
 
 
 
 
 
 
 
e6e631c
30bbdee
e6e631c
e1286f2
 
 
c0e2011
 
 
a4d053b
 
 
 
 
e1286f2
c0e2011
a4d053b
e6e631c
 
e1286f2
 
e6e631c
 
 
a4d053b
e6e631c
 
 
 
 
 
 
 
c0e2011
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import gradio as gr
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import torch
import pickle
from pathlib import Path
import os
import spaces

# Load model/processor
model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
model.eval()

DATASET_DIR = Path("dataset")
CACHE_FILE = "cache.pkl"

# Define supported image formats
IMAGE_EXTENSIONS = ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.gif", "*.webp", "*.tiff", "*.tif"]

def get_all_image_files():
    """Get all image files from dataset directory"""
    image_files = []
    for ext in IMAGE_EXTENSIONS:
        image_files.extend(DATASET_DIR.glob(ext))
        image_files.extend(DATASET_DIR.glob(ext.upper()))  # Also check uppercase
    return image_files

def get_embedding(image: Image.Image, device="cpu"):
    # Use CLIP's built-in preprocessing
    inputs = processor(images=image, return_tensors="pt").to(device)
    model_device = model.to(device)
    with torch.no_grad():
        emb = model_device.get_image_features(**inputs)
    # L2 normalize the embeddings
    emb = emb / emb.norm(p=2, dim=-1, keepdim=True)
    return emb

@spaces.GPU
def get_reference_embeddings():
    # Get all current image files
    current_image_files = get_all_image_files()
    current_images = set(img_path.name for img_path in current_image_files)
    
    # Load existing cache if it exists
    cached_embeddings = {}
    if os.path.exists(CACHE_FILE):
        with open(CACHE_FILE, "rb") as f:
            cached_embeddings = pickle.load(f)
    
    # Check if cache is up to date
    cached_images = set(cached_embeddings.keys())
    
    # If cache is missing images or has extra images, rebuild
    if current_images != cached_images:
        print(f"Cache outdated. Current: {len(current_images)}, Cached: {len(cached_images)}")
        embeddings = {}
        device = "cuda" if torch.cuda.is_available() else "cpu"
        
        for img_path in current_image_files:
            print(f"Processing {img_path.name}...")
            try:
                img = Image.open(img_path).convert("RGB")
                emb = get_embedding(img, device=device)
                embeddings[img_path.name] = emb.cpu()
            except Exception as e:
                print(f"Error processing {img_path.name}: {e}")
                continue
        
        # Save updated cache
        with open(CACHE_FILE, "wb") as f:
            pickle.dump(embeddings, f)
        print(f"Cache updated with {len(embeddings)} images")
        return embeddings
    else:
        print(f"Using cached embeddings for {len(cached_embeddings)} images")
        return cached_embeddings

reference_embeddings = get_reference_embeddings()

@spaces.GPU
def search_similar(query_img):
    # Refresh embeddings to catch any new images
    global reference_embeddings
    reference_embeddings = get_reference_embeddings()
    
    query_emb = get_embedding(query_img, device="cuda")
    results = []
    
    for name, ref_emb in reference_embeddings.items():
        # Move reference embedding to same device as query
        ref_emb_gpu = ref_emb.to("cuda")
        # Compute cosine similarity
        sim = torch.nn.functional.cosine_similarity(query_emb, ref_emb_gpu, dim=1).item()
        results.append((name, sim))
    
    results.sort(key=lambda x: x[1], reverse=True)
    
    # Filter out low similarity results (adjust threshold as needed)
    SIMILARITY_THRESHOLD = 0.2  # Only show results above 20% similarity
    filtered_results = [(name, score) for name, score in results if score > SIMILARITY_THRESHOLD]
    
    if not filtered_results:
        return [("No similar images found", "No matches above similarity threshold")]
    
    # Return top 5 results
    return [(f"dataset/{name}", f"Score: {score:.4f}") for name, score in filtered_results[:5]]

@spaces.GPU
def add_image(name: str, image):
    if not name.strip():
        return "Please provide a valid image name."
    
    # Save as PNG to preserve quality for all input formats
    path = DATASET_DIR / f"{name}.png"
    image.save(path, "PNG")
    
    # Use GPU for consistency if available
    device = "cuda" if torch.cuda.is_available() else "cpu"
    emb = get_embedding(image, device=device)
    
    # Add to current embeddings and save cache
    reference_embeddings[f"{name}.png"] = emb.cpu()
    
    with open(CACHE_FILE, "wb") as f:
        pickle.dump(reference_embeddings, f)
    
    return f"Image '{name}' added to dataset. Total images: {len(reference_embeddings)}"

search_interface = gr.Interface(fn=search_similar,
                                inputs=gr.Image(type="pil", label="Query Image"),
                                outputs=gr.Gallery(label="Top Matches", columns=5),
                                allow_flagging="never")

add_interface = gr.Interface(fn=add_image,
                             inputs=[gr.Text(label="Image Name"), gr.Image(type="pil", label="Product Image")],
                             outputs="text",
                             allow_flagging="never")

demo = gr.TabbedInterface([search_interface, add_interface], tab_names=["Search", "Add Product"])
demo.launch(mcp_server=True)