import gradio as gr
import io
from PIL import Image
import numpy as np

# from config import setResoluton
from models import make_image_controlnet, make_inpainting
from preprocessing import get_mask

def image_to_byte_array(image: Image) -> bytes:
    # BytesIO is a fake file stored in memory
    imgByteArr = io.BytesIO()
    # image.save expects a file as a argument, passing a bytes io ins
    image.save(imgByteArr, format='png')  # image.format
    # Turn the BytesIO object back into a bytes object
    imgByteArr = imgByteArr.getvalue()
    return imgByteArr

def predict(input_img1,
            input_img2,
            positive_prompt,
            negative_prompt,
            num_of_images,
            resolution
            ):



    print("predict")
    # bla bla
    # input_img1 = Image.fromarray(input_img1)
    # input_img2 = Image.fromarray(input_img2)

    # setResoluton(resolution)
    HEIGHT = resolution
    WIDTH = resolution
    # WIDTH = resolution
    # HEIGHT = resolution

    input_img1 = input_img1.resize((resolution, resolution))
    input_img2 = input_img2.resize((resolution, resolution))

    canvas_mask = np.array(input_img2)
    mask = get_mask(canvas_mask)

    print(input_img1, mask, positive_prompt, negative_prompt)

    retList=  make_inpainting(positive_prompt=positive_prompt,
                               image=input_img1,
                               mask_image=mask,
                               negative_prompt=negative_prompt,
                               num_of_images=num_of_images,
                               resolution=resolution
                               )
    # add the rest up to 10
    while (len(retList)<10):
        retList.append(None)

    return retList


app = gr.Interface(
    predict,
    inputs=[gr.Image(label="img", sources=['upload'], type="pil"),
            gr.Image(label="mask", sources=['upload'], type="pil"),
            gr.Textbox(label="positive_prompt",value="empty room"),
            gr.Textbox(label="negative_prompt",value=""),
            gr.Number(label="num_of_images",value=2),
            gr.Number(label="resolution",value=512)
            ],
    outputs= [
        gr.Image(label="resp0"),
        gr.Image(label="resp1"),
        gr.Image(label="resp2"),
        gr.Image(label="resp3"),
        gr.Image(label="resp4"),
        gr.Image(label="resp5"),
        gr.Image(label="resp6"),
        gr.Image(label="resp7"),
        gr.Image(label="resp8"),
        gr.Image(label="resp9")],
    title="rem fur 1",
)

app.launch(share=True)

#


# gr.Interface(
#     test1,
#     inputs=[gr.Textbox(label="param1")],
#     outputs= gr.Textbox(label="result"),
#     title="rem fur 1",
# ).launch(share=True)