import gradio as gr import open_clip import torch import requests import numpy as np from PIL import Image from io import BytesIO # Sidebar content sidebar_markdown = """ Note, this demo can classify 200 items. If you didn't find what you're looking for, reach out to us on our [Community](https://join.slack.com/t/marqo-community/shared_invite/zt-2iab0260n-QJrZLUSOJYUifVxf964Gdw) and request an item to be added. ## Documentation 📚 [Blog Post](https://www.marqo.ai/blog/search-model-for-fashion) 📝 [Use Case Blog Post](https://www.marqo.ai/blog/ecommerce-image-classification-with-marqo-fashionclip) ## Code 💻 [GitHub Repo](https://github.com/marqo-ai/marqo-FashionCLIP) 🤝 [Google Colab](https://colab.research.google.com/drive/1nq978xFJjJcnyrJ2aE5l82GHAXOvTmfd?usp=sharing) 🤗 [Hugging Face Collection](https://huggingface.co/collections/Marqo/marqo-fashionclip-and-marqo-fashionsiglip-66b43f2d09a06ad2368d4af6) """ # List of fashion items and their IDs categories = [ {"name": "Nettoyants visage", "id": 101}, {"name": "Exfoliants visage", "id": 102}, {"name": "Hydratants visage", "id": 103}, {"name": "Masques visage", "id": 104}, {"name": "Soins ciblés visage", "id": 105}, {"name": "Protection solaire visage", "id": 106}, {"name": "Nettoyants visage homme", "id": 107}, {"name": "Crèmes hydratantes homme", "id": 108}, {"name": "Soins après-rasage", "id": 109}, {"name": "Hydratants corps", "id": 110}, {"name": "Exfoliants corps", "id": 111}, {"name": "Soins fermeté & minceur", "id": 112}, {"name": "Auto-bronzants", "id": 113}, {"name": "Soins des mains", "id": 114}, {"name": "Soins des pieds", "id": 115}, {"name": "Hydratants corps homme", "id": 116}, {"name": "Déodorants corps homme", "id": 117}, {"name": "Shampoings", "id": 118}, {"name": "Après-shampoings", "id": 119}, {"name": "Masques capillaires", "id": 120}, {"name": "Huiles capillaires", "id": 121}, {"name": "Coiffants", "id": 122}, {"name": "Accessoires cheveux", "id": 123}, {"name": "Soins cheveux homme", "id": 124}, {"name": "Produits coiffants homme", "id": 125}, {"name": "Fond de teint", "id": 126}, {"name": "BB/CC crèmes", "id": 127}, {"name": "Poudres", "id": 128}, {"name": "Fards à paupières", "id": 129}, {"name": "Mascaras", "id": 130}, {"name": "Eyeliners", "id": 131}, {"name": "Rouges à lèvres", "id": 132}, {"name": "Gloss", "id": 133}, {"name": "Crayons à sourcils", "id": 134}, {"name": "Accessoires maquillage", "id": 135}, {"name": "Correcteurs teint homme", "id": 136}, {"name": "Poudres matifiantes homme", "id": 137}, {"name": "Parfums", "id": 138}, {"name": "Brumes corporelles", "id": 139}, {"name": "Huiles essentielles", "id": 140}, {"name": "Diffuseurs d'huiles", "id": 141}, {"name": "Bougies parfumées", "id": 142}, {"name": "Déodorants solides", "id": 143}, {"name": "Déodorants sprays", "id": 144}, {"name": "Savons solides", "id": 145}, {"name": "Savons liquides", "id": 146}, {"name": "Produits bain", "id": 147}, {"name": "Hygiène intime", "id": 148}, {"name": "Cups menstruelles", "id": 149}, {"name": "Produits zéro déchet", "id": 150}, {"name": "Brosses nettoyantes visage", "id": 151}, {"name": "Pinces à épiler", "id": 152}, {"name": "Trousse de voyage", "id": 153}, {"name": "Huiles de CBD", "id": 154}, {"name": "Cosmétiques au CBD", "id": 155}, {"name": "Infusions au CBD", "id": 156}, {"name": "Bonbons au CBD", "id": 157}, {"name": "Accessoires CBD", "id": 158}, {"name": "Robes femme", "id": 201}, {"name": "Tops femme", "id": 202}, {"name": "Chemisiers femme", "id": 203}, {"name": "T-shirts femme", "id": 204}, {"name": "Pulls femme", "id": 205}, {"name": "Jeans femme", "id": 206}, {"name": "Pantalons femme", "id": 207}, {"name": "Jupes femme", "id": 208}, {"name": "Shorts femme", "id": 209}, {"name": "Vestes femme", "id": 210}, {"name": "Manteaux femme", "id": 211}, {"name": "Maillots de bain femme", "id": 212}, {"name": "Lingerie femme", "id": 213}, {"name": "Chaussures femme", "id": 214}, {"name": "Sacs femme", "id": 215}, {"name": "Bijoux femme", "id": 216}, {"name": "Chemises homme", "id": 301}, {"name": "T-shirts homme", "id": 302}, {"name": "Polos homme", "id": 303}, {"name": "Pulls homme", "id": 304}, {"name": "Jeans homme", "id": 305}, {"name": "Pantalons homme", "id": 306}, {"name": "Shorts homme", "id": 307}, {"name": "Vestes homme", "id": 308}, {"name": "Manteaux homme", "id": 309}, {"name": "Costumes homme", "id": 310}, {"name": "Maillots de bain homme", "id": 311}, {"name": "Sous-vêtements homme", "id": 312}, {"name": "Chaussures homme", "id": 313}, {"name": "Accessoires homme", "id": 314}, {"name": "Montres homme", "id": 315}, {"name": "Vêtements bébé (0-2 ans)", "id": 401}, {"name": "T-shirts enfant", "id": 402}, {"name": "Pulls enfant", "id": 403}, {"name": "Pantalons enfant", "id": 404}, {"name": "Robes enfant", "id": 405}, {"name": "Jeans enfant", "id": 406}, {"name": "Vestes enfant", "id": 407}, {"name": "Pyjamas enfant", "id": 408}, {"name": "Chaussures enfant", "id": 409}, {"name": "Accessoires enfant", "id": 410}, {"name": "Vêtements de sport enfant", "id": 411}, {"name": "Maillots de bain enfant", "id": 412}, {"name": "Sous-vêtements enfant", "id": 413}, {"name": "Déguisements enfant", "id": 414}, {"name": "Cartables et sacs enfant", "id": 415}, # Chaussures Femme détaillées {"name": "Sneakers femme", "id": 217}, {"name": "Boots femme", "id": 218}, {"name": "Escarpins femme", "id": 219}, {"name": "Sandales femme", "id": 220}, {"name": "Ballerines femme", "id": 221}, {"name": "Mocassins femme", "id": 222}, {"name": "Bottines femme", "id": 223}, {"name": "Espadrilles femme", "id": 224}, {"name": "Mules femme", "id": 225}, {"name": "Chaussures de sport femme", "id": 226}, {"name": "Bottes hautes femme", "id": 227}, {"name": "Chaussures compensées femme", "id": 228}, # Chaussures Homme détaillées {"name": "Sneakers homme", "id": 316}, {"name": "Boots homme", "id": 317}, {"name": "Chaussures de ville homme", "id": 318}, {"name": "Mocassins homme", "id": 319}, {"name": "Sandales homme", "id": 320}, {"name": "Chaussures bateau homme", "id": 321}, {"name": "Bottines homme", "id": 322}, {"name": "Chaussures de sport homme", "id": 323}, {"name": "Espadrilles homme", "id": 324}, {"name": "Derbies homme", "id": 325}, {"name": "Richelieus homme", "id": 326}, {"name": "Chaussures de randonnée homme", "id": 327}, # Chaussures Enfant détaillées {"name": "Sneakers enfant", "id": 416}, {"name": "Bottes enfant", "id": 417}, {"name": "Sandales enfant", "id": 418}, {"name": "Chaussures de sport enfant", "id": 419}, {"name": "Chaussures premiers pas", "id": 420}, {"name": "Chaussures à scratch enfant", "id": 421}, {"name": "Chaussures d'école enfant", "id": 422}, {"name": "Pantoufles enfant", "id": 423}, {"name": "Chaussures de cérémonie enfant", "id": 424}, {"name": "Bottes de pluie enfant", "id": 425} ]; # Extract category names items = [category["name"] for category in categories] # Initialize the model and tokenizer model_name = 'hf-hub:Marqo/marqo-fashionSigLIP' model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms(model_name) tokenizer = open_clip.get_tokenizer(model_name) # Generate descriptions def generate_description(item): return f"A fashion item called {item}" items_desc = [generate_description(item) for item in items] text = tokenizer(items_desc) # Encode text features device = 'cuda' if torch.cuda.is_available() else 'cpu' model.to(device) torch.cuda.empty_cache() # Avant de charger le modèle with torch.no_grad(), torch.amp.autocast(device_type=device): text_features = model.encode_text(text.to(device)) text_features /= text_features.norm(dim=-1, keepdim=True) # Prediction function def predict(image, url): if url: response = requests.get(url) image = Image.open(BytesIO(response.content)) processed_image = preprocess_val(image).unsqueeze(0).to(device) with torch.no_grad(), torch.amp.autocast(device_type=device): image_features = model.encode_image(processed_image) image_features /= image_features.norm(dim=-1, keepdim=True) text_probs = (100 * image_features @ text_features.T).softmax(dim=-1) sorted_confidences = sorted( {items[i]: float(text_probs[0, i]) for i in range(len(items))}.items(), key=lambda x: x[1], reverse=True ) # Include category IDs in the response top_10_categories = [ { "category_name": category["name"], "id": category["id"], "confidence": confidence } for category_name, confidence in sorted_confidences[:10] for category in categories if category["name"] == category_name ] return image, top_10_categories # Ajout de la fonction de prédiction par lots def predict_batch(images, urls): # Combiner les images provenant des URLs et des téléchargements directs combined_images = [] for image, url in zip(images, urls): if url: response = requests.get(url) image = Image.open(BytesIO(response.content)) combined_images.append(preprocess_val(image).unsqueeze(0).to(device)) # Empiler toutes les images traitées en un seul lot batch_images = torch.cat(combined_images, dim=0) with torch.no_grad(), torch.amp.autocast(device_type=device): image_features = model.encode_image(batch_images) image_features /= image_features.norm(dim=-1, keepdim=True) text_probs = (100 * image_features @ text_features.T).softmax(dim=-1) # Traiter chaque image dans le lot batch_results = [] for i in range(len(images)): sorted_confidences = sorted( {items[j]: float(text_probs[i, j]) for j in range(len(items))}.items(), key=lambda x: x[1], reverse=True ) # Inclure les IDs de catégorie dans la réponse top_10_categories = [ { "category_name": category["name"], "id": category["id"], "confidence": confidence } for category_name, confidence in sorted_confidences[:10] for category in categories if category["name"] == category_name ] batch_results.append(top_10_categories) return batch_results # Fonction de prédiction avec texte def predict_with_text(image, url, text_prompt): if url: response = requests.get(url) image = Image.open(BytesIO(response.content)) processed_image = preprocess_val(image).unsqueeze(0).to(device) # Encoder l'image with torch.no_grad(), torch.amp.autocast(device_type=device): image_features = model.encode_image(processed_image) image_features /= image_features.norm(dim=-1, keepdim=True) # Encoder le texte fourni par l'utilisateur user_text = tokenizer([text_prompt]).to(device) user_text_features = model.encode_text(user_text) user_text_features /= user_text_features.norm(dim=-1, keepdim=True) # Combiner les caractéristiques de l'image et du texte (moyenne pondérée) combined_features = 0.7 * image_features + 0.3 * user_text_features combined_features /= combined_features.norm(dim=-1, keepdim=True) # Calculer les probabilités avec les caractéristiques combinées text_probs = (100 * combined_features @ text_features.T).softmax(dim=-1) sorted_confidences = sorted( {items[i]: float(text_probs[0, i]) for i in range(len(items))}.items(), key=lambda x: x[1], reverse=True ) # Inclure les IDs de catégorie dans la réponse top_10_categories = [ { "category_name": category["name"], "id": category["id"], "confidence": confidence } for category_name, confidence in sorted_confidences[:10] for category in categories if category["name"] == category_name ] return image, top_10_categories # Fonction de prédiction combinée qui choisit la méthode appropriée def predict_combined(image, url, text_prompt=""): if text_prompt and text_prompt.strip(): return predict_with_text(image, url, text_prompt) else: return predict(image, url) # Clear function def clear_fields(): return None, "", "", None, "" # Gradio interface title = "Fashion Item Classifier with Marqo-FashionSigLIP" description = "Upload an image or provide a URL of a fashion item to classify it using [Marqo-FashionSigLIP](https://huggingface.co/Marqo/marqo-fashionSigLIP)!" examples = [ ["images/dress.jpg", "Dress"], ["images/sweatpants.jpg", "Sweatpants"], ["images/t-shirt.jpg", "T-Shirt"], ] with gr.Blocks() as demo: with gr.Row(): with gr.Column(scale=1): gr.Markdown(f"# {title}") gr.Markdown(description) gr.Markdown(sidebar_markdown) with gr.Column(scale=2): input_image = gr.Image(type="pil", label="Upload Fashion Item Image", height=312) input_url = gr.Textbox(label="Or provide an image URL") input_text = gr.Textbox(label="Ajouter une description textuelle (optionnel)", placeholder="Ex: Robe d'été fleurie pour femme") input_images = gr.Image(type="pil", label="Upload Fashion Item Images", height=312) input_urls = gr.Textbox(label="Or provide image URLs (comma-separated)", lines=2) with gr.Row(): predict_button = gr.Button("Classifier") clear_button = gr.Button("Effacer") gr.Markdown("Ou cliquez sur l'une des images ci-dessous pour la classifier:") gr.Examples(examples=examples, inputs=input_image) output_label = gr.JSON(label="Top Categories") output_batch_label = gr.JSON(label="Top Categories for Each Image") predict_button.click(predict_combined, inputs=[input_image, input_url, input_text], outputs=[input_image, output_label]) clear_button.click(clear_fields, outputs=[input_image, input_url, input_text, input_images, input_urls]) # Launch the interface demo.launch()