"""
Video Keyword Finder with Gradio Interface
========================================

A Python script that finds and timestamps specific keywords within video files.
It transcribes the audio using Whisper AI and finds all occurrences of specified keywords
with their timestamps and surrounding context. Supports both local video files and YouTube URLs.

Requirements
-----------
- Python 3.8 or higher
- ffmpeg (must be installed and accessible in system PATH)
- GPU recommended but not required
"""

import whisper_timestamped as whisper
import datetime
import os
import tempfile
import yt_dlp
import re
import logging
import math
import subprocess
import glob
import gradio as gr

# Set up logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler()
    ]
)

def parse_time(time_str):
    """Convert time string (HH:MM:SS) to seconds"""
    if not time_str or time_str.strip() == "":
        return None
    try:
        time_parts = list(map(int, time_str.split(':')))
        if len(time_parts) == 3:  # HH:MM:SS
            return time_parts[0] * 3600 + time_parts[1] * 60 + time_parts[2]
        elif len(time_parts) == 2:  # MM:SS
            return time_parts[0] * 60 + time_parts[1]
        else:
            return int(time_str)  # Just seconds
    except:
        raise ValueError("Time must be in HH:MM:SS, MM:SS, or seconds format")

def format_time(seconds):
    """Convert seconds to HH:MM:SS format"""
    return str(datetime.timedelta(seconds=int(seconds)))

def is_youtube_url(url):
    """Check if the provided string is a YouTube URL"""
    if not url or url.strip() == "":
        return False
    youtube_regex = (
        r'(https?://)?(www\.)?'
        r'(youtube|youtu|youtube-nocookie)\.(com|be)/'
        r'(watch\?v=|embed/|v/|.+\?v=)?([^&=%\?]{11})'
    )
    return bool(re.match(youtube_regex, url))

def download_youtube_video(url):
    """Download YouTube video and return path to temporary file"""
    temp_dir = tempfile.gettempdir()
    temp_file = os.path.join(temp_dir, 'youtube_video.mp4')
    
    ydl_opts = {
        'format': 'best[ext=mp4]',
        'outtmpl': temp_file,
        'quiet': False,
        'progress_hooks': [lambda d: logging.info(f"Download progress: {d.get('status', 'unknown')}")],
        'socket_timeout': 30,
    }
    
    try:
        logging.info(f"Starting download of YouTube video: {url}")
        with yt_dlp.YoutubeDL(ydl_opts) as ydl:
            info = ydl.extract_info(url, download=True)
            logging.info(f"Video info extracted: {info.get('title', 'Unknown title')}")
        
        if os.path.exists(temp_file):
            file_size = os.path.getsize(temp_file)
            logging.info(f"Download complete. File size: {file_size / (1024*1024):.2f} MB")
            return temp_file
        else:
            raise Exception("Download completed but file not found")
    except Exception as e:
        logging.error(f"Error downloading YouTube video: {str(e)}")
        raise

def get_video_duration(video_path):
    """Get video duration using ffprobe"""
    cmd = [
        'ffprobe', 
        '-v', 'error', 
        '-show_entries', 'format=duration', 
        '-of', 'default=noprint_wrappers=1:nokey=1', 
        video_path
    ]
    try:
        output = subprocess.check_output(cmd).decode().strip()
        return float(output)
    except subprocess.CalledProcessError as e:
        logging.error(f"Error getting video duration: {str(e)}")
        raise

def split_video(video_path, segment_duration=120):
    """Split video into segments using ffmpeg"""
    try:
        temp_dir = tempfile.gettempdir()
        segment_pattern = os.path.join(temp_dir, 'segment_%03d.mp4')
        
        # Remove any existing segments
        for old_segment in glob.glob(os.path.join(temp_dir, 'segment_*.mp4')):
            try:
                os.remove(old_segment)
            except:
                pass
        
        # Split video into segments
        cmd = [
            'ffmpeg',
            '-i', video_path,
            '-f', 'segment',
            '-segment_time', str(segment_duration),
            '-c', 'copy',
            '-reset_timestamps', '1',
            segment_pattern
        ]
        
        logging.info("Splitting video into segments...")
        subprocess.run(cmd, check=True, capture_output=True)
        
        # Get list of generated segments
        segments = sorted(glob.glob(os.path.join(temp_dir, 'segment_*.mp4')))
        logging.info(f"Created {len(segments)} segments")
        
        return segments
        
    except Exception as e:
        logging.error(f"Error splitting video: {str(e)}")
        raise

def process_segments(segments, keywords):
    """Process each segment sequentially"""
    results = {keyword: [] for keyword in keywords}
    
    # Load whisper model once
    logging.info("Loading Whisper model...")
    model = whisper.load_model("base")
    
    for i, segment_path in enumerate(segments):
        try:
            segment_num = int(re.search(r'segment_(\d+)', segment_path).group(1))
            start_time = segment_num * 120  # Each segment is 120 seconds
            
            logging.info(f"Processing segment {i+1}/{len(segments)} (starting at {format_time(start_time)})")
            
            # Transcribe segment
            audio = whisper.load_audio(segment_path)
            transcription = whisper.transcribe(model, audio)
            
            # Process results
            for segment in transcription['segments']:
                text = segment['text'].lower()
                timestamp = segment['start'] + start_time  # Adjust timestamp relative to full video
                
                for keyword in keywords:
                    if keyword.lower() in text:
                        results[keyword].append((
                            timestamp,
                            segment['text']
                        ))
                        logging.info(f"Found keyword '{keyword}' at {format_time(timestamp)}: {segment['text']}")
            
        except Exception as e:
            logging.error(f"Error processing segment {segment_path}: {str(e)}")
            continue
            
        finally:
            # Clean up segment file
            try:
                os.remove(segment_path)
            except:
                pass
    
    return results

def find_keywords_in_video(video_path, keywords, begin_time=None, end_time=None):
    """Find timestamps for keywords in video transcription"""
    try:
        logging.info(f"Processing video: {video_path}")
        logging.info(f"Searching for keywords: {keywords}")
        
        # Convert keywords string to list and clean up
        if isinstance(keywords, str):
            keywords = [k.strip() for k in keywords.split(',')]
        
        # Get video duration
        duration = get_video_duration(video_path)
        logging.info(f"Video duration: {duration:.2f} seconds")
        
        # Set time bounds
        start = parse_time(begin_time) if begin_time else 0
        end = min(parse_time(end_time) if end_time else duration, duration)
        
        if start is not None and end is not None and start >= end:
            raise ValueError("End time must be greater than start time")
        
        # Split video into segments
        segment_duration = 120  # 2 minutes per segment
        segments = split_video(video_path, segment_duration)
        
        # Process segments sequentially
        results = process_segments(segments, keywords)
        
        return results
        
    except Exception as e:
        logging.error(f"Error processing video: {str(e)}")
        raise

def process_video(video_file, youtube_url, keywords_input, start_time="0:00", end_time=None):
    """
    Process video file or YouTube URL and find keywords
    """
    try:
        # Convert keywords string to list
        keywords = [k.strip() for k in keywords_input.split(',') if k.strip()]
        if not keywords:
            return "Please enter at least one keyword (separated by commas)"

        # Handle input source
        if youtube_url and youtube_url.strip():
            if not is_youtube_url(youtube_url):
                return "Invalid YouTube URL. Please provide a valid URL."
            video_path = download_youtube_video(youtube_url)
            cleanup_needed = True
        elif video_file is not None:
            video_path = video_file
            cleanup_needed = False
        else:
            return "Please provide either a video file or a YouTube URL"

        try:
            # Find keywords
            results = find_keywords_in_video(
                video_path=video_path,
                keywords=keywords,
                begin_time=start_time if start_time else None,
                end_time=end_time if end_time else None
            )

            # Format results
            output = []
            total_matches = sum(len(matches) for matches in results.values())
            
            output.append(f"Total matches found: {total_matches}\n")
            
            if total_matches == 0:
                output.append("No matches found for any keywords.")
            else:
                for keyword, matches in results.items():
                    if matches:
                        output.append(f"\nResults for '{keyword}':")
                        for timestamp, context in matches:
                            output.append(f"[{format_time(timestamp)}] {context}")
                    else:
                        output.append(f"\nNo occurrences found for '{keyword}'")

            return "\n".join(output)

        finally:
            # Cleanup temporary files
            if cleanup_needed and os.path.exists(video_path):
                try:
                    os.remove(video_path)
                except:
                    pass

    except Exception as e:
        logging.error(f"Error processing video: {str(e)}")
        return f"Error processing video: {str(e)}"

# Create Gradio interface
with gr.Blocks(title="Video Keyword Finder", theme=gr.themes.Soft()) as demo:
    gr.Markdown("""
    # 🎥 Video Keyword Finder
    Find timestamps for specific keywords in your videos using AI transcription.
    
    Upload a video file or provide a YouTube URL, then enter the keywords you want to find.
    """)
    
    with gr.Row():
        with gr.Column():
            video_input = gr.File(
                label="Upload Video File",
                file_types=["video"],
                type="filepath"
            )
            youtube_url = gr.Textbox(
                label="Or Enter YouTube URL",
                placeholder="https://www.youtube.com/watch?v=..."
            )
            keywords = gr.Textbox(
                label="Keywords (comma-separated)",
                placeholder="enter, your, keywords, here"
            )
            
            with gr.Row():
                start_time = gr.Textbox(
                    label="Start Time (HH:MM:SS)",
                    placeholder="0:00",
                    value="0:00"
                )
                end_time = gr.Textbox(
                    label="End Time (HH:MM:SS)",
                    placeholder="Optional"
                )
            
            submit_btn = gr.Button("Find Keywords", variant="primary")
        
        with gr.Column():
            output = gr.Textbox(
                label="Results",
                placeholder="Keywords and timestamps will appear here...",
                lines=20
            )
    
    gr.Markdown("""
    ### Instructions:
    1. Upload a video file or paste a YouTube URL
    2. Enter keywords separated by commas (e.g., "hello, world, python")
    3. Optionally set start and end times (format: HH:MM:SS)
    4. Click "Find Keywords" and wait for results
    
    Note: Processing time depends on video length. A 1-hour video typically takes 15-30 minutes.
    """)
    
    submit_btn.click(
        fn=process_video,
        inputs=[video_input, youtube_url, keywords, start_time, end_time],
        outputs=output
    )

if __name__ == "__main__":
    demo.launch(share=True, debug=True)