import gradio as gr
from PIL import Image
import os
import spaces
from OmniGen import OmniGenPipeline
pipe = OmniGenPipeline.from_pretrained(
    "shitao/tmp-preview"
)
@spaces.GPU
# 示例处理函数:生成图像
def generate_image(text, img1, img2, img3, height, width, guidance_scale, inference_steps, seed):
    input_images = [img1, img2, img3]
    # 去除 None
    input_images = [img for img in input_images if img is not None]
    if len(input_images) == 0:
        input_images = None
    output = pipe(
        prompt=text,
        input_images=input_images,
        height=height,
        width=width,
        guidance_scale=guidance_scale,
        img_guidance_scale=1.6,
        num_inference_steps=inference_steps,
        separate_cfg_infer=True,
        use_kv_cache=False,
        seed=seed,
    )
    img = output[0]
    return img
# def generate_image(text, img1, img2, img3, height, width, guidance_scale, inference_steps):
#     input_images = []
#     if img1:
#         input_images.append(Image.open(img1))
#     if img2:
#         input_images.append(Image.open(img2))
#     if img3:
#         input_images.append(Image.open(img3))
        
#     return input_images[0] if input_images else None
def get_example():
    case = [
        [
            "A woman holds a bouquet of flowers and faces the camera. Thw woman is the one in 
<|image_1|>.",
            "./imgs/test_cases/liuyifei.png",
            None,
            None,
            1024,
            1024,
            3.0,
            20,
            42,
        ],
        [
            "Three zebras are standing side by side on a vibrant savannah, each showcasing unique patterns and characteristics that highlight their individuality. The zebra on the left has a strikingly bold black and white stripe pattern, with wider stripes that create a dramatic contrast against its sleek body. In the middle, the zebra features a more subtle stripe arrangement, with thinner stripes that blend seamlessly into a slightly sandy-colored coat, giving it a softer appearance. On the right, the zebra's stripes are more irregular, with a distinct patch of brown fur near its shoulder, adding a layer of uniqueness to its overall look. Together, these zebras create a captivating scene, each representing the diverse beauty of their species in the wild. The right zebras is the zebras from 
<|image_1|>. The center zebras is from 
<|image_2|>. The left zebras is the zebras from 
<|image_3|>.",
            "./imgs/test_cases/img1.jpg",
            "./imgs/test_cases/img2.jpg",
            "./imgs/test_cases/img3.jpg",
            1024,
            1024,
            3.0,
            20,
            42,
        ],
    ]
    return case
def run_for_examples(text, img1, img2, img3, height, width, guidance_scale, inference_steps, seed):    
    return generate_image(text, img1, img2, img3, height, width, guidance_scale, inference_steps, seed)
# Gradio 接口
with gr.Blocks() as demo:
    gr.Markdown("# OmniGen: Unified Image Generation")
    with gr.Row():
        with gr.Column():
            # 文本输入框
            prompt_input = gr.Textbox(
                label="Enter your prompt, use 
<|image_i|> tokens for images", placeholder="Type your prompt here..."
            )
            with gr.Row(equal_height=True):
                # 图片上传框
                image_input_1 = gr.Image(label="
<|image_1|>", type="filepath")
                image_input_2 = gr.Image(label="
<|image_2|>", type="filepath")
                image_input_3 = gr.Image(label="
<|image_3|>", type="filepath")
            # 高度和宽度滑块
            height_input = gr.Slider(
                label="Height", minimum=256, maximum=2048, value=1024, step=16
            )
            width_input = gr.Slider(
                label="Width", minimum=256, maximum=2048, value=1024, step=16
            )
            # 引导尺度输入
            guidance_scale_input = gr.Slider(
                label="Guidance Scale", minimum=1.0, maximum=10.0, value=3.0, step=0.1
            )
            num_inference_steps = gr.Slider(
                label="Inference Steps", minimum=1, maximum=50, value=50, step=1
            )
            seed_input = gr.Slider(
                label="Seed", minimum=0, maximum=2147483647, value=42, step=1
            )
            # 生成按钮
            generate_button = gr.Button("Generate Image")
        with gr.Column():
            # 输出图像框
            output_image = gr.Image(label="Output Image")
    # 按钮点击事件
    generate_button.click(
        generate_image,
        inputs=[
            prompt_input,
            image_input_1,
            image_input_2,
            image_input_3,
            height_input,
            width_input,
            guidance_scale_input,
            num_inference_steps,
            seed_input,
        ],
        outputs=output_image,
    )
    gr.Examples(
        examples=get_example(),
        fn=run_for_examples,
        inputs=[
            prompt_input,
            image_input_1,
            image_input_2,
            image_input_3,
            height_input,
            width_input,
            guidance_scale_input,
            num_inference_steps,
            seed_input,
        ],
        outputs=output_image,
    )
# 启动应用
demo.launch()