import os
from io import BytesIO
from PIL import Image
from diffusers import AutoPipelineForText2Image
import gradio as gr
import base64
from generate_prompts import generate_prompt

CONCURRENCY_LIMIT = 10

def load_model():
    print("Loading the Stable Diffusion model...")
    try:
        model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")
        print("Model loaded successfully.")
        return model
    except Exception as e:
        print(f"Error loading model: {e}")
        return None

def generate_image(prompt):
    model = load_model()
    try:
        if model is None:
            raise ValueError("Model not loaded properly.")
        
        print(f"Generating image with prompt: {prompt}")
        output = model(prompt=prompt, num_inference_steps=1, guidance_scale=0.0)
        print(f"Model output: {output}")

        if output is None:
            raise ValueError("Model returned None")

        if hasattr(output, 'images') and output.images:
            print(f"Image generated successfully")
            image = output.images[0]
            buffered = BytesIO()
            image.save(buffered, format="JPEG")
            image_bytes = buffered.getvalue()
            img_str = base64.b64encode(image_bytes).decode("utf-8")
            print("Image encoded to base64")
            print(f'img_str: {img_str[:100]}...')  # Print a snippet of the base64 string
            return img_str, None
        else:
            print(f"No images found in model output")
            raise ValueError("No images found in model output")
    except Exception as e:
        print(f"An error occurred while generating image: {e}")
        return None, str(e)

def inference(sentence_mapping, character_dict, selected_style):
    try:
        print(f"Received sentence_mapping: {sentence_mapping}, type: {type(sentence_mapping)}")
        print(f"Received character_dict: {character_dict}, type: {type(character_dict)}")
        print(f"Received selected_style: {selected_style}, type: {type(selected_style)}")

        images = {}
        for paragraph_number, sentences in sentence_mapping.items():
            combined_sentence = " ".join(sentences)
            prompt = generate_prompt(combined_sentence,character_dict, selected_style)
            print(f"Generated prompt for paragraph {paragraph_number}: {prompt}")
            img_str, error = generate_image(prompt)
            if error:
                images[paragraph_number] = f"Error: {error}"
            else:
                images[paragraph_number] = img_str
        return images
    except Exception as e:
        print(f"An error occurred during inference: {e}")
        return {"error": str(e)}

gradio_interface = gr.Interface(
    fn=inference,
    inputs=[
        gr.JSON(label="Sentence Mapping"),
        gr.JSON(label="Character Dict"),
        gr.Dropdown(["oil painting", "sketch", "watercolor"], label="Selected Style")
    ],
    outputs="json",
    concurrency_limit=CONCURRENCY_LIMIT)

if __name__ == "__main__":
    print("Launching Gradio interface...")
    gradio_interface.launch()