import gradio as gr
import torch
from PIL import Image
import numpy as np
import cv2
from diffusers import StableDiffusionPipeline

# Setup the model
device = "cuda" if torch.cuda.is_available() else "cpu"
model_id = "stabilityai/sdxl-turbo"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16 if device == "cuda" else torch.float32)
pipe = pipe.to(device)

# Generate T-shirt design function
def generate_tshirt_design(style, color, graphics, text=None):
    prompt = f"T-shirt design, style: {style}, color: {color}, graphics: {graphics}"
    if text:
        prompt += f", text: {text}"
    image = pipe(prompt).images[0]
    return image

# T-shirt mockup generator with Gradio interface
examples = [
    ["Casual", "White", "Logo: MyBrand", None],
    ["Formal", "Black", "Text: Hello World", "Custom text"],
    ["Sports", "Red", "Graphic: Team logo", None],
]

css = """
#col-container {
    margin: 0 auto;
    max-width: 520px;
}
"""

with gr.Blocks(css=css) as demo:
    with gr.Column(elem_id="col-container"):
        gr.Markdown("""
        # T-shirt Mockup Generator with Rookus AI
        """)

        with gr.Row():
            style = gr.Dropdown(
                label="T-shirt Style",
                choices=["Casual", "Formal", "Sports"],
                value="Casual",
                container=False,
            )

            run_button = gr.Button("Generate Mockup", scale=0)

        result = gr.Image(label="Mockup", show_label=False)

        with gr.Accordion("Design Options", open=False):
            color = gr.Radio(
                label="T-shirt Color",
                choices=["White", "Black", "Blue", "Red", "Green"],
                value="White",
            )

            graphics = gr.Textbox(
                label="Graphics/Logo",
                placeholder="Enter graphics or logo details",
                visible=True,
            )

            text = gr.Textbox(
                label="Text (optional)",
                placeholder="Enter optional text",
                visible=True,
            )

        gr.Examples(
            examples=examples,
            inputs=[style, color, graphics, text]
        )

    def generate_tshirt_mockup(style, color, graphics, text=None):
        # Generate T-shirt design
        design_image = generate_tshirt_design(style, color, graphics, text)

        # Load blank T-shirt mockup template image
        mockup_template = Image.open("path/to/your/mockup/template.jpg")  # Update the path to your mockup template

        # Convert design image and mockup template to numpy arrays
        design_np = np.array(design_image)
        mockup_np = np.array(mockup_template)

        # Resize design image to fit mockup (example resizing)
        design_resized = cv2.resize(design_np, (mockup_np.shape[1] // 2, mockup_np.shape[0] // 2))

        # Example: Overlay design onto mockup using OpenCV
        y_offset = mockup_np.shape[0] // 4
        x_offset = mockup_np.shape[1] // 4
        y1, y2 = y_offset, y_offset + design_resized.shape[0]
        x1, x2 = x_offset, x_offset + design_resized.shape[1]

        alpha_s = design_resized[:, :, 3] / 255.0 if design_resized.shape[2] == 4 else np.ones(design_resized.shape[:2])
        alpha_l = 1.0 - alpha_s

        for c in range(0, 3):
            mockup_np[y1:y2, x1:x2, c] = (alpha_s * design_resized[:, :, c] +
                                          alpha_l * mockup_np[y1:y2, x1:x2, c])

        # Convert back to PIL image for Gradio output
        result_image = Image.fromarray(mockup_np)

        return result_image

    run_button.click(
        fn=generate_tshirt_mockup,
        inputs=[style, color, graphics, text],
        outputs=[result]
    )

demo.queue().launch()