import project_path

import cv2
import numpy as np
from tqdm import tqdm

from lib.fish_eye.tracker import Tracker


VERSION = "09/21"
PRED_COLOR = (255, 0, 0) # blue
WHITE = (255, 255, 255)
BLACK = (0, 0, 0)
BORDER_PAD = 3
LINE_HEIGHT= 22
VIDEO_HEIGHT = 700
INFO_PANE_WIDTH = 100
BOX_THICKNESS = 2
FONT_SCALE = 0.65
FONT_THICKNESS = 1


def is_fourcc_available(codec):
    try:
        fourcc = cv2.VideoWriter_fourcc(*codec)
        temp_video = cv2.VideoWriter('temp.mp4', fourcc, 30, (640, 480), isColor=True)
        return temp_video.isOpened()
    except:
        return False

def generate_video_batches(didson, preds, frame_rate, video_out_path, gp=None, image_meter_width=None, image_meter_height=None, batch_size=1000):
    """Write a visualized video to video_out_path, given a didson object.
    """
    if (gp): gp(0, "Generating results video...")
    end_frame = didson.info['endframe'] or didson.info['numframes']
    out = None # need to wait til we have height and width to instantiate video file
    
    with tqdm(total=end_frame, desc="Generating results video", ncols=0) as pbar:
        for i in range(0, end_frame, batch_size):
            batch_end = min(end_frame, i+batch_size)
            frames = didson.load_frames(start_frame=i, end_frame=batch_end)
            vid_frames, h, w = get_video_frames(frames, preds, frame_rate, image_meter_width, image_meter_height, start_frame=i)

            if out is None:
                codec = cv2.VideoWriter_fourcc(*'avc1') if is_fourcc_available("avc1") else cv2.VideoWriter_fourcc(*'mp4v')
                out = cv2.VideoWriter(video_out_path, codec, frame_rate, [ int(1.5*w), h ] )

            for j, frame in enumerate(vid_frames):
                if gp: gp(( (i+j) / end_frame), 'Generating results video...')
                out.write(frame)
                pbar.update(1)

            del frames
            del vid_frames
    
    out.release()
    
def get_video_frames(frames, preds, frame_rate, image_meter_width=None, image_meter_height=None, start_frame=0):
    """Get visualized video frames ready for output, given raw ARIS/DIDSON frames.
    Warning: all frames in frames will be stored in memory - careful of OOM errors. Consider processing large files
    in batches, such as in generate_video_batches()
    
    Returns:
        list(np.ndarray), height (int), width (int)
    """
    pred_lengths = { fish['id'] : "%.2fm" % fish['length'] for fish in preds['fish'] }
    clip_pr_counts = Tracker.count_dirs(preds)
    color_map = { fish['id'] : fish['color'] for fish in preds['fish'] }
    
    # filter JSON, if necessary (for shorter clips)
    preds_frames = preds['frames'][start_frame:]
    
    vid_frames = []
    if len(frames):
        # assumes all frames the same size
        h, w = frames[0].shape
        
        # enforce a standard size so that text/box thickness is consistent
        scale_factor = VIDEO_HEIGHT / h
        h = VIDEO_HEIGHT
        w = int(scale_factor*w)

        num_frames = min(len(frames), len(preds_frames))
        
        for i, frame_raw in enumerate(frames[:num_frames]):
            frame_raw = cv2.resize(cv2.cvtColor(frame_raw, cv2.COLOR_GRAY2BGR), (w,h))
            pred = preds_frames[i]

            for fish in pred['fish']:
                xmin, ymin, xmax, ymax = fish['bbox']
                left = int(round(xmin * w))
                right = int(round(xmax * w))
                top = int(round(ymin * h))
                bottom = int(round(ymax * h))
                fish_id = str(fish['fish_id'])
                fish_len = pred_lengths[fish['fish_id']]
                hexx = color_map[fish['fish_id']].lstrip('#')
                color = tuple(int(hexx[i:i+2], 16) for i in (0, 2, 4))
                draw_fish(frame_raw, left, right, top, bottom, color, fish_id, fish_len, anno_align="right")

            # add axis to frame
            frame_raw = add_axis(frame_raw, image_meter_width, image_meter_height)

            # add info
            frame_info_panel = np.zeros((h, int(0.5*w), 3)).astype(np.uint8)
            frame = np.concatenate((frame_info_panel, frame_raw), axis=1)
            cv2.putText(frame, f'VERSION: {VERSION}', (BORDER_PAD, h-BORDER_PAD-LINE_HEIGHT*4), cv2.FONT_HERSHEY_SIMPLEX, FONT_SCALE, WHITE, FONT_THICKNESS, cv2.LINE_AA, False)
            cv2.putText(frame, f'Right count: {clip_pr_counts[0]}',  (BORDER_PAD, h-BORDER_PAD-LINE_HEIGHT*3), cv2.FONT_HERSHEY_SIMPLEX, FONT_SCALE, WHITE, FONT_THICKNESS, cv2.LINE_AA, False)
            cv2.putText(frame, f'Left count: {clip_pr_counts[FONT_THICKNESS]}',  (BORDER_PAD, h-BORDER_PAD-LINE_HEIGHT*2), cv2.FONT_HERSHEY_SIMPLEX, FONT_SCALE, WHITE, FONT_THICKNESS, cv2.LINE_AA, False)
            cv2.putText(frame, f'Other fish: {clip_pr_counts[2]}',  (BORDER_PAD, h-BORDER_PAD-LINE_HEIGHT*1), cv2.FONT_HERSHEY_SIMPLEX, FONT_SCALE, WHITE, FONT_THICKNESS, cv2.LINE_AA, False)
    #         cv2.putText(frame, f'Upstream: {preds["upstream_direction"]}', (0, h-1-LINE_HEIGHT*1), cv2.FONT_HERSHEY_SIMPLEX, FONT_SCALE, WHITE, FONT_THICKNESS, cv2.LINE_AA, False)
            cv2.putText(frame, f'Frame: {start_frame+i}', (BORDER_PAD, h-BORDER_PAD-LINE_HEIGHT*0), cv2.FONT_HERSHEY_SIMPLEX, FONT_SCALE, WHITE, FONT_THICKNESS, cv2.LINE_AA, False)

            vid_frames.append(frame)
    
    return vid_frames, h, w

def draw_fish(frame, left, right, top, bottom, color, fish_id, fish_len, LINE_HEIGHT=18, anno_align="left"):
    cv2.rectangle(frame, (left, top), (right, bottom), color, BOX_THICKNESS)
    
    if anno_align == "left":
        anno_align = left
    else:
        anno_align = right
    cv2.putText(frame, fish_id, (anno_align, top), cv2.FONT_HERSHEY_SIMPLEX, FONT_SCALE, color, FONT_THICKNESS, cv2.LINE_AA, False)
    cv2.putText(frame, fish_len, (anno_align, bottom+int(LINE_HEIGHT/2)), cv2.FONT_HERSHEY_SIMPLEX, FONT_SCALE, color, FONT_THICKNESS, cv2.LINE_AA, False)
    
def add_axis(img, image_meter_width=None, image_meter_height=None):
    h, w, c = img.shape
    
    # add black border around image
    bordersize_t = 25
    bordersize_l = 45
    img = cv2.copyMakeBorder(
        img,
        bottom=bordersize_t,
        top=0,
        left=bordersize_l,
        right=25, # this helps with text getting cut off
        borderType=cv2.BORDER_CONSTANT,
        value=BLACK
    )
    
    # add axis
    axis_thickness = 1
    img = cv2.line(img, (bordersize_l, h+axis_thickness//2), (w+bordersize_l, h+axis_thickness//2), WHITE, axis_thickness) # x
    img = cv2.line(img, (bordersize_l-axis_thickness//2, 0), (bordersize_l-axis_thickness//2, h), WHITE, axis_thickness) # y
    
    # dist between ticks in meters
    x_inc = 100
    if image_meter_width and image_meter_width > 0:
        x_inc = w / image_meter_width / 2 # 0.5m ticks
        if image_meter_width > 4:
            x_inc *= 2 # 1m ticks
        if image_meter_width > 8:
            x_inc *= 2 # 2m ticks
            
    # dist between ticks in meters
    y_inc = 100
    if image_meter_height and image_meter_height > 0:
        y_inc = h / image_meter_height / 2 # 0.5m ticks
        if image_meter_height > 4:
            y_inc *= 2 # 1m ticks
        if image_meter_height > 8:
            y_inc *= 2 # 2m ticks
        if image_meter_height > 12:
            y_inc *= 3/2 # 3m ticks
            
    # tick mark labels
    def x_label(x):
        if image_meter_width and image_meter_width > 0:
            if x_inc < w / image_meter_width: # fractional ticks
                return "%.1fm" % (x / w * image_meter_width)
            return "%.0fm" % (x / w * image_meter_width)
        return str(x) # pixels
    def y_label(y):
        if image_meter_height and image_meter_height > 0:
            if y_inc < y / image_meter_height: # fractional ticks
                return "%.1fm" % (y / h * image_meter_height)
            return "%.0fm" % (y / h * image_meter_height)
        return str(y) # pixels

    # add ticks
    ticksize = 5
    x = 0
    while x < w:
        img = cv2.line(img, (int(bordersize_l+x), h+axis_thickness//2), (int(bordersize_l+x), h+axis_thickness//2+ticksize), WHITE, axis_thickness)
        cv2.putText(img, x_label(x), (int(bordersize_l+x), h+axis_thickness//2+LINE_HEIGHT), cv2.FONT_HERSHEY_SIMPLEX, FONT_SCALE*3/4, WHITE, FONT_THICKNESS, cv2.LINE_AA, False)
        x += x_inc
    y = 0
    while y < h:
        img = cv2.line(img, (bordersize_l-axis_thickness//2, int(h-y)), (bordersize_l-axis_thickness//2-ticksize, int(h-y)), WHITE, axis_thickness)
        ylabel = y_label(y)
        txt_offset = 13*len(ylabel)
        cv2.putText(img, y_label(y), (bordersize_l-axis_thickness//2-ticksize - txt_offset, int(h-y)), cv2.FONT_HERSHEY_SIMPLEX, FONT_SCALE*3/4, WHITE, FONT_THICKNESS, cv2.LINE_AA, False)
        y += y_inc
    
    # resize to original dims
    return cv2.resize(img, (w,h))