import time
import os
import logging


import av
import cv2
import numpy as np
import streamlit as st
from streamlit_webrtc import WebRtcMode, webrtc_streamer

from utils.download import download_file
from utils.turn import get_ice_servers

from mtcnn import MTCNN  # Import MTCNN for face detection
from PIL import Image, ImageDraw  # Import PIL for image processing
from transformers import pipeline  # Import Hugging Face transformers pipeline

import requests
from io import BytesIO  # Import for handling byte streams
import yt_dlp


# CHANGE CODE BELOW HERE, USE TO REPLACE WITH YOUR WANTED ANALYSIS.
# Update below string to set display title of analysis

# Appropriate imports needed for analysis

# Initialize MTCNN for face detection
mtcnn = MTCNN()

# Initialize the Hugging Face pipeline for facial emotion detection
emotion_pipeline = pipeline("image-classification",
                            model="trpakov/vit-face-expression")


# Default title - "Facial Sentiment Analysis"

ANALYSIS_TITLE = "Facial Sentiment Analysis"

# CHANGE THE CONTENTS OF THIS FUNCTION, USE TO REPLACE WITH YOUR WANTED ANALYSIS.
#
#
# Function to analyze an input frame and generate an analyzed frame
# This function takes an input video frame, detects faces in it using MTCNN,
# then for each detected face, it analyzes the sentiment (emotion) using the analyze_sentiment function,
# draws a rectangle around the face, and overlays the detected emotion on the frame.
# It also records the time taken to process the frame and stores it in a global container.
# Constants for text and line size in the output image
TEXT_SIZE = 1
LINE_SIZE = 2


# Set analysis results in img_container and result queue for display
# img_container["input"] - holds the input frame contents - of type np.ndarray
# img_container["analyzed"] - holds the analyzed frame with any added annotations - of type np.ndarray
# img_container["analysis_time"] - holds how long the analysis has taken in miliseconds
# img_container["detections"] - holds the analysis metadata results
def analyze_frame(frame: np.ndarray):
    start_time = time.time()  # Start timing the analysis
    img_container["input"] = frame  # Store the input frame
    frame = frame.copy()  # Create a copy of the frame to modify

    results = mtcnn.detect_faces(frame)  # Detect faces in the frame
    for result in results:
        x, y, w, h = result["box"]  # Get the bounding box of the detected face
        face = frame[y: y + h, x: x + w]  # Extract the face from the frame
        # Analyze the sentiment of the face
        sentiment = analyze_sentiment(face)
        result["label"] = sentiment
        # Draw a rectangle around the face
        cv2.rectangle(frame, (x, y), (x + w, y + h), (0, 0, 255), LINE_SIZE)
        text_size = cv2.getTextSize(sentiment, cv2.FONT_HERSHEY_SIMPLEX, TEXT_SIZE, 2)[
            0
        ]
        text_x = x
        text_y = y - 10
        background_tl = (text_x, text_y - text_size[1])
        background_br = (text_x + text_size[0], text_y + 5)
        # Draw a black background for the text
        cv2.rectangle(frame, background_tl, background_br,
                      (0, 0, 0), cv2.FILLED)
        # Put the sentiment text on the image
        cv2.putText(
            frame,
            sentiment,
            (text_x, text_y),
            cv2.FONT_HERSHEY_SIMPLEX,
            TEXT_SIZE,
            (255, 255, 255),
            2,
        )

    end_time = time.time()  # End timing the analysis
    execution_time_ms = round(
        (end_time - start_time) * 1000, 2
    )  # Calculate execution time in milliseconds
    # Store the execution time
    img_container["analysis_time"] = execution_time_ms

    # store the detections
    img_container["detections"] = results

    img_container["analyzed"] = frame  # Store the analyzed frame

    return  # End of the function


# Function to analyze the sentiment (emotion) of a detected face
# This function converts the face from BGR to RGB format, then converts it to a PIL image,
# uses a pre-trained emotion detection model to get emotion predictions,
# and finally returns the most dominant emotion detected.
def analyze_sentiment(face):
    # Convert face to RGB format
    rgb_face = cv2.cvtColor(face, cv2.COLOR_BGR2RGB)
    pil_image = Image.fromarray(rgb_face)  # Convert to PIL image
    results = emotion_pipeline(pil_image)  # Run emotion detection on the image
    dominant_emotion = max(results, key=lambda x: x["score"])[
        "label"
    ]  # Get the dominant emotion
    return dominant_emotion  # Return the detected emotion


#
#
# DO NOT TOUCH THE BELOW CODE (NOT NEEDED)
#
#

# Suppress FFmpeg logs
os.environ["FFMPEG_LOG_LEVEL"] = "quiet"

# Suppress Streamlit logs using the logging module
logging.getLogger("streamlit").setLevel(logging.ERROR)

# Container to hold image data and analysis results
img_container = {"input": None, "analyzed": None,
                 "analysis_time": None, "detections": None}

# Logger for debugging and information
logger = logging.getLogger(__name__)


# Callback function to process video frames
# This function is called for each video frame in the WebRTC stream.
# It converts the frame to a numpy array in RGB format, analyzes the frame,
# and returns the original frame.
def video_frame_callback(frame: av.VideoFrame) -> av.VideoFrame:
    # Convert frame to numpy array in RGB format
    img = frame.to_ndarray(format="rgb24")
    analyze_frame(img)  # Analyze the frame
    return frame  # Return the original frame


# Get ICE servers for WebRTC
ice_servers = get_ice_servers()

# Streamlit UI configuration
st.set_page_config(layout="wide")

# Custom CSS for the Streamlit page
st.markdown(
    """
    <style>
        .main {
            padding: 2rem;
        }
        h1, h2, h3 {
            font-family: 'Arial', sans-serif;
        }
        h1 {
            font-weight: 700;
            font-size: 2.5rem;
        }
        h2 {
            font-weight: 600;
            font-size: 2rem;
        }
        h3 {
            font-weight: 500;
            font-size: 1.5rem;
        }
    </style>
    """,
    unsafe_allow_html=True,
)

# Streamlit page title and subtitle
st.title("Computer Vision Playground")

# Add a link to the README file
st.markdown(
    """
    <div style="text-align: left;">
        <p>See the <a href="https://huggingface.co/spaces/eusholli/sentiment-analyzer/blob/main/README.md" 
        target="_blank">README</a> to learn how to use this code to help you start your computer vision exploration.</p>
    </div>
    """,
    unsafe_allow_html=True,
)

st.subheader(ANALYSIS_TITLE)

# Columns for input and output streams
col1, col2 = st.columns(2)

with col1:
    st.header("Input Stream")
    input_subheader = st.empty()
    input_placeholder = st.empty()  # Placeholder for input frame
    st.subheader("Input Options")
    # WebRTC streamer to get video input from the webcam
    webrtc_ctx = webrtc_streamer(
        key="input-webcam",
        mode=WebRtcMode.SENDONLY,
        rtc_configuration=ice_servers,
        video_frame_callback=video_frame_callback,
        media_stream_constraints={"video": True, "audio": False},
        async_processing=True,
    )

    # File uploader for images
    st.subheader("Upload an Image")
    uploaded_file = st.file_uploader(
        "Choose an image...", type=["jpg", "jpeg", "png"])

    # Text input for image URL
    st.subheader("Or Enter Image URL")
    image_url = st.text_input("Image URL")

    # Text input for YouTube URL
    st.subheader("Enter a YouTube URL")
    youtube_url = st.text_input("YouTube URL")

   # File uploader for videos
    st.subheader("Upload a Video")
    uploaded_video = st.file_uploader(
        "Choose a video...", type=["mp4", "avi", "mov", "mkv"]
    )

    # Text input for video URL
    st.subheader("Or Enter Video Download URL")
    video_url = st.text_input("Video URL")

# Streamlit footer
st.markdown(
    """
    <div style="text-align: center; margin-top: 2rem;">
        <p>If you want to set up your own computer vision playground see <a href="https://huggingface.co/spaces/eusholli/computer-vision-playground/blob/main/README.md" target="_blank">here</a>.</p>
    </div>
    """,
    unsafe_allow_html=True
)

# Function to initialize the analysis UI
# This function sets up the placeholders and UI elements in the analysis section.
# It creates placeholders for input and output frames, analysis time, and detected labels.


def analysis_init():
    global analysis_time, show_labels, labels_placeholder, input_subheader, input_placeholder, output_placeholder

    with col2:
        st.header("Analysis")
        input_subheader.subheader("Input Frame")

        st.subheader("Output Frame")
        output_placeholder = st.empty()  # Placeholder for output frame
        analysis_time = st.empty()  # Placeholder for analysis time
        show_labels = st.checkbox(
            "Show the detected labels", value=True
        )  # Checkbox to show/hide labels
        labels_placeholder = st.empty()  # Placeholder for labels


# Function to publish frames and results to the Streamlit UI
# This function retrieves the latest frames and results from the global container and result queue,
# and updates the placeholders in the Streamlit UI with the current input frame, analyzed frame, analysis time, and detected labels.
def publish_frame():

    img = img_container["input"]
    if img is None:
        return
    input_placeholder.image(img, channels="RGB")  # Display the input frame

    analyzed = img_container["analyzed"]
    if analyzed is None:
        return
    # Display the analyzed frame
    output_placeholder.image(analyzed, channels="RGB")

    time = img_container["analysis_time"]
    if time is None:
        return
    # Display the analysis time
    analysis_time.text(f"Analysis Time: {time} ms")

    detections = img_container["detections"]
    if detections is None:
        return

    if show_labels:
        labels_placeholder.table(
            detections
        )  # Display labels if the checkbox is checked


# If the WebRTC streamer is playing, initialize and publish frames
if webrtc_ctx.state.playing:
    analysis_init()  # Initialize the analysis UI
    while True:
        publish_frame()  # Publish the frames and results
        time.sleep(0.1)  # Delay to control frame rate


# If an image is uploaded or a URL is provided, process the image
if uploaded_file is not None or image_url:
    analysis_init()  # Initialize the analysis UI

    if uploaded_file is not None:
        image = Image.open(uploaded_file)  # Open the uploaded image
        img = np.array(image.convert("RGB"))  # Convert the image to RGB format
    else:
        response = requests.get(image_url)  # Download the image from the URL
        # Open the downloaded image
        image = Image.open(BytesIO(response.content))
        img = np.array(image.convert("RGB"))  # Convert the image to RGB format

    analyze_frame(img)  # Analyze the image
    publish_frame()  # Publish the results


# Function to process video files
# This function reads frames from a video file, analyzes each frame for face detection and sentiment analysis,
# and updates the Streamlit UI with the current input frame, analyzed frame, and detected labels.
def process_video(video_path):
    cap = cv2.VideoCapture(video_path)  # Open the video file
    while cap.isOpened():
        ret, frame = cap.read()  # Read a frame from the video
        if not ret:
            break  # Exit the loop if no more frames are available

        # Convert the frame from BGR to RGB format
        rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

        # Analyze the frame for face detection and sentiment analysis
        analyze_frame(rgb_frame)

        publish_frame()  # Publish the results

    cap.release()  # Release the video capture object

# Function to get the video stream URL from YouTube using yt-dlp


def get_youtube_stream_url(youtube_url):
    ydl_opts = {
        'format': 'best[ext=mp4]',
        'quiet': True,
    }
    with yt_dlp.YoutubeDL(ydl_opts) as ydl:
        info_dict = ydl.extract_info(youtube_url, download=False)
        stream_url = info_dict['url']
    return stream_url


# If a YouTube URL is provided, process the video
if youtube_url:
    analysis_init()  # Initialize the analysis UI

    stream_url = get_youtube_stream_url(youtube_url)

    process_video(stream_url)  # Process the video


# If a video is uploaded or a URL is provided, process the video
if uploaded_video is not None or video_url:
    analysis_init()  # Initialize the analysis UI

    if uploaded_video is not None:
        video_path = uploaded_video.name  # Get the name of the uploaded video
        with open(video_path, "wb") as f:
            # Save the uploaded video to a file
            f.write(uploaded_video.getbuffer())
    else:
        # Download the video from the URL
        video_path = download_file(video_url)

    process_video(video_path)  # Process the video