Streetmarkets commited on
Commit
1c4aa87
·
verified ·
1 Parent(s): 792840c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -12
app.py CHANGED
@@ -269,10 +269,61 @@ def predict_batch(images, urls):
269
 
270
  return batch_results
271
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
  # Clear function
273
  def clear_fields():
274
- # return None, "", None, ""
275
- return None, ""
276
  # Gradio interface
277
  title = "Fashion Item Classifier with Marqo-FashionSigLIP"
278
  description = "Upload an image or provide a URL of a fashion item to classify it using [Marqo-FashionSigLIP](https://huggingface.co/Marqo/marqo-fashionSigLIP)!"
@@ -292,19 +343,18 @@ with gr.Blocks() as demo:
292
  with gr.Column(scale=2):
293
  input_image = gr.Image(type="pil", label="Upload Fashion Item Image", height=312)
294
  input_url = gr.Textbox(label="Or provide an image URL")
295
- # input_images = gr.Image(type="pil", label="Upload Fashion Item Images", height=312)
296
- # input_urls = gr.Textbox(label="Or provide image URLs (comma-separated)", lines=2)
 
297
  with gr.Row():
298
- predict_button = gr.Button("Classify")
299
- # predict_batch_button = gr.Button("Classify Batch")
300
- clear_button = gr.Button("Clear")
301
- gr.Markdown("Or click on one of the images below to classify it:")
302
  gr.Examples(examples=examples, inputs=input_image)
303
  output_label = gr.JSON(label="Top Categories")
304
- # output_batch_label = gr.JSON(label="Top Categories for Each Image")
305
- predict_button.click(predict, inputs=[input_image, input_url], outputs=[input_image, output_label])
306
- # predict_batch_button.click(predict_batch, inputs=[input_images, input_urls], outputs=output_batch_label)
307
- # clear_button.click(clear_fields, outputs=[input_image, input_url, input_images, input_urls])
308
 
309
  # Launch the interface
310
  demo.launch()
 
269
 
270
  return batch_results
271
 
272
+ # Fonction de prédiction avec texte
273
+ def predict_with_text(image, url, text_prompt):
274
+ if url:
275
+ response = requests.get(url)
276
+ image = Image.open(BytesIO(response.content))
277
+
278
+ processed_image = preprocess_val(image).unsqueeze(0).to(device)
279
+
280
+ # Encoder l'image
281
+ with torch.no_grad(), torch.amp.autocast(device_type=device):
282
+ image_features = model.encode_image(processed_image)
283
+ image_features /= image_features.norm(dim=-1, keepdim=True)
284
+
285
+ # Encoder le texte fourni par l'utilisateur
286
+ user_text = tokenizer([text_prompt]).to(device)
287
+ user_text_features = model.encode_text(user_text)
288
+ user_text_features /= user_text_features.norm(dim=-1, keepdim=True)
289
+
290
+ # Combiner les caractéristiques de l'image et du texte (moyenne pondérée)
291
+ combined_features = 0.7 * image_features + 0.3 * user_text_features
292
+ combined_features /= combined_features.norm(dim=-1, keepdim=True)
293
+
294
+ # Calculer les probabilités avec les caractéristiques combinées
295
+ text_probs = (100 * combined_features @ text_features.T).softmax(dim=-1)
296
+
297
+ sorted_confidences = sorted(
298
+ {items[i]: float(text_probs[0, i]) for i in range(len(items))}.items(),
299
+ key=lambda x: x[1],
300
+ reverse=True
301
+ )
302
+
303
+ # Inclure les IDs de catégorie dans la réponse
304
+ top_10_categories = [
305
+ {
306
+ "category_name": category["name"],
307
+ "id": category["id"],
308
+ "confidence": confidence
309
+ }
310
+ for category_name, confidence in sorted_confidences[:10]
311
+ for category in categories if category["name"] == category_name
312
+ ]
313
+
314
+ return image, top_10_categories
315
+
316
+ # Fonction de prédiction combinée qui choisit la méthode appropriée
317
+ def predict_combined(image, url, text_prompt=""):
318
+ if text_prompt and text_prompt.strip():
319
+ return predict_with_text(image, url, text_prompt)
320
+ else:
321
+ return predict(image, url)
322
+
323
  # Clear function
324
  def clear_fields():
325
+ return None, "", "", None, ""
326
+
327
  # Gradio interface
328
  title = "Fashion Item Classifier with Marqo-FashionSigLIP"
329
  description = "Upload an image or provide a URL of a fashion item to classify it using [Marqo-FashionSigLIP](https://huggingface.co/Marqo/marqo-fashionSigLIP)!"
 
343
  with gr.Column(scale=2):
344
  input_image = gr.Image(type="pil", label="Upload Fashion Item Image", height=312)
345
  input_url = gr.Textbox(label="Or provide an image URL")
346
+ input_text = gr.Textbox(label="Ajouter une description textuelle (optionnel)", placeholder="Ex: Robe d'été fleurie pour femme")
347
+ input_images = gr.Image(type="pil", label="Upload Fashion Item Images", height=312)
348
+ input_urls = gr.Textbox(label="Or provide image URLs (comma-separated)", lines=2)
349
  with gr.Row():
350
+ predict_button = gr.Button("Classifier")
351
+ clear_button = gr.Button("Effacer")
352
+ gr.Markdown("Ou cliquez sur l'une des images ci-dessous pour la classifier:")
 
353
  gr.Examples(examples=examples, inputs=input_image)
354
  output_label = gr.JSON(label="Top Categories")
355
+ output_batch_label = gr.JSON(label="Top Categories for Each Image")
356
+ predict_button.click(predict_combined, inputs=[input_image, input_url, input_text], outputs=[input_image, output_label])
357
+ clear_button.click(clear_fields, outputs=[input_image, input_url, input_text, input_images, input_urls])
 
358
 
359
  # Launch the interface
360
  demo.launch()