import torch
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
from tqdm.auto import tqdm
from huggingface_hub import hf_hub_url, login, HfApi, create_repo
import os
import traceback
from peft import PeftModel
import gradio as gr

def display_image(image):
    """Display the generated image."""
    return image 

def load_and_merge_lora(base_model_id, lora_id, lora_adapter_name):
    try:
        pipe = DiffusionPipeline.from_pretrained(
            base_model_id,
            torch_dtype=torch.float16,
            variant="fp16",
            use_safetensors=True,
        ).to("cpu")

        pipe.scheduler = DPMSolverMultistepScheduler.from_config(
            pipe.scheduler.config
        )

        # Get the UNet model from the pipeline
        unet = pipe.unet

        # Apply PEFT to the UNet model 
        unet = PeftModel.from_pretrained(
            unet, 
            lora_id, 
            torch_dtype=torch.float16, 
            adapter_name=lora_adapter_name
        )

        # Replace the original UNet in the pipeline with the PEFT-loaded one
        pipe.unet = unet

        print("LoRA merged successfully!")
        return pipe

    except Exception as e:
        error_msg = traceback.format_exc()  
        print(f"Error merging LoRA: {e}\n\nFull traceback saved to errors.txt")

        with open("errors.txt", "w") as f:
            f.write(error_msg)

        return None

def save_merged_model(pipe, save_path, push_to_hub=False, hf_token=None):
    """Saves and optionally pushes the merged model to Hugging Face Hub."""
    try:
        pipe.save_pretrained(save_path)
        print(f"Merged model saved successfully to: {save_path}")

        if push_to_hub:
            if hf_token is None:
                hf_token = input("Enter your Hugging Face write token: ")
                login(token=hf_token)

            repo_name = input("Enter the Hugging Face repository name "
                              "(e.g., your_username/your_model_name): ")

            # Create the repository if it doesn't exist
            create_repo(repo_name, token=hf_token, exist_ok=True) 

            api = HfApi()
            api.upload_folder(
                folder_path=save_path,
                repo_id=repo_name,
                token=hf_token,
                repo_type="model",
            )
            print(f"Model pushed successfully to Hugging Face Hub: {repo_name}")

    except Exception as e:
        print(f"Error saving/pushing the merged model: {e}")

def generate_and_save(base_model_id, lora_id, lora_adapter_name, prompt, lora_scale, save_path, push_to_hub, hf_token):
    pipe = load_and_merge_lora(base_model_id, lora_id, lora_adapter_name)

    if pipe:
        lora_scale = float(lora_scale)
        image = pipe(
            prompt, 
            num_inference_steps=30, 
            cross_attention_kwargs={"scale": lora_scale}, 
            generator=torch.manual_seed(0)
        ).images[0]

        image.save("generated_image.png")
        print(f"Image saved to: generated_image.png")

        save_merged_model(pipe, save_path, push_to_hub, hf_token)

        return image, "Image generated and model saved/pushed (if selected)."

iface = gr.Interface(
    fn=generate_and_save,
    inputs=[
        gr.Textbox(label="Base Model ID (e.g., stabilityai/stable-diffusion-xl-base-1.0)"),
        gr.Textbox(label="LoRA ID (e.g., your_username/your_lora)"),
        gr.Textbox(label="LoRA Adapter Name"),
        gr.Textbox(label="Prompt"),
        gr.Slider(label="LoRA Scale", minimum=0.0, maximum=1.0, value=0.7, step=0.1),
        gr.Textbox(label="Save Path"),
        gr.Checkbox(label="Push to Hugging Face Hub"),
        gr.Textbox(label="Hugging Face Write Token", type="password")
    ],
    outputs=[
        gr.Image(label="Generated Image"),
        gr.Textbox(label="Status")
    ],
    title="LoRA Merger and Image Generator",
    description="Merge a LoRA with a base Stable Diffusion model and generate images."
)

iface.launch()