import gradio
import subprocess
from PIL import Image
import torch, torch.backends.cudnn, torch.backends.cuda
from min_dalle import MinDalle
from emoji import demojize
import string


def filename_from_text(text: str) -> str:
    text = demojize(text, delimiters=['', ''])
    text = text.lower().encode('ascii', errors='ignore').decode()
    allowed_chars = string.ascii_lowercase + ' '
    text = ''.join(i for i in text.lower() if i in allowed_chars)
    text = text[:64]
    text = '-'.join(text.strip().split())
    if len(text) == 0: text = 'blank'
    return text

def log_gpu_memory():
    print(subprocess.check_output('nvidia-smi').decode('utf-8'))

# log_gpu_memory()

model = MinDalle(
    is_mega=False, 
    is_reusable=True,
    device='cpu',
    # dtype=torch.float32
)

# log_gpu_memory()

def run_model(
    text: str,
    grid_size: int,
    is_seamless: bool,
    save_as_png: bool,
    temperature: float,
    supercondition: str,
    top_k: str
) -> str:
    torch.set_grad_enabled(False)
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True

    print('text:', text)
    print('grid_size:', grid_size) 
    print('is_seamless:', is_seamless)
    print('temperature:', temperature)
    print('supercondition:', supercondition)
    print('top_k:', top_k)
    
    try:
        temperature = float(temperature)
        assert(temperature > 1e-6)
    except:
        raise Exception('Temperature must be a positive nonzero number')
    try:
        grid_size = int(grid_size)
        assert(grid_size <= 5)
        assert(grid_size >= 1)
    except:
        raise Exception('Grid size must be between 1 and 5')
    try:
        top_k = int(top_k)
        assert(top_k <= 16384)
        assert(top_k >= 1)
    except:
        raise Exception('Top k must be between 1 and 16384')

    with torch.no_grad():
        image = model.generate_image(
            text = text,
            seed = -1,
            grid_size = grid_size,
            is_seamless = bool(is_seamless),
            temperature = temperature,
            supercondition_factor = float(supercondition),
            top_k = top_k,
            is_verbose = True
        )

    log_gpu_memory()

    ext = 'png' if bool(save_as_png) else 'jpg'
    filename = filename_from_text(text)
    image_path = '{}.{}'.format(filename, ext)
    image.save(image_path)

    return image_path

demo = gradio.Blocks(analytics_enabled=True)

with demo:
    with gradio.Row():
        with gradio.Column():
            input_text = gradio.Textbox(
                label='Input Text', 
                value='Portrait of a basset hound, 8k, photograph',
                lines=3
            )
            run_button = gradio.Button(value='Generate Image').style(full_width=True)
            output_image = gradio.Image(
                value='8k dog.png',
                label='Output Image',
                type='file',
                interactive=False
            )

        with gradio.Column():
            gradio.Markdown('## Settings')
            with gradio.Row():
                grid_size = gradio.Slider(
                    label='Grid Size',
                    value=3,
                    minimum=1, 
                    maximum=5,
                    step=1
                )
                save_as_png = gradio.Checkbox(
                    label='Output PNG',
                    value=False
                )
                is_seamless = gradio.Checkbox(
                    label='Seamless',
                    value=False
                )
            gradio.Markdown('#### Advanced')
            with gradio.Row():
                temperature = gradio.Number(
                    label='Temperature',
                    value=1
                )
                top_k = gradio.Dropdown(
                    label='Top-k',
                    choices=[str(2 ** i) for i in range(15)],
                    value='128'
                )
                supercondition = gradio.Dropdown(
                    label='Super Condition',
                    choices=[str(2 ** i) for i in range(2, 7)],
                    value='16'
                )

            gradio.Markdown(
                """
                #### Parameter
                - **Input Text**: For long prompts, only the first 64 text tokens will be used to generate the image.
                - **Grid Size**: Size of the image grid. 3x3 takes about 15 seconds.
                - **Seamless**: Tile images in image token space instead of pixel space.
                - **Temperature**: High temperature increases the probability of sampling low scoring image tokens.
                - **Top-k**: Each image token is sampled from the top-k scoring tokens.
                - **Super Condition**: Higher values can result in better agreement with the text.

                #### 
                """
            )

    gradio.Examples(
        examples=[
            ['Portrait of a basset hound, 8k, photograph', 3, '8k dog.png'],
            ['A diorama of Puppy cloud ,8k, photograph', 3, 'puppy.png'],
            ['A dragon that looks like a cream', 3, 'cream.png'],
            ['A photo of a sleeping orange tabby cat', 3, 'tabby.png'],
            ['A diorama of a bunny family sitting around the table having dinner ,8k, photograph', 3, 'table.png'],
            ['A white cat with golden sunglasses on, pink background, studio lighting, 4k, award winning photography', 2, 'cat.png'],
            ['an astronaut dancing on the moon’s surface, close-up photo', 2, 'astronaut.png'],
            ['A photo of a Samoyed dog with its tongue out hugging a white Siamese cat', 5, 'dog.png'],
            ['Dragons of Earth, Wind, Fire, powering up a huge sphere of compressed energy, digital art', 2, 'dragon.png'],
            ['A snowboarder jumping in the air while coming down a ski mountain, concept art, artstation, unreal engine, 3d render, HD, Bokeh', 3, 'snow.png'],
            ['Antique photo of  a  dragon fire', 3, 'fire.png'],
            ['A space parrot flying through the cosmos, digital art', 3, 'parrot.png'],
        ],
        inputs=[
            input_text,
            grid_size,
            output_image
        ],
        examples_per_page=20
    )

    run_button.click(
        fn=run_model, 
        inputs=[
            input_text,
            grid_size,
            is_seamless,
            save_as_png,
            temperature,
            supercondition,
            top_k
        ], 
        outputs=[
            output_image
        ]
    )


demo.launch()