# utils/ai_generator.py
import gradio as gr
import os
import time
#from turtle import width  # Added for implementing delays
from torch import cuda
import random
from utils.ai_generator_diffusers_flux import generate_ai_image_local
#from pathlib import Path
from huggingface_hub import InferenceClient
import requests
import io
from PIL import Image
from tempfile import NamedTemporaryFile
import utils.constants as constants

def generate_image_from_text(text, model_name="flax-community/dalle-mini", image_width=768, image_height=512, progress=gr.Progress(track_tqdm=True)):
    # Initialize the InferenceClient
    client = InferenceClient()
    # Generate the image from the text
    response = client(text, model_name)
    # Get the image data
    image_data = response.content
    # Load the image from the data
    image = Image.open(io.BytesIO(image_data))
    # Resize the image
    image = image.resize((image_width, image_height))
    return image

def generate_ai_image(
    map_option,
    prompt_textbox_value,
    neg_prompt_textbox_value,
    model,
    lora_weights=None,
    conditioned_image=None,
    pipeline = "FluxPipeline",
    width=912,
    height=512,
    strength=0.5,
    seed = 0,
    progress=gr.Progress(track_tqdm=True),
    *args,
    **kwargs
):   
    if seed == 0:
        seed = random.randint(0, constants.MAX_SEED)
    if (cuda.is_available() and cuda.device_count() >= 1): # Check if a local GPU is available 
        print("Local GPU available. Generating image locally.")
        if conditioned_image is not None:
            pipeline = "FluxImg2ImgPipeline"
        return generate_ai_image_local(
            map_option,
            prompt_textbox_value,
            neg_prompt_textbox_value,
            model,
            lora_weights=lora_weights,
            seed=seed,
            conditioned_image=conditioned_image,
            pipeline_name=pipeline,
            strength=strength,
            height=height,
            width=width
        )
    else:
        print("No local GPU available. Sending request to Hugging Face API.")
        return generate_ai_image_remote(
            map_option,
            prompt_textbox_value,
            neg_prompt_textbox_value,
            model,
            height=height,
            width=width,
            seed=seed
        )

def generate_ai_image_remote(map_option, prompt_textbox_value, neg_prompt_textbox_value, model, height=512, width=912, num_inference_steps=30, guidance_scale=3.5, seed=777,progress=gr.Progress(track_tqdm=True)):
    max_retries = 3
    retry_delay = 4  # Initial delay in seconds

    try:
        if map_option != "Prompt":
            prompt = constants.PROMPTS[map_option]
            # Convert the negative prompt string to a list
            negative_prompt_str = constants.NEGATIVE_PROMPTS.get(map_option, "")
            negative_prompt = [p.strip() for p in negative_prompt_str.split(',') if p.strip()]
        else:
            prompt = prompt_textbox_value
            # Convert the negative prompt string to a list
            negative_prompt = [p.strip() for p in neg_prompt_textbox_value.split(',') if p.strip()] if neg_prompt_textbox_value else []

        print("Remotely Generating image with the following parameters:")
        print(f"Prompt: {prompt}")
        print(f"Negative Prompt: {negative_prompt}")
        print(f"Height: {height}")
        print(f"Width: {width}")
        print(f"Number of Inference Steps: {num_inference_steps}")
        print(f"Guidance Scale: {guidance_scale}")
        print(f"Seed: {seed}")

        for attempt in range(1, max_retries + 1):
            try:
                if os.getenv("IS_SHARED_SPACE") == "True":
                    client = InferenceClient(
                        model,
                        token=constants.HF_API_TOKEN
                    )
                    image = client.text_to_image(
                        inputs=prompt,
                        parameters={
                            "guidance_scale": guidance_scale,
                            "num_inference_steps": num_inference_steps,
                            "width": width,
                            "height": height,
                            "max_sequence_length":512,                            
                            # Optional: Add 'scheduler' and 'seed' if needed
                            "seed": seed
                        }
                    )
                else:
                    API_URL = f"https://api-inference.huggingface.co/models/{model}"
                    headers = {
                        "Authorization": f"Bearer {constants.HF_API_TOKEN}",
                        "Content-Type": "application/json"
                    }
                    payload = {
                        "inputs": prompt,
                        "parameters": {
                            "guidance_scale": guidance_scale,
                            "num_inference_steps": num_inference_steps,
                            "width": width,
                            "height": height,
                            "max_sequence_length":512,
                            # Optional: Add 'scheduler' and 'seed' if needed
                            "seed": seed
                        }
                    }

                    print(f"Attempt {attempt}: Sending POST request to Hugging Face API...")
                    response = requests.post(API_URL, headers=headers, json=payload, timeout=300)  # Increased timeout to 30 seconds
                    if response.status_code == 200:
                        image_bytes = response.content
                        image = Image.open(io.BytesIO(image_bytes))
                        break  # Exit the retry loop on success
                    elif response.status_code == 400:
                        # Handle 400 Bad Request specifically
                        print(f"Bad Request (400): {response.text}")
                        print("Check your request parameters and payload format.")
                        return None  # Do not retry on 400 errors
                    elif response.status_code in [429, 504]:
                        print(f"Received status code {response.status_code}. Retrying in {retry_delay} seconds...")
                        if attempt < max_retries:
                            time.sleep(retry_delay)
                            retry_delay *= 2  # Exponential backoff
                        else:
                            response.raise_for_status()  # Raise exception after max retries
                    else:
                        print(f"Received unexpected status code {response.status_code}: {response.text}")
                        response.raise_for_status()
            except (requests.exceptions.ReadTimeout, requests.exceptions.ConnectTimeout) as timeout_error:
                print(f"Timeout occurred: {timeout_error}. Retrying in {retry_delay} seconds...")
                if attempt < max_retries:
                    time.sleep(retry_delay)
                    retry_delay *= 2  # Exponential backoff
                else:
                    raise  # Re-raise the exception after max retries
            except requests.exceptions.RequestException as req_error:
                print(f"Request exception: {req_error}. Retrying in {retry_delay} seconds...")
                if attempt < max_retries:
                    time.sleep(retry_delay)
                    retry_delay *= 2  # Exponential backoff
                else:
                    raise  # Re-raise the exception after max retries

        else:
            # If all retries failed
            print("Max retries exceeded. Failed to generate image.")
            return None

        with NamedTemporaryFile(delete=False, suffix=".png") as tmp:
            image.save(tmp.name, format="PNG")
            constants.temp_files.append(tmp.name)
            print(f"Image saved to {tmp.name}")
            return tmp.name

    except Exception as e:
        print(f"Error generating AI image: {e}")
        return None