import spaces
import gradio as gr
import torch
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler

import numpy as np
import os
import cv2
from PIL import Image, ImageDraw
import insightface
from insightface.app import FaceAnalysis
import time

# Diffusion
model_base = "runwayml/stable-diffusion-v1-5"

pipe = StableDiffusionPipeline.from_pretrained(model_base, torch_dtype=torch.float16, use_safetensors=True, safety_checker=None,)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)

lora_model_path = "./loralucy6/checkpoint-145000"
pipe.unet.load_attn_procs(lora_model_path)
pipe.to("cuda")


# Insightface model
app = FaceAnalysis(name='buffalo_l')
app.prepare(ctx_id=0, det_size=(640, 640))

def face_swap(src_img, dest_img):
    src_img = Image.open('./images/' + src_img + '.JPG')
    
    # Convert to RGB
    src_img = src_img.convert(mode='RGB')
    dest_img = dest_img.convert(mode='RGB')

    # Convert to array
    src_img_arr = np.asarray(src_img)
    dest_img_arr = np.asarray(dest_img)

    # Face detection
    src_faces = app.get(src_img_arr)
    dest_faces = app.get(dest_img_arr)

    # Initialize swapper
    swapper = insightface.model_zoo.get_model('inswapper_128.onnx', download=False, download_zip=False)

    # Swap face
    res = dest_img_arr.copy()
    for face in dest_faces:
    	res = swapper.get(res, face, src_faces[0], paste_back=True)

    # Convert to PIL image
    final_image = Image.fromarray(np.uint8(res)).convert('RGB')

    return final_image 

@spaces.GPU(enable_queue=True)
def greet(description,color,features,occasion,type_,face):
    start = time.time()

    # Parse input
    prompt = ''
    description = 'description:' + description.replace(' ', '-')
    color = ' color:' + ','.join(color)
    features = ' features:' + ','.join(features)
    occasion = ' occasion:' + ','.join(occasion)
    type_ = ' type:' + ','.join(type_)
   
    prompt += description + color + features +  occasion + type_

    print('prompt:',prompt)
    pipe.to("cuda")
    image = pipe(
        prompt,
        negative_prompt='deformed face,bad anatomy',
        width=312,
        height=512,
        num_inference_steps=100,
        guidance_scale=7.5,
        cross_attention_kwargs={"scale": 1.0}
        ).images[0]

    if(face != 'Normal'):
        image = face_swap(face, image)

    end = time.time()
    print('time:', end - start)
    
    return image

iface = gr.Interface(fn=greet, 
                    inputs=[gr.Textbox(label='Description'),
                            gr.Dropdown(interactive=True, label='Color',choices=['Beige','Black','Blue','Brown','Green','Grey','Orange','Pink','Purple','Red','White','Yellow'],multiselect=True),
                            gr.Dropdown(interactive=True, label='Features',choices=['3/4-sleeve','Babydoll','Closed-Back','Corset','Crochet','Cutouts','Draped','Floral','Gloves','Halter','Lace','Long','Long-Sleeve','Midi','No-Slit','Off-The-Shoulder','One-Shoulder','Open-Back','Pockets','Print','Puff-Sleeve','Ruched','Satin','Sequins','Shimmer','Short','Short-Sleeve','Side-Slit','Square-Neck','Strapless','Sweetheart-Neck','Tight','V-Neck','Velvet','Wrap'],multiselect=True),
                            gr.Dropdown(interactive=True, label='Occasion',choices=['Homecoming','Casual','Wedding-Guest','Festival','Sorority','Day','Vacation','Summer','Pool-Party','Birthday','Date-Night','Party','Holiday','Winter-Formal','Valentines-Day','Prom','Graduation'],multiselect=True),
                            gr.Dropdown(interactive=True, label='Type',choices=['Mini-Dresses','Midi-Dresses','Maxi-Dresses','Two-Piece-Sets','Rompers','Jeans','Jumpsuits','Pants','Tops','Jumpers/Cardigans','Skirts','Shorts','Bodysuits','Swimwear'],multiselect=True),
                            gr.Dropdown(interactive=True, label='Face',choices=['Normal','Cat','Lisa','Mila'], value='Cat'),
                            ], 
                    outputs=gr.Image(type="pil", label="Final Image", width=312, height=512, show_share_button=False),
                    examples=[['Kailani  mesh sequins two piece maxi dress pink',['Pink'],['Cutouts','Long-Sleeve','Sequins','Side-Slit'],['Festival','Party','Prom'],['Maxi-Dresses','Two-Piece-Sets'],'Cat']],
                    title='Lucy in the Sky: Text to Image',
                    description=
                    """
                    Design your own [Lucy in the Sky](https://www.lucyinthesky.com/) dress with text!
                    """
                    )
iface.launch()