import cv2
from ultralytics import YOLO
import os
from dotenv import load_dotenv
from pathlib import Path
import math
import json
import numpy as np

env_path = Path('.') / '.env'
load_dotenv(dotenv_path=env_path)

path = {
    'DET_MODEL_PATH': str(os.getenv('DET_MODEL_PATH')),
    'IMG_DIR_PATH': str(os.getenv('IMG_DIR_PATH')),
    'ACTIVITY_DET_MODEL_PATH':str(os.getenv('ACTIVITY_DET_MODEL_PATH')),
}
#constants
PERSON_HEIGHT = 1.5
VEHICAL_HEIGHT = 1.35
ANIMAL_HEIGHT = 0.6
FOCAL_LENGTH = 6400
# CONF = 0.0

#Load models
det_model = YOLO(path['DET_MODEL_PATH'])
activity_det_model = YOLO(path['ACTIVITY_DET_MODEL_PATH'])

activity_classes = ['Standing','Running','Sitting']

def object_detection(image):
    
    """
    Args:
        image (numpy array): get numpy array of image which has 3 channels

    Returns:
        new_boxes: returns json object which has below format
        [
            {
                "actual_boundries": [
                    {
                        "top_left": [48, 215],
                        "bottom_right": [62, 245],
                        "class": "person"
                    }
                ],
                "updated_boundries": {
                    "top_left": [41, 199],
                    "bottom_right": [73, 269],
                    "person_count": 1,
                    "vehical_count": 0,
                    "animal_count": 0
                }
            }
        ]
    """

    #detect object using yolo model
    results = det_model(image)

    boxes = results[0].boxes.xyxy.tolist()
    classes = results[0].boxes.cls.tolist()
    names = results[0].names
    confidences = results[0].boxes.conf.tolist()
    ctr = 0
    my_boxes = []  # ((x1, y1), (x2,y2), person_count, vehical_count, animal_count)

    for box, cls, conf in zip(boxes, classes, confidences):
        x1, y1, x2, y2 = box
        name = names[int(cls)]
        my_obj = {"actual_boundries": [{"top_left": (int(x1), int(y1)),
                                        "bottom_right": (int(x2), int(y2)),
                                        "class": name}]}
        # img = cv2.imread(img_path)
        x1 = max(0, x1 - (x2-x1)/2)
        y1 = max(0, y1 - (y2-y1)/2)
        x2 = min(len(image[0])-1, x2 + (x2-x1)/2)
        y2 = min(len(image)-1, y2 + (y2-y1)/2)
        x1, y1, x2, y2 = math.floor(x1), math.floor(y1), math.ceil(x2), math.ceil(y2)
        # image = cv2.rectangle(image, (x1, y1), (x2, y2), (255, 0, 0), 2)
        my_obj["updated_boundries"] = {"top_left": (x1, y1),
                                        "bottom_right": (x2, y2),
                                        "person_count": 1 if name == 'person' else 0,
                                        "vehical_count": 1 if name == 'vehical' else 0,
                                        "animal_count": 1 if name == 'animal' else 0}
        my_boxes.append(my_obj)
        ctr += 1
    my_boxes.sort(key=lambda x: (x['updated_boundries']['top_left'], x['updated_boundries']['bottom_right']))

    new_boxes = []
    if len(my_boxes) > 0:
        new_boxes.append(my_boxes[0])

    for indx, box in enumerate(my_boxes):
        if indx != 0:
            top_left_last = new_boxes[-1]['updated_boundries']['top_left']
            bottom_right_last = new_boxes[-1]['updated_boundries']['bottom_right']
            top_left_curr = box['updated_boundries']['top_left']
            bottom_right_curr = box['updated_boundries']['bottom_right']

            if bottom_right_last[0] >= top_left_curr[0] and bottom_right_last[1] >= top_left_curr[1]:
                new_x1 = min(top_left_last[0], top_left_curr[0])
                new_y1 = min(top_left_last[1], top_left_curr[1])
                new_x2 = max(bottom_right_last[0], bottom_right_curr[0])
                new_y2 = max(bottom_right_last[1], bottom_right_curr[1])
                
                new_boxes[-1]['actual_boundries'] += box['actual_boundries']
                new_boxes[-1]['updated_boundries'] = {"top_left": (new_x1, new_y1), 
                                                    "bottom_right": (new_x2, new_y2), 
                                                    "person_count": new_boxes[-1]['updated_boundries']['person_count'] + box['updated_boundries']['person_count'], 
                                                    "vehical_count": new_boxes[-1]['updated_boundries']['vehical_count'] + box['updated_boundries']['vehical_count'], 
                                                    "animal_count": new_boxes[-1]['updated_boundries']['animal_count'] + box['updated_boundries']['animal_count']}
            else:
                new_boxes.append(box)
            
    return new_boxes

def croped_images(image,new_boxes):
    """_summary_

    Args:
        image (numpy array): get numpy array of image which has 3 channels
        new_boxes (json array): get json array 

    Returns:
        croped_images_list(list of numpy array): returns list which has croped images
        single_object_images(list of numpy array): returns list which has single object images
    """
    croped_images_list = []
    single_object_images = []

    for data in new_boxes:
        print(data['updated_boundries'])
        crop_image = image[data['updated_boundries']['top_left'][1]:data['updated_boundries']['bottom_right'][1],data['updated_boundries']['top_left'][0]:data['updated_boundries']['bottom_right'][0]]
        croped_images_list.append(crop_image)

        for object in data['actual_boundries']:
            if object['class']=='person':
                crop_object= image[object['top_left'][1]:object['bottom_right'][1],object['top_left'][0]:object['bottom_right'][0]]
                single_object_images.append(crop_object)


    return croped_images_list,single_object_images

def image_enhancements(croped_images_list,single_object_images):
    """_summary_

    Args:
        croped_images_list (list numpy array): croped images list 
        single_object_images (list numpy array): single object images list

    Returns:
        enhanced croped images: returns enhanced images
        enhanced single_object_images: returns enhanced images
    """
    enhanced_images = []
    enhanced_single_object_images = []

    for image in croped_images_list:

        # resize the image
        res = cv2.resize(image,(500*image.shape[1]//image.shape[0],500), interpolation = cv2.INTER_CUBIC)
       
        # brightness and contrast
        brightness = 16
        contrast = 0.95
        res2 = cv2.addWeighted(res, contrast, np.zeros(res.shape, res.dtype), 0, brightness) 

        # Sharpen the image 
        kernel = np.array([[0, -1, 0], [-1, 5, -1], [0, -1, 0]]) 
        sharpened_image = cv2.filter2D(res2, -1, kernel) 

        #append in the list
        enhanced_images.append(sharpened_image)

    
    for image in single_object_images:

        # resize the image
        res = cv2.resize(image,(500*image.shape[1]//image.shape[0],500), interpolation = cv2.INTER_CUBIC)   

        # brightness and contrast
        brightness = 16
        contrast = 0.95 
        res2 = cv2.addWeighted(res, contrast, np.zeros(res.shape, res.dtype), 0, brightness) 
        
        # Sharpen the image 
        kernel = np.array([[0, -1, 0], [-1, 5, -1], [0, -1, 0]]) 
        sharpened_image = cv2.filter2D(res2, -1, kernel) 

        #append enhnaced single object image
        enhanced_single_object_images.append(sharpened_image)
    
    return enhanced_images,enhanced_single_object_images


def detect_activity(single_object_images):
    """_summary_

    Args:
        single_object_images (list of numpy array): list of single object images

    Returns:
        activities(list of strings): returns list of activities perform by person
    """
    activities = []

    for img in single_object_images:

        predictions =activity_det_model.predict(img)

        for result in predictions:

            probs = result.probs
            class_index = probs.top1

        activities.append(activity_classes[class_index])

    return activities


def get_distances(new_boxes):
    """_summary_

    Args:
        new_boxes (json array): takes json array of detected image's data

    Returns:
        distance_list: list of distances of each object
    """
  
    distance_list = []
    for box in new_boxes:
        for actual_box in box['actual_boundries']:
            height = actual_box['bottom_right'][1] - actual_box['top_left'][1]

            if actual_box['class'] == "person":
                distance = FOCAL_LENGTH*PERSON_HEIGHT/height
                
            elif actual_box['class'] == "vehical":
                distance = FOCAL_LENGTH*PERSON_HEIGHT/height

            else:
                distance = FOCAL_LENGTH*PERSON_HEIGHT/height

            distance_list.append(str(round(distance)) + "m")

    return distance_list


def get_json_data(json_data,enhanced_images,detected_activity,distances_list):
    """_summary_

    Args:
        json_data (json Array): get json data of image
        enhanced_images (list of numpy array): list of enhanced images 
        detected_activity (list of strings): list of activities of person
        distances_list (lsit of integers): list of distances of each object

    Returns:
        results(json Array): contains all informations needed for frontend 
                            {'zoomed_img':np.array([]) ,
                             'actual_boxes':[],
                             'updated_boxes':{},
                            }
    """
    results = []
    object_count = 0
    activity_count = 0
    for idx,box in enumerate(json_data):
        final_json_output = {'zoomed_img':np.array([]) ,
                            'actual_boxes':[],
                            'updated_boxes':{},
                            }

        final_json_output['zoomed_img'] = enhanced_images[idx]
        final_json_output['updated_boxes'] = { "top_left": box['updated_boundries']['top_left'],
                                                "bottom_right": box['updated_boundries']['bottom_right']}
        
        for actual_box in box['actual_boundries']:
            
            temp  = {"top_left": [],
                "bottom_right": [],
                "class": "",
                "distance":0,
                "activity":'none'}
            temp['top_left'] = actual_box['top_left']
            temp['bottom_right'] = actual_box['bottom_right']
            temp['class'] = actual_box['class']
            temp['distance'] = distances_list[object_count]
            object_count+=1
            
            if temp['class'] == 'person':
                temp['activity'] = detected_activity[activity_count]
                activity_count+=1
            
            final_json_output['actual_boxes'].append(temp)
            final_json_output = fix_distance(final_json_output)

        results.append(final_json_output)
    
    return results


def fix_distance(final_json_output):
    """_summary_

    Args:
        final_json_output (json Array): array of json object 

    Returns:
       final_json_output (json Array): array of json object
    """
    distances = []
    DIFF  = 90

    for idx,box in enumerate(final_json_output['actual_boxes']):
        distances.append({'idx':idx,'distance':int(box['distance'][:-1])})

    sorted_dist = sorted(distances, key=lambda d: d['distance'])
    sum_dist = []
    idx= 0
    sum_dist.append({'sum':sorted_dist[0]['distance'],'idxes':[sorted_dist[0]['idx']]})

    for i in range(1,len(sorted_dist)):
        print(sorted_dist[i]['distance'],sorted_dist[i-1]['distance'])
        if abs(sorted_dist[i]['distance']-sorted_dist[i-1]['distance']) <=DIFF:
            sum_dist[idx]['sum']+= sorted_dist[i]['distance']
            sum_dist[idx]['idxes'].append(sorted_dist[i]['idx'])
        
        else:
            sum_dist.append({'sum':sorted_dist[i]['distance'],'idxes':[sorted_dist[i]['idx']]})
            idx+=1

    #change values in distance array 
    for data in sum_dist:
        count  = len(data['idxes'])
        mean = data['sum']//count
        for i in data['idxes']:
            final_json_output['actual_boxes'][i]['distance'] = str(mean)+'m'
        
    return final_json_output