import gradio as gr
import open_clip
import torch
import requests
import numpy as np
from PIL import Image
from io import BytesIO

# Sidebar content
sidebar_markdown = """
Note, this demo can classify 200 items. If you didn't find what you're looking for, reach out to us on our [Community](https://join.slack.com/t/marqo-community/shared_invite/zt-2iab0260n-QJrZLUSOJYUifVxf964Gdw) and request an item to be added.

## Documentation
📚 [Blog Post](https://www.marqo.ai/blog/search-model-for-fashion)

📝 [Use Case Blog Post](https://www.marqo.ai/blog/ecommerce-image-classification-with-marqo-fashionclip)

## Code
💻 [GitHub Repo](https://github.com/marqo-ai/marqo-FashionCLIP)

🤝 [Google Colab](https://colab.research.google.com/drive/1nq978xFJjJcnyrJ2aE5l82GHAXOvTmfd?usp=sharing)

🤗 [Hugging Face Collection](https://huggingface.co/collections/Marqo/marqo-fashionclip-and-marqo-fashionsiglip-66b43f2d09a06ad2368d4af6)
"""

# List of fashion items and their IDs
categories = [
    {"name": "Nettoyants visage", "id": 101},
    {"name": "Exfoliants visage", "id": 102},
    {"name": "Hydratants visage", "id": 103},
    {"name": "Masques visage", "id": 104},
    {"name": "Soins ciblés visage", "id": 105},
    {"name": "Protection solaire visage", "id": 106},
    {"name": "Nettoyants visage homme", "id": 107},
    {"name": "Crèmes hydratantes homme", "id": 108},
    {"name": "Soins après-rasage", "id": 109},
    {"name": "Hydratants corps", "id": 110},
    {"name": "Exfoliants corps", "id": 111},
    {"name": "Soins fermeté & minceur", "id": 112},
    {"name": "Auto-bronzants", "id": 113},
    {"name": "Soins des mains", "id": 114},
    {"name": "Soins des pieds", "id": 115},
    {"name": "Hydratants corps homme", "id": 116},
    {"name": "Déodorants corps homme", "id": 117},
    {"name": "Shampoings", "id": 118},
    {"name": "Après-shampoings", "id": 119},
    {"name": "Masques capillaires", "id": 120},
    {"name": "Huiles capillaires", "id": 121},
    {"name": "Coiffants", "id": 122},
    {"name": "Accessoires cheveux", "id": 123},
    {"name": "Soins cheveux homme", "id": 124},
    {"name": "Produits coiffants homme", "id": 125},
    {"name": "Fond de teint", "id": 126},
    {"name": "BB/CC crèmes", "id": 127},
    {"name": "Poudres", "id": 128},
    {"name": "Fards à paupières", "id": 129},
    {"name": "Mascaras", "id": 130},
    {"name": "Eyeliners", "id": 131},
    {"name": "Rouges à lèvres", "id": 132},
    {"name": "Gloss", "id": 133},
    {"name": "Crayons à sourcils", "id": 134},
    {"name": "Accessoires maquillage", "id": 135},
    {"name": "Correcteurs teint homme", "id": 136},
    {"name": "Poudres matifiantes homme", "id": 137},
    {"name": "Parfums", "id": 138},
    {"name": "Brumes corporelles", "id": 139},
    {"name": "Huiles essentielles", "id": 140},
    {"name": "Diffuseurs d'huiles", "id": 141},
    {"name": "Bougies parfumées", "id": 142},
    {"name": "Déodorants solides", "id": 143},
    {"name": "Déodorants sprays", "id": 144},
    {"name": "Savons solides", "id": 145},
    {"name": "Savons liquides", "id": 146},
    {"name": "Produits bain", "id": 147},
    {"name": "Hygiène intime", "id": 148},
    {"name": "Cups menstruelles", "id": 149},
    {"name": "Produits zéro déchet", "id": 150},
    {"name": "Brosses nettoyantes visage", "id": 151},
    {"name": "Pinces à épiler", "id": 152},
    {"name": "Trousse de voyage", "id": 153},
    {"name": "Huiles de CBD", "id": 154},
    {"name": "Cosmétiques au CBD", "id": 155},
    {"name": "Infusions au CBD", "id": 156},
    {"name": "Bonbons au CBD", "id": 157},
    {"name": "Accessoires CBD", "id": 158},
    {"name": "Robes femme", "id": 201},
    {"name": "Tops femme", "id": 202},
    {"name": "Chemisiers femme", "id": 203},
    {"name": "T-shirts femme", "id": 204},
    {"name": "Pulls femme", "id": 205},
    {"name": "Jeans femme", "id": 206},
    {"name": "Pantalons femme", "id": 207},
    {"name": "Jupes femme", "id": 208},
    {"name": "Shorts femme", "id": 209},
    {"name": "Vestes femme", "id": 210},
    {"name": "Manteaux femme", "id": 211},
    {"name": "Maillots de bain femme", "id": 212},
    {"name": "Lingerie femme", "id": 213},
    {"name": "Chaussures femme", "id": 214},
    {"name": "Sacs femme", "id": 215},
    {"name": "Bijoux femme", "id": 216},
    {"name": "Chemises homme", "id": 301},
    {"name": "T-shirts homme", "id": 302},
    {"name": "Polos homme", "id": 303},
    {"name": "Pulls homme", "id": 304},
    {"name": "Jeans homme", "id": 305},
    {"name": "Pantalons homme", "id": 306},
    {"name": "Shorts homme", "id": 307},
    {"name": "Vestes homme", "id": 308},
    {"name": "Manteaux homme", "id": 309},
    {"name": "Costumes homme", "id": 310},
    {"name": "Maillots de bain homme", "id": 311},
    {"name": "Sous-vêtements homme", "id": 312},
    {"name": "Chaussures homme", "id": 313},
    {"name": "Accessoires homme", "id": 314},
    {"name": "Montres homme", "id": 315},
    {"name": "Vêtements bébé (0-2 ans)", "id": 401},
    {"name": "T-shirts enfant", "id": 402},
    {"name": "Pulls enfant", "id": 403},
    {"name": "Pantalons enfant", "id": 404},
    {"name": "Robes enfant", "id": 405},
    {"name": "Jeans enfant", "id": 406},
    {"name": "Vestes enfant", "id": 407},
    {"name": "Pyjamas enfant", "id": 408},
    {"name": "Chaussures enfant", "id": 409},
    {"name": "Accessoires enfant", "id": 410},
    {"name": "Vêtements de sport enfant", "id": 411},
    {"name": "Maillots de bain enfant", "id": 412},
    {"name": "Sous-vêtements enfant", "id": 413},
    {"name": "Déguisements enfant", "id": 414},
    {"name": "Cartables et sacs enfant", "id": 415},
    # Chaussures Femme détaillées
    {"name": "Sneakers femme", "id": 217},
    {"name": "Boots femme", "id": 218},
    {"name": "Escarpins femme", "id": 219},
    {"name": "Sandales femme", "id": 220},
    {"name": "Ballerines femme", "id": 221},
    {"name": "Mocassins femme", "id": 222},
    {"name": "Bottines femme", "id": 223},
    {"name": "Espadrilles femme", "id": 224},
    {"name": "Mules femme", "id": 225},
    {"name": "Chaussures de sport femme", "id": 226},
    {"name": "Bottes hautes femme", "id": 227},
    {"name": "Chaussures compensées femme", "id": 228},
    # Chaussures Homme détaillées
    {"name": "Sneakers homme", "id": 316},
    {"name": "Boots homme", "id": 317},
    {"name": "Chaussures de ville homme", "id": 318},
    {"name": "Mocassins homme", "id": 319},
    {"name": "Sandales homme", "id": 320},
    {"name": "Chaussures bateau homme", "id": 321},
    {"name": "Bottines homme", "id": 322},
    {"name": "Chaussures de sport homme", "id": 323},
    {"name": "Espadrilles homme", "id": 324},
    {"name": "Derbies homme", "id": 325},
    {"name": "Richelieus homme", "id": 326},
    {"name": "Chaussures de randonnée homme", "id": 327},
    # Chaussures Enfant détaillées
    {"name": "Sneakers enfant", "id": 416},
    {"name": "Bottes enfant", "id": 417},
    {"name": "Sandales enfant", "id": 418},
    {"name": "Chaussures de sport enfant", "id": 419},
    {"name": "Chaussures premiers pas", "id": 420},
    {"name": "Chaussures à scratch enfant", "id": 421},
    {"name": "Chaussures d'école enfant", "id": 422},
    {"name": "Pantoufles enfant", "id": 423},
    {"name": "Chaussures de cérémonie enfant", "id": 424},
    {"name": "Bottes de pluie enfant", "id": 425}
];


# Extract category names
items = [category["name"] for category in categories]

# Initialize the model and tokenizer
model_name = 'hf-hub:Marqo/marqo-fashionSigLIP'
model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms(model_name)
tokenizer = open_clip.get_tokenizer(model_name)

# Generate descriptions
def generate_description(item):
    return f"A fashion item called {item}"

items_desc = [generate_description(item) for item in items]
text = tokenizer(items_desc)

# Encode text features
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)

torch.cuda.empty_cache()  # Avant de charger le modèle

with torch.no_grad(), torch.amp.autocast(device_type=device):
    text_features = model.encode_text(text.to(device))
    text_features /= text_features.norm(dim=-1, keepdim=True)

# Prediction function
def predict(image, url):
    if url:
        response = requests.get(url)
        image = Image.open(BytesIO(response.content))
    
    processed_image = preprocess_val(image).unsqueeze(0).to(device)

    with torch.no_grad(), torch.amp.autocast(device_type=device):
        image_features = model.encode_image(processed_image)
        image_features /= image_features.norm(dim=-1, keepdim=True)

        text_probs = (100 * image_features @ text_features.T).softmax(dim=-1)
        
        sorted_confidences = sorted(
            {items[i]: float(text_probs[0, i]) for i in range(len(items))}.items(), 
            key=lambda x: x[1], 
            reverse=True
        )
        
        # Include category IDs in the response
        top_10_categories = [
            {
                "category_name": category["name"],
                "id": category["id"],
                "confidence": confidence
            }
            for category_name, confidence in sorted_confidences[:10]
            for category in categories if category["name"] == category_name
        ]
        
    return image, top_10_categories

# Ajout de la fonction de prédiction par lots
def predict_batch(images, urls):
    # Combiner les images provenant des URLs et des téléchargements directs
    combined_images = []
    for image, url in zip(images, urls):
        if url:
            response = requests.get(url)
            image = Image.open(BytesIO(response.content))
        combined_images.append(preprocess_val(image).unsqueeze(0).to(device))
    
    # Empiler toutes les images traitées en un seul lot
    batch_images = torch.cat(combined_images, dim=0)

    with torch.no_grad(), torch.amp.autocast(device_type=device):
        image_features = model.encode_image(batch_images)
        image_features /= image_features.norm(dim=-1, keepdim=True)

        text_probs = (100 * image_features @ text_features.T).softmax(dim=-1)
        
        # Traiter chaque image dans le lot
        batch_results = []
        for i in range(len(images)):
            sorted_confidences = sorted(
                {items[j]: float(text_probs[i, j]) for j in range(len(items))}.items(), 
                key=lambda x: x[1], 
                reverse=True
            )
            
            # Inclure les IDs de catégorie dans la réponse
            top_10_categories = [
                {
                    "category_name": category["name"],
                    "id": category["id"],
                    "confidence": confidence
                }
                for category_name, confidence in sorted_confidences[:10]
                for category in categories if category["name"] == category_name
            ]
            batch_results.append(top_10_categories)
        
    return batch_results

# Fonction de prédiction avec texte
def predict_with_text(image, url, text_prompt):
    if url:
        response = requests.get(url)
        image = Image.open(BytesIO(response.content))
    
    processed_image = preprocess_val(image).unsqueeze(0).to(device)
    
    # Encoder l'image
    with torch.no_grad(), torch.amp.autocast(device_type=device):
        image_features = model.encode_image(processed_image)
        image_features /= image_features.norm(dim=-1, keepdim=True)
        
        # Encoder le texte fourni par l'utilisateur
        user_text = tokenizer([text_prompt]).to(device)
        user_text_features = model.encode_text(user_text)
        user_text_features /= user_text_features.norm(dim=-1, keepdim=True)
        
        # Combiner les caractéristiques de l'image et du texte (moyenne pondérée)
        combined_features = 0.7 * image_features + 0.3 * user_text_features
        combined_features /= combined_features.norm(dim=-1, keepdim=True)
        
        # Calculer les probabilités avec les caractéristiques combinées
        text_probs = (100 * combined_features @ text_features.T).softmax(dim=-1)
        
        sorted_confidences = sorted(
            {items[i]: float(text_probs[0, i]) for i in range(len(items))}.items(), 
            key=lambda x: x[1], 
            reverse=True
        )
        
        # Inclure les IDs de catégorie dans la réponse
        top_10_categories = [
            {
                "category_name": category["name"],
                "id": category["id"],
                "confidence": confidence
            }
            for category_name, confidence in sorted_confidences[:10]
            for category in categories if category["name"] == category_name
        ]
        
    return image, top_10_categories

# Fonction de prédiction combinée qui choisit la méthode appropriée
def predict_combined(image, url, text_prompt=""):
    if text_prompt and text_prompt.strip():
        return predict_with_text(image, url, text_prompt)
    else:
        return predict(image, url)

# Clear function
def clear_fields():
    return None, "", "", None, ""

# Gradio interface
title = "Fashion Item Classifier with Marqo-FashionSigLIP"
description = "Upload an image or provide a URL of a fashion item to classify it using [Marqo-FashionSigLIP](https://huggingface.co/Marqo/marqo-fashionSigLIP)!"

examples = [
    ["images/dress.jpg", "Dress"],
    ["images/sweatpants.jpg", "Sweatpants"],
    ["images/t-shirt.jpg", "T-Shirt"],
]

with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column(scale=1):
            gr.Markdown(f"# {title}")
            gr.Markdown(description)
            gr.Markdown(sidebar_markdown)
        with gr.Column(scale=2):
            input_image = gr.Image(type="pil", label="Upload Fashion Item Image", height=312)
            input_url = gr.Textbox(label="Or provide an image URL")
            input_text = gr.Textbox(label="Ajouter une description textuelle (optionnel)", placeholder="Ex: Robe d'été fleurie pour femme")
            input_images = gr.Image(type="pil", label="Upload Fashion Item Images", height=312)
            input_urls = gr.Textbox(label="Or provide image URLs (comma-separated)", lines=2)
            with gr.Row():
                predict_button = gr.Button("Classifier")
                clear_button = gr.Button("Effacer")
            gr.Markdown("Ou cliquez sur l'une des images ci-dessous pour la classifier:")
            gr.Examples(examples=examples, inputs=input_image)
            output_label = gr.JSON(label="Top Categories")
            output_batch_label = gr.JSON(label="Top Categories for Each Image")
            predict_button.click(predict_combined, inputs=[input_image, input_url, input_text], outputs=[input_image, output_label])
            clear_button.click(clear_fields, outputs=[input_image, input_url, input_text, input_images, input_urls])
        
# Launch the interface
demo.launch()