import cv2
import mediapipe as mp
import numpy as np
import gradio as gr

# Initialize MediaPipe Pose
mp_pose = mp.solutions.pose
pose = mp_pose.Pose(
    static_image_mode=False,
    model_complexity=1,
    enable_segmentation=False,
    min_detection_confidence=0.5,
    min_tracking_confidence=0.5
)

# Initialize MediaPipe Face Mesh
mp_face_mesh = mp.solutions.face_mesh
face_mesh = mp_face_mesh.FaceMesh(
    static_image_mode=False,
    max_num_faces=1,
    refine_landmarks=True,
    min_detection_confidence=0.5,
    min_tracking_confidence=0.5
)

def process_frame(image):
    """
    Processes a frame by:
      1. Converting RGB to BGR for OpenCV.
      2. Flipping the frame horizontally for a mirror view.
      3. Creating a black background.
      4. Drawing body landmarks and computing shoulder center.
      5. Drawing facial mesh and extracting chin point.
      6. Drawing a neck line from shoulder center to chin.
      7. Converting the result back to RGB.
    """
    frame = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
    frame = cv2.flip(frame, 1)  # Flip horizontally for mirror effect
    output = np.zeros_like(frame)
    rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

    # --- Body Posture Analysis ---
    pose_results = pose.process(rgb_frame)
    shoulder_center = None
    if pose_results.pose_landmarks:
        h, w, _ = frame.shape
        landmarks = [(int(lm.x * w), int(lm.y * h)) for lm in pose_results.pose_landmarks.landmark]

        for connection in mp_pose.POSE_CONNECTIONS:
            start_idx, end_idx = connection
            if start_idx >= 11 and end_idx >= 11:
                if start_idx < len(landmarks) and end_idx < len(landmarks):
                    cv2.line(output, landmarks[start_idx], landmarks[end_idx], (255, 255, 0), 2)

        for i, pt in enumerate(landmarks):
            if i >= 11:
                cv2.circle(output, pt, 3, (255, 255, 0), -1)

        if len(landmarks) > 12:
            left_shoulder, right_shoulder = landmarks[11], landmarks[12]
            shoulder_center = ((left_shoulder[0] + right_shoulder[0]) // 2,
                               (left_shoulder[1] + right_shoulder[1]) // 2)
            cv2.circle(output, shoulder_center, 4, (0, 255, 255), -1)

    # --- Facial Mesh Analysis ---
    chin_point = None
    fm_results = face_mesh.process(rgb_frame)
    if fm_results.multi_face_landmarks:
        for face_landmarks in fm_results.multi_face_landmarks:
            h, w, _ = frame.shape
            fm_points = [(int(lm.x * w), int(lm.y * h)) for lm in face_landmarks.landmark]

            for connection in mp_face_mesh.FACEMESH_TESSELATION:
                start_idx, end_idx = connection
                if start_idx < len(fm_points) and end_idx < len(fm_points):
                    cv2.line(output, fm_points[start_idx], fm_points[end_idx], (0, 0, 255), 1)

            for pt in fm_points:
                cv2.circle(output, pt, 2, (0, 255, 0), -1)

            if len(face_landmarks.landmark) > 152:
                lm = face_landmarks.landmark[152]
                chin_point = (int(lm.x * w), int(lm.y * h))
                cv2.circle(output, chin_point, 4, (0, 0, 255), -1)
            break  # Process only the first detected face.

    # --- Draw Neck Line ---
    if shoulder_center and chin_point:
        cv2.line(output, shoulder_center, chin_point, (0, 255, 255), 2)

    return cv2.cvtColor(output, cv2.COLOR_BGR2RGB)


# --- Gradio Interface for Live Webcam Inference ---
iface = gr.Interface(
    fn=process_frame,
    inputs=gr.Image(sources=["webcam"], streaming=True, label="Webcam Input"),
    outputs=gr.Image(type="numpy", label="Processed Output"),
    live=True,
    title="Live Body Posture & Neck Analysis (Mirror View)",
    description="Real-time webcam analysis using MediaPipe Pose and Face Mesh with live inference and mirrored camera view."
)

iface.launch()