import torch
from PIL import Image
from diffusers import DDIMScheduler
from accelerate.utils import set_seed
from torchvision.transforms.functional import to_pil_image, to_tensor, resize

from pipeline_sd import ADPipeline
from pipeline_sdxl import ADPipeline as ADXLPipeline
from utils import Controller

import os
import spaces


class Runner:
    def __init__(self):
        self.sd15 = None
        self.sdxl = None
        self.loss_fn = torch.nn.L1Loss(reduction="mean")
    
    def load_pipeline(self, model_path_or_name):

        if 'xl' in model_path_or_name and self.sdxl is None:
            scheduler = DDIMScheduler.from_pretrained(model_path_or_name, subfolder="scheduler")
            self.sdxl = ADXLPipeline.from_pretrained(model_path_or_name, scheduler=scheduler, safety_checker=None)
            self.sdxl.classifier = self.sdxl.unet
        elif self.sd15 is None:
            scheduler = DDIMScheduler.from_pretrained(model_path_or_name, subfolder="scheduler")
            self.sd15 = ADPipeline.from_pretrained(model_path_or_name, scheduler=scheduler, safety_checker=None)
            self.sd15.classifier = self.sd15.unet

    def preprocecss(self, image: Image.Image, height=None, width=None):
        image = resize(image, size=512)
        
        if width is None or height is None: 
            width, height = image.size
        new_width = (width // 64) * 64
        new_height = (height // 64) * 64
        size = (new_width, new_height)
        image = image.resize(size, Image.BICUBIC)
        return to_tensor(image).unsqueeze(0)

    @spaces.GPU
    def run_style_transfer(self, content_image, style_image, seed, num_steps, lr, content_weight, mixed_precision, model, **kwargs):
        self.load_pipeline(model)

        content_image = self.preprocecss(content_image)
        style_image = self.preprocecss(style_image, height=512, width=512)

        height, width = content_image.shape[-2:]
        set_seed(seed)
        controller = Controller(self_layers=(10, 16))
        result = self.sd15.optimize(
            lr=lr,
            batch_size=1,
            iters=1,
            width=width,
            height=height,
            weight=content_weight,
            controller=controller,
            style_image=style_image,
            content_image=content_image,
            mixed_precision=mixed_precision,
            num_inference_steps=num_steps,
            enable_gradient_checkpoint=False,
        )
        output_image = to_pil_image(result[0].float())
        del result
        torch.cuda.empty_cache()
        return [output_image]

    @spaces.GPU
    def run_style_t2i_generation(self, style_image, prompt, negative_prompt, guidance_scale, height, width, seed, num_steps, iterations, lr, num_images_per_prompt, mixed_precision, is_adain, model):
        self.load_pipeline(model)

        use_xl = 'xl' in model
        height, width = (1024, 1024) if 'xl' in model else (512, 512)
        style_image = self.preprocecss(style_image, height=height, width=width)

        set_seed(seed)
        self_layers = (64, 70) if use_xl else (10, 16)
        
        controller = Controller(self_layers=self_layers)

        pipeline = self.sdxl if use_xl else self.sd15
        images = pipeline.sample(
            controller=controller,
            iters=iterations,
            lr=lr,
            adain=is_adain,
            height=height,
            width=width,
            mixed_precision=mixed_precision,
            style_image=style_image,
            prompt=prompt,
            negative_prompt=negative_prompt,
            guidance_scale=guidance_scale,
            num_inference_steps=num_steps,
            num_images_per_prompt=num_images_per_prompt,
            enable_gradient_checkpoint=False
        )
        output_images = [to_pil_image(image.float()) for image in images]

        del images
        torch.cuda.empty_cache()
        return output_images

    @spaces.GPU
    def run_texture_synthesis(self, texture_image, height, width, seed, num_steps, iterations, lr, mixed_precision, num_images_per_prompt, synthesis_way,model):
        self.load_pipeline(model) 

        texture_image = self.preprocecss(texture_image, height=512, width=512)

        set_seed(seed)
        controller = Controller(self_layers=(10, 16))

        if synthesis_way == 'Sampling':
            results = self.sd15.sample(
                lr=lr,
                adain=False,
                iters=iterations,
                width=width,
                height=height,
                weight=0.,
                controller=controller,
                style_image=texture_image,
                content_image=None,
                prompt="",
                negative_prompt="",
                mixed_precision=mixed_precision,
                num_inference_steps=num_steps,
                guidance_scale=1.,
                num_images_per_prompt=num_images_per_prompt,
                enable_gradient_checkpoint=False,
            )
        elif synthesis_way == 'MultiDiffusion':   
            results = self.sd15.panorama(
                lr=lr,
                iters=iterations,
                width=width,
                height=height,
                weight=0.,
                controller=controller,
                style_image=texture_image,
                content_image=None,
                prompt="",
                negative_prompt="",
                stride=8,
                view_batch_size=8,
                mixed_precision=mixed_precision,
                num_inference_steps=num_steps,
                guidance_scale=1.,
                num_images_per_prompt=num_images_per_prompt,
                enable_gradient_checkpoint=False,
            )
        else:
            raise ValueError
        
        output_images = [to_pil_image(image.float()) for image in results]
        del results
        torch.cuda.empty_cache()
        return output_images