################################################################################
# Copyright (C) 2023 Jiayi Guo, Xingqian Xu, Manushree Vasu - All Rights Reserved                         #
################################################################################

import gradio as gr
import os
import os.path as osp
import PIL
from PIL import Image
import numpy as np
from collections import OrderedDict
from easydict import EasyDict as edict
from functools import partial

import torch
import torchvision.transforms as tvtrans
import time
import argparse
import json
import hashlib
import copy
from tqdm import tqdm

from diffusers import StableDiffusionPipeline
from diffusers import DDIMScheduler
from app_utils import auto_dropdown

from huggingface_hub import hf_hub_download
import spaces

version = "Smooth Diffusion Demo v1.0"
refresh_symbol = "\U0001f504" # 🔄
recycle_symbol = '\U0000267b' #

##############
# model_book #
##############

choices = edict()
choices.diffuser = OrderedDict([
    ['SD-v1-5' , "runwayml/stable-diffusion-v1-5"],
    ['OJ-v4' , "prompthero/openjourney-v4"],
    ['RR-v2', "SG161222/Realistic_Vision_V2.0"],
])

choices.lora = OrderedDict([
    ['empty', ""],
    ['Smooth-LoRA-v1', hf_hub_download('shi-labs/smooth-diffusion-lora', 'smooth_lora.safetensors')],
])

choices.scheduler = OrderedDict([
    ['DDIM', DDIMScheduler],
])

choices.inversion = OrderedDict([
    ['NTI', 'NTI'],
    ['DDIM w/o text', 'DDIM w/o text'],
    ['DDIM', 'DDIM'], 
])

default = edict()
default.diffuser = 'RR-v2'
default.scheduler = 'DDIM'
default.lora = 'Smooth-LoRA-v1'
default.inversion = 'NTI'
default.step = 50
default.cfg_scale = 7.5
default.framen = 24
default.fps = 16
default.nullinv_inner_step = 10
default.threshold = 0.8
default.variation = 0.8

##########
# helper #
##########

def lerp(t, v0, v1):
    if isinstance(t, float):
        return v0*(1-t) + v1*t
    elif isinstance(t, (list, np.ndarray)):
        return [v0*(1-ti) + v1*ti for ti in t]

def slerp(t, v0, v1, DOT_THRESHOLD=0.9995):
    # mostly copied from
    # https://gist.github.com/dvschultz/3af50c40df002da3b751efab1daddf2c
    v0_unit = v0 / np.linalg.norm(v0)
    v1_unit = v1 / np.linalg.norm(v1)
    dot = np.sum(v0_unit * v1_unit)
    if np.abs(dot) > DOT_THRESHOLD:
        return lerp(t, v0, v1)
    # Calculate initial angle between v0 and v1
    theta_0 = np.arccos(dot)
    sin_theta_0 = np.sin(theta_0)
    # Angle at timestep t

    if isinstance(t, float):
        tlist = [t]
    elif isinstance(t, (list, np.ndarray)):
        tlist = t

    v2_list = []

    for ti in tlist:
        theta_t = theta_0 * ti
        sin_theta_t = np.sin(theta_t)
        # Finish the slerp algorithm
        s0 = np.sin(theta_0 - theta_t) / sin_theta_0
        s1 = sin_theta_t / sin_theta_0
        v2 = s0 * v0 + s1 * v1
        v2_list.append(v2)

    if isinstance(t, float):
        return v2_list[0]
    else:
        return v2_list

def offset_resize(image, width=512, height=512, left=0, right=0, top=0, bottom=0):
   
    image = np.array(image)[:, :, :3]
    h, w, c = image.shape
    left = min(left, w-1)
    right = min(right, w - left - 1)
    top = min(top, h - left - 1)
    bottom = min(bottom, h - top - 1)
    image = image[top:h-bottom, left:w-right]
    h, w, c = image.shape
    if h < w:
        offset = (w - h) // 2
        image = image[:, offset:offset + h]
    elif w < h:
        offset = (h - w) // 2
        image = image[offset:offset + w]
    image = Image.fromarray(image).resize((width, height))
    return image

def auto_dtype_device_shape(tlist, v0, v1, func,):
    vshape = v0.shape
    assert v0.shape == v1.shape
    assert isinstance(tlist, (list, np.ndarray))
    
    if isinstance(v0, torch.Tensor):
        is_torch = True
        dtype, device = v0.dtype, v0.device
        v0 = v0.to('cpu').numpy().astype(float).flatten()
        v1 = v1.to('cpu').numpy().astype(float).flatten()
    else:
        is_torch = False
        dtype = v0.dtype
        assert isinstance(v0, np.ndarray)
        assert isinstance(v1, np.ndarray)
        v0 = v0.astype(float).flatten()
        v1 = v1.astype(float).flatten()

    r = func(tlist, v0, v1)

    if is_torch:
        r = [torch.Tensor(ri).view(*vshape).to(dtype).to(device) for ri in r]
    else:
        r = [ri.astype(dtype) for ri in r]
    return r

auto_lerp = partial(auto_dtype_device_shape, func=lerp)
auto_slerp = partial(auto_dtype_device_shape, func=slerp)

def frames2mp4(vpath, frames, fps):
    import moviepy.editor as mpy
    frames = [np.array(framei) for framei in frames]
    clip = mpy.ImageSequenceClip(frames, fps=fps)
    clip.write_videofile(vpath, fps=fps)

def negseed_to_rndseed(seed):
    if seed < 0:
        seed = np.random.randint(0, np.iinfo(np.uint32).max-100)
    return seed

def regulate_image(pilim):
    w, h = pilim.size
    w = int(round(w/64)) * 64
    h = int(round(h/64)) * 64
    return pilim.resize([w, h], resample=PIL.Image.BILINEAR)

def txt_to_emb(model, prompt):
    text_input = model.tokenizer(
        prompt,
        padding="max_length",
        max_length=model.tokenizer.model_max_length,
        truncation=True,
        return_tensors="pt",)
    text_embeddings = model.text_encoder(text_input.input_ids.to(model.device))[0]
    return text_embeddings

def hash_pilim(pilim):
    hasha = hashlib.md5(pilim.tobytes()).hexdigest()
    return hasha

def hash_cfgdict(cfgdict):
    hashb = hashlib.md5(json.dumps(cfgdict, sort_keys=True).encode('utf-8')).hexdigest()
    return hashb

def remove_earliest_file(path, max_allowance=500, remove_ratio=0.1, ext=None):
    if len(os.listdir(path)) <= max_allowance:
        return
    def get_mtime(fname):
        return osp.getmtime(osp.join(path, fname))
    if ext is None:
        flist = sorted(os.listdir(path), key=get_mtime)
    else:
        flist = [fi for fi in os.listdir(path) if fi.endswith(ext)]
        flist = sorted(flist, key=get_mtime)
    exceedn = max(len(flist)-max_allowance, 0)
    removen = int(max_allowance*remove_ratio)
    removen = max(1, removen) + exceedn
    for fi in flist[0:removen]:
        os.remove(osp.join(path, fi))

def remove_decoupled_file(path, exta='.mp4', extb='.json'):
    tag_a = [osp.splitext(fi)[0] for fi in os.listdir(path) if fi.endswith(exta)]
    tag_b = [osp.splitext(fi)[0] for fi in os.listdir(path) if fi.endswith(extb)]
    tag_a_extra = set(tag_a) - set(tag_b)
    tag_b_extra = set(tag_b) - set(tag_a)
    [os.remove(osp.join(path, tagi+exta)) for tagi in tag_a_extra]
    [os.remove(osp.join(path, tagi+extb)) for tagi in tag_b_extra]

@spaces.GPU()
@torch.no_grad()
def t2i_core(model, xt, emb, nemb, step=30, cfg_scale=7.5, return_list=False):
    from nulltxtinv_wrapper import diffusion_step, latent2image
    model.scheduler.set_timesteps(step)
    xi = xt
    emb = txt_to_emb(model, "") if emb is None else emb
    nemb = txt_to_emb(model, "") if nemb is None else nemb
    if return_list:
        xi_list = [xi.clone()]
    for i, t in enumerate(tqdm(model.scheduler.timesteps)):
        embi = emb[i] if isinstance(emb, list) else emb
        nembi = nemb[i] if isinstance(nemb, list) else nemb
        context = torch.cat([nembi, embi])
        xi = diffusion_step(model, xi, context, t, cfg_scale, low_resource=False)
        if return_list:
            xi_list.append(xi.clone())
    x0 = xi
    im = latent2image(model.vae, x0, return_type='pil')

    if return_list:
        return im, xi_list
    else:
        return im

########
# main #
########

class wrapper(object):
    
    def __init__(self, 
                 fp16=False, 
                 tag_diffuser=None, 
                 tag_lora=None,
                 tag_scheduler=None,):

        self.device = "cuda" #if torch.cuda.is_available() else "cpu"
        if fp16:
            self.torch_dtype = torch.float16
        else:
            self.torch_dtype = torch.float32
        self.load_all(tag_diffuser, tag_lora, tag_scheduler)

        self.image_latent_dim = 4
        self.batchsize = 8
        self.seed = {}

        self.cache_video_folder = "temp/video"
        self.cache_video_maxn = 500
        self.cache_image_folder = "temp/image"
        self.cache_image_maxn = 500
        self.cache_inverse_folder = "temp/inverse"
        self.cache_inverse_maxn = 500

    
    def load_all(self, tag_diffuser, tag_lora, tag_scheduler):
        self.load_diffuser_lora(tag_diffuser, tag_lora)
        self.load_scheduler(tag_scheduler)
        return tag_diffuser, tag_lora, tag_scheduler

    
    def load_diffuser_lora(self, tag_diffuser, tag_lora):
        self.net = StableDiffusionPipeline.from_pretrained(
            choices.diffuser[tag_diffuser], torch_dtype=self.torch_dtype).to(self.device)
        self.net.safety_checker = None
        if tag_lora != 'empty':
            self.net.unet.load_attn_procs(
                choices.lora[tag_lora], use_safetensors=True,)
        self.tag_diffuser = tag_diffuser
        self.tag_lora = tag_lora
        return tag_diffuser, tag_lora

    
    def load_scheduler(self, tag_scheduler):
        self.net.scheduler = choices.scheduler[tag_scheduler].from_config(self.net.scheduler.config)
        self.tag_scheduler = tag_scheduler
        return tag_scheduler

    def reset_seed(self, which='ltintp'):
        return -1

    def recycle_seed(self, which='ltintp'):
        if which not in self.seed:
            return self.reset_seed(which=which)
        else:
            return self.seed[which]

    ##########
    # helper #
    ##########
    
    def precheck_model(self, tag_diffuser, tag_lora, tag_scheduler):
        if (tag_diffuser != self.tag_diffuser) or (tag_lora != self.tag_lora):
            self.load_all(tag_diffuser, tag_lora, tag_scheduler)
        if tag_scheduler != self.tag_scheduler:
            self.load_scheduler(tag_scheduler)

    ########
    # main #
    ########

    @spaces.GPU()
    def ddiminv(self, img, cfgdict):
        txt, step, cfg_scale = cfgdict['txt'], cfgdict['step'], cfgdict['cfg_scale']
        from nulltxtinv_wrapper import NullInversion
        null_inversion_model = NullInversion(self.net, step, cfg_scale)
        with torch.no_grad():
            emb = txt_to_emb(self.net, txt)
            nemb = txt_to_emb(self.net, "")
        xt = null_inversion_model.ddim_invert(img, txt)
        data = {
            'step' : step, 'cfg_scale' : cfg_scale, 'txt' : txt,
            'diffuser' : self.tag_diffuser, 'lora' : self.tag_lora,
            'xt': xt, 'emb': emb, 'nemb': nemb,}
        return data

    @spaces.GPU()
    def nullinv_or_loadcache(self, img, cfgdict, force_reinvert=False):
        hash = hash_pilim(img) + "--" + hash_cfgdict(cfgdict)
        cdir = self.cache_inverse_folder
        cfname = osp.join(cdir, hash+'.pth')

        if osp.isfile(cfname) and (not force_reinvert):
            cache_data = torch.load(cfname)
            dtype = next(self.net.unet.parameters()).dtype
            device = next(self.net.unet.parameters()).device
            cache_data['xt'] = cache_data['xt'].to(device=device, dtype=dtype)
            cache_data['emb'] = cache_data['emb'].to(device=device, dtype=dtype)
            cache_data['nemb'] = [
                nembi.to(device=device, dtype=dtype)
                    for nembi in cache_data['nemb']]
            return cache_data
        else:
            txt, step, cfg_scale = cfgdict['txt'], cfgdict['step'], cfgdict['cfg_scale']
            inner_step = cfgdict['inner_step']
            from nulltxtinv_wrapper import NullInversion
            null_inversion_model = NullInversion(self.net, step, cfg_scale)
            with torch.no_grad():
                emb = txt_to_emb(self.net, txt)
            xt, nemb = null_inversion_model.null_invert(img, txt, num_inner_steps=inner_step)
            cache_data = {
                'step' : step, 'cfg_scale' : cfg_scale, 'txt' : txt,
                'inner_step' : inner_step,
                'diffuser' : self.tag_diffuser, 'lora' : self.tag_lora,
                'xt' : xt.to('cpu'),
                'emb' : emb.to('cpu'),
                'nemb' : [nembi.to('cpu') for nembi in nemb],}
            os.makedirs(cdir, exist_ok=True)
            remove_earliest_file(cdir, max_allowance=self.cache_inverse_maxn)
            torch.save(cache_data, cfname)
            data = {
                'step' : step, 'cfg_scale' : cfg_scale, 'txt' : txt,
                'inner_step' : inner_step,
                'diffuser' : self.tag_diffuser, 'lora' : self.tag_lora,
                'xt' : xt, 'emb' : emb, 'nemb' : nemb,}
            return data

    @spaces.GPU()
    def nullinvdual_or_loadcachedual(self, img0, img1, cfgdict, force_reinvert=False):
        hash = hash_pilim(img0) + "--" + hash_pilim(img1) + "--" + hash_cfgdict(cfgdict)
        cdir = self.cache_inverse_folder
        cfname = osp.join(cdir, hash+'.pth')

        if osp.isfile(cfname) and (not force_reinvert):
            cache_data = torch.load(cfname)
            dtype = next(self.net.unet.parameters()).dtype
            device = next(self.net.unet.parameters()).device
            cache_data['xt0'] = cache_data['xt0'].to(device=device, dtype=dtype)
            cache_data['xt1'] = cache_data['xt1'].to(device=device, dtype=dtype)
            cache_data['emb0'] = cache_data['emb0'].to(device=device, dtype=dtype)
            cache_data['emb1'] = cache_data['emb1'].to(device=device, dtype=dtype)
            cache_data['nemb'] = [
                nembi.to(device=device, dtype=dtype)
                    for nembi in cache_data['nemb']]

            cache_data_a = copy.deepcopy(cache_data)
            cache_data_a['xt'] = cache_data_a.pop('xt0')
            cache_data_a['emb'] = cache_data_a.pop('emb0')
            cache_data_a.pop('xt1'); cache_data_a.pop('emb1')

            cache_data_b = cache_data
            cache_data_b['xt'] = cache_data_b.pop('xt1')
            cache_data_b['emb'] = cache_data_b.pop('emb1')
            cache_data_b.pop('xt0'); cache_data_b.pop('emb0')

            return cache_data_a, cache_data_b
        else:
            txt0, txt1, step, cfg_scale, inner_step = \
                cfgdict['txt0'], cfgdict['txt1'], cfgdict['step'], \
                cfgdict['cfg_scale'], cfgdict['inner_step']
            
            from nulltxtinv_wrapper import NullInversion
            null_inversion_model = NullInversion(self.net, step, cfg_scale)
            with torch.no_grad():
                emb0 = txt_to_emb(self.net, txt0)
                emb1 = txt_to_emb(self.net, txt1)
            
            xt0, xt1, nemb = null_inversion_model.null_invert_dual(
                img0, img1, txt0, txt1, num_inner_steps=inner_step)
            cache_data = {
                'step' : step, 'cfg_scale' : cfg_scale, 
                'txt0' : txt0, 'txt1' : txt1,
                'inner_step' : inner_step,
                'diffuser' : self.tag_diffuser, 'lora' : self.tag_lora,
                'xt0' : xt0.to('cpu'), 'xt1' : xt1.to('cpu'),
                'emb0' : emb0.to('cpu'), 'emb1' : emb1.to('cpu'),
                'nemb' : [nembi.to('cpu') for nembi in nemb],}
            os.makedirs(cdir, exist_ok=True)
            remove_earliest_file(cdir, max_allowance=self.cache_inverse_maxn)
            torch.save(cache_data, cfname)
            data0 = {
                'step' : step, 'cfg_scale' : cfg_scale, 'txt' : txt0,
                'inner_step' : inner_step,
                'diffuser' : self.tag_diffuser, 'lora' : self.tag_lora,
                'xt' : xt0, 'emb' : emb0, 'nemb' : nemb,}
            data1 = {
                'step' : step, 'cfg_scale' : cfg_scale, 'txt' : txt1,
                'inner_step' : inner_step,
                'diffuser' : self.tag_diffuser, 'lora' : self.tag_lora,
                'xt' : xt1, 'emb' : emb1, 'nemb' : nemb,}
            return data0, data1

    @spaces.GPU()
    def image_inversion(
            self, img, txt, 
            cfg_scale, step, 
            inversion, inner_step, force_reinvert):
        from nulltxtinv_wrapper import text2image_ldm
        if inversion == 'DDIM w/o text':
            txt = ''
        if not inversion == 'NTI':
            data = self.ddiminv(img, {'txt':txt, 'step':step, 'cfg_scale':cfg_scale,})
        else:
            data = self.nullinv_or_loadcache(
                img, {'txt':txt, 'step':step,
                      'cfg_scale':cfg_scale, 'inner_step':inner_step,
                      'diffuser' : self.tag_diffuser, 'lora' : self.tag_lora,}, force_reinvert)
        
        if inversion == 'NTI':
            img_inv, _ = text2image_ldm(
                self.net, [txt], step, cfg_scale, 
                latent=data['xt'], uncond_embeddings=data['nemb'])
        else:
            img_inv, _ = text2image_ldm(
            self.net, [txt], step, cfg_scale,
            latent=data['xt'], uncond_embeddings=None)
            
        return img_inv

    @spaces.GPU()
    def image_editing(
        self, img, txt_0, txt_1,
        cfg_scale, step, thresh, 
        inversion, inner_step, force_reinvert):
        from nulltxtinv_wrapper import text2image_ldm_imedit
        if inversion == 'DDIM w/o text':
            txt_0 = ''
        if not inversion == 'NTI':
            data = self.ddiminv(img, {'txt':txt_0, 'step':step, 'cfg_scale':cfg_scale,})
            img_edited, _ = text2image_ldm_imedit(
                self.net, thresh, [txt_0], [txt_1], step, cfg_scale,
                latent=data['xt'], uncond_embeddings=None)
        else:
            data = self.nullinv_or_loadcache(
                img, {'txt':txt_0, 'step':step,
                      'cfg_scale':cfg_scale, 'inner_step':inner_step,
                      'diffuser' : self.tag_diffuser, 'lora' : self.tag_lora,}, force_reinvert)
            img_edited, _ = text2image_ldm_imedit(
                self.net, thresh, [txt_0], [txt_1], step, cfg_scale,
                latent=data['xt'], uncond_embeddings=data['nemb'])
        
        return img_edited

    @spaces.GPU()
    def general_interpolation(
            self, xset0, xset1,
            cfg_scale, step, tlist,):

        xt0, emb0, nemb0 = xset0['xt'], xset0['emb'], xset0['nemb']
        xt1, emb1, nemb1 = xset1['xt'], xset1['emb'], xset1['nemb']
        framen = len(tlist)

        xt_list = auto_slerp(tlist, xt0, xt1)
        emb_list = auto_lerp(tlist, emb0, emb1)
        
        if isinstance(nemb0, list) and isinstance(nemb1, list):
            assert len(nemb0) == len(nemb1)
            nemb_list = [auto_lerp(tlist, e0, e1) for e0, e1 in zip(nemb0, nemb1)]
            nemb_islist = True
        else:
            nemb_list = auto_lerp(tlist, nemb0, nemb1)
            nemb_islist = False

        im_list = []
        for frameidx in range(0, len(xt_list), self.batchsize):
            xt_batch = [xt_list[idx] for idx in range(frameidx, min(frameidx+self.batchsize, framen))]
            xt_batch = torch.cat(xt_batch, dim=0)
            emb_batch = [emb_list[idx] for idx in range(frameidx, min(frameidx+self.batchsize, framen))]
            emb_batch = torch.cat(emb_batch, dim=0)
            if nemb_islist:
                nemb_batch = []
                for nembi in nemb_list:
                    nembi_batch = [nembi[idx] for idx in range(frameidx, min(frameidx+self.batchsize, framen))]
                    nembi_batch = torch.cat(nembi_batch, dim=0)
                    nemb_batch.append(nembi_batch)
            else:
                nemb_batch = [nemb_list[idx] for idx in range(frameidx, min(frameidx+self.batchsize, framen))]
                nemb_batch = torch.cat(nemb_batch, dim=0)

            im = t2i_core(
                self.net, xt_batch, emb_batch, nemb_batch, step, cfg_scale)
            im_list += im if isinstance(im, list) else [im]

        return im_list
    
    @spaces.GPU()
    def run_iminvs(
            self, img, text, 
            cfg_scale, step, 
            force_resize, width, height,
            inversion, inner_step, force_reinvert,
            tag_diffuser, tag_lora, tag_scheduler, ):
        
        self.precheck_model(tag_diffuser, tag_lora, tag_scheduler)
        
        if force_resize:
            img = offset_resize(img, width, height)
        else:
            img = regulate_image(img)

        recon_output = self.image_inversion(
            img, text, cfg_scale, step, 
            inversion, inner_step, force_reinvert)

        idir = self.cache_image_folder
        os.makedirs(idir, exist_ok=True)
        remove_earliest_file(idir, max_allowance=self.cache_image_maxn)
        sname = "time{}_iminvs_{}_{}".format(
            int(time.time()), self.tag_diffuser, self.tag_lora,)
        ipath = osp.join(idir, sname+'.png')
        recon_output.save(ipath)
        
        return [recon_output]
    
    @spaces.GPU()
    def run_imedit(
            self, img, txt_0,txt_1, 
            threshold, cfg_scale, step, 
            force_resize, width, height,
            inversion, inner_step, force_reinvert,
            tag_diffuser, tag_lora, tag_scheduler, ):
        
        self.precheck_model(tag_diffuser, tag_lora, tag_scheduler)
        if force_resize:
            img = offset_resize(img, width, height)
        else:
            img = regulate_image(img)

        edited_img= self.image_editing(
            img, txt_0,txt_1, cfg_scale, step, threshold,
            inversion, inner_step, force_reinvert)

        idir = self.cache_image_folder
        os.makedirs(idir, exist_ok=True)
        remove_earliest_file(idir, max_allowance=self.cache_image_maxn)
        sname = "time{}_imedit_{}_{}".format(
            int(time.time()), self.tag_diffuser, self.tag_lora,)
        ipath = osp.join(idir, sname+'.png')
        edited_img.save(ipath)
     
        return [edited_img]

    @spaces.GPU()
    def run_imintp(
            self, 
            img0, img1, txt0, txt1,
            cfg_scale, step, 
            framen, fps, 
            force_resize, width, height,
            inversion, inner_step, force_reinvert,
            tag_diffuser, tag_lora, tag_scheduler,):
        
        self.precheck_model(tag_diffuser, tag_lora, tag_scheduler)
        if txt1 == '':
            txt1 = txt0
        if force_resize:
            img0 = offset_resize(img0, width, height)
            img1 = offset_resize(img1, width, height)
        else:
            img0 = regulate_image(img0)
            img1 = regulate_image(img1)

        if inversion == 'DDIM':
            data0 = self.ddiminv(img0, {'txt':txt0, 'step':step, 'cfg_scale':cfg_scale,})
            data1 = self.ddiminv(img1, {'txt':txt1, 'step':step, 'cfg_scale':cfg_scale,})
        elif inversion == 'DDIM w/o text':
            data0 = self.ddiminv(img0, {'txt':"", 'step':step, 'cfg_scale':cfg_scale,})
            data1 = self.ddiminv(img1, {'txt':"", 'step':step, 'cfg_scale':cfg_scale,})
        else:
            data0, data1 = self.nullinvdual_or_loadcachedual(
                img0, img1, {'txt0':txt0, 'txt1':txt1, 'step':step,
                             'cfg_scale':cfg_scale, 'inner_step':inner_step,
                             'diffuser' : self.tag_diffuser, 'lora' : self.tag_lora,}, force_reinvert)

        tlist = np.linspace(0.0, 1.0, framen)

        iminv0 = t2i_core(self.net, data0['xt'], data0['emb'], data0['nemb'], step, cfg_scale)
        iminv1 = t2i_core(self.net, data1['xt'], data1['emb'], data1['nemb'], step, cfg_scale)
        frames = self.general_interpolation(data0, data1, cfg_scale, step, tlist)

        vdir = self.cache_video_folder
        os.makedirs(vdir, exist_ok=True)
        remove_earliest_file(vdir, max_allowance=self.cache_video_maxn)
        sname = "time{}_imintp_{}_{}_framen{}_fps{}".format(
            int(time.time()), self.tag_diffuser, self.tag_lora, framen, fps)
        vpath = osp.join(vdir, sname+'.mp4')
        frames2mp4(vpath, frames, fps)
        jpath = osp.join(vdir, sname+'.json')
        cfgdict = {
            "method" : "image_interpolation",
            "txt0" : txt0, "txt1" : txt1,
            "cfg_scale" : cfg_scale, "step" : step, 
            "framen" : framen, "fps" : fps,
            "force_resize" : force_resize, "width" : width, "height" : height,
            "inversion" : inversion, "inner_step" : inner_step, 
            "force_reinvert" : force_reinvert, 
            "tag_diffuser" : tag_diffuser, "tag_lora" : tag_lora, "tag_scheduler" : tag_scheduler,}
        with open(jpath, 'w') as f:
            json.dump(cfgdict, f, indent=4)

        return frames, vpath, [iminv0, iminv1]

#################
# get examples #
#################
cache_examples = False
def get_imintp_example():
    case = [
        [
            'assets/images/interpolation/cityview1.png', 
            'assets/images/interpolation/cityview2.png', 
            'A city view',],
        [
            'assets/images/interpolation/woman1.png', 
            'assets/images/interpolation/woman2.png', 
            'A woman face',],
        [
            'assets/images/interpolation/land1.png', 
            'assets/images/interpolation/land2.png', 
            'A beautiful landscape',],
        [
            'assets/images/interpolation/dog1.png', 
            'assets/images/interpolation/dog2.png', 
            'A realistic dog',],
        [
            'assets/images/interpolation/church1.png', 
            'assets/images/interpolation/church2.png', 
            'A church',],
        [
            'assets/images/interpolation/rabbit1.png', 
            'assets/images/interpolation/rabbit2.png', 
            'A cute rabbit',],
        [
            'assets/images/interpolation/horse1.png', 
            'assets/images/interpolation/horse2.png', 
            'A robot horse',],
    ]
    return case

def get_iminvs_example():
    case = [
        [
            'assets/images/inversion/000000560011.jpg', 
            'A mouse is next to a keyboard on a desk',],
        [
            'assets/images/inversion/000000029596.jpg', 
            'A room with a couch, table set with dinnerware and a television.',],
    ]
    return case


def get_imedit_example():
    case = [
        [
            'assets/images/editing/rabbit.png', 
            'A rabbit is eating a watermelon on the table', 
            'A cat is eating a watermelon on the table', 
            0.7,],
        [
            'assets/images/editing/cake.png', 
            'A chocolate cake with cream on it', 
            'A chocolate cake with strawberries on it', 
            0.9,],
        [
            'assets/images/editing/banana.png', 
            'A banana on the table', 
            'A banana and an apple on the table', 
            0.8,],
        
    ]
    return case


#################
# sub interface #
#################


def interface_imintp(wrapper_obj):
    with gr.Row():
        with gr.Column():
            img0 = gr.Image(label="Image Input 0", type='pil',  elem_id='customized_imbox')
        with gr.Column():
            img1 = gr.Image(label="Image Input 1", type='pil', elem_id='customized_imbox')
        with gr.Column():
            video_output = gr.Video(label="Video Result", format='mp4', elem_id='customized_imbox')
    with gr.Row(): 
        with gr.Column():      
            txt0 = gr.Textbox(label='Text Input', lines=1, placeholder="Input prompt...", )
        with gr.Column(): 
            with gr.Row():
                inversion = auto_dropdown('Inversion', choices.inversion, default.inversion)
                inner_step = gr.Slider(label="Inner Step (NTI)", value=default.nullinv_inner_step, minimum=1, maximum=10, step=1)
                force_reinvert = gr.Checkbox(label="Force ReInvert (NTI)", value=False)
                    

    with gr.Row():
        with gr.Column(): 
            with gr.Row():
                framen = gr.Slider(label="Frame Number", minimum=8, maximum=default.framen, value=default.framen, step=1)
                fps = gr.Slider(label="Video FPS", minimum=4, maximum=default.fps, value=default.fps, step=4)
            with gr.Row():
                button_run = gr.Button("Run") 
            

        with gr.Column():
            with gr.Accordion('Frame Results', open=False):
                frame_output = gr.Gallery(label="Frames", elem_id='customized_imbox')
            with gr.Accordion("Inversion Results", open=False):
                inv_output = gr.Gallery(label="Inversion Results", elem_id='customized_imbox')
            with gr.Accordion('Advanced Settings', open=False):
                with gr.Row():
                    tag_diffuser = auto_dropdown('Diffuser', choices.diffuser, default.diffuser)
                    tag_lora = auto_dropdown('Use LoRA', choices.lora, default.lora)
                    tag_scheduler = auto_dropdown('Scheduler', choices.scheduler, default.scheduler)
                with gr.Row():
                    cfg_scale = gr.Number(label="Scale", minimum=1, maximum=10, value=default.cfg_scale, step=0.5)
                    step = gr.Number(default.step, label="Step", precision=0)
                with gr.Row():
                    force_resize = gr.Checkbox(label="Force Resize", value=True)
                    inp_width = gr.Slider(label="Width", minimum=256, maximum=1024, value=512, step=64)
                    inp_height = gr.Slider(label="Height", minimum=256, maximum=1024, value=512, step=64)
                with gr.Row():
                    txt1 = gr.Textbox(label='Optional Different Text Input for Image Input 1', lines=1, placeholder="Input prompt...", )
            

    tag_diffuser.change(
        wrapper_obj.load_all,
        inputs = [tag_diffuser, tag_lora, tag_scheduler],
        outputs = [tag_diffuser, tag_lora, tag_scheduler],)

    tag_lora.change(
        wrapper_obj.load_all,
        inputs = [tag_diffuser, tag_lora, tag_scheduler],
        outputs = [tag_diffuser, tag_lora, tag_scheduler],)

    tag_scheduler.change(
        wrapper_obj.load_scheduler,
        inputs = [tag_scheduler],
        outputs = [tag_scheduler],)

    button_run.click(
        wrapper_obj.run_imintp,
        inputs=[img0, img1, txt0, txt1,
                cfg_scale, step, 
                framen, fps, 
                force_resize, inp_width, inp_height,
                inversion, inner_step, force_reinvert,
                tag_diffuser, tag_lora, tag_scheduler,],
        outputs=[frame_output, video_output, inv_output])

    gr.Examples(
        label='Examples', 
        examples=get_imintp_example(), 
        fn=wrapper_obj.run_imintp,
        inputs=[img0, img1, txt0,],
        outputs=[frame_output, video_output, inv_output],
        cache_examples=cache_examples,)

def interface_iminvs(wrapper_obj):
    with gr.Row():
        image_input = gr.Image(label="Image input", type='pil', elem_id='customized_imbox')
        recon_output = gr.Gallery(label="Reconstruction output", elem_id='customized_imbox')
    with gr.Row():
        with gr.Column():
            prompt = gr.Textbox(label='Text Input', lines=1, placeholder="Input prompt...", )
            with gr.Row():
                button_run = gr.Button("Run")
            
            
        with gr.Column():
            with gr.Row():
                inversion = auto_dropdown('Inversion', choices.inversion, default.inversion)
                inner_step = gr.Slider(label="Inner Step (NTI)", value=default.nullinv_inner_step, minimum=1, maximum=10, step=1)
                force_reinvert = gr.Checkbox(label="Force ReInvert (NTI)", value=False)
            with gr.Accordion('Advanced Settings', open=False):
                with gr.Row():
                    tag_diffuser = auto_dropdown('Diffuser', choices.diffuser, default.diffuser)
                    tag_lora = auto_dropdown('Use LoRA', choices.lora, default.lora)
                    tag_scheduler = auto_dropdown('Scheduler', choices.scheduler, default.scheduler)
                with gr.Row():
                    cfg_scale = gr.Number(label="Scale", minimum=1, maximum=10, value=default.cfg_scale, step=0.5)
                    step = gr.Number(default.step, label="Step", precision=0)
                with gr.Row():
                    force_resize = gr.Checkbox(label="Force Resize", value=True)
                    inp_width = gr.Slider(label="Width", minimum=256, maximum=1024, value=512, step=64)
                    inp_height = gr.Slider(label="Height", minimum=256, maximum=1024, value=512, step=64)
            

    tag_diffuser.change(
        wrapper_obj.load_all,
        inputs = [tag_diffuser, tag_lora, tag_scheduler],
        outputs = [tag_diffuser, tag_lora, tag_scheduler],)

    tag_lora.change(
        wrapper_obj.load_all,
        inputs = [tag_diffuser, tag_lora, tag_scheduler],
        outputs = [tag_diffuser, tag_lora, tag_scheduler],)

    tag_scheduler.change(
        wrapper_obj.load_scheduler,
        inputs = [tag_scheduler],
        outputs = [tag_scheduler],)

    button_run.click(
        wrapper_obj.run_iminvs,
        inputs=[image_input, prompt,  
                cfg_scale, step, 
                force_resize, inp_width, inp_height,
                inversion, inner_step, force_reinvert, 
                tag_diffuser, tag_lora, tag_scheduler,],
        outputs=[recon_output])
    
    gr.Examples(
        label='Examples', 
        examples=get_iminvs_example(), 
        fn=wrapper_obj.run_iminvs,
        inputs=[image_input, prompt,],
        outputs=[recon_output],
        cache_examples=cache_examples,)


def interface_imedit(wrapper_obj):
    with gr.Row():
        image_input = gr.Image(label="Image input", type='pil', elem_id='customized_imbox')
        edited_output = gr.Gallery(label="Edited output", elem_id='customized_imbox')
    with gr.Row():
        with gr.Column():
            prompt_0 = gr.Textbox(label='Source Text', lines=1, placeholder="Source prompt...", )
            prompt_1 = gr.Textbox(label='Target Text', lines=1, placeholder="Target prompt...", )
            with gr.Row():
                button_run = gr.Button("Run")
            
        with gr.Column():
            with gr.Row():
                inversion = auto_dropdown('Inversion', choices.inversion, default.inversion)
                inner_step = gr.Slider(label="Inner Step (NTI)", value=default.nullinv_inner_step, minimum=1, maximum=10, step=1)
                force_reinvert = gr.Checkbox(label="Force ReInvert (NTI)", value=False)
                threshold = gr.Slider(label="Threshold", minimum=0, maximum=1, value=default.threshold, step=0.1)
            with gr.Accordion('Advanced Settings', open=False):
                with gr.Row():
                    tag_diffuser = auto_dropdown('Diffuser', choices.diffuser, default.diffuser)
                    tag_lora = auto_dropdown('Use LoRA', choices.lora, default.lora)
                    tag_scheduler = auto_dropdown('Scheduler', choices.scheduler, default.scheduler)
                with gr.Row():
                    cfg_scale = gr.Number(label="Scale", minimum=1, maximum=10, value=default.cfg_scale, step=0.5)
                    step = gr.Number(default.step, label="Step", precision=0)
                with gr.Row():
                    force_resize = gr.Checkbox(label="Force Resize", value=True)
                    inp_width = gr.Slider(label="Width", minimum=256, maximum=1024, value=512, step=64)
                    inp_height = gr.Slider(label="Height", minimum=256, maximum=1024, value=512, step=64)
            

    tag_diffuser.change(
        wrapper_obj.load_all,
        inputs = [tag_diffuser, tag_lora, tag_scheduler],
        outputs = [tag_diffuser, tag_lora, tag_scheduler],)

    tag_lora.change(
        wrapper_obj.load_all,
        inputs = [tag_diffuser, tag_lora, tag_scheduler],
        outputs = [tag_diffuser, tag_lora, tag_scheduler],)

    tag_scheduler.change(
        wrapper_obj.load_scheduler,
        inputs = [tag_scheduler],
        outputs = [tag_scheduler],)

    button_run.click(
        wrapper_obj.run_imedit,
        inputs=[image_input, prompt_0, prompt_1, 
                threshold, cfg_scale, step, 
                force_resize, inp_width, inp_height,
                inversion, inner_step, force_reinvert, 
                tag_diffuser, tag_lora, tag_scheduler,],
        outputs=[edited_output])
    
    gr.Examples(
        label='Examples', 
        examples=get_imedit_example(), 
        fn=wrapper_obj.run_imedit,
        inputs=[image_input, prompt_0, prompt_1, threshold,],
        outputs=[edited_output],
        cache_examples=cache_examples,)
        

#############
# Interface #
#############

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-p', '--port', type=int, default=None)
    args = parser.parse_args()
    from app_utils import css_empty, css_version_4_11_0
    # css = css_empty
    css = css_version_4_11_0

    wrapper_obj = wrapper(
        fp16=False, 
        tag_diffuser=default.diffuser,
        tag_lora=default.lora,
        tag_scheduler=default.scheduler)

    if True:
        with gr.Blocks(css=css) as demo:
            gr.HTML(
                """
                <div style="text-align: center; max-width: 1200px; margin: 20px auto;">
                <h1 style="font-weight: 900; font-size: 3rem; margin: 0rem">
                    {}
                </h1>
                <h2 style="font-weight: 450; font-size: 1rem; margin-top: 0.5rem; margin-bottom: 0.5rem">
                <b>Smooth Diffusion</b> is a new category of diffusion models that is simultaneously high-performing and smooth. <br>
                Our method formally introduces latent space smoothness to diffusion models like Stable Diffusion. This smoothness dramatically aids in: 1) improving the continuity of transitions in image interpolation, 2) reducing approximation errors in image inversion, and 3) better preserving unedited contents in image editing.
                </h2>
                <h3 style="font-weight: 450; font-size: 1rem; margin: 0rem">
                <a href="https://www.jiayiguo.net/" target="_blank">Jiayi Guo</a>, <a href="https://www.linkedin.com/in/xingqian-xu-97b46526/" target="_blank">Xingqian Xu</a>, 
                <a href="https://scholar.google.com/citations?user=oM9rnYQAAAAJ&hl=en" target="_blank">Yifan Pu</a>, <a href="https://scholar.google.com/citations?user=Yibz_asAAAAJ&hl=en" target="_blank">Zanlin Ni</a>, 
                <a href="https://scholar.google.com/citations?user=-hwGMHcAAAAJ&hl=en" target="_blank">Chaofei Wang</a>, <a href="https://in.linkedin.com/in/v-manushree" target="_blank">Manushree Vasu</a>, 
                <a href="https://www.au.tsinghua.edu.cn/info/1103/1553.htm" target="_blank">Shiji Song</a>, <a href="https://www.gaohuang.net/" target="_blank">Gao Huang</a> 
                and <a href="https://www.humphreyshi.com/home">Humphrey Shi</a> 
                [<a href="https://arxiv.org/abs/2312.04410" style="color:blue;">arXiv</a>] 
                [<a href="https://github.com/SHI-Labs/Smooth-Diffusion" style="color:blue;">GitHub</a>]
                </h3>
                </div>
                """.format(version))

            with gr.Tab('Image Interpolation'):
                interface_imintp(wrapper_obj)
            with gr.Tab('Image Inversion'):
                interface_iminvs(wrapper_obj)
            with gr.Tab('Image Editing'):
                interface_imedit(wrapper_obj)

            gr.Markdown(r"""
            If you find our work helpful, please **star 🌟** the <a href='https://github.com/SHI-Labs/Smooth-Diffusion' target='_blank'>Github Repo</a>. Thanks for your support! 
            [![GitHub Stars](https://img.shields.io/github/stars/SHI-Labs/Smooth-Diffusion?style=social)](https://github.com/SHI-Labs/Smooth-Diffusion)
            ---
            📑 **Citation**
            <br>
            If our work is useful for your research, please consider citing:
            ```bibtex
            @InProceedings{guo2024smooth,
              title={Smooth Diffusion: Crafting Smooth Latent Spaces in Diffusion Models},
              author={Jiayi Guo and Xingqian Xu and Yifan Pu and Zanlin Ni and Chaofei Wang and Manushree Vasu and Shiji Song and Gao Huang and Humphrey Shi},
              booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
              year={2024}
            }
            ```
            """)

        demo.queue()
        demo.launch()