import torch.multiprocessing as mp
import torch
import os
import re
import random
from collections import deque
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
import gradio as gr
from accelerate import Accelerator
import spaces

# Check if the start method has already been set
if mp.get_start_method(allow_none=True) != 'spawn':
    mp.set_start_method('spawn')

# Instantiate the Accelerator
accelerator = Accelerator()

dtype = torch.bfloat16

# Set environment variables for local path
os.environ['FLUX_DEV'] = '.'
os.environ['AE'] = '.'
os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = 'false'  # Disable HF_HUB_ENABLE_HF_TRANSFER

# Seed words pool
seed_words = []

used_words = set()

# Queue to store parsed descriptions
parsed_descriptions_queue = deque()

# Usage limits
MAX_DESCRIPTIONS = 30
MAX_IMAGES = 4  # Limit to 4 images

# Preload models and checkpoints
print("Preloading models and checkpoints...")
model_name = 'NousResearch/Meta-Llama-3.1-8B-Instruct'
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map='auto')
tokenizer = AutoTokenizer.from_pretrained(model_name)
text_generator = pipeline('text-generation', model=model, tokenizer=tokenizer)

def initialize_diffusers():
    from optimum.quanto import freeze, qfloat8, quantize
    from diffusers import FlowMatchEulerDiscreteScheduler, AutoencoderKL
    from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
    from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
    from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast

    bfl_repo = 'black-forest-labs/FLUX.1-schnell'
    revision = 'refs/pr/1'

    scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(bfl_repo, subfolder='scheduler', revision=revision)
    text_encoder = CLIPTextModel.from_pretrained('openai/clip-vit-large-patch14', torch_dtype=dtype)
    tokenizer = CLIPTokenizer.from_pretrained('openai/clip-vit-large-patch14', torch_dtype=dtype)
    text_encoder_2 = T5EncoderModel.from_pretrained(bfl_repo, subfolder='text_encoder_2', torch_dtype=dtype, revision=revision)
    tokenizer_2 = T5TokenizerFast.from_pretrained(bfl_repo, subfolder='tokenizer_2', torch_dtype=dtype, revision=revision)
    vae = AutoencoderKL.from_pretrained(bfl_repo, subfolder='vae', torch_dtype=dtype, revision=revision)
    transformer = FluxTransformer2DModel.from_pretrained(bfl_repo, subfolder='transformer', torch_dtype=dtype, revision=revision)

    quantize(transformer, weights=qfloat8)
    freeze(transformer)
    quantize(text_encoder_2, weights=qfloat8)
    freeze(text_encoder_2)

    pipe = FluxPipeline(
        scheduler=scheduler,
        text_encoder=text_encoder,
        tokenizer=tokenizer,
        text_encoder_2=None,
        tokenizer_2=tokenizer_2,
        vae=vae,
        transformer=None,
    )
    pipe.text_encoder_2 = text_encoder_2
    pipe.transformer = transformer
    pipe.enable_model_cpu_offload()

    return pipe

pipe = initialize_diffusers()
print("Models and checkpoints preloaded.")

def generate_description_prompt(user_prompt, text_generator):
    injected_prompt = f"write three concise descriptions enclosed in brackets like [ <description> ] less than 100 words each of {user_prompt}. "
    max_length = 110  # Set max token length to 110

    try:
        generated_text = text_generator(injected_prompt, max_length=max_length, num_return_sequences=1, truncation=True)[0]['generated_text']
        generated_descriptions = re.findall(r'\[([^\[\]]+)\]', generated_text)  # Extract descriptions enclosed in brackets
        # Filter descriptions to ensure they are at least 4 words long
        filtered_descriptions = [desc for desc in generated_descriptions if len(desc.split()) >= 4]
        return filtered_descriptions if filtered_descriptions else None
    except Exception as e:
        print(f"Error generating descriptions: {e}")
        return None

def format_descriptions(descriptions):
    formatted_descriptions = "\n".join(descriptions)
    return formatted_descriptions

@spaces.GPU
def generate_descriptions(user_prompt, seed_words_input, batch_size=100, max_iterations=4):  # Set max_iterations to 4
    descriptions = []
    for _ in range(4):  # Perform four iterations
        new_descriptions = generate_description_prompt(user_prompt, text_generator)
        if new_descriptions:
            descriptions.extend(new_descriptions)
            # Pick a random description to feed back into the seed bank for subject
            random_description = random.choice(new_descriptions)
            seed_words.append(random_description)

    # Limit the number of descriptions to MAX_IMAGES (4)
    if len(descriptions) > MAX_IMAGES:
        descriptions = descriptions[:MAX_IMAGES]

    parsed_descriptions_queue.extend(descriptions)
    return list(parsed_descriptions_queue)[:MAX_IMAGES]

@spaces.GPU(duration=120)
def generate_images(parsed_descriptions, max_iterations=4):  # Set max_iterations to 4
    # Limit the number of descriptions passed to the image generator to MAX_IMAGES (4)
    if len(parsed_descriptions) > MAX_IMAGES:
        parsed_descriptions = parsed_descriptions[:MAX_IMAGES]

    images = []
    for prompt in parsed_descriptions:
        try:
            images.extend(pipe(prompt, num_inference_steps=4, height=1024, width=1024).images)  # Set resolution to 1024 x 1024
        except Exception as e:
            print(f"Error generating image for prompt '{prompt}': {e}")

    return images

def combined_function(user_prompt, seed_words_input):
    parsed_descriptions = generate_descriptions(user_prompt, seed_words_input)
    formatted_descriptions = format_descriptions(parsed_descriptions)
    images = generate_images(parsed_descriptions)
    return formatted_descriptions, images

if __name__ == '__main__':
    def generate_and_display(user_prompt, seed_words_input):
        parsed_descriptions = generate_descriptions(user_prompt, seed_words_input)
        formatted_descriptions = format_descriptions(parsed_descriptions)
        images = generate_images(parsed_descriptions)
        return formatted_descriptions, images

    interface = gr.Interface(
        fn=generate_and_display,
        inputs=[gr.Textbox(lines=2, placeholder="Enter a prompt for descriptions..."), gr.Textbox(lines=2, placeholder='Enter example in quotes, e.g., "cat", "dog", "sunset"...')],
        outputs=[gr.Textbox(label="Generated Descriptions"), gr.Gallery(label="Generated Images")],
        live=False,  # Set live to False
        allow_flagging='never'  # Disable flagging
    )

    interface.launch(share=True)