import os
import torch
import torch._dynamo
import gc

import json
import transformers
from huggingface_hub.constants import HF_HUB_CACHE
from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel

# ApricityApricityApricityApricityApricityApricityApricityApricityApricityApricityApricity

from torch import Generator
from diffusers import FluxTransformer2DModel, DiffusionPipeline

from PIL.Image import Image
from diffusers import FluxPipeline, AutoencoderKL, AutoencoderTiny
from pipelines.models import TextToImageRequest
from optimum.quanto import requantize
import json


# ApricityApricityApricityApricityApricityApricityApricityApricityApricityApricityApricity


torch._dynamo.config.suppress_errors = True
os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True"
os.environ["TOKENIZERS_PARALLELISM"] = "True"

ckpt_root = "MyApricity/FLUX_OPT_SCHNELL_1.2"
revision_root = "488528b6f815bff1bbc747cf1e0947c77c544665"
Pipeline = None
use_com = False

import torch
import math
from typing import Dict, Any

def remove_cache():
    torch.cuda.empty_cache()
    torch.cuda.reset_max_memory_allocated()
    gc.collect()
    torch.cuda.reset_peak_memory_stats()



def text_t5_loader() -> T5EncoderModel:
    print("Loading text encoder...")
    text_encoder = T5EncoderModel.from_pretrained(
        "city96/t5-v1_1-xxl-encoder-bf16",
        revision="1b9c856aadb864af93c1dcdc226c2774fa67bc86",
        torch_dtype=torch.bfloat16,
    )
    return text_encoder.to(memory_format=torch.channels_last)


class StableDiffusionTransformerCompile:
    def __init__(self, pipeline, optimize=False):
        self.pipeline = pipeline
        self.optimize = optimize
        if self.optimize:
            self.model_compiling()
    
    def model_compiling(self):
        # Staff doing here
        self.pipeline.unet = torch.compile(self.pipeline.unet)
    
    def __call__(self, *args, **kwargs):
        return self.pipeline(*args, **kwargs)
    
def load_pipeline() -> Pipeline:

    text_t5_encoder = text_t5_loader()

    transformer_path__ = os.path.join(HF_HUB_CACHE, "models--MyApricity--FLUX_OPT_SCHNELL_1.2/snapshots/488528b6f815bff1bbc747cf1e0947c77c544665")
    
    transformer__ = FluxTransformer2DModel.from_pretrained(transformer_path__, torch_dtype=torch.bfloat16, use_safetensors=False)

    try:
        pipeline = DiffusionPipeline.from_pretrained(ckpt_root, 
                            revision=revision_root, 
                            transformer=transformer__, 
                            torch_dtype=torch.bfloat16)

    except:
        pipeline = DiffusionPipeline.from_pretrained(ckpt_root, 
                            revision=revision_root,
                            torch_dtype=torch.bfloat16)

    pipeline.to("cuda")

    try:
        compiled_pipeline = StableDiffusionTransformerCompile(pipeline, optimize=False)

        if use_com:
            pipeline = compiled_pipeline
        else:
            print("Currently not compling affectively")

        pipeline.disable_vae_compress()
        pipeline.text_encoder_2 = text_t5_encoder

    except:
        print("pipeline")
        

    prompt_1 = "albaspidin, pillmonger, palaeocrystalline"
    pipeline(prompt=prompt_1, 
                    width=1024,
                    height=1024,
                    guidance_scale=0.0,
                    num_inference_steps=4,
                    max_sequence_length=256)

    prompt_2 = "obe, kilometrage, circuition"
    pipeline(prompt=prompt_2, 
                    width=1024,
                    height=1024,
                    guidance_scale=0.0,
                    num_inference_steps=4,
                    max_sequence_length=256)

    return pipeline


@torch.no_grad()
def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:

    remove_cache()
    # remove cache here for better result
    generator = Generator(pipeline.device).manual_seed(request.seed)

    return pipeline(
        request.prompt,
        generator=generator,
        guidance_scale=0.0,
        num_inference_steps=4,
        max_sequence_length=256,
        height=request.height,
        width=request.width,
    ).images[0]