from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler, AutoencoderKL, UNet2DConditionModel
from transformers import CLIPTextModel, CLIPTokenizer
import torch
import gradio as gr
import spaces
from huggingface_hub import hf_hub_download
import os
import requests
import hashlib
from pathlib import Path
import re
import random

# Default LoRA for fallback
DEFAULT_LORA = "OedoSoldier/detail-tweaker-lora"
LORA_CACHE_DIR = "lora_cache"

def download_lora(url):
    """Download LoRA file from Civitai URL and cache it locally"""
    # Create cache directory if it doesn't exist
    os.makedirs(LORA_CACHE_DIR, exist_ok=True)
    
    # Generate a filename from the URL
    url_hash = hashlib.md5(url.encode()).hexdigest()
    local_path = os.path.join(LORA_CACHE_DIR, f"{url_hash}.safetensors")
    
    # If file already exists in cache, return the path
    if os.path.exists(local_path):
        print()
        print("********** Lora Already Exists **********")
        print()
        return local_path
    
    # Download the file
    try:
        response = requests.get(url, stream=True)
        response.raise_for_status()
        
        # Get the total file size
        total_size = int(response.headers.get('content-length', 0))
        
        # Download and save the file
        with open(local_path, 'wb') as f:
            if total_size == 0:
                f.write(response.content)
            else:
                for chunk in response.iter_content(chunk_size=8192):
                    if chunk:
                        f.write(chunk)
        print()
        print("********** Lora Downloading Successfull **********")
        print()
        return local_path
    except Exception as e:
        print()
        print(f"Error downloading LoRA: {str(e)}")
        print()
        return None

def is_civitai_url(url):
    """Check if the URL is a valid Civitai download URL"""
    return bool(re.match(r'https?://civitai\.com/api/download/models/\d+', url))

@spaces.GPU
def generate_image(prompt, negative_prompt, lora_url, num_inference_steps=30, guidance_scale=7.0, 
                  model="Real6.0", num_images=1, width=512, height=512,seed=None):
    
    if model == "Real5.0":
        model_id = "SG161222/Realistic_Vision_V5.0_noVAE"
    elif model == "Real5.1":
        model_id = "SG161222/Realistic_Vision_V5.1_noVAE"
    elif model == "majicv7":
        model_id = "digiplay/majicMIX_realistic_v7"
    else:
        model_id = "SG161222/Realistic_Vision_V6.0_B1_noVAE"

    # Initialize models
    vae = AutoencoderKL.from_pretrained(
        model_id,
        subfolder="vae"
    ).to("cuda")

    text_encoder = CLIPTextModel.from_pretrained(
        model_id,
        subfolder="text_encoder"
    ).to("cuda")
    
    tokenizer = CLIPTokenizer.from_pretrained(
        model_id,
        subfolder="tokenizer"
    )

    unet = UNet2DConditionModel.from_pretrained(
        model_id,
        subfolder="unet"
    ).to("cuda")
        
    pipe = DiffusionPipeline.from_pretrained(
        model_id,
        text_encoder=text_encoder,
        tokenizer=tokenizer,
        vae=vae
    ).to("cuda")

    # Load LoRA weights
    try:
        if lora_url and lora_url.strip():
            if is_civitai_url(lora_url):
                # Download and load Civitai LoRA
                lora_path = download_lora(lora_url)
                if lora_path:
                    pipe.load_lora_weights(lora_path)
                    print()
                    print("********** URL Lora Loaded **********")
                    print()
                    
                else:
                    pipe.load_lora_weights(DEFAULT_LORA)
                    print()
                    print("********** Default Lora Loaded **********")
                    print()
                    
            # If it's a HuggingFace repo path
            elif '/' in lora_url and not lora_url.startswith('http'):
                pipe.load_lora_weights(lora_url)
                print()
                print("********** URL Lora Loaded **********")
                print()
            else:
                pipe.load_lora_weights(DEFAULT_LORA)
                print()
                print("********** Default Lora Loaded **********")
                print()
        else:
            pipe.load_lora_weights(DEFAULT_LORA)
    except Exception as e:
        print()
        print(f"Error loading LoRA weights: {str(e)}")
        print()
        pipe.load_lora_weights(DEFAULT_LORA)
    
    if model == "Real6.0":
        pipe.safety_checker = lambda images, **kwargs: (images, [False] * len(images))

    pipe.scheduler = DPMSolverMultistepScheduler.from_config(
        pipe.scheduler.config,
        algorithm_type="dpmsolver++",
        use_karras_sigmas=True
    )

    if seed is None:
        seed = random.randint(0, 2**32 - 1)
    
    generator = torch.manual_seed(seed)

    text_inputs = tokenizer(
        prompt,
        padding="max_length",
        max_length=tokenizer.model_max_length,
        truncation=True,
        return_tensors="pt"
    ).to("cuda")
    
    negative_text_inputs = tokenizer(
        negative_prompt,
        padding="max_length",
        max_length=tokenizer.model_max_length,
        truncation=True,
        return_tensors="pt"
    ).to("cuda")

    prompt_embeds = text_encoder(text_inputs.input_ids)[0]
    negative_prompt_embeds = text_encoder(negative_text_inputs.input_ids)[0]

    # Generate the image
    result = pipe(
        prompt_embeds=prompt_embeds,
        negative_prompt_embeds=negative_prompt_embeds,
        cross_attention_kwargs={"scale": 1},
        num_inference_steps=num_inference_steps,
        guidance_scale=guidance_scale,
        width=width,
        height=height,
        num_images_per_prompt=num_images,
        generator=generator
    )
    torch.cuda.empty_cache()
    return result.images,seed

def clean_lora_cache():
    """Clean the LoRA cache directory"""
    if os.path.exists(LORA_CACHE_DIR):
        for file in os.listdir(LORA_CACHE_DIR):
            file_path = os.path.join(LORA_CACHE_DIR, file)
            try:
                if os.path.isfile(file_path):
                    os.unlink(file_path)
            except Exception as e:
                print(f"Error deleting {file_path}: {str(e)}")

title = """<h1 align="center">ProFaker</h1>"""
# Create the Gradio interface
with gr.Blocks() as demo:
    gr.HTML(title)
    
    with gr.Row():
        with gr.Column():
            # Input components
            prompt = gr.Textbox(
                label="Prompt",
                info="Enter your image description here...",
                lines=3
            )
            negative_prompt = gr.Textbox(
                label="Negative Prompt",
                info="Enter what you don't want in Image...",
                lines=3
            )
            lora_input = gr.Textbox(
                label="LoRA URL/Path",
                info="Enter Civitai download URL or HuggingFace path (e.g., 'username/model-name')",
                value=DEFAULT_LORA
            )
            clear_cache = gr.Button("Clear LoRA Cache")
            generate_button = gr.Button("Generate Image")
            
            with gr.Accordion("Advanced Options", open=False):
                model = gr.Dropdown(
                    choices=["Real6.0","Real5.1","Real5.0","majicv7"],
                    value="Real6.0",
                    label="Model",
                )
                
                num_images = gr.Slider(
                    minimum=1,
                    maximum=4,
                    value=1,
                    step=1,
                    label="Number of Images to Generate"
                )
                width = gr.Slider(
                    minimum=256,
                    maximum=1024,
                    value=512,
                    step=64,
                    label="Image Width"
                )
                height = gr.Slider(
                    minimum=256,
                    maximum=1024,
                    value=512,
                    step=64,
                    label="Image Height"
                )
                steps_slider = gr.Slider(
                    minimum=1,
                    maximum=100,
                    value=30,
                    step=1,
                    label="Number of Steps"
                )
                guidance_slider = gr.Slider(
                    minimum=1,
                    maximum=10,
                    value=7.0,
                    step=0.5,
                    label="Guidance Scale"
                )
                seed_input = gr.Number(value=random.randint(0, 2**32 - 1), label="Seed (optional)")
        
        with gr.Column():
            # Output component
            gallery = gr.Gallery(
                label="Generated Images",
                show_label=True,
                elem_id="gallery",
                columns=2,
                rows=2
            )
            seed_display = gr.Textbox(label="Seed Used", interactive=False)
    
    # Connect the interface to the generation function
    generate_button.click(
        fn=generate_image,
        inputs=[prompt, negative_prompt, lora_input, steps_slider, guidance_slider, 
                model, num_images, width, height,seed_input],
        outputs=[gallery,seed_display]
    )
    
    # Connect clear cache button
    clear_cache.click(fn=clean_lora_cache)

demo.queue(max_size=10).launch(share=False)