# utils/ai_generator_diffusers_flux.py
import os
import utils.constants as constants
#import spaces
import gradio as gr
from torch import __version__ as torch_version_, version, cuda, bfloat16, float32, Generator, no_grad, backends
from diffusers import FluxPipeline,FluxImg2ImgPipeline,FluxControlPipeline
import accelerate 
from transformers import AutoTokenizer
import safetensors
#import xformers
#from diffusers.utils import load_image
#from huggingface_hub import hf_hub_download
from PIL import Image
from tempfile import NamedTemporaryFile

from utils.image_utils import (
     crop_and_resize_image,
)
from utils.version_info import (
    get_torch_info,
    # get_diffusers_version,
    # get_transformers_version,
    # get_xformers_version,
    initialize_cuda,
    release_torch_resources
)
import gc
from utils.lora_details import get_trigger_words, approximate_token_count, split_prompt_precisely
#from utils.color_utils import detect_color_format
#import utils.misc as misc
#from pathlib import Path
import warnings
warnings.filterwarnings("ignore", message=".*Torch was not compiled with flash attention.*")
#print(torch_version_)  # Ensure it's 2.0 or newer
#print(cuda.is_available())  # Ensure CUDA is available

PIPELINE_CLASSES = {
    "FluxPipeline": FluxPipeline,
    "FluxImg2ImgPipeline": FluxImg2ImgPipeline,
    "FluxControlPipeline": FluxControlPipeline
}
#@spaces.GPU()
def generate_image_from_text(
    text,
    model_name="black-forest-labs/FLUX.1-dev",
    lora_weights=None,
    conditioned_image=None,
    image_width=1344,
    image_height=848,
    guidance_scale=3.5,
    num_inference_steps=50,
    seed=0,
    additional_parameters=None,
    progress=gr.Progress(track_tqdm=True)
):
    from src.condition import Condition
    device = "cuda" if cuda.is_available() else "cpu"
    print(f"device:{device}\nmodel_name:{model_name}\n")

    # Initialize the pipeline
    pipe = FluxPipeline.from_pretrained(
        model_name,
        torch_dtype=bfloat16 if device == "cuda" else float32
    ).to(device)
    pipe.enable_model_cpu_offload()

    # Access the tokenizer from the pipeline
    tokenizer = pipe.tokenizer

    # Handle add_prefix_space attribute
    if getattr(tokenizer, 'add_prefix_space', False):
        tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
        # Update the pipeline's tokenizer
        pipe.tokenizer = tokenizer

    # Load and apply LoRA weights
    if lora_weights:
        for lora_weight in lora_weights:
            lora_configs = constants.LORA_DETAILS.get(lora_weight, [])
            if lora_configs:
                for config in lora_configs:
                    weight_name = config.get("weight_name")
                    adapter_name = config.get("adapter_name")
                    pipe.load_lora_weights(
                        lora_weight,
                        weight_name=weight_name,
                        adapter_name=adapter_name,
                        use_auth_token=constants.HF_API_TOKEN
                    )
            else:
                pipe.load_lora_weights(lora_weight, use_auth_token=constants.HF_API_TOKEN)

    # Set the random seed for reproducibility
    generator = Generator(device=device).manual_seed(seed)
    conditions = []

    # Handle conditioned image if provided
    if conditioned_image is not None:
        conditioned_image = crop_and_resize_image(conditioned_image, 1024, 1024)
        condition = Condition("subject", conditioned_image)
        conditions.append(condition)

    # Prepare parameters for image generation
    generate_params = {
        "prompt": text,
        "height": image_height,
        "width": image_width,
        "guidance_scale": guidance_scale,
        "num_inference_steps": num_inference_steps,
        "generator": generator,
        "conditions": conditions if conditions else None
    }

    if additional_parameters:
        generate_params.update(additional_parameters)
    generate_params = {k: v for k, v in generate_params.items() if v is not None}

    # Generate the image
    result = pipe(**generate_params)
    image = result.images[0]
    pipe.unload_lora_weights()

    # Clean up
    del result
    del conditions
    del generator
    del pipe
    cuda.empty_cache()
    cuda.ipc_collect()

    return image

#@spaces.GPU(progress=gr.Progress(track_tqdm=True))
def generate_image_lowmem(
    text,
    neg_prompt=None,
    model_name="black-forest-labs/FLUX.1-dev",
    lora_weights=None,
    conditioned_image=None,
    image_width=1368,
    image_height=848,
    guidance_scale=3.5,
    num_inference_steps=30,
    seed=0,
    true_cfg_scale=1.0,
    pipeline_name="FluxPipeline",
    strength=0.75,
    additional_parameters=None,
    progress=gr.Progress(track_tqdm=True)
):  
    # Retrieve the pipeline class from the mapping
    pipeline_class = PIPELINE_CLASSES.get(pipeline_name)
    if not pipeline_class:
        raise ValueError(f"Unsupported pipeline type '{pipeline_name}'. "
                        f"Available options: {list(PIPELINE_CLASSES.keys())}")

    initialize_cuda()
    device = "cuda" if cuda.is_available() else "cpu"
    from src.condition import Condition

    print(f"device:{device}\nmodel_name:{model_name}\nlora_weights:{lora_weights}\n")
    print(f"\n {get_torch_info()}\n")
    # Disable gradient calculations
    with no_grad():
        # Initialize the pipeline inside the context manager
        pipe = pipeline_class.from_pretrained(
            model_name,
            torch_dtype=bfloat16 if device == "cuda" else float32
        ).to(device)
        # Optionally, don't use CPU offload if not necessary
        
        # alternative version that may be more efficient
        # pipe.enable_sequential_cpu_offload()
        if pipeline_name == "FluxPipeline":
            pipe.enable_model_cpu_offload()
            pipe.vae.enable_slicing()
            pipe.vae.enable_tiling()
        else:
            pipe.enable_model_cpu_offload()

        # Access the tokenizer from the pipeline
        tokenizer = pipe.tokenizer

        # Check if add_prefix_space is set and convert to slow tokenizer if necessary
        if getattr(tokenizer, 'add_prefix_space', False):
            tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False, device_map = 'cpu')
            # Update the pipeline's tokenizer
            pipe.tokenizer = tokenizer
            pipe.to(device)

        flash_attention_enabled = backends.cuda.flash_sdp_enabled()
        if flash_attention_enabled == False:
            #Enable xFormers memory-efficient attention (optional)
            #pipe.enable_xformers_memory_efficient_attention()
            print("\nEnabled xFormers memory-efficient attention.\n")
        else:            
            pipe.attn_implementation="flash_attention_2"
            print("\nEnabled flash_attention_2.\n")

        condition_type = "subject"
        # Load LoRA weights
        # note: does not yet handle multiple LoRA weights with different names, needs .set_adapters(["depth", "hyper-sd"], adapter_weights=[0.85, 0.125])        
        if lora_weights:
            for lora_weight in lora_weights:
                lora_configs = constants.LORA_DETAILS.get(lora_weight, [])
                lora_weight_set = False
                if lora_configs:
                    for config in lora_configs:
                        # Load LoRA weights with optional weight_name and adapter_name
                        if 'weight_name' in config:
                            weight_name = config.get("weight_name")
                            adapter_name = config.get("adapter_name")
                            lora_collection = config.get("lora_collection")
                            if weight_name and adapter_name and lora_collection and lora_weight_set == False:
                                pipe.load_lora_weights(
                                    lora_collection,
                                    weight_name=weight_name,
                                    adapter_name=adapter_name,
                                    token=constants.HF_API_TOKEN
                                )
                                lora_weight_set = True
                                print(f"\npipe.load_lora_weights({lora_weight}, weight_name={weight_name}, adapter_name={adapter_name}, lora_collection={lora_collection}\n")  
                            elif weight_name and adapter_name==None and lora_collection and lora_weight_set == False:
                                pipe.load_lora_weights(
                                    lora_collection,
                                    weight_name=weight_name,
                                    token=constants.HF_API_TOKEN
                                )
                                lora_weight_set = True
                                print(f"\npipe.load_lora_weights({lora_weight}, weight_name={weight_name}, adapter_name={adapter_name}, lora_collection={lora_collection}\n")  
                            elif weight_name and adapter_name and lora_weight_set == False:
                                pipe.load_lora_weights(
                                    lora_weight,
                                    weight_name=weight_name,
                                    adapter_name=adapter_name,
                                    token=constants.HF_API_TOKEN
                                )
                                lora_weight_set = True
                                print(f"\npipe.load_lora_weights({lora_weight}, weight_name={weight_name}, adapter_name={adapter_name}\n")  
                            elif weight_name and adapter_name==None and lora_weight_set == False:
                                pipe.load_lora_weights(
                                    lora_weight,
                                    weight_name=weight_name,
                                    token=constants.HF_API_TOKEN
                                )
                                lora_weight_set = True
                                print(f"\npipe.load_lora_weights({lora_weight}, weight_name={weight_name}, adapter_name={adapter_name}\n")  
                            elif lora_weight_set == False:
                                pipe.load_lora_weights(
                                    lora_weight,
                                    token=constants.HF_API_TOKEN
                                )  
                                lora_weight_set = True
                                print(f"\npipe.load_lora_weights({lora_weight}, weight_name={weight_name}, adapter_name={adapter_name}\n")  
                        # Apply 'pipe' configurations if present
                        if 'pipe' in config:
                            pipe_config = config['pipe']
                            for method_name, params in pipe_config.items():
                                method = getattr(pipe, method_name, None)
                                if method:
                                    print(f"Applying pipe method: {method_name} with params: {params}")
                                    method(**params)
                                else:
                                    print(f"Method {method_name} not found in pipe.")
                        if 'condition_type' in config:
                            condition_type = config['condition_type']
                            if condition_type == "coloring":
                                #pipe.enable_coloring()
                                print("\nEnabled coloring.\n")
                            elif condition_type == "deblurring":
                                #pipe.enable_deblurring()
                                print("\nEnabled deblurring.\n")
                            elif condition_type == "fill":
                                #pipe.enable_fill()
                                print("\nEnabled fill.\n")
                            elif condition_type == "depth":
                                #pipe.enable_depth()
                                print("\nEnabled depth.\n")
                            elif condition_type == "canny":
                                #pipe.enable_canny()
                                print("\nEnabled canny.\n")
                            elif condition_type == "subject":
                                #pipe.enable_subject()
                                print("\nEnabled subject.\n")
                            else:
                                print(f"Condition type {condition_type} not implemented.")
                else:
                    pipe.load_lora_weights(lora_weight, use_auth_token=constants.HF_API_TOKEN)
        # Set the random seed for reproducibility
        generator = Generator(device=device).manual_seed(seed)
        conditions = []
        if conditioned_image is not None:
            conditioned_image = crop_and_resize_image(conditioned_image, image_width, image_height)
            condition = Condition(condition_type, conditioned_image)
            conditions.append(condition)
            print(f"\nAdded conditioned image.\n {conditioned_image.size}")
            # Prepare the parameters for image generation
            additional_parameters ={
                "strength": strength,
                "image": conditioned_image,
            }
        else:
            print("\nNo conditioned image provided.")
            if neg_prompt!=None:
                true_cfg_scale=1.1
            additional_parameters ={
                "negative_prompt": neg_prompt,
                "true_cfg_scale": true_cfg_scale,
            }
        # handle long prompts by splitting them
        if approximate_token_count(text) > 76:
            prompt, prompt2 = split_prompt_precisely(text)
            prompt_parameters = {
                "prompt" : prompt,
                "prompt_2": prompt2
            }
        else:
            prompt_parameters = {
                "prompt" :text
        }
        additional_parameters.update(prompt_parameters)
        # Combine all parameters
        generate_params = {
            "height": image_height,
            "width": image_width,
            "guidance_scale": guidance_scale,
            "num_inference_steps": num_inference_steps,
            "generator": generator,        }
        if additional_parameters:
            generate_params.update(additional_parameters)
        generate_params = {k: v for k, v in generate_params.items() if v is not None}
        print(f"generate_params: {generate_params}")
        # Generate the image
        result = pipe(**generate_params)
        image = result.images[0]
        # Clean up
        del result
        del conditions
        del generator
    # Delete the pipeline and clear cache
    del pipe
    cuda.empty_cache()
    cuda.ipc_collect()
    print(cuda.memory_summary(device=None, abbreviated=False))
    
    return image

def generate_ai_image_local (
    map_option,
    prompt_textbox_value,
    neg_prompt_textbox_value,
    model="black-forest-labs/FLUX.1-dev",
    lora_weights=None,
    conditioned_image=None,
    height=512,
    width=912,
    num_inference_steps=30,
    guidance_scale=3.5,
    seed=777,
    pipeline_name="FluxPipeline",
    strength=0.75,
    progress=gr.Progress(track_tqdm=True)
):
    release_torch_resources()
    print(f"Generating image with lowmem")
    try:
        if map_option != "Prompt":
            prompt = constants.PROMPTS[map_option]
            negative_prompt = constants.NEGATIVE_PROMPTS.get(map_option, "")
        else:
            prompt = prompt_textbox_value
            negative_prompt = neg_prompt_textbox_value or ""
        #full_prompt = f"{prompt} {negative_prompt}"
        additional_parameters = {}
        if lora_weights:
            for lora_weight in lora_weights:
                lora_configs = constants.LORA_DETAILS.get(lora_weight, [])
                for config in lora_configs:
                    if 'parameters' in config:
                        additional_parameters.update(config['parameters'])
                    elif 'trigger_words' in config:
                        trigger_words = get_trigger_words(lora_weight)
                        prompt = f"{trigger_words} {prompt}"
        for key, value in additional_parameters.items():
            if key in ['height', 'width', 'num_inference_steps', 'max_sequence_length']:
                additional_parameters[key] = int(value)
            elif key in ['guidance_scale','true_cfg_scale']:
                additional_parameters[key] = float(value)
        height = additional_parameters.pop('height', height)
        width = additional_parameters.pop('width', width)
        num_inference_steps = additional_parameters.pop('num_inference_steps', num_inference_steps)        
        guidance_scale = additional_parameters.pop('guidance_scale', guidance_scale)
        print("Generating image with the following parameters:")
        print(f"Model: {model}")
        print(f"LoRA Weights: {lora_weights}")
        print(f"Prompt: {prompt}")
        print(f"Neg Prompt: {negative_prompt}")
        print(f"Height: {height}")
        print(f"Width: {width}")
        print(f"Number of Inference Steps: {num_inference_steps}")
        print(f"Guidance Scale: {guidance_scale}")
        print(f"Seed: {seed}")
        print(f"Additional Parameters: {additional_parameters}")
        print(f"Conditioned Image: {conditioned_image}")
        print(f"Conditioned Image Strength: {strength}")
        print(f"pipeline: {pipeline_name}")
        image = generate_image_lowmem(
            text=prompt,
            model_name=model,
            neg_prompt=negative_prompt,
            lora_weights=lora_weights,
            conditioned_image=conditioned_image,
            image_width=width,
            image_height=height,
            guidance_scale=guidance_scale,
            num_inference_steps=num_inference_steps,
            seed=seed,
            pipeline_name=pipeline_name,
            strength=strength,
            additional_parameters=additional_parameters
        )
        with NamedTemporaryFile(delete=False, suffix=".png") as tmp:
            image.save(tmp.name, format="PNG")
            constants.temp_files.append(tmp.name)
            print(f"Image saved to {tmp.name}")
            gc.collect()
            return tmp.name
    except Exception as e:
        print(f"Error generating AI image: {e}")
        gc.collect()
        return None

# does not work
def merge_LoRA_weights(model="black-forest-labs/FLUX.1-dev",
    lora_weights="Borcherding/FLUX.1-dev-LoRA-FractalLand-v0.1"):
    
    model_suffix = model.split("/")[-1]
    if model_suffix not in lora_weights:
        raise ValueError(f"The model suffix '{model_suffix}' must be in the lora_weights string '{lora_weights}' to proceed.")
    
    pipe = FluxPipeline.from_pretrained(model, torch_dtype=bfloat16)
    pipe.load_lora_weights(lora_weights)
    pipe.save_lora_weights(os.getenv("TMPDIR"))
    lora_name = lora_weights.split("/")[-1] + "-merged"
    pipe.save_pretrained(lora_name)
    pipe.unload_lora_weights()