import spaces
import gradio as gr
import subprocess
from PIL import Image,ImageOps,ImageDraw,ImageFilter
import json
import os
import time

from mp_utils import get_pixel_cordinate_list,extract_landmark,get_pixel_cordinate
from glibvision.draw_utils import points_to_box,box_to_xy,plus_point
import mp_constants
import mp_box
import io
import numpy as np
from glibvision.pil_utils import fill_points,create_color_image,draw_points,draw_box

from gradio_utils import save_image,save_buffer,clear_old_files ,read_file
import opencvinpaint

'''
Face landmark detection based Face Detection.
https://ai.google.dev/edge/mediapipe/solutions/vision/face_landmarker
from model card
https://storage.googleapis.com/mediapipe-assets/MediaPipe%20BlazeFace%20Model%20Card%20(Short%20Range).pdf
Licensed Apache License, Version 2.0
Train with google's dataset(more detail see model card)

'''

def picker_color_to_rgba(picker_color):
   
    color_value = picker_color.strip("rgba()").split(",")
    color_value[0] = int(float(color_value[0]))
    color_value[1] = int(float(color_value[1]))
    color_value[2] = int(float(color_value[2]))
    color_value[3] = int(float(color_value[3])*255)
    print(f"picker_color = {picker_color} color_value={color_value}")
    return color_value

#@spaces.GPU(duration=120)
'''

    innner_eyes_blur - inner eyes blur
    iris_mask_blur - final iris edge blur
'''


def process_images(image,eyes_slide_x_ratio,eyes_slide_y_ratio,innner_eyes_blur_ratio=0.1,iris_mask_blur_ratio=0.1,pupil_offset_ratio=0.08,draw_eye_pupil_ratio=1.1,iris_color_value="rgba(20,20,20,255)",eyes_white_erode_ratio=0.2,eyes_ball_mask_ratio=0.9,
                   output_important_only=True,progress=gr.Progress(track_tqdm=True)):
    clear_old_files()
    if image == None:
        raise gr.Error("Need Image")
    
    

    iris_color =  tuple(picker_color_to_rgba(iris_color_value))
    #print(iris_color)
    #return None,None
    
    


    # TODO resize max 2048

    white_image = create_color_image(image.width,image.height,(255,255,255))
    ## Mediapipe landmark
    progress(0, desc="Start Making Animation")
    mp_image,face_landmarker_result = extract_landmark(image)
    larndmarks=face_landmarker_result.face_landmarks


    ## eyes cordinates
    left_iris_points = get_pixel_cordinate_list(larndmarks,mp_constants.LINE_LEFT_IRIS,image.width,image.height)
    right_iris_points = get_pixel_cordinate_list(larndmarks,mp_constants.LINE_RIGHT_IRIS,image.width,image.height)

    left_box = points_to_box(left_iris_points)
    right_box = points_to_box(right_iris_points)
    left_eye_radius = (left_box[2] if left_box[2]>left_box[3] else left_box[3])/2
    right_eye_radius = (right_box[2] if right_box[2]>right_box[3] else right_box[3])/2


    innner_eyes_blur = int(left_eye_radius*innner_eyes_blur_ratio)
    iris_mask_blur = int(left_eye_radius*iris_mask_blur_ratio)
    

    eyes_slide_x = int(left_eye_radius*2 * eyes_slide_x_ratio)
    eyes_slide_y = int(left_eye_radius*2 * eyes_slide_y_ratio)
    
    
    pupil_offset_y = right_eye_radius * pupil_offset_ratio

    point_right_pupil = get_pixel_cordinate(larndmarks,mp_constants.POINT_RIGHT_PUPIL,image.width,image.height)
    point_right_pupil = plus_point(point_right_pupil,[0,pupil_offset_y])

    point_left_pupil = get_pixel_cordinate(larndmarks,mp_constants.POINT_LEFT_PUPIL,image.width,image.height)
    point_left_pupil = plus_point(point_left_pupil,[0,pupil_offset_y])



    left_inner_eyes = get_pixel_cordinate_list(larndmarks,mp_constants.LINE_RIGHT_UPPER_INNER_EYE+mp_constants.LINE_RIGHT_LOWER_INNER_EYE,image.width,image.height)
    right_inner_eyes = get_pixel_cordinate_list(larndmarks,mp_constants.LINE_LEFT_UPPER_INNER_EYE+mp_constants.LINE_LEFT_LOWER_INNER_EYE ,image.width,image.height)
    
    left_white_eyes = get_pixel_cordinate_list(larndmarks,mp_constants.LINE_LEFT_EYES_WHITE,image.width,image.height)
    fill_points(white_image,left_white_eyes,(0,0,0))

    right_white_eyes = get_pixel_cordinate_list(larndmarks,mp_constants.LINE_RIGHT_EYES_WHITE ,image.width,image.height)
    fill_points(white_image,right_white_eyes,(0,0,0))

    left_eyes_box = points_to_box(left_white_eyes)
    right_eyes_box = points_to_box(right_white_eyes)

    black_image = create_color_image(image.width,image.height,(0,0,0))
    draw_box(black_image,left_eyes_box,fill=(255,255,255))
    draw_box(black_image,right_eyes_box,fill=(255,255,255))


    eyes_mask_area_image = black_image.convert("L")
    eyes_mask_image = white_image.convert("L") #eyes-white-area-hole painted black

    galleries = []
    
    progressed = 0

    def add_webp(add_image,label,important=False):
        nonlocal progressed
        progressed += .038
        progress(progressed)
        if important ==False and output_important_only == True:
            return
        
        file_path = save_image(add_image,"webp")
        galleries.append((file_path,label))
    
    # Create EYE LINE IMAGE
    eyes_line_image = image.copy()
    draw_points(eyes_line_image,left_inner_eyes,outline=(200,200,255),fill=None,width=3)
    draw_points(eyes_line_image,right_inner_eyes,outline=(200,200,255),fill=None,width=3)

    draw_points(eyes_line_image,left_white_eyes,outline=(255,0,0),fill=None,width=4)
    draw_points(eyes_line_image,right_white_eyes,outline=(255,0,0),fill=None,width=4)
    draw_points(eyes_line_image,left_iris_points,outline=(0,255,0),fill=None,width=4)
    draw_points(eyes_line_image,right_iris_points,outline=(0,255,0),fill=None,width=4)
    add_webp(eyes_line_image,"eyes-line",True)
    

    # eyes socket(face) image
    rgba_image = image.convert("RGBA")
    rgba_image.putalpha(eyes_mask_image)
    eyes_socket = rgba_image
    add_webp(eyes_socket,"eyes-socket",True)
    eyes_socket_mask = eyes_mask_image

    # Save Eyes mask and area
    eyes_white_mask = ImageOps.invert(eyes_mask_image)
    add_webp(eyes_white_mask,"eyes-mask")
    add_webp(eyes_mask_area_image,"eyes-box")


    # Remove Edge,
    
    erode_size = int(left_box[3]*eyes_white_erode_ratio) # eyes-height base #TODO take care right eyes
    if erode_size%2==0:
        erode_size+=1
    eyes_white_mask=eyes_white_mask.filter(ImageFilter.MinFilter(erode_size))


    # eyes_only_image inner-white-eyes - erode
    rgba_image = image.convert("RGBA")
    rgba_image.putalpha(eyes_white_mask)
    eyes_only_image = rgba_image
    add_webp(eyes_only_image,"eyes-only")
    eyes_only_image_mask = eyes_white_mask.copy()

    

    draw = ImageDraw.Draw(eyes_white_mask)
    draw.circle(point_right_pupil,left_eye_radius*draw_eye_pupil_ratio,fill=(0))
    draw.circle(point_left_pupil,left_eye_radius*draw_eye_pupil_ratio,fill=(0))

    rgba_image = image.convert("RGBA")
    rgba_image.putalpha(eyes_white_mask)
    add_webp(rgba_image,"white-inapint-image",True)

    eyes_mask_area_image.paste(ImageOps.invert(eyes_white_mask),None,mask=eyes_white_mask)
    add_webp(eyes_mask_area_image,"white-inapint-mask")


    cropped_right_eye = rgba_image.crop(box_to_xy(right_eyes_box))
    add_webp(cropped_right_eye,"right-eye")
    cropped_right_eye_mask = eyes_mask_area_image.crop(box_to_xy(right_eyes_box))
    add_webp(cropped_right_eye_mask,"right-eye-mask")


    cropped_left_eye = rgba_image.crop(box_to_xy(left_eyes_box))
    add_webp(cropped_left_eye,"left-eye")
    cropped_left_eye_mask = eyes_mask_area_image.crop(box_to_xy(left_eyes_box))
    add_webp(cropped_left_eye_mask,"left-eye-mask")


    inpaint_radius  = 20
    blur_radius = 15
    edge_expand = 4

    inpaint_mode = "Telea"
    inner_eyes_image = create_color_image(image.width,image.height,color=(0,0,0,0))
    inpaint_right,tmp_mask=opencvinpaint.process_cvinpaint(cropped_right_eye,cropped_right_eye_mask.convert("RGB"),inpaint_radius,blur_radius,edge_expand,inpaint_mode)
    add_webp(inpaint_right,"right-eye")
    inpaint_left,tmp_mask=opencvinpaint.process_cvinpaint(cropped_left_eye,cropped_left_eye_mask.convert("RGB"),inpaint_radius,blur_radius,edge_expand,inpaint_mode)
    add_webp(inpaint_left,"left-eye")
    
    inner_eyes_image.paste(inpaint_right,box_to_xy(right_eyes_box))
    inner_eyes_image.paste(inpaint_left,box_to_xy(left_eyes_box))
    add_webp(inner_eyes_image,"inpainted-eyes",True)
    eyes_blank = inner_eyes_image.copy()
    eyes_blank.paste(eyes_socket,eyes_socket_mask)
    add_webp(eyes_blank,"eyes_blank",True)
    
    eyes_move_pt = (eyes_slide_x,eyes_slide_y)
    draw_pupil_border=2
    
    draw_left_pupil_radius = int(left_eye_radius*draw_eye_pupil_ratio)
    draw_right_pupil_radius = int(right_eye_radius*draw_eye_pupil_ratio)

    eyes_ball_image = eyes_only_image.convert("RGBA")#create_color_image(image.width,image.height,color=(0,0,0,0))
    draw = ImageDraw.Draw(eyes_ball_image)
   
    
    draw.circle(point_right_pupil,draw_right_pupil_radius,outline=iris_color,width=draw_pupil_border)
    draw.circle(point_left_pupil,draw_left_pupil_radius,outline=iris_color,width=draw_pupil_border)
    add_webp(eyes_ball_image,"eyes-ball-inpaint-base",True)

    
    #draw mask too
    

    eyes_ball_image_mask = create_color_image(image.width,image.height,color=(0,0,0))
    draw = ImageDraw.Draw(eyes_ball_image_mask)
    draw.circle(point_right_pupil,draw_right_pupil_radius-draw_pupil_border,fill=(255,255,255))
    draw.circle(point_left_pupil,draw_left_pupil_radius-draw_pupil_border,fill=(255,255,255))
    add_webp(eyes_ball_image_mask,"eyes-ball-image-mask")
    eyes_ball_image_mask = eyes_ball_image_mask.convert("L")

    eyes_ball_image_inpaint_mask = eyes_ball_image_mask.copy()
    eyes_ball_image_inpaint_mask.paste(ImageOps.invert(eyes_only_image_mask),mask=eyes_only_image_mask)
    add_webp(eyes_ball_image_inpaint_mask,"eyes_ball_image_inpaint_mask")

    ### create inpaint and replace
    pupil_inpaint_radius  = 5
    pupil_blur_radius = 2
    pupil_edge_expand = 5

    inpaint_eyes_ball_image,tmp_mask=opencvinpaint.process_cvinpaint(eyes_ball_image,eyes_ball_image_inpaint_mask.convert("RGB"),pupil_inpaint_radius,pupil_blur_radius,pupil_edge_expand,inpaint_mode)
    inpaint_eyes_ball_image.putalpha(eyes_ball_image_mask)
    add_webp(inpaint_eyes_ball_image,"inpaint_eyes_ball_image")
    eyes_ball_image = inpaint_eyes_ball_image


    eyes_and_ball_mask = eyes_only_image_mask.copy()
    eyes_and_ball_mask.paste(eyes_ball_image_mask,mask=eyes_ball_image_mask)
    add_webp(eyes_and_ball_mask,"eyes-ball-mask")

    eyes_ball_image.paste(eyes_only_image,mask=eyes_only_image_mask)
    add_webp(eyes_ball_image,"eyes-ball",True)



    inner_eyes_image.paste(eyes_ball_image,eyes_move_pt,mask=eyes_and_ball_mask)
    add_webp(inner_eyes_image,"inner-eyes")

    inner_eyes_image.paste(eyes_socket,None,mask=eyes_socket_mask)

    #ImageFilter.BLUR,"Smooth More":ImageFilter.SMOOTH_MORE,"Smooth":ImageFilter.SMOOTH
    filtered_image = inner_eyes_image.filter(ImageFilter.GaussianBlur(radius=innner_eyes_blur))
    add_webp(filtered_image,"bluerd_inner_face",True)
    

    #filtered_image.paste(eyes_only_image,eyes_move_pt,mask=eyes_ball_image_mask.convert("L"))
    


    ### create innner mask minus eyeballs
    white_image = create_color_image(image.width,image.height,color=(255,255,255))
    draw = ImageDraw.Draw(white_image)
    right_eyes_xy = get_pixel_cordinate(larndmarks,mp_constants.POINT_RIGHT_PUPIL,image.width,image.height)
    left_eyes_xy = get_pixel_cordinate(larndmarks,mp_constants.POINT_LEFT_PUPIL,image.width,image.height)
    draw.circle(plus_point(left_eyes_xy,eyes_move_pt),left_eye_radius*eyes_ball_mask_ratio,fill=(0,0,0,255))
    draw.circle(plus_point(right_eyes_xy,eyes_move_pt),right_eye_radius*eyes_ball_mask_ratio,fill=(0,0,0,255))
    add_webp(white_image,"eyes_ball_mask")

    eyes_socket_mask_invert = ImageOps.invert(eyes_socket_mask)
    eyes_socket_mask_invert.paste(white_image,eyes_socket_mask_invert)
    add_webp(eyes_socket_mask_invert,"inner_mask_without_eyesball")


    ### final paste eyes-ball and outer-faces on blured inner
    eyes_socket_mask_invert = eyes_socket_mask_invert.filter(ImageFilter.GaussianBlur(radius=iris_mask_blur))
    add_webp(eyes_socket_mask_invert,"inner_mask_without_eyesball-blur",True)

    filtered_image.paste(inner_eyes_image,None,mask=ImageOps.invert(eyes_socket_mask_invert))
    filtered_image.paste(eyes_socket,None,mask=eyes_socket_mask)
    file_path = save_image(filtered_image,"webp")

    return filtered_image,galleries
    



css="""
#col-left {
    margin: 0 auto;
    max-width: 640px;
}
#col-right {
    margin: 0 auto;
    max-width: 640px;
}
.grid-container {
  display: flex;
  align-items: center;
  justify-content: center;
  gap:10px
}

.image {
  width: 128px; 
  height: 128px; 
  object-fit: cover; 
}

.text {
  font-size: 16px;
}
"""

#css=css,



with gr.Blocks(css=css, elem_id="demo-container") as demo:
    with gr.Column():
        gr.HTML(read_file("demo_header.html"))
        gr.HTML(read_file("demo_tools.html"))
    with gr.Row():
                with gr.Column():
                    image = gr.Image(height=800,sources=['upload','clipboard'],image_mode='RGB',elem_id="image_upload", type="pil", label="Image")
                    with gr.Row(elem_id="prompt-container",  equal_height=False):
                        with gr.Row():
                            btn = gr.Button("Slide Eyes Direction", elem_id="run_button",variant="primary")
                    
                    with gr.Accordion(label="Eyes Slide", open=True):
                        with gr.Row(equal_height=False):
                            eyes_slide_x_ratio = gr.Slider(
                                    label="Horizontal Slide (Iris based size)",
                                    minimum=-2,
                                    maximum=2,
                                    step=0.01,
                                    value=0,info="Based iris size minus to left,plus to right")
                            eyes_slide_y_ratio = gr.Slider(
                                    label="Vertical Slide (Iris based size)",
                                    minimum=-1.5,
                                    maximum=1.5,
                                    step=0.01,
                                    value=0.25,info="Based iris size minus to up,plus to down")
                        
                    with gr.Accordion(label="Advanced Settings", open=False):
                        
                        with gr.Row( equal_height=True):
                            innner_eyes_blur_ratio = gr.Slider(
                                label="Inner Eyes Blur Ratio",
                                minimum=0,
                                maximum=1,
                                step=0.01,
                                value=0.2,info="increse valueinnser eyes make flat")
                            iris_mask_blur_ratio = gr.Slider(
                                label="Iris Mask Blur Ratio",
                                minimum=0,
                                maximum=1,
                                step=0.01,
                                value=0.15,info="mask edge smooth")
                        with gr.Row( equal_height=True):
                            pupil_offset_ratio = gr.Slider(
                                label="Pupil center Offset Y",
                                minimum=-0.5,
                                maximum=0.5,
                                step=0.01,
                                value=-0.08,info="mediapipe detection is not middle")
                            draw_eye_pupil_ratio = gr.Slider(
                                label="Draw Pupil radius ratio",
                                minimum=0.5,
                                maximum=1.5,
                                step=0.01,
                                value=1.1,info="mediapipe detection is usually small")
                            iris_color_value = gr.ColorPicker(value="rgba(20,20,20,1)",label="Iris Border Color")
                        with gr.Row( equal_height=True):
                            eyes_white_erode_ratio = gr.Slider(
                                label="Eye Erode erode ratio",
                                minimum=0,
                                maximum=0.5,
                                step=0.01,
                                value=0.1,info="eyes edge is pink")
                            eyes_ball_mask_ratio = gr.Slider(
                                label="Eye Ball Mask ratio",
                                minimum=0,
                                maximum=1,
                                step=0.01,
                                value=0.9,info="iris blur and mask for img2img")
                        with gr.Row( equal_height=True):
                            output_important_only=gr.Checkbox(label="output important image only",value=True)        

                                         
                with gr.Column():
                    animation_out = gr.Image(height=760,label="Result", elem_id="output-animation")
                    image_out = gr.Gallery(label="Output", elem_id="output-img",preview=True)
                    

    btn.click(fn=process_images, inputs=[image,eyes_slide_x_ratio,eyes_slide_y_ratio,innner_eyes_blur_ratio,iris_mask_blur_ratio,
                                         pupil_offset_ratio,draw_eye_pupil_ratio,iris_color_value,eyes_white_erode_ratio,eyes_ball_mask_ratio,output_important_only
                                         ],outputs=[animation_out,image_out] ,api_name='infer')
    gr.Examples(
                examples =[
                     ["examples/02316230.jpg"],
                    ["examples/00003245_00.jpg"],
                   ["examples/00827009.jpg"],
                   
                     ["examples/00002062.jpg"],
                   

                    ["examples/00824008.jpg"],
                    ["examples/00825000.jpg"],
                    ["examples/00826007.jpg"],
                     ["examples/00824006.jpg"],
                    ["examples/00828003.jpg"],

                     ["examples/00002200.jpg"],
                    ["examples/00005259.jpg"],
                    ["examples/00018022.jpg"],
                    ["examples/img-above.jpg"],
                     ["examples/00100265.jpg"],
                ],
                #examples =["examples/00003245_00.jpg","examples/00002062.jpg","examples/00100265.jpg","examples/00824006.jpg","examples/00824008.jpg",
                #           "examples/00825000.jpg","examples/00826007.jpg","examples/00827009.jpg","examples/00828003.jpg",],
                inputs=[image],examples_per_page=5
    )
    gr.HTML(read_file("demo_footer.html"))

    if __name__ == "__main__":
        demo.launch()