Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -269,10 +269,61 @@ def predict_batch(images, urls):
|
|
269 |
|
270 |
return batch_results
|
271 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
272 |
# Clear function
|
273 |
def clear_fields():
|
274 |
-
|
275 |
-
|
276 |
# Gradio interface
|
277 |
title = "Fashion Item Classifier with Marqo-FashionSigLIP"
|
278 |
description = "Upload an image or provide a URL of a fashion item to classify it using [Marqo-FashionSigLIP](https://huggingface.co/Marqo/marqo-fashionSigLIP)!"
|
@@ -292,19 +343,18 @@ with gr.Blocks() as demo:
|
|
292 |
with gr.Column(scale=2):
|
293 |
input_image = gr.Image(type="pil", label="Upload Fashion Item Image", height=312)
|
294 |
input_url = gr.Textbox(label="Or provide an image URL")
|
295 |
-
|
296 |
-
|
|
|
297 |
with gr.Row():
|
298 |
-
predict_button = gr.Button("
|
299 |
-
|
300 |
-
|
301 |
-
gr.Markdown("Or click on one of the images below to classify it:")
|
302 |
gr.Examples(examples=examples, inputs=input_image)
|
303 |
output_label = gr.JSON(label="Top Categories")
|
304 |
-
|
305 |
-
predict_button.click(
|
306 |
-
|
307 |
-
# clear_button.click(clear_fields, outputs=[input_image, input_url, input_images, input_urls])
|
308 |
|
309 |
# Launch the interface
|
310 |
demo.launch()
|
|
|
269 |
|
270 |
return batch_results
|
271 |
|
272 |
+
# Fonction de prédiction avec texte
|
273 |
+
def predict_with_text(image, url, text_prompt):
|
274 |
+
if url:
|
275 |
+
response = requests.get(url)
|
276 |
+
image = Image.open(BytesIO(response.content))
|
277 |
+
|
278 |
+
processed_image = preprocess_val(image).unsqueeze(0).to(device)
|
279 |
+
|
280 |
+
# Encoder l'image
|
281 |
+
with torch.no_grad(), torch.amp.autocast(device_type=device):
|
282 |
+
image_features = model.encode_image(processed_image)
|
283 |
+
image_features /= image_features.norm(dim=-1, keepdim=True)
|
284 |
+
|
285 |
+
# Encoder le texte fourni par l'utilisateur
|
286 |
+
user_text = tokenizer([text_prompt]).to(device)
|
287 |
+
user_text_features = model.encode_text(user_text)
|
288 |
+
user_text_features /= user_text_features.norm(dim=-1, keepdim=True)
|
289 |
+
|
290 |
+
# Combiner les caractéristiques de l'image et du texte (moyenne pondérée)
|
291 |
+
combined_features = 0.7 * image_features + 0.3 * user_text_features
|
292 |
+
combined_features /= combined_features.norm(dim=-1, keepdim=True)
|
293 |
+
|
294 |
+
# Calculer les probabilités avec les caractéristiques combinées
|
295 |
+
text_probs = (100 * combined_features @ text_features.T).softmax(dim=-1)
|
296 |
+
|
297 |
+
sorted_confidences = sorted(
|
298 |
+
{items[i]: float(text_probs[0, i]) for i in range(len(items))}.items(),
|
299 |
+
key=lambda x: x[1],
|
300 |
+
reverse=True
|
301 |
+
)
|
302 |
+
|
303 |
+
# Inclure les IDs de catégorie dans la réponse
|
304 |
+
top_10_categories = [
|
305 |
+
{
|
306 |
+
"category_name": category["name"],
|
307 |
+
"id": category["id"],
|
308 |
+
"confidence": confidence
|
309 |
+
}
|
310 |
+
for category_name, confidence in sorted_confidences[:10]
|
311 |
+
for category in categories if category["name"] == category_name
|
312 |
+
]
|
313 |
+
|
314 |
+
return image, top_10_categories
|
315 |
+
|
316 |
+
# Fonction de prédiction combinée qui choisit la méthode appropriée
|
317 |
+
def predict_combined(image, url, text_prompt=""):
|
318 |
+
if text_prompt and text_prompt.strip():
|
319 |
+
return predict_with_text(image, url, text_prompt)
|
320 |
+
else:
|
321 |
+
return predict(image, url)
|
322 |
+
|
323 |
# Clear function
|
324 |
def clear_fields():
|
325 |
+
return None, "", "", None, ""
|
326 |
+
|
327 |
# Gradio interface
|
328 |
title = "Fashion Item Classifier with Marqo-FashionSigLIP"
|
329 |
description = "Upload an image or provide a URL of a fashion item to classify it using [Marqo-FashionSigLIP](https://huggingface.co/Marqo/marqo-fashionSigLIP)!"
|
|
|
343 |
with gr.Column(scale=2):
|
344 |
input_image = gr.Image(type="pil", label="Upload Fashion Item Image", height=312)
|
345 |
input_url = gr.Textbox(label="Or provide an image URL")
|
346 |
+
input_text = gr.Textbox(label="Ajouter une description textuelle (optionnel)", placeholder="Ex: Robe d'été fleurie pour femme")
|
347 |
+
input_images = gr.Image(type="pil", label="Upload Fashion Item Images", height=312)
|
348 |
+
input_urls = gr.Textbox(label="Or provide image URLs (comma-separated)", lines=2)
|
349 |
with gr.Row():
|
350 |
+
predict_button = gr.Button("Classifier")
|
351 |
+
clear_button = gr.Button("Effacer")
|
352 |
+
gr.Markdown("Ou cliquez sur l'une des images ci-dessous pour la classifier:")
|
|
|
353 |
gr.Examples(examples=examples, inputs=input_image)
|
354 |
output_label = gr.JSON(label="Top Categories")
|
355 |
+
output_batch_label = gr.JSON(label="Top Categories for Each Image")
|
356 |
+
predict_button.click(predict_combined, inputs=[input_image, input_url, input_text], outputs=[input_image, output_label])
|
357 |
+
clear_button.click(clear_fields, outputs=[input_image, input_url, input_text, input_images, input_urls])
|
|
|
358 |
|
359 |
# Launch the interface
|
360 |
demo.launch()
|