"""
WebDataset format handling for Video Model Studio
"""

import os
import tarfile
import tempfile
import logging
from pathlib import Path
from typing import List, Dict, Tuple, Optional

from ..utils import is_image_file, is_video_file, extract_scene_info

logger = logging.getLogger(__name__)

def is_webdataset_file(file_path: Path) -> bool:
    """Check if file is a WebDataset tar file
    
    Args:
        file_path: Path to check
        
    Returns:
        bool: True if file has .tar extension
    """
    return file_path.suffix.lower() == '.tar'

def process_webdataset_shard(
    tar_path: Path, 
    videos_output_dir: Path, 
    staging_output_dir: Path
) -> Tuple[int, int]:
    """Process a WebDataset shard (tar file) extracting video/image and caption pairs
    
    Args:
        tar_path: Path to the WebDataset tar file
        videos_output_dir: Directory to store videos for splitting
        staging_output_dir: Directory to store images and captions
        
    Returns:
        Tuple of (video_count, image_count)
    """
    video_count = 0
    image_count = 0

    print(f"videos_output_dir = {videos_output_dir}")
    print(f"staging_output_dir = {staging_output_dir}")
    try:
        # Dictionary to store grouped files by prefix
        grouped_files = {}
        
        # First pass: collect and group files by prefix
        with tarfile.open(tar_path, 'r') as tar:
            for member in tar.getmembers():
                if member.isdir():
                    continue
                    
                # Skip hidden files
                if os.path.basename(member.name).startswith('.'):
                    continue
                
                # Extract file prefix (everything up to the first dot after the last slash)
                file_path = Path(member.name)
                file_name = file_path.name
                
                # Get prefix (filename without extensions)
                # For WebDataset, the prefix is everything up to the first dot
                prefix_parts = file_name.split('.', 1)
                if len(prefix_parts) < 2:
                    # No extension, skip
                    continue
                
                prefix = prefix_parts[0]
                extension = '.' + prefix_parts[1]
                
                # Include directory in the prefix to keep samples grouped correctly
                full_prefix = str(file_path.parent / prefix) if file_path.parent != Path('.') else prefix
                
                if full_prefix not in grouped_files:
                    grouped_files[full_prefix] = []
                
                grouped_files[full_prefix].append((member, extension))
        
        # Second pass: extract and process grouped files
        with tarfile.open(tar_path, 'r') as tar:
            for prefix, members in grouped_files.items():
                # Create safe filename from prefix
                safe_prefix = Path(prefix).name
                
                # Find media and caption files
                media_file = None
                caption_file = None
                media_ext = None
                
                for member, ext in members:
                    if ext.lower() in ['.jpg', '.jpeg', '.png', '.webp', '.avif', '.heic']:
                        media_file = member
                        media_ext = ext
                    elif ext.lower() in ['.mp4', '.webm']:
                        media_file = member
                        media_ext = ext
                    elif ext.lower() in ['.txt', '.caption', '.json', '.cls']:
                        caption_file = member
                
                # If we have a media file, process it
                if media_file:
                    # Determine if it's video or image
                    is_video = media_ext.lower() in ['.mp4', '.webm']
                    
                    # Choose target directory based on media type
                    target_dir = videos_output_dir if is_video else staging_output_dir
                    
                    # Create target filename
                    target_filename = f"{safe_prefix}{media_ext}"
                    target_path = target_dir / target_filename
                    
                    # If file already exists, add number suffix
                    counter = 1
                    while target_path.exists():
                        target_path = target_dir / f"{safe_prefix}___{counter}{media_ext}"
                        counter += 1
                    
                    # Extract media file
                    with open(target_path, 'wb') as f:
                        f.write(tar.extractfile(media_file).read())
                    
                    # If we have a caption file, extract it too
                    if caption_file:
                        caption_text = tar.extractfile(caption_file).read().decode('utf-8', errors='ignore')
                        
                        # Save caption with media file extension
                        caption_path = target_path.with_suffix('.txt')
                        with open(caption_path, 'w', encoding='utf-8') as f:
                            f.write(caption_text)
                    
                    # Update counters
                    if is_video:
                        video_count += 1
                    else:
                        image_count += 1
    
    except Exception as e:
        logger.error(f"Error processing WebDataset file {tar_path}: {e}")
        raise
    
    return video_count, image_count