|  | """ | 
					
						
						|  | WebDataset format handling for Video Model Studio | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | import os | 
					
						
						|  | import tarfile | 
					
						
						|  | import tempfile | 
					
						
						|  | import logging | 
					
						
						|  | from pathlib import Path | 
					
						
						|  | from typing import List, Dict, Tuple, Optional | 
					
						
						|  |  | 
					
						
						|  | from ..utils import is_image_file, is_video_file, extract_scene_info | 
					
						
						|  |  | 
					
						
						|  | logger = logging.getLogger(__name__) | 
					
						
						|  |  | 
					
						
						|  | def is_webdataset_file(file_path: Path) -> bool: | 
					
						
						|  | """Check if file is a WebDataset tar file | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | file_path: Path to check | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | bool: True if file has .tar extension | 
					
						
						|  | """ | 
					
						
						|  | return file_path.suffix.lower() == '.tar' | 
					
						
						|  |  | 
					
						
						|  | def process_webdataset_shard( | 
					
						
						|  | tar_path: Path, | 
					
						
						|  | videos_output_dir: Path, | 
					
						
						|  | staging_output_dir: Path | 
					
						
						|  | ) -> Tuple[int, int]: | 
					
						
						|  | """Process a WebDataset shard (tar file) extracting video/image and caption pairs | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | tar_path: Path to the WebDataset tar file | 
					
						
						|  | videos_output_dir: Directory to store videos for splitting | 
					
						
						|  | staging_output_dir: Directory to store images and captions | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | Tuple of (video_count, image_count) | 
					
						
						|  | """ | 
					
						
						|  | video_count = 0 | 
					
						
						|  | image_count = 0 | 
					
						
						|  |  | 
					
						
						|  | try: | 
					
						
						|  |  | 
					
						
						|  | grouped_files = {} | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | with tarfile.open(tar_path, 'r') as tar: | 
					
						
						|  | for member in tar.getmembers(): | 
					
						
						|  | if member.isdir(): | 
					
						
						|  | continue | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if os.path.basename(member.name).startswith('.'): | 
					
						
						|  | continue | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | file_path = Path(member.name) | 
					
						
						|  | file_name = file_path.name | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | prefix_parts = file_name.split('.', 1) | 
					
						
						|  | if len(prefix_parts) < 2: | 
					
						
						|  |  | 
					
						
						|  | continue | 
					
						
						|  |  | 
					
						
						|  | prefix = prefix_parts[0] | 
					
						
						|  | extension = '.' + prefix_parts[1] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | full_prefix = str(file_path.parent / prefix) if file_path.parent != Path('.') else prefix | 
					
						
						|  |  | 
					
						
						|  | if full_prefix not in grouped_files: | 
					
						
						|  | grouped_files[full_prefix] = [] | 
					
						
						|  |  | 
					
						
						|  | grouped_files[full_prefix].append((member, extension)) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | with tarfile.open(tar_path, 'r') as tar: | 
					
						
						|  | for prefix, members in grouped_files.items(): | 
					
						
						|  |  | 
					
						
						|  | safe_prefix = Path(prefix).name | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | media_file = None | 
					
						
						|  | caption_file = None | 
					
						
						|  | media_ext = None | 
					
						
						|  |  | 
					
						
						|  | for member, ext in members: | 
					
						
						|  | if ext.lower() in ['.jpg', '.jpeg', '.png', '.webp', '.avif', '.heic']: | 
					
						
						|  | media_file = member | 
					
						
						|  | media_ext = ext | 
					
						
						|  | elif ext.lower() in ['.mp4', '.webm']: | 
					
						
						|  | media_file = member | 
					
						
						|  | media_ext = ext | 
					
						
						|  | elif ext.lower() in ['.txt', '.caption', '.json', '.cls']: | 
					
						
						|  | caption_file = member | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if media_file: | 
					
						
						|  |  | 
					
						
						|  | is_video = media_ext.lower() in ['.mp4', '.webm'] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | target_dir = videos_output_dir if is_video else staging_output_dir | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | target_filename = f"{safe_prefix}{media_ext}" | 
					
						
						|  | target_path = target_dir / target_filename | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | counter = 1 | 
					
						
						|  | while target_path.exists(): | 
					
						
						|  | target_path = target_dir / f"{safe_prefix}___{counter}{media_ext}" | 
					
						
						|  | counter += 1 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | with open(target_path, 'wb') as f: | 
					
						
						|  | f.write(tar.extractfile(media_file).read()) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if caption_file: | 
					
						
						|  | caption_text = tar.extractfile(caption_file).read().decode('utf-8', errors='ignore') | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | caption_path = target_path.with_suffix('.txt') | 
					
						
						|  | with open(caption_path, 'w', encoding='utf-8') as f: | 
					
						
						|  | f.write(caption_text) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if is_video: | 
					
						
						|  | video_count += 1 | 
					
						
						|  | else: | 
					
						
						|  | image_count += 1 | 
					
						
						|  |  | 
					
						
						|  | except Exception as e: | 
					
						
						|  | logger.error(f"Error processing WebDataset file {tar_path}: {e}") | 
					
						
						|  | raise | 
					
						
						|  |  | 
					
						
						|  | return video_count, image_count |