import torch
from enum import Enum
import gc
import numpy as np
import jax.numpy as jnp
import jax

from PIL import Image
from typing import List

from flax.training.common_utils import shard
from flax.jax_utils import replicate
from flax import jax_utils
import einops

from transformers import CLIPTokenizer, CLIPFeatureExtractor, FlaxCLIPTextModel
from diffusers import (
    FlaxDDIMScheduler,
    FlaxAutoencoderKL,
    FlaxStableDiffusionControlNetPipeline,
    StableDiffusionPipeline,
    FlaxUNet2DConditionModel as VanillaFlaxUNet2DConditionModel,
)
from text_to_animation.models.unet_2d_condition_flax import (
    FlaxUNet2DConditionModel
)
from diffusers import FlaxControlNetModel

from text_to_animation.pipelines.text_to_video_pipeline_flax import (
    FlaxTextToVideoPipeline,
)

import utils.utils as utils
import utils.gradio_utils as gradio_utils
import os

on_huggingspace = os.environ.get("SPACE_AUTHOR_NAME") == "PAIR"

unshard = lambda x: einops.rearrange(x, "d b ... -> (d b) ...")


class ModelType(Enum):
    Text2Video = 1
    ControlNetPose = 2
    StableDiffusion = 3


def replicate_devices(array):
    return jnp.expand_dims(array, 0).repeat(jax.device_count(), 0)


class ControlAnimationModel:
    def __init__(self, dtype, **kwargs):
        self.dtype = dtype
        self.rng = jax.random.PRNGKey(0)
        self.pipe = None
        self.model_type = None

        self.states = {}
        self.model_name = ""

    def set_model(
        self,
        model_id: str,
        **kwargs,
    ):
        if hasattr(self, "pipe") and self.pipe is not None:
            del self.pipe
            self.pipe = None
        gc.collect()

        controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
            "fusing/stable-diffusion-v1-5-controlnet-openpose",
            from_pt=True,
            dtype=jnp.float16,
        )

        scheduler, scheduler_state = FlaxDDIMScheduler.from_pretrained(
            model_id, subfolder="scheduler", from_pt=True
        )
        tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
        feature_extractor = CLIPFeatureExtractor.from_pretrained(
            model_id, subfolder="feature_extractor"
        )
        unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(
            model_id, subfolder="unet", from_pt=True, dtype=self.dtype
        )
        unet_vanilla = VanillaFlaxUNet2DConditionModel.from_config(
            model_id, subfolder="unet", from_pt=True, dtype=self.dtype
        )
        vae, vae_params = FlaxAutoencoderKL.from_pretrained(
            model_id, subfolder="vae", from_pt=True, dtype=self.dtype
        )
        text_encoder = FlaxCLIPTextModel.from_pretrained(
            model_id, subfolder="text_encoder", from_pt=True, dtype=self.dtype
        )
        self.pipe = FlaxTextToVideoPipeline(
            vae=vae,
            text_encoder=text_encoder,
            tokenizer=tokenizer,
            unet=unet,
            unet_vanilla=unet_vanilla,
            controlnet=controlnet,
            scheduler=scheduler,
            safety_checker=None,
            feature_extractor=feature_extractor,
        )
        self.params = {
            "unet": unet_params,
            "vae": vae_params,
            "scheduler": scheduler_state,
            "controlnet": controlnet_params,
            "text_encoder": text_encoder.params,
        }
        self.p_params = jax_utils.replicate(self.params)
        self.model_name = model_id

    def generate_initial_frames(
        self,
        prompt: str,
        video_path: str,
        n_prompt: str = "",
        seed: int = 0,
        num_imgs: int = 4,
        resolution: int = 512,
        model_id: str = "runwayml/stable-diffusion-v1-5",
    ) -> List[Image.Image]:
        self.set_model(model_id=model_id)

        video_path = gradio_utils.motion_to_video_path(video_path)

        added_prompt = "high quality, best quality, HD, clay stop-motion, claymation, HQ, masterpiece, art, smooth"
        prompts = added_prompt + ", " + prompt

        added_n_prompt = "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer difits, cropped, worst quality, low quality, deformed body, bloated, ugly"
        negative_prompts = added_n_prompt + ", " + n_prompt

        video, fps = utils.prepare_video(
            video_path, resolution, None, self.dtype, False, output_fps=4
        )
        control = utils.pre_process_pose(video, apply_pose_detect=False)

        # seeds = [seed for seed in jax.random.randint(self.rng, [num_imgs], 0, 65536)]
        prngs = [jax.random.PRNGKey(seed)] * num_imgs
        images = self.pipe.generate_starting_frames(
            params=self.p_params,
            prngs=prngs,
            controlnet_image=control,
            prompt=prompts,
            neg_prompt=negative_prompts,
        )

        images = [np.array(images[i]) for i in range(images.shape[0])]

        return video, images

    def generate_video_from_frame(self, controlnet_video, prompt, n_prompt, seed):
        # generate a video using the seed provided
        prng_seed = jax.random.PRNGKey(seed)
        len_vid = controlnet_video.shape[0]
        # print(f"Generating video from prompt {'<aardman> style '+ prompt}, with {controlnet_video.shape[0]} frames and prng seed {seed}")
        added_prompt = "high quality, best quality, HD, clay stop-motion, claymation, HQ, masterpiece, art, smooth"
        prompts = added_prompt + ", " + prompt

        added_n_prompt = "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer difits, cropped, worst quality, low quality, deformed body, bloated, ugly"
        negative_prompts = added_n_prompt + ", " + n_prompt
        
        # prompt_ids = self.pipe.prepare_text_inputs(["aardman style "+ prompt]*len_vid)
        # n_prompt_ids = self.pipe.prepare_text_inputs([neg_prompt]*len_vid)
        
        prompt_ids = self.pipe.prepare_text_inputs([prompts]*len_vid)
        n_prompt_ids = self.pipe.prepare_text_inputs([negative_prompts]*len_vid)
        prng = replicate_devices(prng_seed) #jax.random.split(prng, jax.device_count())
        image = replicate_devices(controlnet_video)
        prompt_ids = replicate_devices(prompt_ids)
        n_prompt_ids = replicate_devices(n_prompt_ids)
        motion_field_strength_x = replicate_devices(jnp.array(3))
        motion_field_strength_y = replicate_devices(jnp.array(4))
        smooth_bg_strength = replicate_devices(jnp.array(0.8))
        vid = (self.pipe(image=image,
                        prompt_ids=prompt_ids,
                        neg_prompt_ids=n_prompt_ids, 
                        params=self.p_params,
                        prng_seed=prng,
                        jit = True,
                        smooth_bg_strength=smooth_bg_strength,
                        motion_field_strength_x=motion_field_strength_x,
                        motion_field_strength_y=motion_field_strength_y,
                        ).images)[0]
        return utils.create_gif(np.array(vid), 4, path=None, watermark=None)