import cv2
import base64
import gradio as gr
import json
import numpy as np

VIDEO_HEIGHT = 700


# annotation_btn.clock - switches to annotation tab and starts load_annotation
def prepare_annotation(state, result, result_index):

    state['annotation_index'] = result_index
    state['frame_index'] = 0

    # output for [annotation_progress, master_tabs]
    if result["aris_input"][result_index]:
        return [
            gr.update(value="<p id='annotation_info' style='display:none'>[]</p><!--" + str(np.random.rand()) + "-->", visible=True),
            gr.update(selected=2)
        ]
    return [gr.update(), gr.update()]

# annotation_progress.change - loads annotation frames in batches - called after prepare_annotation
def load_annotation(state, result, progress_bar):

    # Get result index
    result_index = state['annotation_index']

    set_progress = lambda pct, msg: progress_bar(pct, desc=msg)

    if state['frame_index'] == 0:
        if set_progress: set_progress(0, "Loading Frames")

    # Check that frames remain to be loaded
    if state['frame_index'] < len(result['json_result'][result_index]['frames']):

        # load frames and annotation
        annotation_info, state['frame_index'] = init_frames(result["aris_input"][result_index], result['json_result'][result_index], state['frame_index'], gp=set_progress)
    
        # save as html element
        annotation_content = "<p id='annotation_info' style='display:none'>" + json.dumps(annotation_info) + "</p>"

        # output for [annotation_editor, annotation_progress]
        return [gr.update(), gr.update(value=annotation_content)]

    # If complete, start annotation editor    

    annotation_html = ""

    # Header
    annotation_html += "<div id='annotation_header'>"
    annotation_html += "     <h1 id='annotation_frame_nbr'>Frame 0/100</h1>"
    annotation_html += "     <p id='annotation_edited'>(edited)</p>"
    annotation_html += "</div>"

    # Annotation Body
    annotation_html += "<div style='display:flex'>"
    annotation_html += "     <canvas id='canvas' style='width:50%' onmousedown='mouse_down(event)' onmousemove='mouse_move(event)' onmouseup='mouse_up()' onmouseleave='mouse_up()'></canvas>"
    annotation_html += "     <div id='annotation_display' style='width:50%'></div>"
    annotation_html += "</div>"
    
    # Dummy objects
    annotation_html += "<img id='annotation_img' onload='draw()' style='display:none'></img>"
    annotation_html += "<!--" + str(np.random.rand()) + "-->"

    # output for [annotation_editor, annotation_progress]
    return [gr.update(value=annotation_html, visible=True), gr.update(visible=False)]

# called by load_annotation - read frames from dataloader and formats tracks
def init_frames(dataset, preds, index, gp=None):
    """Load frames for annotation editing
    
        Returns:
            list({
                frame: frame image as base64 string,
                annotations: list(
                    bbox: dict of int defining bounding box {left, right, top, bottom},
                    id: id of fish as int,
                    conf: confidence in bbox as float
                )
            })
    """

    images = dataset.didson.load_frames(start_frame=0, end_frame=1)

    # assumes all frames the same size
    h, w = images[0].shape

    # enforce a standard size so that text/box thickness is consistent
    scale_factor = VIDEO_HEIGHT / h
    h = VIDEO_HEIGHT
    w = int(scale_factor*w)
    
    annotations = []

    if gp: gp(0, "Extracting Frames")
    if len(preds['frames']):
        
        end_index = min(index+1000, len(preds['frames']))
        for i, frame_info in enumerate(preds['frames'][index:end_index]):
            if gp: gp((index + i)/len(preds['frames']), "Extracting Frames")

            # Extract frames

            img_raw = dataset.didson.load_frames(start_frame=index+i, end_frame=index+i+1)[0]
            image = cv2.resize(cv2.cvtColor(img_raw, cv2.COLOR_GRAY2BGR), (w, h))
            retval, buffer = cv2.imencode('.jpg', image)
            b64 = base64.b64encode(buffer).decode("utf-8")

            # Extract annotations
            frame = {
                'annotations': [],
                'base64': b64
            }
            for fish in frame_info['fish']:
                xmin, ymin, xmax, ymax = fish['bbox']
                frame['annotations'].append({
                    'bbox': {
                        'left': int(round(xmin * w)),
                        'right': int(round(xmax * w)),
                        'top': int(round(ymin * h)),
                        'bottom': int(round(ymax * h)),
                    },
                    'id': str(fish['fish_id']),
                    'conf': fish['conf']
                })
            annotations.append(frame)
    
    return annotations, end_index

# javascript code that retrieves the data from load_annotation and saves it to the javascript window
js_store_frame_info = """
    () => {
        info_string = document.getElementById("annotation_info").innerHTML;
        info = JSON.parse(info_string);
        console.log(info)
        if (info.length == 0) {
            window.annotation_info = [];
            return false;
        } 
        window.annotation_info = window.annotation_info.concat(info)
        console.log(window.annotation_info)
        return true;
    }
"""

annotation_css = """
#annotation_frame_nbr {
    left: calc(50% - 100px);
    position: absolute;
    width: 200px;
    text-align: center;
    font-size: x-large;
}
#annotation_header {
    height: 40px;
}
#annotation_frame_nbr {
    left: calc(50% - 100px);
    position: absolute;
    width: 200px;
    text-align: center;
    font-size: x-large;
}
#annotation_edited {
    right: 0px;
    position: absolute;
    margin-top: 5px;
}
"""