""" WebDataset format handling for Video Model Studio """ import os import tarfile import tempfile import logging from pathlib import Path from typing import List, Dict, Tuple, Optional from ..utils import is_image_file, is_video_file, extract_scene_info logger = logging.getLogger(__name__) def is_webdataset_file(file_path: Path) -> bool: """Check if file is a WebDataset tar file Args: file_path: Path to check Returns: bool: True if file has .tar extension """ return file_path.suffix.lower() == '.tar' def process_webdataset_shard( tar_path: Path, videos_output_dir: Path, staging_output_dir: Path ) -> Tuple[int, int]: """Process a WebDataset shard (tar file) extracting video/image and caption pairs Args: tar_path: Path to the WebDataset tar file videos_output_dir: Directory to store videos for splitting staging_output_dir: Directory to store images and captions Returns: Tuple of (video_count, image_count) """ video_count = 0 image_count = 0 print(f"videos_output_dir = {videos_output_dir}") print(f"staging_output_dir = {staging_output_dir}") try: # Dictionary to store grouped files by prefix grouped_files = {} # First pass: collect and group files by prefix with tarfile.open(tar_path, 'r') as tar: for member in tar.getmembers(): if member.isdir(): continue # Skip hidden files if os.path.basename(member.name).startswith('.'): continue # Extract file prefix (everything up to the first dot after the last slash) file_path = Path(member.name) file_name = file_path.name # Get prefix (filename without extensions) # For WebDataset, the prefix is everything up to the first dot prefix_parts = file_name.split('.', 1) if len(prefix_parts) < 2: # No extension, skip continue prefix = prefix_parts[0] extension = '.' + prefix_parts[1] # Include directory in the prefix to keep samples grouped correctly full_prefix = str(file_path.parent / prefix) if file_path.parent != Path('.') else prefix if full_prefix not in grouped_files: grouped_files[full_prefix] = [] grouped_files[full_prefix].append((member, extension)) # Second pass: extract and process grouped files with tarfile.open(tar_path, 'r') as tar: for prefix, members in grouped_files.items(): # Create safe filename from prefix safe_prefix = Path(prefix).name # Find media and caption files media_file = None caption_file = None media_ext = None for member, ext in members: if ext.lower() in ['.jpg', '.jpeg', '.png', '.webp', '.avif', '.heic']: media_file = member media_ext = ext elif ext.lower() in ['.mp4', '.webm']: media_file = member media_ext = ext elif ext.lower() in ['.txt', '.caption', '.json', '.cls']: caption_file = member # If we have a media file, process it if media_file: # Determine if it's video or image is_video = media_ext.lower() in ['.mp4', '.webm'] # Choose target directory based on media type target_dir = videos_output_dir if is_video else staging_output_dir # Create target filename target_filename = f"{safe_prefix}{media_ext}" target_path = target_dir / target_filename # If file already exists, add number suffix counter = 1 while target_path.exists(): target_path = target_dir / f"{safe_prefix}___{counter}{media_ext}" counter += 1 # Extract media file with open(target_path, 'wb') as f: f.write(tar.extractfile(media_file).read()) # If we have a caption file, extract it too if caption_file: caption_text = tar.extractfile(caption_file).read().decode('utf-8', errors='ignore') # Save caption with media file extension caption_path = target_path.with_suffix('.txt') with open(caption_path, 'w', encoding='utf-8') as f: f.write(caption_text) # Update counters if is_video: video_count += 1 else: image_count += 1 except Exception as e: logger.error(f"Error processing WebDataset file {tar_path}: {e}") raise return video_count, image_count