Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -111,6 +111,55 @@ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
|
|
| 111 |
def change_rank_default(concept_name):
|
| 112 |
return RANKS_MAP.get(concept_name, 30)
|
| 113 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
@spaces.GPU
|
| 115 |
def get_image_embeds(pil_image, model=clip_model, preproc=preprocess, dev=device):
|
| 116 |
"""Get CLIP image embeddings for a given PIL image"""
|
|
@@ -464,9 +513,21 @@ Following the algorithm proposed in IP-Composer: Semantic Composition of Visual
|
|
| 464 |
inputs=[concept_name3],
|
| 465 |
outputs=[rank3]
|
| 466 |
)
|
| 467 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 468 |
if __name__ == "__main__":
|
| 469 |
-
demo.launch()
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
|
|
|
|
| 111 |
def change_rank_default(concept_name):
|
| 112 |
return RANKS_MAP.get(concept_name, 30)
|
| 113 |
|
| 114 |
+
@spaces.GPU
|
| 115 |
+
def match_image_to_concept(image):
|
| 116 |
+
"""
|
| 117 |
+
Match an uploaded image to the closest concept type using CLIP embeddings
|
| 118 |
+
"""
|
| 119 |
+
if image is None:
|
| 120 |
+
return None
|
| 121 |
+
|
| 122 |
+
# Get image embeddings
|
| 123 |
+
img_pil = Image.fromarray(image).convert("RGB")
|
| 124 |
+
img_embed = get_image_embeds(img_pil, clip_model, preprocess, device)
|
| 125 |
+
|
| 126 |
+
# Calculate similarity to each concept
|
| 127 |
+
similarities = {}
|
| 128 |
+
for concept_name, concept_file in CONCEPTS_MAP.items():
|
| 129 |
+
try:
|
| 130 |
+
# Load concept embeddings
|
| 131 |
+
embeds_path = f"./IP_Composer/text_embeddings/{concept_file}"
|
| 132 |
+
with open(embeds_path, "rb") as f:
|
| 133 |
+
concept_embeds = np.load(f)
|
| 134 |
+
|
| 135 |
+
# Calculate similarity to each text embedding
|
| 136 |
+
sim_scores = []
|
| 137 |
+
for embed in concept_embeds:
|
| 138 |
+
# Normalize both embeddings
|
| 139 |
+
img_embed_norm = img_embed / np.linalg.norm(img_embed)
|
| 140 |
+
text_embed_norm = embed / np.linalg.norm(embed)
|
| 141 |
+
|
| 142 |
+
# Calculate cosine similarity
|
| 143 |
+
similarity = np.dot(img_embed_norm.flatten(), text_embed_norm.flatten())
|
| 144 |
+
sim_scores.append(similarity)
|
| 145 |
+
|
| 146 |
+
# Use the average of top 5 similarities for better matching
|
| 147 |
+
sim_scores.sort(reverse=True)
|
| 148 |
+
top_similarities = sim_scores[:min(5, len(sim_scores))]
|
| 149 |
+
avg_similarity = sum(top_similarities) / len(top_similarities)
|
| 150 |
+
|
| 151 |
+
similarities[concept_name] = avg_similarity
|
| 152 |
+
except Exception as e:
|
| 153 |
+
print(f"Error processing concept {concept_name}: {e}")
|
| 154 |
+
|
| 155 |
+
# Return the concept with highest similarity
|
| 156 |
+
if similarities:
|
| 157 |
+
matched_concept = max(similarities.items(), key=lambda x: x[1])[0]
|
| 158 |
+
# Display a notification to the user
|
| 159 |
+
gr.Info(f"Image automatically matched to concept: {matched_concept}")
|
| 160 |
+
return matched_concept
|
| 161 |
+
return None
|
| 162 |
+
|
| 163 |
@spaces.GPU
|
| 164 |
def get_image_embeds(pil_image, model=clip_model, preproc=preprocess, dev=device):
|
| 165 |
"""Get CLIP image embeddings for a given PIL image"""
|
|
|
|
| 513 |
inputs=[concept_name3],
|
| 514 |
outputs=[rank3]
|
| 515 |
)
|
| 516 |
+
concept_image1.upload(
|
| 517 |
+
fn=match_image_to_concept,
|
| 518 |
+
inputs=[concept_image1],
|
| 519 |
+
outputs=[concept_name1]
|
| 520 |
+
)
|
| 521 |
+
concept_image2.upload(
|
| 522 |
+
fn=match_image_to_concept,
|
| 523 |
+
inputs=[concept_image2],
|
| 524 |
+
outputs=[concept_name2]
|
| 525 |
+
)
|
| 526 |
+
concept_image3.upload(
|
| 527 |
+
fn=match_image_to_concept,
|
| 528 |
+
inputs=[concept_image3],
|
| 529 |
+
outputs=[concept_name3]
|
| 530 |
+
)
|
| 531 |
+
|
| 532 |
if __name__ == "__main__":
|
| 533 |
+
demo.launch()
|
|
|
|
|
|
|
|
|