import torch
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
from diffusers import StableDiffusionPipeline
from transformers import CLIPTokenizer
import os
import zipfile
import gradio as gr

# Define the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define your custom dataset
class CustomImageDataset(Dataset):
    def __init__(self, images, prompts, transform=None):
        self.images = images
        self.prompts = prompts
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        if self.transform:
            image = self.transform(image)
        prompt = self.prompts[idx]
        return image, prompt

# Function to fine-tune the model
def fine_tune_model(images, prompts, model_save_path, num_epochs=3):
    transform = transforms.Compose([
        transforms.Resize((512, 512)),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ])
    dataset = CustomImageDataset(images, prompts, transform)
    dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

    # Load Stable Diffusion model
    pipeline = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2").to(device)

    # Load model components
    vae = pipeline.vae.to(device)
    unet = pipeline.unet.to(device)
    text_encoder = pipeline.text_encoder.to(device)
    tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")  # Ensure correct tokenizer is used
    optimizer = torch.optim.AdamW(unet.parameters(), lr=5e-6)  # Define the optimizer

    # Define timestep range for training
    timesteps = torch.linspace(0, 1, steps=5).to(device)

    # Fine-tuning loop
    for epoch in range(num_epochs):
        for i, (images, prompts) in enumerate(dataloader):
            images = images.to(device)  # Move images to GPU if available

            # Tokenize the prompts
            inputs = tokenizer(list(prompts), padding=True, return_tensors="pt", truncation=True).to(device)

            latents = vae.encode(images).latent_dist.sample() * 0.18215
            text_embeddings = text_encoder(inputs.input_ids).last_hidden_state

            noise = torch.randn_like(latents).to(device)
            noisy_latents = latents + noise

            # Pass text embeddings and timestep to UNet
            timestep = torch.randint(0, len(timesteps), (latents.size(0),), device=device).float()
            pred_noise = unet(noisy_latents, timestep=timestep, encoder_hidden_states=text_embeddings).sample

            loss = torch.nn.functional.mse_loss(pred_noise, noise)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    # Save the fine-tuned model
    pipeline.save_pretrained(model_save_path)

# Function to convert tensor to PIL Image
def tensor_to_pil(tensor):
    tensor = tensor.squeeze().cpu().clamp(0, 1)  # Remove batch dimension if necessary
    tensor = transforms.ToPILImage()(tensor)
    return tensor

# Function to generate images
def generate_images(pipeline, prompt):
    with torch.no_grad():
        # Generate image from the prompt
        output = pipeline(prompt)

        # Convert the output to PIL Image
        image = output.images[0]  # Get the first generated image
    return image

# Function to zip the fine-tuned model
def zip_model(model_path):
    zip_path = f"{model_path}.zip"
    with zipfile.ZipFile(zip_path, "w") as zipf:
        for root, _, files in os.walk(model_path):
            for file in files:
                zipf.write(os.path.join(root, file), os.path.relpath(os.path.join(root, file), model_path))
    return zip_path

# Function to save uploaded files
def save_uploaded_file(uploaded_file, save_path):
    # Open the file in binary write mode
    with open(save_path, 'wb') as f:
        f.write(uploaded_file.data)  # Use .data for the file content
    return f"File saved at {save_path}"

# Gradio interface functions
def start_fine_tuning(uploaded_files, prompts, num_epochs):
    images = [Image.open(file).convert("RGB") for file in uploaded_files]
    model_save_path = "fine_tuned_model"
    fine_tune_model(images, prompts, model_save_path, num_epochs=int(num_epochs))
    return "Fine-tuning completed! Model is ready for download."

def download_model():
    model_save_path = "fine_tuned_model"
    if os.path.exists(model_save_path):
        return zip_model(model_save_path)
    else:
        return None

def generate_new_image(prompt):
    model_save_path = "fine_tuned_model"
    if os.path.exists(model_save_path):
        pipeline = StableDiffusionPipeline.from_pretrained(model_save_path).to(device)
    else:
        pipeline = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2").to(device)
    image = generate_images(pipeline, prompt)
    image_path = "generated_image.png"
    image.save(image_path)
    return image_path

# Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# Fine-Tune Stable Diffusion and Generate Images")

    with gr.Tab("Fine-Tune Model"):
        with gr.Row():
            uploaded_files = gr.File(label="Upload Images", file_types=[".png", ".jpg", ".jpeg"], file_count="multiple")
        with gr.Row():
            prompts = gr.Textbox(label="Enter Prompts (comma-separated)")
            num_epochs = gr.Number(label="Number of Epochs", value=3)
        with gr.Row():
            fine_tune_button = gr.Button("Start Fine-Tuning")
        fine_tune_output = gr.Textbox(label="Output")

        fine_tune_button.click(start_fine_tuning, [uploaded_files, prompts, num_epochs], fine_tune_output)

    with gr.Tab("Download Fine-Tuned Model"):
        download_button = gr.Button("Download Fine-Tuned Model")
        download_output = gr.File()

        download_button.click(download_model, [], download_output)

    with gr.Tab("Generate New Images"):
        prompt_input = gr.Textbox(label="Enter a Prompt")
        generate_button = gr.Button("Generate Image")
        generated_image = gr.Image(label="Generated Image")

        generate_button.click(generate_new_image, [prompt_input], generated_image)

demo.launch()