import numpy as np
import time
from pathlib import Path
import torch
import imageio

from my.utils import tqdm
from my.utils.seed import seed_everything

from run_img_sampling import SD, StableDiffusion
from misc import torch_samps_to_imgs
from pose import PoseConfig

from run_nerf import VoxConfig
from voxnerf.utils import every
from voxnerf.vis import stitch_vis, bad_vis as nerf_vis

from run_sjc import render_one_view, tsr_stats
from highres_final_vis import highres_render_one_view

import gradio as gr
import gc
import os

device_glb = torch.device("cuda")

def vis_routine(y, depth):
    pane = nerf_vis(y, depth, final_H=256)
    im = torch_samps_to_imgs(y)[0]
    depth = depth.cpu().numpy()
    return pane, im, depth

css = '''
    .instruction{position: absolute; top: 0;right: 0;margin-top: 0px !important}
    .arrow{position: absolute;top: 0;right: -110px;margin-top: -8px !important}
    #component-4, #component-3, #component-10{min-height: 0}
    .duplicate-button img{margin: 0}
'''

with gr.Blocks(css=css) as demo:
    # title
    gr.Markdown('# [Score Jacobian Chaining](https://github.com/pals-ttic/sjc): Lifting Pretrained 2D Diffusion Models for 3D Generation')

    gr.HTML(f'''
                <div class="gr-prose" style="max-width: 80%">
                <h2>Attention - This Space takes over 30min to run!</h2>
                <p>If the Queue is too long you can run locally or duplicate the Space and run it on your own profile using a (paid) private T4 GPU for training. As each T4 costs US$0.60/h, it should cost < US$1 to train most models using default settings!&nbsp;&nbsp;<a style='display:inline-block' href='https://huggingface.co/spaces/MirageML/sjc?duplicate=true'><img src='https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14' alt='Duplicate Space'></a></p>
                </div>
            ''')
    # inputs
    prompt = gr.Textbox(label="Prompt", max_lines=1, value="A high quality photo of a delicious burger")
    iters = gr.Slider(label="Iters", minimum=100, maximum=20000, value=10000, step=100)
    seed = gr.Slider(label="Seed", minimum=0, maximum=2147483647, step=1, randomize=True)
    button = gr.Button('Generate')

    # outputs
    image = gr.Image(label="image", visible=True)
    # depth = gr.Image(label="depth", visible=True)
    video = gr.Video(label="video", visible=False)
    logs = gr.Textbox(label="logging")

    def submit(prompt, iters, seed):
        start_t = time.time()
        seed_everything(seed)
        # cfgs = {'gddpm': {'model': 'm_lsun_256', 'lsun_cat': 'bedroom', 'imgnet_cat': -1}, 'sd': {'variant': 'v1', 'v2_highres': False, 'prompt': 'A high quality photo of a delicious burger', 'scale': 100.0, 'precision': 'autocast'}, 'lr': 0.05, 'n_steps': 10000, 'emptiness_scale': 10, 'emptiness_weight': 10000, 'emptiness_step': 0.5, 'emptiness_multiplier': 20.0, 'depth_weight': 0, 'var_red': True}
        pose = PoseConfig(rend_hw=64, FoV=60.0, R=1.5)
        poser = pose.make()
        sd_model = SD(variant='v1', v2_highres=False, prompt=prompt, scale=100.0, precision='autocast')
        model = sd_model.make()
        vox  = VoxConfig(
                model_type="V_SD", grid_size=100, density_shift=-1.0, c=4,
                blend_bg_texture=True, bg_texture_hw=4,
                bbox_len=1.0)
        vox = vox.make()

        lr = 0.05
        n_steps = iters
        emptiness_scale = 10
        emptiness_weight = 10000
        emptiness_step = 0.5
        emptiness_multiplier = 20.0
        depth_weight = 0
        var_red = True

        assert model.samps_centered()
        _, target_H, target_W = model.data_shape()
        bs = 1
        aabb = vox.aabb.T.cpu().numpy()
        vox = vox.to(device_glb)
        opt = torch.optim.Adamax(vox.opt_params(), lr=lr)

        H, W = poser.H, poser.W
        Ks, poses, prompt_prefixes = poser.sample_train(n_steps)

        ts = model.us[30:-10]

        same_noise = torch.randn(1, 4, H, W, device=model.device).repeat(bs, 1, 1, 1)

        with tqdm(total=n_steps) as pbar:
            for i in range(n_steps):

                p = f"{prompt_prefixes[i]} {model.prompt}"
                score_conds = model.prompts_emb([p])

                y, depth, ws = render_one_view(vox, aabb, H, W, Ks[i], poses[i], return_w=True)

                if isinstance(model, StableDiffusion):
                    pass
                else:
                    y = torch.nn.functional.interpolate(y, (target_H, target_W), mode='bilinear')

                opt.zero_grad()

                with torch.no_grad():
                    chosen_σs = np.random.choice(ts, bs, replace=False)
                    chosen_σs = chosen_σs.reshape(-1, 1, 1, 1)
                    chosen_σs = torch.as_tensor(chosen_σs, device=model.device, dtype=torch.float32)
                    # chosen_σs = us[i]

                    noise = torch.randn(bs, *y.shape[1:], device=model.device)

                    zs = y + chosen_σs * noise
                    Ds = model.denoise(zs, chosen_σs, **score_conds)

                    if var_red:
                        grad = (Ds - y) / chosen_σs
                    else:
                        grad = (Ds - zs) / chosen_σs

                    grad = grad.mean(0, keepdim=True)

                y.backward(-grad, retain_graph=True)

                if depth_weight > 0:
                    center_depth = depth[7:-7, 7:-7]
                    border_depth_mean = (depth.sum() - center_depth.sum()) / (64*64-50*50)
                    center_depth_mean = center_depth.mean()
                    depth_diff = center_depth_mean - border_depth_mean
                    depth_loss = - torch.log(depth_diff + 1e-12)
                    depth_loss = depth_weight * depth_loss
                    depth_loss.backward(retain_graph=True)

                emptiness_loss = torch.log(1 + emptiness_scale * ws).mean()
                emptiness_loss = emptiness_weight * emptiness_loss
                if emptiness_step * n_steps <= i:
                    emptiness_loss *= emptiness_multiplier
                emptiness_loss.backward()

                opt.step()


                # metric.put_scalars()

                with torch.no_grad():
                    if isinstance(model, StableDiffusion):
                        y = model.decode(y)
                    pane, img, depth = vis_routine(y, depth)
                    yield {
                        image: gr.update(value=img, visible=True),
                        video: gr.update(visible=False),
                        logs: f"Steps: {i}/{n_steps}: \n" + str(tsr_stats(y)),
                    }

                # TODO: Output pane, img and depth to Gradio

                pbar.update()
                pbar.set_description(p)

            # TODO: Save Checkpoint
            with torch.no_grad():
                n_frames=200
                factor=4
                ckpt = vox.state_dict()
                H, W = poser.H, poser.W
                vox.eval()
                K, poses = poser.sample_test(n_frames)
                del n_frames
                poses = poses[60:]  # skip the full overhead view; not interesting

                aabb = vox.aabb.T.cpu().numpy()
                vox = vox.to(device_glb)

                num_imgs = len(poses)
                all_images = []

                for i in (pbar := tqdm(range(num_imgs))):

                    pose = poses[i]
                    y, depth = highres_render_one_view(vox, aabb, H, W, K, pose, f=factor)
                    if isinstance(model, StableDiffusion):
                        y = model.decode(y)
                    pane, img, depth = vis_routine(y, depth)

                    # Save img to output
                    all_images.append(img)

                    yield {
                        image: gr.update(value=img, visible=True),
                        video: gr.update(visible=False),
                        logs: str(tsr_stats(y)),
                    }

                output_video = "/tmp/tmp.mp4"

                imageio.mimwrite(output_video, all_images, quality=8, fps=10)

                end_t = time.time()

                yield {
                    image: gr.update(value=img, visible=False),
                    video: gr.update(value=output_video, visible=True),
                    logs: f"Generation Finished in {(end_t - start_t)/ 60:.4f} minutes!",
                }

    button.click(
            submit,
            [prompt, iters, seed],
            [image, video, logs]
    )

# concurrency_count: only allow ONE running progress, else GPU will OOM.
demo.queue(concurrency_count=1)
demo.launch()