import gradio as gr
import os 
# os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
from categories import categories, font_list
import pdb 
from PIL import Image
import random

try:
    import pygsheets
except Exception as e:
    print("pygsheets not found", e)

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"


global_index = 0
mode_global = "DS-Fusion-Express"
prompt_global = ""

def log_data_to_sheet(pre_style, style_custom, glyph, attribute):
    try:
        #authorization
        gc = pygsheets.authorize(service_file='./huggingface-connector-02adcb4cdf00.json')

        #open the google spreadsheet (where 'PY to Gsheet Test' is the name of my sheet)
        sh = gc.open('HuggingFace Logs')

        #select the first sheet 
        wks = sh[0]

        # Send fields list
        wks.append_table(values=[pre_style, style_custom, glyph, attribute])
    except:
        pass

def change_font(evt: gr.SelectData):
    global global_index
    global_index = evt.index

def my_main(pre_style, style_custom, glyph, attribute):

    log_data_to_sheet(pre_style, style_custom, glyph, attribute)

    global prompt_global
    glyph = glyph[0]
    
    command = "rm -r out_cur/*"
    os.system(command)
    for i in range(1,5):
        command = "cp initial_show/" + str(i) +".png out_cur/"+str(i)+".png"
        os.system(command)

    style = pre_style
    
    command = "rm -r data_style/"
    os.system(command)

    if style_custom != "":
        style = style_custom

    if len(glyph) != 1:
        prompt_global = f" {style}"
    else:
        prompt_global = f" {style} {glyph}"

    li = "ckpt/all2.ckpt"
    output_path = f"out/express"
    if attribute == "":
        prompt = f" '{style} {glyph}'"
    else:
        prompt = f" '{attribute} {style} {glyph}'"

    command = "rm -r out/"
    os.system(command)
    
    print(prompt)
    
    command = "python txt2img.py --ddim_eta 1.0 \
                        --n_samples 4 \
                        --n_iter 1\
                        --ddim_steps 50 \
                        --scale 5.0\
                        --H 256\
                        --W 256\
                        --outdir " + output_path + " --ckpt " +li +" --prompt " + prompt
                        
    os.system(command)

    command = "rm -r out_cur/*"
    os.system(command)
    path = []
    final_imgs = os.listdir(output_path+"/samples")
    for i in range(4):
        path.append(os.path.join(output_path+"/samples", final_imgs[i]))
        path_in = os.path.join(output_path+"/samples", final_imgs[i])
        command = "cp " + path_in + " " + "out_cur/"+final_imgs[i]
        os.system(command)

    return gr.update(value=path)



def rem_bg():
    command= "rm -r out_bg/*"
    os.system(command)
    files = os.listdir("out_cur")
    if len(files)>0:
        command_3 = f"python script_step3.py --input_dir out_cur --method 'rembg'"
        os.system(command_3)

    for file in files:
        command = "cp out_bg/"+file +" out_cur/"
        os.system(command)
    
    path = []
    for file in files:
        file_path = os.path.join("out_cur", file)
        image = Image.open(file_path)
        new_image = Image.new("RGBA", image.size, "WHITE") 
        new_image.paste(image, (0, 0), image)            
        new_image.save(file_path, "PNG")  
        path.append(file_path)

    return gr.update(value = path)

font_list_express = [
    'Caladea-Regular', #works good
    # 'Garuda-Bold', #works poorlyß
    'FreeSansOblique', #works average
    "Purisa", #works good
    "Uroob" #worksaverage
    ]

path_fonts_express = []
for font in font_list_express:
    path_in = "font_list/fonts/"+font+".png"
    path_fonts_express.append(path_in)
    
def make_upper(value):
    if value == "":
        return ""
    return value[0].upper()

def get_out_cur():
    path = []
    pth = "log_view"
    for file in os.listdir(pth):
        file_final = os.path.join(pth, file)
        path.append(file_final)
    return gr.update(value=path) 


def update_time_mode(value):
    if value == 'DS-Fusion':
        return gr.update(value="Generation Time: ~5 mins")
    else:
        return gr.update(value="Generation Time: ~30 seconds")

def update_time_cb(value):
    if value:
        if mode_global == "DS-Fusion":
            return gr.update(value="Generation Time: ~8 mins")
        return gr.update(value="Generation Time: ~30 seconds")
    else:
        if mode_global == "DS-Fusion":
            return gr.update(value="Generation Time: ~5 mins")
        return gr.update(value="Generation Time: ~30 seconds")
    
            

def load_img():

    path = []
    dir = "out_cur"
    for file in os.listdir(dir):
        file_full = os.path.join(dir, file)
        path.append(file_full)
    return path

css = '''
<!-- this makes the items inside the gallery to shrink -->
div#gallery_64 div.grid {
    height: 64px;
    width: 180px;
}

<!-- this makes the gallery's height to shrink -->
div#gallery_64 > div:nth-child(3) {
    min-height: 172px !important;
}

<!-- this makes the gallery's height to shrink when you click one image to view it bigger -->
div#gallery_64 > div:nth-child(4) {
    min-height: 172px !important;
}
'''

with gr.Blocks(css=css) as demo:
    
    with gr.Column(): 
        with gr.Row():
            with gr.Column(): 
                with gr.Row():
                    in4 = gr.Text(label="Character (A-Z, 0-9) to Stylize", info = "Only works with capitals. Will pick first letter if more than one", value = "R", interactive = True)
                    in2 = gr.Dropdown(categories, label="Pre-Defined Style Categories", info = "Categories used to train Express", value = "DRAGON", interactive = True)
                
                with gr.Row():
                    in3 = gr.Text(label="Override Style Category ", info="This will replace the pre-defined style value", value = "", interactive = True)
                    in5 = gr.Text(label="Additional Style Attribute ",info= "e.g. pixel, grayscale, etc", value = "", interactive = True)

                # with gr.Row():
                #     with gr.Column():
                #         in8 = gr.Checkbox(label="MULTI FONT INPUT - font selection below is over-ridden", info="Select for more abstract results", value = False, interactive = True).style(container=True)
                #         gallery = gr.Gallery([], label="Select Font", show_label=True, elem_id="gallery_64").style(grid=[2,6],  preview=True, height="auto")
                        
                with gr.Row():
                    btn = gr.Button("Let's Stylize It - Generation Time: ~60 seconds", interactive = True)
                    # btn_bg = gr.Button("Remove Background", interactive = True)

            with gr.Column():
                gallery_out = gr.Gallery(label="Generated images", elem_id="gallery_out").style(grid=[2,2], height="full")
        
    
    
    outputs = [gallery_out]    
    # gallery.select(change_font, None, None)
        

    inputs = [in2,in3,in4,in5]
    
    btn.click(my_main, inputs, outputs)
    # btn_bg.click(rem_bg, None, outputs)




if __name__ == "__main__":
    command = "rm -r out_cur/*"
    os.system(command)
    for i in range(1,5):
        command = "cp initial_show/" + str(i) +".png out_cur/"+str(i)+".png"
        os.system(command)
    demo.queue()
    demo.launch(share=False)