WebDataset format handling for Video Model Studio

import os
import tarfile
import tempfile
import logging
from pathlib import Path
from typing import List, Dict, Tuple, Optional

from ..utils import is_image_file, is_video_file, extract_scene_info

logger = logging.getLogger(__name__)

def is_webdataset_file(file_path: Path) -> bool:
    """Check if file is a WebDataset tar file
        file_path: Path to check
        bool: True if file has .tar extension
    return file_path.suffix.lower() == '.tar'

def process_webdataset_shard(
    tar_path: Path, 
    videos_output_dir: Path, 
    staging_output_dir: Path
) -> Tuple[int, int]:
    """Process a WebDataset shard (tar file) extracting video/image and caption pairs
        tar_path: Path to the WebDataset tar file
        videos_output_dir: Directory to store videos for splitting
        staging_output_dir: Directory to store images and captions
        Tuple of (video_count, image_count)
    video_count = 0
    image_count = 0

    print(f"videos_output_dir = {videos_output_dir}")
    print(f"staging_output_dir = {staging_output_dir}")
        # Dictionary to store grouped files by prefix
        grouped_files = {}
        # First pass: collect and group files by prefix
        with tarfile.open(tar_path, 'r') as tar:
            for member in tar.getmembers():
                if member.isdir():
                # Skip hidden files
                if os.path.basename(member.name).startswith('.'):
                # Extract file prefix (everything up to the first dot after the last slash)
                file_path = Path(member.name)
                file_name = file_path.name
                # Get prefix (filename without extensions)
                # For WebDataset, the prefix is everything up to the first dot
                prefix_parts = file_name.split('.', 1)
                if len(prefix_parts) < 2:
                    # No extension, skip
                prefix = prefix_parts[0]
                extension = '.' + prefix_parts[1]
                # Include directory in the prefix to keep samples grouped correctly
                full_prefix = str(file_path.parent / prefix) if file_path.parent != Path('.') else prefix
                if full_prefix not in grouped_files:
                    grouped_files[full_prefix] = []
                grouped_files[full_prefix].append((member, extension))
        # Second pass: extract and process grouped files
        with tarfile.open(tar_path, 'r') as tar:
            for prefix, members in grouped_files.items():
                # Create safe filename from prefix
                safe_prefix = Path(prefix).name
                # Find media and caption files
                media_file = None
                caption_file = None
                media_ext = None
                for member, ext in members:
                    if ext.lower() in ['.jpg', '.jpeg', '.png', '.webp', '.avif', '.heic']:
                        media_file = member
                        media_ext = ext
                    elif ext.lower() in ['.mp4', '.webm']:
                        media_file = member
                        media_ext = ext
                    elif ext.lower() in ['.txt', '.caption', '.json', '.cls']:
                        caption_file = member
                # If we have a media file, process it
                if media_file:
                    # Determine if it's video or image
                    is_video = media_ext.lower() in ['.mp4', '.webm']
                    # Choose target directory based on media type
                    target_dir = videos_output_dir if is_video else staging_output_dir
                    # Create target filename
                    target_filename = f"{safe_prefix}{media_ext}"
                    target_path = target_dir / target_filename
                    # If file already exists, add number suffix
                    counter = 1
                    while target_path.exists():
                        target_path = target_dir / f"{safe_prefix}___{counter}{media_ext}"
                        counter += 1
                    # Extract media file
                    with open(target_path, 'wb') as f:
                    # If we have a caption file, extract it too
                    if caption_file:
                        caption_text = tar.extractfile(caption_file).read().decode('utf-8', errors='ignore')
                        # Save caption with media file extension
                        caption_path = target_path.with_suffix('.txt')
                        with open(caption_path, 'w', encoding='utf-8') as f:
                    # Update counters
                    if is_video:
                        video_count += 1
                        image_count += 1
    except Exception as e:
        logger.error(f"Error processing WebDataset file {tar_path}: {e}")
    return video_count, image_count