import datetime
import numpy as np
from matplotlib.backends.backend_pdf import PdfPages
import matplotlib.pyplot as plt
import math
import cv2
import os

STANDARD_FIG_SIZE = (16, 9)
OUT_PDF_FILE_NAME = 'tmp/fisheye_pdf.pdf'
os.makedirs('tmp', exist_ok=True)


def make_pdf(i, state, result, dataset, table_headers):

    fish_info = result["fish_info"][i]
    fish_table = result["fish_table"][i]
    json_result = result['json_result'][i]
    dataset = result['datasets'][i]
    metadata = json_result['metadata']

    with PdfPages(OUT_PDF_FILE_NAME) as pdf:
        plt.rcParams['text.usetex'] = False
        
        generate_title_page(pdf, metadata, state)
        
        generate_global_result(pdf, fish_info)

        generate_fish_list(pdf, table_headers, fish_table)

        for i, fish in enumerate(json_result['fish']):
            calculate_fish_paths(json_result, dataset, i)
        
        draw_combined_fish_graphs(pdf, json_result)

        for i, fish in enumerate(json_result['fish']):
            draw_fish_tracks(pdf, json_result, dataset, i)

        # We can also set the file's metadata via the PdfPages object:
        d = pdf.infodict()
        d['Title'] = 'Multipage PDF Example'
        d['Author'] = 'Oskar Åström'
        d['Subject'] = 'How to create a multipage pdf file and set its metadata'
        d['Keywords'] = ''
        d['CreationDate'] = datetime.datetime.today()
        d['ModDate'] = datetime.datetime.today()


def generate_title_page(pdf, metadata, state):
    # set up figure that will be used to display the opening banner
    fig = plt.figure(figsize=STANDARD_FIG_SIZE)
    plt.axis('off')

    title_font_size = 40
    minor_font_size = 20

    # stuff to be printed out on the first page of the report
    plt.text(0.5,-0.5,f'{metadata["FILE_NAME"].split("/")[-1]}',fontsize=title_font_size, horizontalalignment='center')

    plt.text(0,1,f'Duration: {metadata["TOTAL_TIME"]}',fontsize=minor_font_size)
    plt.text(0,1.5,f'Frames: {metadata["TOTAL_FRAMES"]}',fontsize=minor_font_size)
    plt.text(0,2,f'Frame Rate: {metadata["FRAME_RATE"]}',fontsize=minor_font_size)

    plt.text(0.5,1,f'Time of filming: {metadata["DATE"]} ({metadata["START"]} - {metadata["END"]})',fontsize=minor_font_size)
    plt.text(0.5,1.5,f'Web app version: {state["version"]}',fontsize=minor_font_size)

    plt.text(1.1,4.5,f'PDF generated on {datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}',fontsize=minor_font_size, horizontalalignment='right')

    plt.ylim([-1, 4])
    plt.xlim([0, 1])
    plt.gca().invert_yaxis()
    
    pdf.savefig(fig)
    plt.close(fig)

def generate_global_result(pdf, fish_info):
    # set up figure that will be used to display the opening banner
    fig = plt.figure(figsize=STANDARD_FIG_SIZE)
    plt.axis('off')
    # stuff to be printed out on the first page of the report

    minor_font_size = 18

    headers = ["Result", "Camera Info", "Hyperparameters"]
    info_col_1 = []
    info_col_2 = []
    info_col = info_col_1
    row_state = -1
    for row in fish_info:
        if row_state >= 0:
            info_col.append([row[0].replace("**","").replace("_", " ").lower(), row[1], 'normal'])
        if (row[0] == "****"):
            row_state += 1
            if row_state == 2: info_col = info_col_2
            info_col.append([headers[row_state], "", 'bold'])
    for row_i, row in enumerate(info_col_1):
        h = -1 + 5*row_i/len(info_col_1)
        plt.text(0,    h, row[0], fontsize=minor_font_size, weight=row[2])
        plt.text(0.25, h, row[1], fontsize=minor_font_size, weight=row[2])
    for row_i, row in enumerate(info_col_2):
        h = -1 + 5*row_i/len(info_col_2)
        plt.text(0.5,  h, row[0], fontsize=minor_font_size, weight=row[2])
        plt.text(0.75, h, row[1], fontsize=minor_font_size, weight=row[2])
    plt.ylim([-1, 4])
    plt.xlim([0, 1])
    plt.gca().invert_yaxis()
    
    pdf.savefig(fig)
    plt.close(fig)

def generate_fish_list(pdf, table_headers, fish_table):
    # set up figure that will be used to display the opening banner
    fig = plt.figure(figsize=STANDARD_FIG_SIZE)
    plt.axis('off')
    # stuff to be printed out on the first page of the report

    title_font_size = 40
    header_font_size = 12
    body_font_size = 20

    # Title
    plt.text(0.5,-1.3,f'{"Identified Fish"}',fontsize=title_font_size, horizontalalignment='center', weight='bold')

    # Identified fish
    row_h = 0.25
    col_start = 0
    row_l = 1
    dropout_i = None
    for col_i, col in enumerate(table_headers):
        x = col_start + row_l*(col_i+0.5)/len(table_headers)
        if col == "TOTAL": col = "ID"
        if col == "DETECTION_DROPOUT": 
            col = "frame drop rate"
            dropout_i = col_i
        col = col.lower().replace("_", " ")
        plt.text(x, -1, col, fontsize=header_font_size, horizontalalignment='center', weight="bold")
        plt.plot([col_start*2, -col_start*2 + row_l], [-1 + 0.05, -1 + 0.05], color='black')

    for row_i, row in enumerate(fish_table):
        y = -1 + (row_i+1)*row_h
        plt.plot([col_start*2, -col_start*2 + row_l], [y+0.05, y+0.05], color='black')
        for col_i, col in enumerate(row):
            x = col_start + row_l*(col_i+0.5)/len(row)
            if (col_i == dropout_i and type(col) is not str):
                col = str(int(col*100)) + "%"
            elif type(col) == float:
                col = "{:.4f}".format(col)
            plt.text(x, y, col, fontsize=body_font_size, horizontalalignment='center')
    plt.ylim([-1, 4])
    plt.xlim([0, 1])
    plt.gca().invert_yaxis()
    
    pdf.savefig(fig)
    plt.close(fig)

def calculate_fish_paths(result, dataset, id):

    fish = result['metadata']['FISH'][id]
    start_frame = fish['START_FRAME']
    end_frame = fish['END_FRAME']
    fps = result['metadata']['FRAME_RATE']

    # Extract base frame (first frame for that fish)
    w, h = 1, 2
    img = None
    if (dataset is not None):
        

        images = dataset.didson.load_frames(start_frame=start_frame, end_frame=start_frame+1)
        img = images[0]

        w, h = img.shape

        frame_height = 2
        scale_factor = frame_height / h
        h = frame_height
        w = int(scale_factor*w)

    fish['base_frame'] = img
    fish['scaled_frame_size'] = (h, w)


    # Find frames for this fish
    bboxes = []
    for frame in result['frames'][start_frame:end_frame+1]:
        bbox = None
        for ann in frame['fish']:
            if ann['fish_id'] == id+1:
                bbox = ann
        bboxes.append(bbox)


    # Calculate tracks through frames
    missed = 0
    X = []
    Y = []
    V = []
    certainty = []
    for bbox in bboxes:
        if bbox is not None:

            # Find fish centers
            x = (bbox['bbox'][0] + bbox['bbox'][2])/2
            y = (bbox['bbox'][1] + bbox['bbox'][3])/2

            # Calculate velocity
            v = None
            if len(X) > 0:
                last_x = X[-1]
                last_y = Y[-1]
                dx = result['image_meter_width']*(last_x - x)/(missed+1)
                dy = result['image_meter_height']*(last_y - y)/(missed+1)
                v = math.sqrt(dx*dx + dy*dy) * fps

            # Interpolate over missing frames
            if missed > 0:    
                for i in range(missed):
                    p = (i+1)/(missed+1)
                    X.append(last_x*(1-p) + p*x)
                    Y.append(last_y*(1-p) + p*y)
                    V.append(v)
                    certainty.append(False)
            
            # Append new track frame
            X.append(x)
            Y.append(y)
            if v is not None: V.append(v)
            certainty.append(True)
            missed = 0
        else:
            missed += 1

    fish['path'] = {
        'X': X,
        'Y': Y,
        'certainty': certainty,
        'V': V
    }


def draw_combined_fish_graphs(pdf, result):

    vel = []
    log_vel = []
    eps = 0.00000000001
    for fish in result['metadata']['FISH']:
        for v in fish['path']['V']:
            vel += [v]
            if v > 0:
                log_vel += [math.log(v)]

    fig, axs = plt.subplots(2, 2, sharex=False, sharey=False, figsize=STANDARD_FIG_SIZE)

    # Title
    fig.suptitle('Fish velocities', fontsize=40, horizontalalignment='center', weight='bold')

    axs[0,0].hist(log_vel, bins=20)
    axs[0,0].set_title('Fish Log-Velocities between frames')
    axs[0,0].set_xlabel("Log-Velocity (log(m/s))")

    axs[0,1].hist(vel, bins=20)
    axs[0,1].set_title('Fish Velocities between frames')
    axs[0,1].set_xlabel("Velocity (m/s)")

    for fish in result['metadata']['FISH']:
        data = []
        for v in fish['path']['V']:
            if v > 0: data += [math.log(v)]
        n, bin_c = make_hist(data)
        axs[1,0].plot(bin_c, n)
    axs[1,0].set_title('Fish Log-Velocities between frames (per fish)')
    axs[1,0].set_xlabel("Log-Velocity (log(m/s))")

    for fish in result['metadata']['FISH']:
        data = fish['path']['V']
        n, bin_c = make_hist(data)
        axs[1,1].plot(bin_c, n)
    axs[1,1].set_title('Fish Velocities between frames (per fish)')
    axs[1,1].set_xlabel("Velocity (m/s)")

    pdf.savefig(fig)
    plt.close(fig)
    

def make_hist(data):
    '''histogram and return vectors for plotting'''
    
    # figure out the bins
    min_bin = np.min(data)
    max_bin = np.max(data)
    PTS_PER_BIN = 6 #np.sqrt(len(data)) #300
    bin_sz = (max_bin-min_bin)/(len(data)/PTS_PER_BIN)
    bins = np.arange(min_bin-bin_sz,max_bin+2*bin_sz,bin_sz)
    bin_centers = (bins[0:-1]+bins[1:])/2 # bin centers
    
    # compute the histogram
    n,b = np.histogram(data,bins=bins,density=False)
    return n,bin_centers

def draw_fish_tracks(pdf, result, dataset, id):

    fish = result['metadata']['FISH'][id]
    start_frame = fish['START_FRAME']
    end_frame = fish['END_FRAME']

    print(fish)

    fig, ax = plt.subplots(figsize=STANDARD_FIG_SIZE)
    plt.axis('off')

    w, h = fish['scaled_frame_size']
    if (fish['base_frame'] is not None):
        img = fish['base_frame']
        img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
        plt.imshow(img, extent=(0, h, 0, w), cmap=plt.colormaps['Greys'])

    # Title
    plt.text(h/2,2,f'Fish {id+1} (frames {start_frame}-{end_frame})',fontsize=40, color="black", horizontalalignment='center', zorder=5)

    X = fish['path']['X']
    Y = fish['path']['Y']
    certainty = fish['path']['certainty']
    
    plt.text(h*(1-Y[0]), w*(1-X[0]), "Start", fontsize=15, weight="bold")
    plt.text(h*(1-Y[-1]), w*(1-X[-1]), "End", fontsize=15, weight="bold")

    colors = [""]
    for i in range(1, len(X)):

        certain = certainty[i]
        fully_certain = certain
        half_certain = certain
        if i > 0: 
            fully_certain &= certainty[i-1]
            half_certain |= certainty[i-1]
            
        #color = 'yellow' if certain else 'orangered'
        #plt.plot(h*(1-y), w*(1-x), marker='o', markersize=3, color=color, zorder=3)
        col = 'yellow' if fully_certain else ('darkorange' if half_certain else 'orangered')
        colors.append(col)
        ax.plot([h*(1-Y[i-1]), h*(1-Y[i])], [w*(1-X[i-1]), w*(1-X[i])], color=col, linewidth=1)
    
    for i in range(1, len(X)):
        ax.plot(h*(1-Y[i]), w*(1-X[i]), color=colors[i], marker='o', markersize=3)
    

    plt.ylim([0, w])
    plt.xlim([0, h])
    pdf.savefig(fig)
    plt.close(fig)



    if (dataset is not None):
        indices = [start_frame, int(2/3*start_frame + end_frame/3), int(1/3*start_frame + 2/3*end_frame), end_frame]
        fig, axs = plt.subplots(2, len(indices), sharex=False, sharey=False, figsize=STANDARD_FIG_SIZE)
        
        print("id", id)
        print('indices', indices)
        for i, frame_index in enumerate(indices):
            img = dataset.didson.load_frames(start_frame=frame_index, end_frame=frame_index+1)[0]
            box = None
            for fi in range(frame_index, min(frame_index+10, len(result['frames']))):
                for ann in result['frames'][fi]['fish']:
                    if ann['fish_id'] == id+1:
                        box = ann['bbox']
                        frame_index = fi
                        break
            
            print("box", i, box)
            if box is not None:
                h, w = img.shape
                print(w, h)
                x1, x2, y1, y2 = int(box[0]*w), int(box[2]*w), int(box[1]*h), int(box[3]*h)
                cx, cy = int((x2 + x1)/2), int((y2 + y1)/2)
                s = min(int(max(x2 - x1, y2 - y1)*5/2), cx, cy, w-cx, h-cy)
                print(x1, x2, y1, y2)
                print(cx, cy, s)
                cropped_img = img[cy-s:cy+s, cx-s:cx+s]
                axs[0, i].imshow(cropped_img, extent=(cx-s, cx+s, cy-s, cy+s), cmap=plt.colormaps['Greys_r'])
                axs[0, i].plot([x1, x1, x2, x2, x1], [y1, y2, y2, y1, y1], color="red")
                axs[0, i].set_title('Frame ' + str(frame_index))
        
        pdf.savefig(fig)
        plt.close(fig)