jbilcke-hf HF staff commited on
Commit
246c64e
·
1 Parent(s): b239757

fix issue with scene splitting

Browse files
vms/ui/app_ui.py CHANGED
@@ -396,11 +396,14 @@ class AppUI:
396
  model_version_val = available_model_versions[0]
397
  logger.info(f"Using first available model version: {model_version_val}")
398
 
399
- # IMPORTANT: Update the dropdown choices directly in the UI component
400
- # This is essential to avoid the error when loading the UI
 
 
 
401
  try:
402
- self.project_tabs["train_tab"].components["model_version"].choices = available_model_versions
403
- logger.info(f"Updated model_version dropdown choices: {len(available_model_versions)} options")
404
  except Exception as e:
405
  logger.error(f"Error updating model_version dropdown: {str(e)}")
406
  else:
@@ -410,7 +413,7 @@ class AppUI:
410
  self.project_tabs["train_tab"].components["model_version"].choices = []
411
  except Exception as e:
412
  logger.error(f"Error setting empty model_version choices: {str(e)}")
413
-
414
  # Ensure training_type is a valid display name
415
  training_type_val = ui_state.get("training_type", list(TRAINING_TYPES.keys())[0])
416
  if training_type_val not in TRAINING_TYPES:
 
396
  model_version_val = available_model_versions[0]
397
  logger.info(f"Using first available model version: {model_version_val}")
398
 
399
+ # IMPORTANT: Create a new list of simple strings for the dropdown choices
400
+ # This ensures each choice is a single string, not a tuple or other structure
401
+ simple_choices = [str(version) for version in available_model_versions]
402
+
403
+ # Update the dropdown choices directly in the UI component
404
  try:
405
+ self.project_tabs["train_tab"].components["model_version"].choices = simple_choices
406
+ logger.info(f"Updated model_version dropdown choices: {len(simple_choices)} options")
407
  except Exception as e:
408
  logger.error(f"Error updating model_version dropdown: {str(e)}")
409
  else:
 
413
  self.project_tabs["train_tab"].components["model_version"].choices = []
414
  except Exception as e:
415
  logger.error(f"Error setting empty model_version choices: {str(e)}")
416
+
417
  # Ensure training_type is a valid display name
418
  training_type_val = ui_state.get("training_type", list(TRAINING_TYPES.keys())[0])
419
  if training_type_val not in TRAINING_TYPES:
vms/ui/project/services/importing/file_upload.py CHANGED
@@ -22,20 +22,23 @@ logger = logging.getLogger(__name__)
22
  class FileUploadHandler:
23
  """Handles processing of uploaded files"""
24
 
25
- def process_uploaded_files(self, file_paths: List[str]) -> str:
26
  """Process uploaded file (ZIP, TAR, MP4, or image)
27
 
28
  Args:
29
  file_paths: File paths to the uploaded files from Gradio
 
30
 
31
  Returns:
32
  Status message string
33
  """
 
34
  if not file_paths or len(file_paths) == 0:
35
  logger.warning("No files provided to process_uploaded_files")
36
  return "No files provided"
37
-
38
  for file_path in file_paths:
 
39
  file_path = Path(file_path)
40
  try:
41
  original_name = file_path.name
@@ -45,11 +48,11 @@ class FileUploadHandler:
45
  file_ext = file_path.suffix.lower()
46
 
47
  if file_ext == '.zip':
48
- return self.process_zip_file(file_path)
49
  elif file_ext == '.tar':
50
- return self.process_tar_file(file_path)
51
  elif file_ext == '.mp4' or file_ext == '.webm':
52
- return self.process_mp4_file(file_path, original_name)
53
  elif is_image_file(file_path):
54
  return self.process_image_file(file_path, original_name)
55
  else:
@@ -60,56 +63,12 @@ class FileUploadHandler:
60
  logger.error(f"Error processing file {file_path}: {str(e)}", exc_info=True)
61
  raise gr.Error(f"Error processing file: {str(e)}")
62
 
63
- def process_image_file(self, file_path: Path, original_name: str) -> str:
64
- """Process a single image file
65
-
66
- Args:
67
- file_path: Path to the image
68
- original_name: Original filename
69
-
70
- Returns:
71
- Status message string
72
- """
73
- try:
74
- # Create a unique filename with configured extension
75
- stem = Path(original_name).stem
76
- target_path = STAGING_PATH / f"{stem}.{NORMALIZE_IMAGES_TO}"
77
-
78
- # If file already exists, add number suffix
79
- counter = 1
80
- while target_path.exists():
81
- target_path = STAGING_PATH / f"{stem}___{counter}.{NORMALIZE_IMAGES_TO}"
82
- counter += 1
83
-
84
- logger.info(f"Processing image file: {original_name} -> {target_path}")
85
-
86
- # Convert to normalized format and remove black bars
87
- success = normalize_image(file_path, target_path)
88
-
89
- if not success:
90
- logger.error(f"Failed to process image: {original_name}")
91
- raise gr.Error(f"Failed to process image: {original_name}")
92
-
93
- # Handle caption
94
- src_caption_path = file_path.with_suffix('.txt')
95
- if src_caption_path.exists():
96
- caption = src_caption_path.read_text()
97
- caption = add_prefix_to_caption(caption, DEFAULT_PROMPT_PREFIX)
98
- target_path.with_suffix('.txt').write_text(caption)
99
-
100
- logger.info(f"Successfully stored image: {target_path.name}")
101
- gr.Info(f"Successfully stored image: {target_path.name}")
102
- return f"Successfully stored image: {target_path.name}"
103
-
104
- except Exception as e:
105
- logger.error(f"Error processing image file: {str(e)}", exc_info=True)
106
- raise gr.Error(f"Error processing image file: {str(e)}")
107
-
108
- def process_zip_file(self, file_path: Path) -> str:
109
  """Process uploaded ZIP file containing media files or WebDataset tar files
110
 
111
  Args:
112
  file_path: Path to the uploaded ZIP file
 
113
 
114
  Returns:
115
  Status message string
@@ -143,17 +102,18 @@ class FileUploadHandler:
143
  logger.info(f"Processing WebDataset archive from ZIP: {file}")
144
  # Process WebDataset shard
145
  vid_count, img_count = webdataset_handler.process_webdataset_shard(
146
- file_path, VIDEOS_TO_SPLIT_PATH, STAGING_PATH
147
  )
148
  video_count += vid_count
149
  image_count += img_count
150
  tar_count += 1
151
  elif is_video_file(file_path):
152
- # Copy video to videos_to_split
153
- target_path = VIDEOS_TO_SPLIT_PATH / file_path.name
 
154
  counter = 1
155
  while target_path.exists():
156
- target_path = VIDEOS_TO_SPLIT_PATH / f"{file_path.stem}___{counter}{file_path.suffix}"
157
  counter += 1
158
  shutil.copy2(file_path, target_path)
159
  logger.info(f"Extracted video from ZIP: {file} -> {target_path.name}")
@@ -208,11 +168,12 @@ class FileUploadHandler:
208
  logger.error(f"Error processing ZIP: {str(e)}", exc_info=True)
209
  raise gr.Error(f"Error processing ZIP: {str(e)}")
210
 
211
- def process_tar_file(self, file_path: Path) -> str:
212
  """Process a WebDataset tar file
213
 
214
  Args:
215
  file_path: Path to the uploaded tar file
 
216
 
217
  Returns:
218
  Status message string
@@ -220,7 +181,7 @@ class FileUploadHandler:
220
  try:
221
  logger.info(f"Processing WebDataset TAR file: {file_path}")
222
  video_count, image_count = webdataset_handler.process_webdataset_shard(
223
- file_path, VIDEOS_TO_SPLIT_PATH, STAGING_PATH
224
  )
225
 
226
  # Generate status message
@@ -243,25 +204,30 @@ class FileUploadHandler:
243
  logger.error(f"Error processing WebDataset tar file: {str(e)}", exc_info=True)
244
  raise gr.Error(f"Error processing WebDataset tar file: {str(e)}")
245
 
246
- def process_mp4_file(self, file_path: Path, original_name: str) -> str:
247
  """Process a single video file
248
 
249
  Args:
250
  file_path: Path to the file
251
  original_name: Original filename
 
252
 
253
  Returns:
254
  Status message string
255
  """
 
256
  try:
 
 
 
257
  # Create a unique filename
258
- target_path = VIDEOS_TO_SPLIT_PATH / original_name
259
 
260
  # If file already exists, add number suffix
261
  counter = 1
262
  while target_path.exists():
263
  stem = Path(original_name).stem
264
- target_path = VIDEOS_TO_SPLIT_PATH / f"{stem}___{counter}.mp4"
265
  counter += 1
266
 
267
  logger.info(f"Processing video file: {original_name} -> {target_path}")
 
22
  class FileUploadHandler:
23
  """Handles processing of uploaded files"""
24
 
25
+ def process_uploaded_files(self, file_paths: List[str], enable_splitting: bool) -> str:
26
  """Process uploaded file (ZIP, TAR, MP4, or image)
27
 
28
  Args:
29
  file_paths: File paths to the uploaded files from Gradio
30
+ enable_splitting: Whether to enable automatic video splitting
31
 
32
  Returns:
33
  Status message string
34
  """
35
+ print(f"process_uploaded_files called with enable_splitting={enable_splitting} and file_paths = {str(file_paths)}")
36
  if not file_paths or len(file_paths) == 0:
37
  logger.warning("No files provided to process_uploaded_files")
38
  return "No files provided"
39
+
40
  for file_path in file_paths:
41
+ print(f" - {str(file_path)}")
42
  file_path = Path(file_path)
43
  try:
44
  original_name = file_path.name
 
48
  file_ext = file_path.suffix.lower()
49
 
50
  if file_ext == '.zip':
51
+ return self.process_zip_file(file_path, enable_splitting)
52
  elif file_ext == '.tar':
53
+ return self.process_tar_file(file_path, enable_splitting)
54
  elif file_ext == '.mp4' or file_ext == '.webm':
55
+ return self.process_mp4_file(file_path, original_name, enable_splitting)
56
  elif is_image_file(file_path):
57
  return self.process_image_file(file_path, original_name)
58
  else:
 
63
  logger.error(f"Error processing file {file_path}: {str(e)}", exc_info=True)
64
  raise gr.Error(f"Error processing file: {str(e)}")
65
 
66
+ def process_zip_file(self, file_path: Path, enable_splitting: bool) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  """Process uploaded ZIP file containing media files or WebDataset tar files
68
 
69
  Args:
70
  file_path: Path to the uploaded ZIP file
71
+ enable_splitting: Whether to enable automatic video splitting
72
 
73
  Returns:
74
  Status message string
 
102
  logger.info(f"Processing WebDataset archive from ZIP: {file}")
103
  # Process WebDataset shard
104
  vid_count, img_count = webdataset_handler.process_webdataset_shard(
105
+ file_path, VIDEOS_TO_SPLIT_PATH if enable_splitting else STAGING_PATH, STAGING_PATH
106
  )
107
  video_count += vid_count
108
  image_count += img_count
109
  tar_count += 1
110
  elif is_video_file(file_path):
111
+ # Choose target directory based on auto-splitting setting
112
+ target_dir = VIDEOS_TO_SPLIT_PATH if enable_splitting else STAGING_PATH
113
+ target_path = target_dir / file_path.name
114
  counter = 1
115
  while target_path.exists():
116
+ target_path = target_dir / f"{file_path.stem}___{counter}{file_path.suffix}"
117
  counter += 1
118
  shutil.copy2(file_path, target_path)
119
  logger.info(f"Extracted video from ZIP: {file} -> {target_path.name}")
 
168
  logger.error(f"Error processing ZIP: {str(e)}", exc_info=True)
169
  raise gr.Error(f"Error processing ZIP: {str(e)}")
170
 
171
+ def process_tar_file(self, file_path: Path, enable_splitting: bool) -> str:
172
  """Process a WebDataset tar file
173
 
174
  Args:
175
  file_path: Path to the uploaded tar file
176
+ enable_splitting: Whether to enable automatic video splitting
177
 
178
  Returns:
179
  Status message string
 
181
  try:
182
  logger.info(f"Processing WebDataset TAR file: {file_path}")
183
  video_count, image_count = webdataset_handler.process_webdataset_shard(
184
+ file_path, VIDEOS_TO_SPLIT_PATH if enable_splitting else STAGING_PATH, STAGING_PATH
185
  )
186
 
187
  # Generate status message
 
204
  logger.error(f"Error processing WebDataset tar file: {str(e)}", exc_info=True)
205
  raise gr.Error(f"Error processing WebDataset tar file: {str(e)}")
206
 
207
+ def process_mp4_file(self, file_path: Path, original_name: str, enable_splitting: bool) -> str:
208
  """Process a single video file
209
 
210
  Args:
211
  file_path: Path to the file
212
  original_name: Original filename
213
+ enable_splitting: Whether to enable automatic video splitting
214
 
215
  Returns:
216
  Status message string
217
  """
218
+ print(f"process_mp4_file(self, file_path={str(file_path)}, original_name={str(original_name)}, enable_splitting={enable_splitting})")
219
  try:
220
+ # Choose target directory based on auto-splitting setting
221
+ target_dir = VIDEOS_TO_SPLIT_PATH if enable_splitting else STAGING_PATH
222
+ print(f"target_dir = {target_dir}")
223
  # Create a unique filename
224
+ target_path = target_dir / original_name
225
 
226
  # If file already exists, add number suffix
227
  counter = 1
228
  while target_path.exists():
229
  stem = Path(original_name).stem
230
+ target_path = target_dir / f"{stem}___{counter}.mp4"
231
  counter += 1
232
 
233
  logger.info(f"Processing video file: {original_name} -> {target_path}")
vms/ui/project/services/importing/hub_dataset.py CHANGED
@@ -168,7 +168,7 @@ class HubDatasetBrowser:
168
  self,
169
  dataset_id: str,
170
  file_type: str,
171
- enable_splitting: bool = True,
172
  progress_callback: Optional[Callable] = None
173
  ) -> str:
174
  """Download all files of a specific type from the dataset
@@ -328,7 +328,7 @@ class HubDatasetBrowser:
328
  async def download_dataset(
329
  self,
330
  dataset_id: str,
331
- enable_splitting: bool = True,
332
  progress_callback: Optional[Callable] = None
333
  ) -> Tuple[str, str]:
334
  """Download a dataset and process its video/image content
 
168
  self,
169
  dataset_id: str,
170
  file_type: str,
171
+ enable_splitting: bool,
172
  progress_callback: Optional[Callable] = None
173
  ) -> str:
174
  """Download all files of a specific type from the dataset
 
328
  async def download_dataset(
329
  self,
330
  dataset_id: str,
331
+ enable_splitting: bool,
332
  progress_callback: Optional[Callable] = None
333
  ) -> Tuple[str, str]:
334
  """Download a dataset and process its video/image content
vms/ui/project/services/importing/import_service.py CHANGED
@@ -28,32 +28,37 @@ class ImportingService:
28
  self.youtube_handler = YouTubeDownloader()
29
  self.hub_browser = HubDatasetBrowser(self.hf_api)
30
 
31
- def process_uploaded_files(self, file_paths: List[str]) -> str:
32
  """Process uploaded file (ZIP, TAR, MP4, or image)
33
 
34
  Args:
35
  file_paths: File paths to the uploaded files from Gradio
 
36
 
37
  Returns:
38
  Status message string
39
  """
 
40
  if not file_paths or len(file_paths) == 0:
41
  logger.warning("No files provided to process_uploaded_files")
42
  return "No files provided"
43
 
44
- return self.file_handler.process_uploaded_files(file_paths)
 
 
45
 
46
- def download_youtube_video(self, url: str, progress=None) -> str:
47
  """Download a video from YouTube
48
 
49
  Args:
50
  url: YouTube video URL
 
51
  progress: Optional Gradio progress indicator
52
 
53
  Returns:
54
  Status message string
55
  """
56
- return self.youtube_handler.download_video(url, progress)
57
 
58
  def search_datasets(self, query: str) -> List[List[str]]:
59
  """Search for datasets on the Hugging Face Hub
@@ -80,7 +85,7 @@ class ImportingService:
80
  async def download_dataset(
81
  self,
82
  dataset_id: str,
83
- enable_splitting: bool = True,
84
  progress_callback: Optional[Callable] = None
85
  ) -> Tuple[str, str]:
86
  """Download a dataset and process its video/image content
@@ -99,7 +104,7 @@ class ImportingService:
99
  self,
100
  dataset_id: str,
101
  file_type: str,
102
- enable_splitting: bool = True,
103
  progress_callback: Optional[Callable] = None
104
  ) -> str:
105
  """Download a group of files (videos or WebDatasets)
 
28
  self.youtube_handler = YouTubeDownloader()
29
  self.hub_browser = HubDatasetBrowser(self.hf_api)
30
 
31
+ def process_uploaded_files(self, file_paths: List[str], enable_splitting: bool) -> str:
32
  """Process uploaded file (ZIP, TAR, MP4, or image)
33
 
34
  Args:
35
  file_paths: File paths to the uploaded files from Gradio
36
+ enable_splitting: Whether to enable automatic video splitting
37
 
38
  Returns:
39
  Status message string
40
  """
41
+ print(f"process_uploaded_files(..., enable_splitting = { enable_splitting})")
42
  if not file_paths or len(file_paths) == 0:
43
  logger.warning("No files provided to process_uploaded_files")
44
  return "No files provided"
45
 
46
+ print(f"process_uploaded_files(..., enable_splitting = {enable_splitting:})")
47
+ print(f"process_uploaded_files: calling self.file_handler.process_uploaded_files")
48
+ return self.file_handler.process_uploaded_files(file_paths, enable_splitting)
49
 
50
+ def download_youtube_video(self, url: str, enable_splitting: bool, progress=None) -> str:
51
  """Download a video from YouTube
52
 
53
  Args:
54
  url: YouTube video URL
55
+ enable_splitting: Whether to enable automatic video splitting
56
  progress: Optional Gradio progress indicator
57
 
58
  Returns:
59
  Status message string
60
  """
61
+ return self.youtube_handler.download_video(url, enable_splitting, progress)
62
 
63
  def search_datasets(self, query: str) -> List[List[str]]:
64
  """Search for datasets on the Hugging Face Hub
 
85
  async def download_dataset(
86
  self,
87
  dataset_id: str,
88
+ enable_splitting: bool,
89
  progress_callback: Optional[Callable] = None
90
  ) -> Tuple[str, str]:
91
  """Download a dataset and process its video/image content
 
104
  self,
105
  dataset_id: str,
106
  file_type: str,
107
+ enable_splitting: bool,
108
  progress_callback: Optional[Callable] = None
109
  ) -> str:
110
  """Download a group of files (videos or WebDatasets)
vms/ui/project/services/importing/youtube.py CHANGED
@@ -17,11 +17,12 @@ logger = logging.getLogger(__name__)
17
  class YouTubeDownloader:
18
  """Handles downloading videos from YouTube"""
19
 
20
- def download_video(self, url: str, progress: Optional[Callable] = None) -> str:
21
  """Download a video from YouTube
22
 
23
  Args:
24
  url: YouTube video URL
 
25
  progress: Optional Gradio progress indicator
26
 
27
  Returns:
@@ -40,7 +41,10 @@ class YouTubeDownloader:
40
  if progress else None)
41
 
42
  video_id = yt.video_id
43
- output_path = VIDEOS_TO_SPLIT_PATH / f"{video_id}.mp4"
 
 
 
44
 
45
  # Download highest quality progressive MP4
46
  if progress:
@@ -58,7 +62,7 @@ class YouTubeDownloader:
58
  logger.info("Starting YouTube video download...")
59
  progress(0, desc="Starting download...")
60
 
61
- video.download(output_path=str(VIDEOS_TO_SPLIT_PATH), filename=f"{video_id}.mp4")
62
 
63
  # Update UI
64
  if progress:
 
17
  class YouTubeDownloader:
18
  """Handles downloading videos from YouTube"""
19
 
20
+ def download_video(self, url: str, enable_splitting: bool, progress: Optional[Callable] = None) -> str:
21
  """Download a video from YouTube
22
 
23
  Args:
24
  url: YouTube video URL
25
+ enable_splitting: Whether to enable automatic video splitting
26
  progress: Optional Gradio progress indicator
27
 
28
  Returns:
 
41
  if progress else None)
42
 
43
  video_id = yt.video_id
44
+
45
+ # Choose target directory based on auto-splitting setting
46
+ target_dir = VIDEOS_TO_SPLIT_PATH if enable_splitting else STAGING_PATH
47
+ output_path = target_dir / f"{video_id}.mp4"
48
 
49
  # Download highest quality progressive MP4
50
  if progress:
 
62
  logger.info("Starting YouTube video download...")
63
  progress(0, desc="Starting download...")
64
 
65
+ video.download(output_path=str(target_dir), filename=f"{video_id}.mp4")
66
 
67
  # Update UI
68
  if progress:
vms/ui/project/services/splitting.py CHANGED
@@ -63,7 +63,7 @@ class SplittingService:
63
  """Process a single video file to detect and split scenes"""
64
  try:
65
  self._processing_status[video_path.name] = f'Processing video "{video_path.name}"...'
66
-
67
  parent_caption_path = video_path.with_suffix('.txt')
68
  # Create output path for split videos
69
  base_name, _ = extract_scene_info(video_path.name)
@@ -180,6 +180,7 @@ class SplittingService:
180
 
181
  async def start_processing(self, enable_splitting: bool) -> None:
182
  """Start background processing of unprocessed videos"""
 
183
  if self.processing:
184
  return
185
 
@@ -188,6 +189,8 @@ class SplittingService:
188
  # Process each video
189
  for video_file in VIDEOS_TO_SPLIT_PATH.glob("*.mp4"):
190
  self._current_file = video_file.name
 
 
191
  await self.process_video(video_file, enable_splitting)
192
 
193
  finally:
 
63
  """Process a single video file to detect and split scenes"""
64
  try:
65
  self._processing_status[video_path.name] = f'Processing video "{video_path.name}"...'
66
+ print(f'Going to split scenes for video "{video_path.name}"...')
67
  parent_caption_path = video_path.with_suffix('.txt')
68
  # Create output path for split videos
69
  base_name, _ = extract_scene_info(video_path.name)
 
180
 
181
  async def start_processing(self, enable_splitting: bool) -> None:
182
  """Start background processing of unprocessed videos"""
183
+ #print(f"start_processing(enable_splitting={enable_splitting}), self.processing = {self.processing}")
184
  if self.processing:
185
  return
186
 
 
189
  # Process each video
190
  for video_file in VIDEOS_TO_SPLIT_PATH.glob("*.mp4"):
191
  self._current_file = video_file.name
192
+ #print(f"calling await self.process_video(video_file, {enable_splitting})")
193
+
194
  await self.process_video(video_file, enable_splitting)
195
 
196
  finally:
vms/ui/project/tabs/import_tab/import_tab.py CHANGED
@@ -90,25 +90,37 @@ class ImportTab(BaseTab):
90
  self.youtube_tab.connect_events()
91
  self.hub_tab.connect_events()
92
 
93
- def on_import_success(self, enable_splitting, enable_automatic_content_captioning, prompt_prefix):
 
 
 
 
 
94
  """Handle successful import of files"""
 
95
  # If splitting is disabled, we need to directly move videos to staging
96
- if not enable_splitting:
97
- # Copy files without splitting
98
- self._start_copy_to_staging_bg()
99
- msg = "Copying videos to staging directory without splitting..."
100
- else:
101
  # Start scene detection if not already running and there are videos to process
102
  if not self.app.splitting.is_processing():
 
103
  # Start the scene detection in a separate thread
104
  self._start_scene_detection_bg(enable_splitting)
105
  msg = "Starting automatic scene detection..."
106
  else:
107
  msg = "Scene detection already running..."
108
 
109
- # Copy files to training directory
110
- self.app.tabs["caption_tab"].copy_files_to_training_dir(prompt_prefix)
111
-
 
 
 
 
 
 
 
 
112
  # Start auto-captioning if enabled
113
  if enable_automatic_content_captioning:
114
  self._start_captioning_bg(DEFAULT_CAPTIONING_BOT_INSTRUCTIONS, prompt_prefix)
@@ -122,8 +134,9 @@ class ImportTab(BaseTab):
122
  logger.warning("Cannot switch tabs - project_tabs_component not available")
123
  return None, msg
124
 
125
- def _start_scene_detection_bg(self, enable_splitting):
126
  """Start scene detection in a background thread"""
 
127
  def run_async_in_thread():
128
  loop = asyncio.new_event_loop()
129
  asyncio.set_event_loop(loop)
@@ -207,11 +220,13 @@ class ImportTab(BaseTab):
207
  thread.daemon = True
208
  thread.start()
209
 
210
- async def update_titles_after_import(self, enable_splitting, enable_automatic_content_captioning, prompt_prefix):
211
  """Handle post-import updates including titles"""
212
  # Call the non-async version since we need to return immediately for the UI
213
  tabs, status_msg = self.on_import_success(
214
- enable_splitting, enable_automatic_content_captioning, prompt_prefix
 
 
215
  )
216
 
217
  # Get updated titles
 
90
  self.youtube_tab.connect_events()
91
  self.hub_tab.connect_events()
92
 
93
+ def on_import_success(
94
+ self,
95
+ enable_splitting: bool,
96
+ enable_automatic_content_captioning: bool,
97
+ prompt_prefix: str
98
+ ):
99
  """Handle successful import of files"""
100
+ #print(f"on_import_success(self, enable_splitting={enable_splitting}, enable_automatic_content_captioning={enable_automatic_content_captioning}, prompt_prefix={prompt_prefix})")
101
  # If splitting is disabled, we need to directly move videos to staging
102
+ if enable_splitting:
103
+ #print("on_import_success: -> splitting enabled!")
 
 
 
104
  # Start scene detection if not already running and there are videos to process
105
  if not self.app.splitting.is_processing():
106
+ #print("on_import_success: -> calling self._start_scene_detection_bg(enable_splitting)")
107
  # Start the scene detection in a separate thread
108
  self._start_scene_detection_bg(enable_splitting)
109
  msg = "Starting automatic scene detection..."
110
  else:
111
  msg = "Scene detection already running..."
112
 
113
+ # Copy files to training directory
114
+ self.app.tabs["caption_tab"].copy_files_to_training_dir(prompt_prefix)
115
+ else:
116
+ #print("on_import_success: -> splitting NOT enabled")
117
+ # Copy files without splitting
118
+ self._start_copy_to_staging_bg()
119
+ msg = "Copying videos to staging directory without splitting..."
120
+
121
+ # Also immediately copy to training directory
122
+ self.app.tabs["caption_tab"].copy_files_to_training_dir(prompt_prefix)
123
+
124
  # Start auto-captioning if enabled
125
  if enable_automatic_content_captioning:
126
  self._start_captioning_bg(DEFAULT_CAPTIONING_BOT_INSTRUCTIONS, prompt_prefix)
 
134
  logger.warning("Cannot switch tabs - project_tabs_component not available")
135
  return None, msg
136
 
137
+ def _start_scene_detection_bg(self, enable_splitting: bool):
138
  """Start scene detection in a background thread"""
139
+ print(f"_start_scene_detection_bg(enable_splitting={enable_splitting})")
140
  def run_async_in_thread():
141
  loop = asyncio.new_event_loop()
142
  asyncio.set_event_loop(loop)
 
220
  thread.daemon = True
221
  thread.start()
222
 
223
+ async def update_titles_after_import(self, enable_splitting: bool, enable_automatic_content_captioning: bool, prompt_prefix: str):
224
  """Handle post-import updates including titles"""
225
  # Call the non-async version since we need to return immediately for the UI
226
  tabs, status_msg = self.on_import_success(
227
+ enable_splitting,
228
+ enable_automatic_content_captioning,
229
+ prompt_prefix
230
  )
231
 
232
  # Get updated titles
vms/ui/project/tabs/import_tab/upload_tab.py CHANGED
@@ -62,11 +62,22 @@ class UploadTab(BaseTab):
62
  logger.warning("import_status component is not set in UploadTab")
63
  return
64
 
65
- # File upload event
66
  upload_event = self.components["files"].upload(
67
- fn=lambda x: self.app.importing.process_uploaded_files(x),
68
- inputs=[self.components["files"]],
69
  outputs=[self.components["import_status"]]
 
 
 
 
 
 
 
 
 
 
 
70
  )
71
 
72
  # Only add success handler if all required components exist
@@ -102,4 +113,4 @@ class UploadTab(BaseTab):
102
  )
103
  except (AttributeError, KeyError) as e:
104
  logger.error(f"Error connecting event handlers in UploadTab: {str(e)}")
105
- # Continue without the success handler
 
62
  logger.warning("import_status component is not set in UploadTab")
63
  return
64
 
65
+ # File upload event with enable_splitting parameter
66
  upload_event = self.components["files"].upload(
67
+ fn=self.app.importing.process_uploaded_files,
68
+ inputs=[self.components["files"], self.components["enable_automatic_video_split"]],
69
  outputs=[self.components["import_status"]]
70
+ ).success(
71
+ fn=self.app.tabs["import_tab"].on_import_success,
72
+ inputs=[
73
+ self.components["enable_automatic_video_split"],
74
+ self.components["enable_automatic_content_captioning"],
75
+ self.app.tabs["caption_tab"].components["custom_prompt_prefix"]
76
+ ],
77
+ outputs=[
78
+ self.app.project_tabs_component,
79
+ self.components["import_status"]
80
+ ]
81
  )
82
 
83
  # Only add success handler if all required components exist
 
113
  )
114
  except (AttributeError, KeyError) as e:
115
  logger.error(f"Error connecting event handlers in UploadTab: {str(e)}")
116
+ # Continue without the success handler
vms/ui/project/tabs/import_tab/youtube_tab.py CHANGED
@@ -83,8 +83,8 @@ class YouTubeTab(BaseTab):
83
 
84
  # YouTube download event
85
  download_event = self.components["youtube_download_btn"].click(
86
- fn=self.app.importing.download_youtube_video,
87
- inputs=[self.components["youtube_url"]],
88
  outputs=[self.components["import_status"]]
89
  )
90
 
@@ -106,4 +106,8 @@ class YouTubeTab(BaseTab):
106
  )
107
  except (AttributeError, KeyError) as e:
108
  logger.error(f"Error connecting success handler in YouTubeTab: {str(e)}")
109
- # Continue without the success handler
 
 
 
 
 
83
 
84
  # YouTube download event
85
  download_event = self.components["youtube_download_btn"].click(
86
+ fn=self.download_youtube_with_splitting,
87
+ inputs=[self.components["youtube_url"], self.components["enable_automatic_video_split"]],
88
  outputs=[self.components["import_status"]]
89
  )
90
 
 
106
  )
107
  except (AttributeError, KeyError) as e:
108
  logger.error(f"Error connecting success handler in YouTubeTab: {str(e)}")
109
+ # Continue without the success handler
110
+
111
+ def download_youtube_with_splitting(self, url, enable_splitting):
112
+ """Download YouTube video with splitting option"""
113
+ return self.app.importing.download_youtube_video(url, enable_splitting, gr.Progress())
vms/ui/project/tabs/preview_tab.py CHANGED
@@ -200,8 +200,10 @@ class PreviewTab(BaseTab):
200
  # Return just the model IDs as a list of simple strings
201
  version_ids = list(MODEL_VERSIONS.get(internal_type, {}).keys())
202
  logger.info(f"Found {len(version_ids)} versions for {model_type}: {version_ids}")
203
- return version_ids
204
-
 
 
205
  def get_default_model_version(self, model_type: str) -> str:
206
  """Get default model version for the given model type"""
207
  # Convert UI display name to internal name
 
200
  # Return just the model IDs as a list of simple strings
201
  version_ids = list(MODEL_VERSIONS.get(internal_type, {}).keys())
202
  logger.info(f"Found {len(version_ids)} versions for {model_type}: {version_ids}")
203
+
204
+ # Ensure they're all strings
205
+ return [str(version) for version in version_ids]
206
+
207
  def get_default_model_version(self, model_type: str) -> str:
208
  """Get default model version for the given model type"""
209
  # Convert UI display name to internal name
vms/ui/project/tabs/train_tab.py CHANGED
@@ -462,12 +462,15 @@ class TrainTab(BaseTab):
462
  # Update UI state with proper model_type first
463
  self.app.update_ui_state(model_type=model_type)
464
 
 
 
 
465
  # Create a new dropdown with the updated choices
466
  if not model_versions:
467
  logger.warning(f"No model versions available for {model_type}, using empty list")
468
  # Return empty dropdown to avoid errors
469
  return gr.Dropdown(choices=[], value=None)
470
-
471
  # Ensure default_version is in model_versions
472
  if default_version not in model_versions and model_versions:
473
  default_version = model_versions[0]
@@ -481,8 +484,7 @@ class TrainTab(BaseTab):
481
  logger.error(f"Error in update_model_versions: {str(e)}")
482
  # Return empty dropdown to avoid errors
483
  return gr.Dropdown(choices=[], value=None)
484
-
485
-
486
  def handle_training_start(
487
  self, preset, model_type, model_version, training_type,
488
  lora_rank, lora_alpha, train_steps, batch_size, learning_rate,
@@ -561,7 +563,9 @@ class TrainTab(BaseTab):
561
  # Return just the model IDs as a list of simple strings
562
  version_ids = list(MODEL_VERSIONS.get(internal_type, {}).keys())
563
  logger.info(f"Found {len(version_ids)} versions for {model_type}: {version_ids}")
564
- return version_ids
 
 
565
 
566
  def get_default_model_version(self, model_type: str) -> str:
567
  """Get default model version for the given model type"""
@@ -749,9 +753,6 @@ class TrainTab(BaseTab):
749
  model_versions = self.get_model_version_choices(model_display_name)
750
  default_model_version = self.get_default_model_version(model_display_name)
751
 
752
- # Create the model version dropdown update
753
- model_version_update = gr.Dropdown(choices=model_versions, value=default_model_version)
754
-
755
  # Ensure we have valid choices and values
756
  if not model_versions:
757
  logger.warning(f"No versions found for {model_display_name}, using empty list")
@@ -761,6 +762,12 @@ class TrainTab(BaseTab):
761
  default_model_version = model_versions[0]
762
  logger.info(f"Reset default version to first available: {default_model_version}")
763
 
 
 
 
 
 
 
764
  # Return values in the same order as the output components
765
  return (
766
  model_display_name,
 
462
  # Update UI state with proper model_type first
463
  self.app.update_ui_state(model_type=model_type)
464
 
465
+ # Ensure model_versions is a simple list of strings
466
+ model_versions = [str(version) for version in model_versions]
467
+
468
  # Create a new dropdown with the updated choices
469
  if not model_versions:
470
  logger.warning(f"No model versions available for {model_type}, using empty list")
471
  # Return empty dropdown to avoid errors
472
  return gr.Dropdown(choices=[], value=None)
473
+
474
  # Ensure default_version is in model_versions
475
  if default_version not in model_versions and model_versions:
476
  default_version = model_versions[0]
 
484
  logger.error(f"Error in update_model_versions: {str(e)}")
485
  # Return empty dropdown to avoid errors
486
  return gr.Dropdown(choices=[], value=None)
487
+
 
488
  def handle_training_start(
489
  self, preset, model_type, model_version, training_type,
490
  lora_rank, lora_alpha, train_steps, batch_size, learning_rate,
 
563
  # Return just the model IDs as a list of simple strings
564
  version_ids = list(MODEL_VERSIONS.get(internal_type, {}).keys())
565
  logger.info(f"Found {len(version_ids)} versions for {model_type}: {version_ids}")
566
+
567
+ # Ensure they're all strings
568
+ return [str(version) for version in version_ids]
569
 
570
  def get_default_model_version(self, model_type: str) -> str:
571
  """Get default model version for the given model type"""
 
753
  model_versions = self.get_model_version_choices(model_display_name)
754
  default_model_version = self.get_default_model_version(model_display_name)
755
 
 
 
 
756
  # Ensure we have valid choices and values
757
  if not model_versions:
758
  logger.warning(f"No versions found for {model_display_name}, using empty list")
 
762
  default_model_version = model_versions[0]
763
  logger.info(f"Reset default version to first available: {default_model_version}")
764
 
765
+ # Ensure model_versions is a simple list of strings
766
+ model_versions = [str(version) for version in model_versions]
767
+
768
+ # Create the model version dropdown update
769
+ model_version_update = gr.Dropdown(choices=model_versions, value=default_model_version)
770
+
771
  # Return values in the same order as the output components
772
  return (
773
  model_display_name,
vms/utils/webdataset_handler.py CHANGED
@@ -41,7 +41,9 @@ def process_webdataset_shard(
41
  """
42
  video_count = 0
43
  image_count = 0
44
-
 
 
45
  try:
46
  # Dictionary to store grouped files by prefix
47
  grouped_files = {}
 
41
  """
42
  video_count = 0
43
  image_count = 0
44
+
45
+ print(f"videos_output_dir = {videos_output_dir}")
46
+ print(f"staging_output_dir = {staging_output_dir}")
47
  try:
48
  # Dictionary to store grouped files by prefix
49
  grouped_files = {}