import math

import mediapipe as mp
from mediapipe.tasks import python
from mediapipe.tasks.python import vision
from mediapipe.framework.formats import landmark_pb2
from mediapipe import solutions
import numpy as np

# 2024-11-27 -extract_landmark :add args 
# add get_pixel_xyz
# 2024-11-28 add get_normalized_xyz
# 2024-11-30 add get_normalized_landmarks,sort_triangles_by_depth,get_range_all,get_bbox
def calculate_distance(p1, p2):
  """

  """
  return math.sqrt((p2[0] - p1[0])**2 + (p2[1] - p1[1])**2)

def get_range_all():
   return range(0,468)

def get_landmark_bbox(face_landmarks_list,w=1024,h=1024,margin_hw=0,margin_hh=0):
   return get_bbox(face_landmarks_list,get_range_all(),w,h,margin_hw,margin_hh)
def get_bbox(face_landmarks_list,indices,w=1024,h=1024,margin_hw=0,margin_hh=0):
  x1=w
  y1=h
  x2=0
  y2=0
  for index in indices:
    x=min(w,max(0,(face_landmarks_list[0][index].x*w)))
    y=min(h,max(0,(face_landmarks_list[0][index].y*h)))
    if x<x1:
      x1=x

    if y<y1:
      y1=y
    
    if x>x2:
      x2=x
    if y>y2:
      y2=y
  return [max(0,int(x1)-margin_hw),max(0,int(y1)-margin_hh),min(w,int(x2-x1)+margin_hw),min(h,int(y2-y1)+margin_hh)]

def to_int_points(points):
    ints=[]
    for pt in points:
        #print(pt)
        value = [int(pt[0]),int(pt[1])]
        #print(value)
        ints.append(value)
    return ints

debug = False
def divide_line_to_points(points,divided): # return divided + 1
    total_length = 0
    line_length_list = []
    for i in range(len(points)-1):
        pt_length = calculate_distance(points[i],points[i+1])
        total_length += pt_length
        line_length_list.append(pt_length)
     
    splited_length = total_length/divided

    def get_new_point(index,lerp):
        pt1 = points[index]
        pt2 = points[index+1]
        diff = [pt2[0] - pt1[0], pt2[1]-pt1[1]]
        new_point = [pt1[0]+diff[0]*lerp,pt1[1]+diff[1]*lerp]
        if debug:
          print(f"pt1 ={pt1}  pt2 ={pt2} diff={diff} new_point={new_point}")

        return new_point

    if debug:
      print(f"{total_length} splitted = {splited_length} line-length-list = {len(line_length_list)}")
    splited_points=[points[0]]
    for i in range(1,divided):
        need_length = splited_length*i
        if debug:
          print(f"{i} need length = {need_length}")
        current_length = 0
        for j in range(len(line_length_list)):
            line_length = line_length_list[j]
            current_length+=line_length
            if current_length>need_length:
                if debug:
                  print(f"over need length index = {j} current={current_length}")
                diff = current_length - need_length
            
                lerp_point = 1.0 - (diff/line_length)
                if debug:
                  print(f"over = {diff} lerp ={lerp_point}")
                new_point = get_new_point(j,lerp_point)
                
                splited_points.append(new_point)
                break
        
    splited_points.append(points[-1]) # last one
    splited_points=to_int_points(splited_points)    
         
    if debug:
      print(f"sp={len(splited_points)}")
    return splited_points



def expand_bbox(bbox,left=5,top=5,right=5,bottom=5):
   left_pixel = bbox[2]*(float(left)/100)
   top_pixel = bbox[3]*(float(top)/100)
   right_pixel = bbox[2]*(float(right)/100)
   bottom_pixel = bbox[3]*(float(bottom)/100)
   new_box = list(bbox)
   new_box[0] -=left_pixel
   new_box[1] -=top_pixel
   new_box[2] +=left_pixel+right_pixel
   new_box[3] +=top_pixel+bottom_pixel
   return new_box

#normalized value index see mp_constants
def get_normalized_cordinate(face_landmarks_list,index):
    x=face_landmarks_list[0][index].x
    y=face_landmarks_list[0][index].y
    return x,y

def get_normalized_xyz(face_landmarks_list,index):
    x=face_landmarks_list[0][index].x
    y=face_landmarks_list[0][index].y
    z=face_landmarks_list[0][index].z
    return x,y,z

def get_normalized_landmarks(face_landmarks_list):
   return [get_normalized_xyz(face_landmarks_list,i) for i in range(0,468)]

def sort_triangles_by_depth(landmark_points,mesh_triangle_indices):
   assert len(landmark_points) == 468
   mesh_triangle_indices.sort(key=lambda triangle: sum(landmark_points[index][2] for index in triangle) / len(triangle)
                       ,reverse=True)
# z is normalized
def get_pixel_xyz(face_landmarks_list,landmark,width,height):
    point = get_normalized_cordinate(face_landmarks_list,landmark)
    z = y=face_landmarks_list[0][landmark].z
    return int(point[0]*width),int(point[1]*height),z

def get_pixel_cordinate(face_landmarks_list,landmark,width,height):
    point = get_normalized_cordinate(face_landmarks_list,landmark)
    return int(point[0]*width),int(point[1]*height)

def get_pixel_cordinate_list(face_landmarks_list,indices,width,height):
   cordinates = []
   for index in indices:
      cordinates.append(get_pixel_cordinate(face_landmarks_list,index,width,height))
   return cordinates

def extract_landmark(image_data,model_path="face_landmarker.task",min_face_detection_confidence=0, min_face_presence_confidence=0,output_facial_transformation_matrixes=False):
  BaseOptions = mp.tasks.BaseOptions
  FaceLandmarker = mp.tasks.vision.FaceLandmarker
  FaceLandmarkerOptions = mp.tasks.vision.FaceLandmarkerOptions
  VisionRunningMode = mp.tasks.vision.RunningMode

  options = FaceLandmarkerOptions(
      base_options=BaseOptions(model_asset_path=model_path),
      running_mode=VisionRunningMode.IMAGE
      ,min_face_detection_confidence=min_face_detection_confidence, min_face_presence_confidence=min_face_presence_confidence,
      output_facial_transformation_matrixes=output_facial_transformation_matrixes
      )
  
  with FaceLandmarker.create_from_options(options) as landmarker:
    if isinstance(image_data,str):
        mp_image = mp.Image.create_from_file(image_data)
    else:
        mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=np.asarray(image_data))
    face_landmarker_result = landmarker.detect(mp_image)
    return mp_image,face_landmarker_result