from diffusers import DiffusionPipeline
import gradio as gr
import torch
import time
import psutil


start_time = time.time()

device = "GPU 🔥" if torch.cuda.is_available() else "CPU 🥶"


def error_str(error, title="Error"):
    return (
        f"""#### {title}
            {error}"""
        if error
        else ""
    )


def inference(
    repo_id,
    discuss_nr,
    prompt,
):

    print(psutil.virtual_memory())  # print memory usage

    seed = 0
    torch_device = "cuda" if "GPU" in device else "cpu"

    generator = torch.Generator(torch_device).manual_seed(seed)

    dtype = torch.float16 if torch_device == "cuda" else torch.float32

    try:
        revision = f"refs/pr/{discuss_nr}" if (discuss_nr != "" or discuss_nr is None) else None
        pipe = DiffusionPipeline.from_pretrained(repo_id, revision=revision, torch_dtype=dtype)
        pipe.to(torch_device)

        return pipe(prompt, generator=generator, num_inference_steps=25).images, f"Done. Seed: {seed}"
    except Exception as e:
        url = f"https://huggingface.co/{repo_id}/discussions/{discuss_nr}"
        message = f"There is a problem with your diffusers weights of the PR: {url}. Error message: \n"
        return None, error_str(message + e)


with gr.Blocks(css="style.css") as demo:
    gr.HTML(
        f"""
            <div class="diffusion">
              <p>
               Space to test whether `diffusers` PRs work.
              </p>
              <p>
               Running on <b>{device}</b>
              </p>
            </div>
        """
    )
    with gr.Row():

        with gr.Column(scale=55):
            with gr.Group():
                repo_id = gr.Textbox(
                    label="Repo id on Hub",
                    placeholder="Path to model, e.g. CompVis/stable-diffusion-v1-4 for https://huggingface.co/CompVis/stable-diffusion-v1-4",
                )
                discuss_nr = gr.Textbox(
                    label="Discussion number",
                    placeholder="Number of the discussion that should be checked, e.g. 171 for https://huggingface.co/CompVis/stable-diffusion-v1-4/discussions/171",
                )
                prompt = gr.Textbox(
                    label="Prompt",
                    default="An astronaut riding a horse on Mars.",
                    placeholder="Enter prompt.",
                )
                gallery = gr.Gallery(
                    label="Generated images", show_label=False, elem_id="gallery"
                ).style(grid=[2], height="auto")

            error_output = gr.Markdown()

            generate = gr.Button(value="Generate").style(
                rounded=(False, True, True, False)
            )

    inputs = [
        repo_id,
        discuss_nr,
        prompt,
    ]
    outputs = [gallery, error_output]
    prompt.submit(inference, inputs=inputs, outputs=outputs)
    generate.click(inference, inputs=inputs, outputs=outputs)

print(f"Space built in {time.time() - start_time:.2f} seconds")

demo.queue(concurrency_count=1)
demo.launch()