import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from datasets import load_dataset
from huggingface_hub import Repository
from huggingface_hub import HfApi, HfFolder, Repository, create_repo
import os
import pandas as pd
import gradio as gr
from PIL import Image
import numpy as np
from small_256_model import UNet as small_UNet
from big_1024_model import UNet as big_UNet
from CLIP import load as load_clip
from rich import print as rp

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
big = False if device == torch.device('cpu') else True

# Parameters
IMG_SIZE = 1024 if big else 256
BATCH_SIZE = 1 if big else 1
EPOCHS = 12
LR = 0.0002
dataset_id = "K00B404/pix2pix_flux_set"
model_repo_id = "K00B404/pix2pix_flux"

# Global model variable
global_model = None

# CLIP
clip_model, clip_tokenizer = load_clip()

def load_model():
    """Load the models at startup"""
    global global_model
    weights_name = 'big_model_weights.pth' if big else 'small_model_weights.pth'
    try:
        checkpoint = torch.load(weights_name, map_location=device)
        model = big_UNet() if checkpoint['model_config']['big'] else small_UNet()
        model.load_state_dict(checkpoint['model_state_dict'])
        model.to(device)
        model.eval()
        global_model = model
        rp("Model loaded successfully!")
        return model
    except Exception as e:
        rp(f"Error loading model: {e}")
        model = big_UNet().to(device) if big else small_UNet().to(device)
        global_model = model
        return model

class Pix2PixDataset(torch.utils.data.Dataset):
    def __init__(self, combined_data, transform, clip_tokenizer):
        self.data = combined_data
        self.transform = transform
        self.clip_tokenizer = clip_tokenizer
        self.original_folder = 'images_dataset/original/'
        self.target_folder = 'images_dataset/target/'

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        original_img_filename = os.path.basename(self.data.iloc[idx]['image_path'])
        original_img_path = os.path.join(self.original_folder, original_img_filename)
        target_img_path = os.path.join(self.target_folder, original_img_filename)

        original_img = Image.open(original_img_path).convert('RGB')
        target_img = Image.open(target_img_path).convert('RGB')
        
        # Transform images
        original = self.transform(original_img)
        target = self.transform(target_img)
        
        # Get prompts from the DataFrame
        original_prompt = self.data.iloc[idx]['original_prompt']
        enhanced_prompt = self.data.iloc[idx]['enhanced_prompt']
        
        # Tokenize the prompts using CLIP tokenizer
        original_tokens = self.clip_tokenizer(original_prompt, return_tensors="pt", padding=True, truncation=True, max_length=77)
        enhanced_tokens = self.clip_tokenizer(enhanced_prompt, return_tensors="pt", padding=True, truncation=True, max_length=77)
        
        return original, target, original_tokens, enhanced_tokens



class UNetWrapper:
    def __init__(self, unet_model, repo_id, epoch, loss, optimizer, scheduler=None):
        self.loss = loss
        self.epoch = epoch
        self.model = unet_model
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.repo_id = repo_id
        self.token = os.getenv('NEW_TOKEN')  # Ensure the token is set in the environment
        self.api = HfApi(token=self.token)

    def save_checkpoint(self, save_path):
        """Save checkpoint with model, optimizer, and scheduler states."""
        self.save_dict = {
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict() if self.scheduler else None,
            'model_config': {
                'big': isinstance(self.model, big_UNet),
                'img_size': 1024 if isinstance(self.model, big_UNet) else 256
            },
            'epoch': self.epoch,
            'loss': self.loss
        }
        torch.save(self.save_dict, save_path)
        print(f"Checkpoint saved at epoch {self.epoch}, loss: {self.loss}")

    def load_checkpoint(self, checkpoint_path):
        """Load model, optimizer, and scheduler states from the checkpoint."""
        checkpoint = torch.load(checkpoint_path, map_location=device)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        if self.scheduler and checkpoint['scheduler_state_dict']:
            self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        self.epoch = checkpoint['epoch']
        self.loss = checkpoint['loss']
        print(f"Checkpoint loaded: epoch {self.epoch}, loss: {self.loss}")

    def push_to_hub(self, pth_name):
        """Push model checkpoint and metadata to the Hugging Face Hub."""
        try:
            self.api.upload_file(
                path_or_fileobj=pth_name,
                path_in_repo=pth_name,
                repo_id=self.repo_id,
                token=self.token,
                repo_type="model"
            )
            print(f"Model checkpoint successfully uploaded to {self.repo_id}")
        except Exception as e:
            print(f"Error uploading model: {e}")

        
   
            
        # Create and upload model card
        model_card = f"""---
tags:
- unet
- pix2pix
- pytorch
library_name: pytorch
license: wtfpl
datasets:
- K00B404/pix2pix_flux_set
language:
- en
pipeline_tag: image-to-image
---

# Pix2Pix UNet Model

## Model Description
Custom UNet model for Pix2Pix image translation.
- **Image Size:** {self.save_dict['model_config']['img_size']}
- **Model Type:** {"big" if big else "small"}_UNet ({self.save_dict['model_config']['img_size']})

## Usage

```python
import torch
from small_256_model import UNet as small_UNet
from big_1024_model import UNet as big_UNet
big = True
# Load the model
name='big_model_weights.pth' if big else 'small_model_weights.pth'
checkpoint = torch.load(name)
model = big_UNet() if checkpoint['model_config']['big'] else small_UNet()
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
```

## Model Architecture

{str(self.model)} """
        rp(model_card)
    try:
        # Save and upload README
        with open("README.md", "w") as f:
            f.write(f"# Pix2Pix UNet Model\n\n"
                    f"- **Image Size:** {self.save_dict['model_config']['img_size']}\n"
                    f"- **Model Type:** {'big' if big else 'small'}_UNet ({self.save_dict['model_config']['img_size']})\n"
                    f"## Model Architecture\n{str(self.model)}")
        
        self.api.upload_file(
            path_or_fileobj="README.md",
            path_in_repo="README.md",
            repo_id=self.repo_id,
            token=self.token,
            repo_type="model"
        )
        
        # Clean up local files
        os.remove(pth_name)
        os.remove("README.md")
        
        print(f"Model successfully uploaded to {self.repo_id}")
        
    except Exception as e:
        print(f"Error uploading model: {e}")

def prepare_input(image, device='cpu'):
    """Prepare image for inference"""
    transform = transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
    ])
    
    if isinstance(image, np.ndarray):
        image = Image.fromarray(image)
    input_tensor = transform(image).unsqueeze(0).to(device)
    return input_tensor

def run_inference(image):
    """Run inference on a single image"""
    global global_model
    if global_model is None:
        return "Error: Model not loaded"
    
    global_model.eval()
    input_tensor = prepare_input(image, device)
    
    with torch.no_grad():
        output = global_model(input_tensor)
    
    # Convert output to image
    output = output.cpu().squeeze(0).permute(1, 2, 0).numpy()
    output = ((output - output.min()) / (output.max() - output.min()) * 255).astype(np.uint8)
    rp(output[0])
    return output
    
def to_hub(model, epoch, loss):
    wrapper = UNetWrapper(model, model_repo_id, epoch, loss)
    wrapper.push_to_hub()

    
def train_model(epochs, save_interval=1):
    """Training function with checkpoint saving and model uploading."""
    global global_model
    
    # Load combined data CSV
    data_path = 'combined_data.csv'
    combined_data = pd.read_csv(data_path)
    
    # Define the transformation
    transform = transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
    ])
    
    # Initialize dataset and dataloader
    dataset = Pix2PixDataset(combined_data, transform, clip_tokenizer)
    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

    model = global_model
    criterion = nn.L1Loss()
    optimizer = optim.Adam(model.parameters(), lr=LR)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)  # Example scheduler
    wrapper = UNetWrapper(model, model_repo_id, epoch=0, loss=0.0, optimizer=optimizer, scheduler=scheduler)

    output_text = []

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0

        for i, (original, target, original_prompt_tokens, enhanced_prompt_tokens) in enumerate(dataloader):
            # Move data to device
            original, target = original.to(device), target.to(device)
            original_prompt_tokens = original_prompt_tokens.input_ids.to(device).float()
            enhanced_prompt_tokens = enhanced_prompt_tokens.input_ids.to(device).float()
            
            optimizer.zero_grad()

            # Forward pass
            output = model(target)
            img_loss = criterion(output, original)
            total_loss = img_loss
            total_loss.backward()
            optimizer.step()

            running_loss += total_loss.item()

            if i % 10 == 0:
                status = f"Epoch [{epoch}/{epochs}], Step [{i}/{len(dataloader)}], Loss: {total_loss.item():.8f}"
                print(status)
                output_text.append(status)
        
        # Update the epoch and loss for checkpoint
        wrapper.epoch = epoch + 1
        wrapper.loss = running_loss / len(dataloader)
        
        # Save checkpoint at specified intervals
        if (epoch + 1) % save_interval == 0:
            checkpoint_path = f'big_checkpoint_epoch_{epoch+1}.pth' if big else   f'small_checkpoint_epoch_{epoch+1}.pth'
            wrapper.save_checkpoint(checkpoint_path)
            wrapper.push_to_hub(checkpoint_path)

        scheduler.step()  # Update learning rate scheduler

    global_model = model  # Update global model after training
    return model, "\n".join(output_text)

    
def train_model_old(epochs):
    """Training function"""
    global global_model
    
    # Load combined data CSV
    data_path = 'combined_data.csv'  # Adjust this path
    combined_data = pd.read_csv(data_path)
    
    # Define the transformation
    transform = transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
    ])
    
    # Initialize the dataset and dataloader
    dataset = Pix2PixDataset(combined_data, transform, clip_tokenizer)
    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

    model = global_model
    criterion = nn.L1Loss()  # L1 loss for image reconstruction
    optimizer = optim.Adam(model.parameters(), lr=LR)
    output_text = []
    
    for epoch in range(epochs):
        model.train()
        for i, (original, target, original_prompt_tokens, enhanced_prompt_tokens) in enumerate(dataloader):
            # Move images and prompt embeddings to the appropriate device (CPU or GPU)
            original, target = original.to(device), target.to(device)
            original_prompt_tokens = original_prompt_tokens.input_ids.to(device).float()  # Convert to float
            enhanced_prompt_tokens = enhanced_prompt_tokens.input_ids.to(device).float()  # Convert to float
            
            optimizer.zero_grad()
        
            # Forward pass through the model
            output = model(target)
        
            # Compute image reconstruction loss
            img_loss = criterion(output, original)
            rp(f"Image {i} Loss:{img_loss}")
            
            # Combine losses
            total_loss = img_loss  # Add any other losses if necessary
            total_loss.backward()
        
            # Optimizer step
            optimizer.step()

            if i % 10 == 0:
                status = f"Epoch [{epoch}/{epochs}], Step [{i}/{len(dataloader)}], Loss: {total_loss.item():.8f}"
                rp(status)
                output_text.append(status)
        
        # Push model to Hugging Face Hub at the end of each epoch
        to_hub(model, epoch, total_loss)
        
    global_model = model  # Update the global model after training
    return model, "\n".join(output_text)

def gradio_train(epochs):
    # Gradio training interface function
    model, training_log = train_model(int(epochs))
    to_hub(model)
    return f"{training_log}\n\nModel trained for {epochs} epochs and pushed to {model_repo_id}"

def gradio_inference(input_image):
    # Gradio inference interface function
    output_image = run_inference(input_image)  # Assuming `run_inference` returns a tuple (output_image, other_data)
    rp(output_image)
    # If `run_inference` returns a tuple, you should only return the image part
    return output_image  # Ensure you're only returning the processed output image


# Create Gradio interface with tabs
with gr.Blocks() as app:
    gr.Markdown("# Pix2Pix Model Training and Inference")
    
    with gr.Tab("Train"):
        epochs_input = gr.Number(value=EPOCHS, label="Number of epochs")
        train_button = gr.Button("Train")
        training_output = gr.Textbox(label="Training Log", interactive=False)
        train_button.click(gradio_train, inputs=[epochs_input], outputs=[training_output])
    
    with gr.Tab("Inference"):
        image_input = gr.Image(type='numpy')
        prompt_input = gr.Textbox(label="Prompt")
        inference_button = gr.Button("Generate")
        inference_output = gr.Image(type='numpy', label="Generated Image")
        inference_button.click(gradio_inference, inputs=[image_input], outputs=[inference_output])

load_model()
app.launch()