import cv2
import numpy as np
import matplotlib.pyplot as plt
from dtaidistance import dtw
from easy_ViTPose.inference import VitInference
import os
import requests
from pathlib import Path
from datetime import timedelta
from scipy.signal import savgol_filter
from scipy.stats import mstats


def predict_keypoints_vitpose(
        video_path, 
        model_path, 
        model_name,
        detector_path, 
        display_video=False
):

    model = VitInference(
        model=model_path, 
        yolo=detector_path, 
        model_name=model_name,
        det_class=None,
        dataset=None,
        yolo_size=320, 
        is_video=False,
        single_pose=False,
        yolo_step=1
    )

    cap = cv2.VideoCapture(video_path)
    detection_results = []
    while True:
        ret, frame = cap.read()
        if not ret:
            print(f"Keypoints were extracted from {video_path}")
            break
        
        frame = cv2.resize(frame, (1280, 720))
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        frame_keypoints = model.inference(frame)

        if 0 in frame_keypoints:
            detection_results.append(frame_keypoints[0])

        if display_video:
            frame = model.draw(False, False, 0.5)[..., ::-1]

            if display_video:
                cv2.imshow('preview', frame)

                if cv2.waitKey(1) & 0xFF == ord('q'):
                    break

    if display_video:     
        cap.release()
        cv2.destroyAllWindows()

    return np.array(detection_results)


def get_point_list_vitpose(detection_results):
    return np.array(detection_results)[:, :, :-1]


def get_edge_groups(connections):

    all_pairs = []
    for i in range(len(connections)):
        pairs = []
        init_con = connections[i]
        for k in range(len(connections)):
            if k == i:
                pass
            candidat_con = connections[k]

            point_1_init, point_2_init = init_con
            point_1_candidat, point_2_candidat = candidat_con

            if point_1_candidat == point_1_init or point_1_candidat == point_2_init or point_2_candidat == point_1_init or point_2_candidat == point_2_init:
                pairs.append([init_con, candidat_con])
        all_pairs.append(pairs)

    all_point_for_edges = []
    for set_of_pairs in all_pairs:
        clean_pairs = []
        for pair in set_of_pairs:
            pair_a = pair[0]
            pair_b = pair[1]
            if len(list(set(pair_a + pair_b))) == 3:

                center = int(list(set(pair_a) & set(pair_b))[0])
                edges = list(set(pair_a) ^ set(pair_b))
                points_for_edge = [edges[0], center, edges[1]]
                clean_pairs.append(points_for_edge)
        all_point_for_edges.extend(clean_pairs)

    unique_set = set()
    unique_list = []
    for sublist in all_point_for_edges:
        sublist_tuple = tuple(sublist)
        if sublist_tuple not in unique_set:
            unique_set.add(sublist_tuple)
            unique_list.append(sublist)
            
    unique_list.sort() 

    return unique_list


def calculate_angle(A, B, C):
   
    A = np.round(np.array(A), decimals=3)
    B = np.round(np.array(B), decimals=3)
    C = np.round(np.array(C), decimals=3)

    BA = A - B
    BC = C - B

    cosine_angle = np.dot(BA, BC) / ((np.linalg.norm(BA) * np.linalg.norm(BC)))
    cosine_angle = np.clip(cosine_angle, -1, 1)
    angle = np.arccos(cosine_angle)

    if np.isnan(angle):
        print(f"Invalid angle calculation.\n{A} \n{B} \n{C}")

    minimum = np.min(np.array((np.linalg.norm(BA), np.linalg.norm(BC))))

    return np.degrees(angle), minimum


def compute_all_angels(keypoints, edge_groups):

    all_angles = []
    for group in edge_groups:
        
        A = keypoints[group[0]]
        B = keypoints[group[1]]
        C = keypoints[group[2]]

        angle, minimum = calculate_angle(A, B, C)
        all_angles.append([angle, minimum])

    return np.array(all_angles)


def xy2phi(points_result, connections):

    edge_groups = get_edge_groups(connections)
    new_array = np.zeros((points_result.shape[0], len(edge_groups), 1))

    for idx, frame in enumerate(points_result):
        all_angels = compute_all_angels(keypoints=frame, edge_groups=edge_groups)[:, 0]
        new_array[idx, :, :] = all_angels.reshape((len(edge_groups), 1))

    return new_array


def get_series(point_list, edge_groups):

    list_of_series = []
    for edge_group in edge_groups:

        keypoint_1, keypoint_2, keypoint_3 = edge_group
        relevant_point_list = point_list[:, (keypoint_1, keypoint_2, keypoint_3), :]

        series = []
        for frame in relevant_point_list:
            angle, _ = calculate_angle(frame[0, :], frame[1, :], frame[2, :])
            series.append(angle)
        list_of_series.append(series)
        
    return np.array(list_of_series)


def plot_serieses(series_1, series_2):

    plt.figure(dpi=150, figsize=(12, 5))
    plt.plot(series_1, label='Video #1', lw=1)
    plt.plot(series_2, label='Video #2', lw=1)
    plt.axis("on") 
    plt.grid(True)  
    plt.xlabel("frames")  
    plt.ylabel("angles")
    plt.legend() 


def z_score_normalization(serieses, axis_for_znorm=1):

    serieses_mean = np.mean(serieses, axis=axis_for_znorm, keepdims=True)
    serieses_std = np.std(serieses, axis=axis_for_znorm, keepdims=True)
    serieses_normalized = (serieses - serieses_mean) / serieses_std

    return serieses_normalized 


def get_dtw_mean_path(serieses_teacher, serieses_student, dtw_mean, dtw_filter):
    
    list_of_paths = []
    for idx in range(len(serieses_teacher)):
        series_teacher = np.array(serieses_teacher[idx])
        series_student  = np.array(serieses_student[idx])
        _ , paths = dtw.warping_paths(series_teacher, series_student, window=50)
        path = dtw.best_path(paths)
        list_of_paths.append(path)

    all_dtw_tupples = []
    for path in list_of_paths:
        all_dtw_tupples.extend(path)

    mean_path = []
    for student_frame in range(len(serieses_student[0])):
        frame_from_teacher = []
        for frame_teacher in all_dtw_tupples:
            if frame_teacher[1] == student_frame:
                frame_from_teacher.append(frame_teacher[0])

        mean_path.append((int(mstats.winsorize(np.array(frame_from_teacher), limits=[dtw_mean, dtw_mean]).mean()), student_frame))

    path_array = np.array(mean_path)
    smoothed_data = savgol_filter(path_array, window_length=dtw_filter, polyorder=0, axis=0)
    path_array = np.array(smoothed_data).astype(int)

    alignments = np.unique(path_array, axis=0) # TODO check if this correct

    return alignments


def modify_student_frame(
    detection_result_teacher,
    detection_result_student,
    detection_result_teacher_angles,
    detection_result_student_angles,
    video_teacher,
    video_student,
    alignment_frames,
    edge_groups,
    connections,
    thresholds,
    previously_trigered,
    previously_trigered_2,
    triger_state,
    show_arrows,
    text_dictionary,
):
    arrows_bgr = (175, 75, 190)
    arrows_sz = 3
    skeleton_bgr = (0, 0, 255)
    skeleton_sz = 3

    frame_copy = video_student[alignment_frames[1]]
    frame_teacher_copy = video_teacher[alignment_frames[0]]
    frame_errors = np.abs(detection_result_teacher_angles[alignment_frames[0]] - detection_result_student_angles[alignment_frames[1]])
    edge_groups_as_keys = [tuple(group) for group in edge_groups]
    edge_groups2errors = dict(zip(edge_groups_as_keys, frame_errors))
    edge_groups2thresholds = dict(zip(edge_groups_as_keys, thresholds))
    edge_groups_relevant = [edge_group[1:] for edge_group in edge_groups]

    text_info = []
    trigered_connections = []
    trigered_connections2 = []
    for connection in connections:

        edges_for_given_connection = [edge for edge in edge_groups2errors if connection[0] in edge or connection[1] in edge]

        for edge in edges_for_given_connection:

            check_threshold = edge_groups2errors[edge] > edge_groups2thresholds[edge]
            check_certain = True
            for keypoint in edge:
                prob = detection_result_student[:, :,-1][alignment_frames[1]][keypoint]
                if prob < 0.7:
                    check_certain = False

            relevant_plane = [connection[0], connection[1]] in edge_groups_relevant or [connection[1], connection[0]] in edge_groups_relevant

            if check_threshold and check_certain and relevant_plane:

                point1, point2, point2_t = align_points(
                    detection_result_student, 
                    detection_result_teacher, 
                    alignment_frames, 
                    edge
                )

                arrow = get_arrow_direction(point2, point2_t)

                if triger_state == "one":

                    _ = cv2.line(frame_copy, point1, point2, skeleton_bgr, skeleton_sz)

                    if show_arrows:
                        _ = cv2.arrowedLine(frame_copy, point2, point2_t, arrows_bgr, arrows_sz) 

                    if (connection[0], connection[1]) in text_dictionary:
                        text_info.append((text_dictionary[(connection[0], connection[1])], arrow))
                                
                    if (connection[1], connection[0]) in text_dictionary:
                        text_info.append((text_dictionary[(connection[1], connection[0])], arrow))

                if triger_state == "two":

                    trigered_connections.append((connection[0], connection[1]))

                    if (connection[0], connection[1]) in previously_trigered:

                        _ = cv2.line(frame_copy, point1, point2, skeleton_bgr, skeleton_sz)

                        if show_arrows:
                            _ = cv2.arrowedLine(frame_copy, point2, point2_t, arrows_bgr, arrows_sz) 

                        if (connection[0], connection[1]) in text_dictionary:
                            text_info.append((text_dictionary[(connection[0], connection[1])], arrow))
                                
                        if (connection[1], connection[0]) in text_dictionary:
                            text_info.append((text_dictionary[(connection[1], connection[0])], arrow))

                if triger_state == "three":

                    trigered_connections.append((connection[0], connection[1]))

                    if (connection[0], connection[1]) in previously_trigered:

                        trigered_connections2.append((connection[0], connection[1]))

                        if (connection[0], connection[1]) in previously_trigered_2:

                            _ = cv2.line(frame_copy, point1, point2, skeleton_bgr, skeleton_sz)

                            if show_arrows:
                                _ = cv2.arrowedLine(frame_copy, point2, point2_t, arrows_bgr, arrows_sz) 

                            if (connection[0], connection[1]) in text_dictionary:
                                text_info.append((text_dictionary[(connection[0], connection[1])], arrow))

                            if (connection[1], connection[0]) in text_dictionary:
                                text_info.append((text_dictionary[(connection[1], connection[0])], arrow))
           
    return frame_copy, frame_teacher_copy, list(set(trigered_connections)), list(set(trigered_connections2)), text_info


def get_video_frames(video_path):
    cap = cv2.VideoCapture(video_path)
    video = []
    while True:
        ret, frame = cap.read()
        if not ret:
            print(f"Video {video_path} was loaded")
            break
        frame = cv2.resize(frame, (1280, 720))
        video.append(frame)

    return np.array(video)


def download_file(url, save_path):
    response = requests.get(url, stream=True)
    response.raise_for_status()
    with open(save_path, 'wb') as file:
        for chunk in response.iter_content(chunk_size=8192):
            file.write(chunk)


def check_and_download_models():
    
    # vit_model_s_url = "https://huggingface.co/JunkyByte/easy_ViTPose/resolve/main/torch/wholebody/vitpose-s-wholebody.pth?download=true"
    vit_model_b_url = "https://huggingface.co/JunkyByte/easy_ViTPose/resolve/main/torch/wholebody/vitpose-b-wholebody.pth?download=true"
    # vit_model_l_url = "https://huggingface.co/JunkyByte/easy_ViTPose/resolve/main/torch/wholebody/vitpose-l-wholebody.pth?download=true"

    yolo_model_url = "https://huggingface.co/JunkyByte/easy_ViTPose/resolve/main/yolov8/yolov8s.pt?download=true"

    # vit_model_s_path = "models/vitpose-s-wholebody.pth"
    vit_model_b_path = "models/vitpose-b-wholebody.pth"
    # vit_model_l_path = "models/vitpose-l-wholebody.pth"

    yolo_model_path = "models/yolov8s.pt"

    # Path(os.path.dirname(vit_model_s_path)).mkdir(parents=True, exist_ok=True)
    Path(os.path.dirname(vit_model_b_path)).mkdir(parents=True, exist_ok=True)
    # Path(os.path.dirname(vit_model_l_path)).mkdir(parents=True, exist_ok=True)

    Path(os.path.dirname(yolo_model_path)).mkdir(parents=True, exist_ok=True)

    # if not os.path.exists(vit_model_s_path):
    #     print("Downloading ViT-Pose-s model...")
    #     download_file(vit_model_s_url, vit_model_s_path)
    #     print("ViT-Pose-s model was downloaded.")
    
    if not os.path.exists(vit_model_b_path):
        print("Downloading ViT-Pose-b model...")
        download_file(vit_model_b_url, vit_model_b_path)
        print("ViT-Pose-b model was downloaded.")

    # if not os.path.exists(vit_model_l_path):
    #     print("Downloading ViT-Pose-l model...")
    #     download_file(vit_model_l_url, vit_model_l_path)
    #     print("ViT-Pose-l model was downloaded.")
    
    if not os.path.exists(yolo_model_path):
        print("Downloading YOLO model...")
        download_file(yolo_model_url, yolo_model_path)
        print("YOLO model was downloaded.")


def generate_output_video(teacher_frames, student_frames, timestamp_str):

    teacher_frames = np.array(teacher_frames)
    student_frames = np.array(student_frames)

    teacher_frames_resized = np.array([cv2.resize(frame, (1280, 720)) for frame in teacher_frames])
    student_frames_resized = np.array([cv2.resize(frame, (1280, 720)) for frame in student_frames])

    concat_video = np.concatenate((teacher_frames_resized, student_frames_resized), axis=2)
    concat_video = np.array(concat_video)

    root_dir = "videos"
    if not os.path.exists(root_dir):
        os.makedirs(root_dir)

    video_path = f"{root_dir}/pose_{timestamp_str}.mp4"
    out = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), 30, (1280 * 2, 720))
    for frame in concat_video:
        out.write(frame)
    out.release()

    return video_path


def generate_log(all_text_summaries):

    all_text_summaries_clean = list(set(all_text_summaries))
    all_text_summaries_clean.sort(key=lambda x: x[1])

    general_summary = []
    for log in all_text_summaries_clean:
        comment, frame, arrow = log
        total_seconds = frame / 30
        general_summary.append(f"{comment}. Direction: {arrow}. Video time: {str(timedelta(seconds=total_seconds))[3:-4]}")

    general_summary = "\n".join(general_summary)

    return general_summary


def write_log(
    timestamp_str, 
    dtw_mean, 
    dtw_filter, 
    angles_sensitive, 
    angles_common, 
    angles_insensitive,
    trigger_state,
    general_summary
):

    logs_dir = "logs"
    if not os.path.exists(logs_dir):
        os.makedirs(logs_dir)

    log_path = f"{logs_dir}/log_{timestamp_str}.txt"

    content = f"""
Settings:

Dynamic Time Warping:
- Winsorize mean: {dtw_mean}
- Savitzky-Golay Filter: {dtw_filter}

Thresholds:
- Sensitive: {angles_sensitive}
- Standart: {angles_common}
- Insensitive: {angles_insensitive}

Patience:
- trigger count: {trigger_state}


Error logs:

{general_summary}
"""

    with open(log_path, "w") as file:
        file.write(content)

    print(f"log {log_path} was created.")

    return log_path


def angle_between(v1, v2):
    return np.arctan2(v2[1], v2[0]) - np.arctan2(v1[1], v1[0])


def align_points(detection_result_student, detection_result_teacher, alignment_frames, edge):

    point0 = detection_result_student[alignment_frames[1], edge[0], :-1].astype(int)[::-1]
    point1 = detection_result_student[alignment_frames[1], edge[1], :-1].astype(int)[::-1]
    point2 = detection_result_student[alignment_frames[1], edge[2], :-1].astype(int)[::-1]

    point0_t = detection_result_teacher[alignment_frames[0], edge[0], :-1].astype(int)[::-1]
    point1_t = detection_result_teacher[alignment_frames[0], edge[1], :-1].astype(int)[::-1]
    point2_t = detection_result_teacher[alignment_frames[0], edge[2], :-1].astype(int)[::-1]

    translation = point0 - point0_t

    point0_t += translation
    point1_t += translation
    point2_t += translation

    BsA = point1 - point0
    BtA = point1_t - point0

    theta = angle_between(BtA, BsA)

    R = np.array([
        [np.cos(theta), -np.sin(theta)],
        [np.sin(theta), np.cos(theta)]
    ])

    point1_t = np.dot(R, (point1_t - point0).T).T + point0
    point2_t = np.dot(R, (point2_t - point0).T).T + point0

    point2_t = point2_t.astype(int)

    return point1, point2, point2_t


def get_arrow_direction(A, B):

    translation_vector = B - A
    angle_deg = np.degrees(np.arctan2(translation_vector[0], translation_vector[1]))

    match angle_deg:
        case angle if -22.5 <= angle < 22.5:
            arrow = "⬆"
        case angle if 22.5 <= angle < 67.5:
            arrow = "⬈"
        case angle if 67.5 <= angle < 112.5:
            arrow = "➡"
        case angle if 112.5 <= angle < 157.5:
            arrow = "⬊"
        case angle if 157.5 <= angle or angle < -157.5:
            arrow = "⬇"
        case angle if -157.5 <= angle < -112.5:
            arrow = "⬋"
        case angle if -112.5 <= angle < -67.5:
            arrow = "⬅"
        case angle if -67.5 <= angle < -22.5:
            arrow = "⬉"
        case _:
            arrow = ""

    return arrow