import gradio as gr
import numpy as np
import spaces
import torch
import random
import json
import os
from PIL import Image
from kontext_pipeline import FluxKontextPipeline
from diffusers import FluxTransformer2DModel
from diffusers.utils import load_image
from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard
from safetensors.torch import load_file
import requests
import re
# Load Kontext model
kontext_path = hf_hub_download(repo_id="diffusers/kontext-v2", filename="dev-opt-2-a-3.safetensors")
MAX_SEED = np.iinfo(np.int32).max
transformer = FluxTransformer2DModel.from_single_file(kontext_path, torch_dtype=torch.bfloat16)
pipe = FluxKontextPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16).to("cuda")
# Load LoRA data (you'll need to create this JSON file or modify to load your LoRAs)
with open("flux_loras.json", "r") as file:
    data = json.load(file)
    flux_loras_raw = [
        {
            "image": item["image"],
            "title": item["title"],
            "repo": item["repo"],
            "trigger_word": item.get("trigger_word", ""),
            "trigger_position": item.get("trigger_position", "prepend"),
            "weights": item.get("weights", "pytorch_lora_weights.safetensors"),
        }
        for item in data
    ]
print(f"Loaded {len(flux_loras_raw)} LoRAs from JSON")
# Global variables for LoRA management
current_lora = None
lora_cache = {}
def load_lora_weights(repo_id, weights_filename):
    """Load LoRA weights from HuggingFace"""
    try:
        if repo_id not in lora_cache:
            lora_path = hf_hub_download(repo_id=repo_id, filename=weights_filename)
            lora_cache[repo_id] = lora_path
        return lora_cache[repo_id]
    except Exception as e:
        print(f"Error loading LoRA from {repo_id}: {e}")
        return None
def update_selection(selected_state: gr.SelectData, flux_loras):
    """Update UI when a LoRA is selected"""
    if selected_state.index >= len(flux_loras):
        return "### No LoRA selected", gr.update(), None
    
    lora_repo = flux_loras[selected_state.index]["repo"]
    trigger_word = flux_loras[selected_state.index]["trigger_word"]
    
    updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo})"
    new_placeholder = f"optional description, e.g. 'a man with glasses and a beard'"
    
    return updated_text, gr.update(placeholder=new_placeholder), selected_state.index
def get_huggingface_lora(link):
    """Download LoRA from HuggingFace link"""
    split_link = link.split("/")
    if len(split_link) == 2:
        try:
            model_card = ModelCard.load(link)
            trigger_word = model_card.data.get("instance_prompt", "")
            
            fs = HfFileSystem()
            list_of_files = fs.ls(link, detail=False)
            safetensors_file = None
            
            for file in list_of_files:
                if file.endswith(".safetensors") and "lora" in file.lower():
                    safetensors_file = file.split("/")[-1]
                    break
            
            if not safetensors_file:
                safetensors_file = "pytorch_lora_weights.safetensors"
            
            return split_link[1], safetensors_file, trigger_word
        except Exception as e:
            raise Exception(f"Error loading LoRA: {e}")
    else:
        raise Exception("Invalid HuggingFace repository format")
def load_custom_lora(link):
    """Load custom LoRA from user input"""
    if not link:
        return gr.update(visible=False), "", gr.update(visible=False), None, gr.Gallery(selected_index=None), "### Click on a LoRA in the gallery to select it", None
    
    try:
        repo_name, weights_file, trigger_word = get_huggingface_lora(link)
        
        card = f'''
        
            Loaded custom LoRA:
            
                
{repo_name}
                {"Using: "+trigger_word+" as trigger word" if trigger_word else "No trigger word found"}
            
         
        '''
        
        custom_lora_data = {
            "repo": link,
            "weights": weights_file,
            "trigger_word": trigger_word
        }
        
        return gr.update(visible=True), card, gr.update(visible=True), custom_lora_data, gr.Gallery(selected_index=None), f"Custom: {repo_name}", None
    
    except Exception as e:
        return gr.update(visible=True), f"Error: {str(e)}", gr.update(visible=False), None, gr.update(), "### Click on a LoRA in the gallery to select it", None
def remove_custom_lora():
    """Remove custom LoRA"""
    return "", gr.update(visible=False), gr.update(visible=False), None, None
def classify_gallery(flux_loras):
    """Sort gallery by likes"""
    sorted_gallery = sorted(flux_loras, key=lambda x: x.get("likes", 0), reverse=True)
    return [(item["image"], item["title"]) for item in sorted_gallery], sorted_gallery
def infer_with_lora_wrapper(input_image, prompt, selected_index, custom_lora, seed=42, randomize_seed=False, guidance_scale=2.5, lora_scale=1.75, flux_loras=None, progress=gr.Progress(track_tqdm=True)):
    """Wrapper function to handle state serialization"""
    return infer_with_lora(input_image, prompt, selected_index, custom_lora, seed, randomize_seed, guidance_scale, lora_scale, flux_loras, progress)
@spaces.GPU
def infer_with_lora(input_image, prompt, selected_index, custom_lora, seed=42, randomize_seed=False, guidance_scale=2.5, lora_scale=1.0, flux_loras=None, progress=gr.Progress(track_tqdm=True)):
    """Generate image with selected LoRA"""
    global current_lora, pipe
    
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)
    
    # Determine which LoRA to use
    lora_to_use = None
    if custom_lora:
        lora_to_use = custom_lora
    elif selected_index is not None and flux_loras and selected_index < len(flux_loras):
        lora_to_use = flux_loras[selected_index]
    print(f"Loaded {len(flux_loras)} LoRAs from JSON")
    # Load LoRA if needed
    if lora_to_use and lora_to_use != current_lora:
        try:
            # Unload current LoRA
            if current_lora:
                pipe.unload_lora_weights()
            
            # Load new LoRA
            lora_path = load_lora_weights(lora_to_use["repo"], lora_to_use["weights"])
            if lora_path:
                pipe.load_lora_weights(lora_path, adapter_name="selected_lora")
                pipe.set_adapters(["selected_lora"], adapter_weights=[lora_scale])
                print(f"loaded: {lora_path} with scale {lora_scale}")
                current_lora = lora_to_use
            
        except Exception as e:
            print(f"Error loading LoRA: {e}")
            # Continue without LoRA
    else:
        print(f"using already loaded lora: {lora_to_use}")
    
    input_image = input_image.convert("RGB")
    # Add trigger word to prompt
    trigger_word = lora_to_use["trigger_word"]
    if trigger_word == ", How2Draw":
        prompt = f"create a How2Draw sketch of the person of the photo {prompt}, maintain the facial identity of the person and general features"
    else:
        prompt = f"convert the style of this portrait photo to {trigger_word} while maintaining the identity of the person. {prompt}. Make sure to maintain the person's facial identity and features, while still changing the overall style to {trigger_word}."
    
    try:
        image = pipe(
            image=input_image, 
            prompt=prompt,
            guidance_scale=guidance_scale,
            generator=torch.Generator().manual_seed(seed),
        ).images[0]
        
        return image, seed, gr.update(visible=True)
    
    except Exception as e:
        print(f"Error during inference: {e}")
        return None, seed, gr.update(visible=False)
# CSS styling
css = """
#main_app {
    display: flex;
    gap: 20px;
}
#box_column {
    min-width: 400px;
}
#selected_lora {
    color: #2563eb;
    font-weight: bold;
}
#prompt {
    flex-grow: 1;
}
#run_button {
    background: linear-gradient(45deg, #2563eb, #3b82f6);
    color: white;
    border: none;
    padding: 8px 16px;
    border-radius: 6px;
    font-weight: bold;
}
.custom_lora_card {
    background: #f8fafc;
    border: 1px solid #e2e8f0;
    border-radius: 8px;
    padding: 12px;
    margin: 8px 0;
}
#gallery{
    overflow: scroll
}
"""
# Create Gradio interface
with gr.Blocks(css="custom.css") as demo:
    gr_flux_loras = gr.State(value=flux_loras_raw)
    
    title = gr.HTML(
        """ FLUX.1 Kontext Portrait 👩🏻🎤
        
""",
    )
    
    selected_state = gr.State(value=None)
    custom_loaded_lora = gr.State(value=None)
    
    with gr.Row(elem_id="main_app"):
        with gr.Column(scale=4, elem_id="box_column"):
            with gr.Group(elem_id="gallery_box"):
                input_image = gr.Image(label="Upload a picture of yourself", type="pil", height=300)
                
                gallery = gr.Gallery(
                    label="Pick a LoRA",
                    allow_preview=False,
                    columns=3,
                    elem_id="gallery",
                    show_share_button=False,
                    height=400
                )
                
                custom_model = gr.Textbox(
                    label="Or enter a custom HuggingFace FLUX LoRA", 
                    placeholder="e.g., username/lora-name",
                    visible=False
                )
                custom_model_card = gr.HTML(visible=False)
                custom_model_button = gr.Button("Remove custom LoRA", visible=False)
        
        with gr.Column(scale=5):
            with gr.Row():
                prompt = gr.Textbox(
                    label="Editing Prompt",
                    show_label=False,
                    lines=1,
                    max_lines=1,
                    placeholder="optional description, e.g. 'a man with glasses and a beard'",
                    elem_id="prompt"
                )
                run_button = gr.Button("Generate", elem_id="run_button")
            
            result = gr.Image(label="Generated Image", interactive=False)
            reuse_button = gr.Button("Reuse this image", visible=False)
            
            with gr.Accordion("Advanced Settings", open=False):
                lora_scale = gr.Slider(
                    label="LoRA Scale",
                    minimum=0,
                    maximum=2,
                    step=0.1,
                    value=1.5,
                    info="Controls the strength of the LoRA effect"
                )
                seed = gr.Slider(
                    label="Seed",
                    minimum=0,
                    maximum=MAX_SEED,
                    step=1,
                    value=0,
                )
                randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
                guidance_scale = gr.Slider(
                    label="Guidance Scale",
                    minimum=1,
                    maximum=10,
                    step=0.1,
                    value=2.5,
                )
            
            prompt_title = gr.Markdown(
                value="### Click on a LoRA in the gallery to select it",
                visible=True,
                elem_id="selected_lora",
            )
    # Event handlers
    custom_model.input(
        fn=load_custom_lora,
        inputs=[custom_model],
        outputs=[custom_model_card, custom_model_card, custom_model_button, custom_loaded_lora, gallery, prompt_title, selected_state],
    )
    
    custom_model_button.click(
        fn=remove_custom_lora,
        outputs=[custom_model, custom_model_button, custom_model_card, custom_loaded_lora, selected_state]
    )
    
    gallery.select(
        fn=update_selection,
        inputs=[gr_flux_loras],
        outputs=[prompt_title, prompt, selected_state],
        show_progress=False
    )
    
    gr.on(
        triggers=[run_button.click, prompt.submit],
        fn=infer_with_lora_wrapper,
        inputs=[input_image, prompt, selected_state, custom_loaded_lora, seed, randomize_seed, guidance_scale, lora_scale, gr_flux_loras],
        outputs=[result, seed, reuse_button]
    )
    
    reuse_button.click(
        fn=lambda image: image,
        inputs=[result],
        outputs=[input_image]
    )
    
    # Initialize gallery
    demo.load(
        fn=classify_gallery, 
        inputs=[gr_flux_loras], 
        outputs=[gallery, gr_flux_loras]
    )
demo.queue(default_concurrency_limit=None)
demo.launch()