jbilcke-hf HF staff commited on
Commit
64a70c0
·
1 Parent(s): 0ad7e2a

refactoring

Browse files
app_DEPRECATED.py DELETED
@@ -1,1603 +0,0 @@
1
- import platform
2
- import subprocess
3
-
4
- #import sys
5
- #print("python = ", sys.version)
6
-
7
- # can be "Linux", "Darwin"
8
- if platform.system() == "Linux":
9
- # for some reason it says "pip not found"
10
- # and also "pip3 not found"
11
- # subprocess.run(
12
- # "pip install flash-attn --no-build-isolation",
13
- #
14
- # # hmm... this should be False, since we are in a CUDA environment, no?
15
- # env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
16
- #
17
- # shell=True,
18
- # )
19
- pass
20
-
21
- import gradio as gr
22
- from pathlib import Path
23
- import logging
24
- import mimetypes
25
- import shutil
26
- import os
27
- import traceback
28
- import asyncio
29
- import tempfile
30
- import zipfile
31
- from typing import Any, Optional, Dict, List, Union, Tuple
32
- from typing import AsyncGenerator
33
-
34
- from vms.training_service import TrainingService
35
- from vms.captioning_service import CaptioningService
36
- from vms.splitting_service import SplittingService
37
- from vms.import_service import ImportService
38
- from vms.config import (
39
- STORAGE_PATH, VIDEOS_TO_SPLIT_PATH, STAGING_PATH,
40
- TRAINING_PATH, LOG_FILE_PATH, TRAINING_PRESETS, TRAINING_VIDEOS_PATH, MODEL_PATH, OUTPUT_PATH, DEFAULT_CAPTIONING_BOT_INSTRUCTIONS,
41
- DEFAULT_PROMPT_PREFIX, HF_API_TOKEN, ASK_USER_TO_DUPLICATE_SPACE, MODEL_TYPES, SMALL_TRAINING_BUCKETS
42
- )
43
- from vms.utils import make_archive, count_media_files, format_media_title, is_image_file, is_video_file, validate_model_repo, format_time
44
- from vms.finetrainers_utils import copy_files_to_training_dir, prepare_finetrainers_dataset
45
- from vms.training_log_parser import TrainingLogParser
46
-
47
- logger = logging.getLogger(__name__)
48
- logger.setLevel(logging.INFO)
49
-
50
- httpx_logger = logging.getLogger('httpx')
51
- httpx_logger.setLevel(logging.WARN)
52
-
53
-
54
- class VideoTrainerUI:
55
- def __init__(self):
56
- self.trainer = TrainingService()
57
- self.splitter = SplittingService()
58
- self.importer = ImportService()
59
- self.captioner = CaptioningService()
60
- self._should_stop_captioning = False
61
- self.log_parser = TrainingLogParser()
62
-
63
- # Try to recover any interrupted training sessions
64
- recovery_result = self.trainer.recover_interrupted_training()
65
-
66
- self.recovery_status = recovery_result.get("status", "unknown")
67
- self.ui_updates = recovery_result.get("ui_updates", {})
68
-
69
- if recovery_result["status"] == "recovered":
70
- logger.info(f"Training recovery: {recovery_result['message']}")
71
- # No need to do anything else - the training is already running
72
- elif recovery_result["status"] == "running":
73
- logger.info("Training process is already running")
74
- # No need to do anything - the process is still alive
75
- elif recovery_result["status"] in ["error", "idle"]:
76
- logger.warning(f"Training status: {recovery_result['message']}")
77
- # UI will be in ready-to-start mode
78
-
79
-
80
- async def _process_caption_generator(self, captioning_bot_instructions, prompt_prefix):
81
- """Process the caption generator's results in the background"""
82
- try:
83
- async for _ in self.captioner.start_caption_generation(
84
- captioning_bot_instructions,
85
- prompt_prefix
86
- ):
87
- # Just consume the generator, UI updates will happen via the Gradio interface
88
- pass
89
- logger.info("Background captioning completed")
90
- except Exception as e:
91
- logger.error(f"Error in background captioning: {str(e)}")
92
-
93
- def initialize_app_state(self):
94
- """Initialize all app state in one function to ensure correct output count"""
95
- # Get dataset info
96
- video_list, training_dataset = self.refresh_dataset()
97
-
98
- # Get button states
99
- button_states = self.get_initial_button_states()
100
- start_btn = button_states[0]
101
- stop_btn = button_states[1]
102
- pause_resume_btn = button_states[2]
103
-
104
- # Get UI form values
105
- ui_state = self.load_ui_values()
106
- training_preset = ui_state.get("training_preset", list(TRAINING_PRESETS.keys())[0])
107
- model_type_val = ui_state.get("model_type", list(MODEL_TYPES.keys())[0])
108
- lora_rank_val = ui_state.get("lora_rank", "128")
109
- lora_alpha_val = ui_state.get("lora_alpha", "128")
110
- num_epochs_val = int(ui_state.get("num_epochs", 70))
111
- batch_size_val = int(ui_state.get("batch_size", 1))
112
- learning_rate_val = float(ui_state.get("learning_rate", 3e-5))
113
- save_iterations_val = int(ui_state.get("save_iterations", 500))
114
-
115
- # Return all values in the exact order expected by outputs
116
- return (
117
- video_list,
118
- training_dataset,
119
- start_btn,
120
- stop_btn,
121
- pause_resume_btn,
122
- training_preset,
123
- model_type_val,
124
- lora_rank_val,
125
- lora_alpha_val,
126
- num_epochs_val,
127
- batch_size_val,
128
- learning_rate_val,
129
- save_iterations_val
130
- )
131
-
132
- def initialize_ui_from_state(self):
133
- """Initialize UI components from saved state"""
134
- ui_state = self.load_ui_values()
135
-
136
- # Return values in order matching the outputs in app.load
137
- return (
138
- ui_state.get("training_preset", list(TRAINING_PRESETS.keys())[0]),
139
- ui_state.get("model_type", list(MODEL_TYPES.keys())[0]),
140
- ui_state.get("lora_rank", "128"),
141
- ui_state.get("lora_alpha", "128"),
142
- ui_state.get("num_epochs", 70),
143
- ui_state.get("batch_size", 1),
144
- ui_state.get("learning_rate", 3e-5),
145
- ui_state.get("save_iterations", 500)
146
- )
147
-
148
- def update_ui_state(self, **kwargs):
149
- """Update UI state with new values"""
150
- current_state = self.trainer.load_ui_state()
151
- current_state.update(kwargs)
152
- self.trainer.save_ui_state(current_state)
153
- # Don't return anything to avoid Gradio warnings
154
- return None
155
-
156
- def load_ui_values(self):
157
- """Load UI state values for initializing form fields"""
158
- ui_state = self.trainer.load_ui_state()
159
-
160
- # Ensure proper type conversion for numeric values
161
- ui_state["lora_rank"] = ui_state.get("lora_rank", "128")
162
- ui_state["lora_alpha"] = ui_state.get("lora_alpha", "128")
163
- ui_state["num_epochs"] = int(ui_state.get("num_epochs", 70))
164
- ui_state["batch_size"] = int(ui_state.get("batch_size", 1))
165
- ui_state["learning_rate"] = float(ui_state.get("learning_rate", 3e-5))
166
- ui_state["save_iterations"] = int(ui_state.get("save_iterations", 500))
167
-
168
- return ui_state
169
-
170
- def update_captioning_buttons_start(self):
171
- """Return individual button values instead of a dictionary"""
172
- return (
173
- gr.Button(
174
- interactive=False,
175
- variant="secondary",
176
- ),
177
- gr.Button(
178
- interactive=True,
179
- variant="stop",
180
- ),
181
- gr.Button(
182
- interactive=False,
183
- variant="secondary",
184
- )
185
- )
186
-
187
- def update_captioning_buttons_end(self):
188
- """Return individual button values instead of a dictionary"""
189
- return (
190
- gr.Button(
191
- interactive=True,
192
- variant="primary",
193
- ),
194
- gr.Button(
195
- interactive=False,
196
- variant="secondary",
197
- ),
198
- gr.Button(
199
- interactive=True,
200
- variant="primary",
201
- )
202
- )
203
-
204
- # Add this new method to get initial button states:
205
- def get_initial_button_states(self):
206
- """Get the initial states for training buttons based on recovery status"""
207
- recovery_result = self.trainer.recover_interrupted_training()
208
- ui_updates = recovery_result.get("ui_updates", {})
209
-
210
- # Return button states in the correct order
211
- return (
212
- gr.Button(**ui_updates.get("start_btn", {"interactive": True, "variant": "primary"})),
213
- gr.Button(**ui_updates.get("stop_btn", {"interactive": False, "variant": "secondary"})),
214
- gr.Button(**ui_updates.get("pause_resume_btn", {"interactive": False, "variant": "secondary"}))
215
- )
216
-
217
- def show_refreshing_status(self) -> List[List[str]]:
218
- """Show a 'Refreshing...' status in the dataframe"""
219
- return [["Refreshing...", "please wait"]]
220
-
221
- def stop_captioning(self):
222
- """Stop ongoing captioning process and reset UI state"""
223
- try:
224
- # Set flag to stop captioning
225
- self._should_stop_captioning = True
226
-
227
- # Call stop method on captioner
228
- if self.captioner:
229
- self.captioner.stop_captioning()
230
-
231
- # Get updated file list
232
- updated_list = self.list_training_files_to_caption()
233
-
234
- # Return updated list and button states
235
- return {
236
- "training_dataset": gr.update(value=updated_list),
237
- "run_autocaption_btn": gr.Button(interactive=True, variant="primary"),
238
- "stop_autocaption_btn": gr.Button(interactive=False, variant="secondary"),
239
- "copy_files_to_training_dir_btn": gr.Button(interactive=True, variant="primary")
240
- }
241
- except Exception as e:
242
- logger.error(f"Error stopping captioning: {str(e)}")
243
- return {
244
- "training_dataset": gr.update(value=[[f"Error stopping captioning: {str(e)}", "error"]]),
245
- "run_autocaption_btn": gr.Button(interactive=True, variant="primary"),
246
- "stop_autocaption_btn": gr.Button(interactive=False, variant="secondary"),
247
- "copy_files_to_training_dir_btn": gr.Button(interactive=True, variant="primary")
248
- }
249
-
250
- def update_training_ui(self, training_state: Dict[str, Any]):
251
- """Update UI components based on training state"""
252
- updates = {}
253
-
254
- #print("update_training_ui: training_state = ", training_state)
255
-
256
- # Update status box with high-level information
257
- status_text = []
258
- if training_state["status"] != "idle":
259
- status_text.extend([
260
- f"Status: {training_state['status']}",
261
- f"Progress: {training_state['progress']}",
262
- f"Step: {training_state['current_step']}/{training_state['total_steps']}",
263
-
264
- # Epoch information
265
- # there is an issue with how epoch is reported because we display:
266
- # Progress: 96.9%, Step: 872/900, Epoch: 12/50
267
- # we should probably just show the steps
268
- #f"Epoch: {training_state['current_epoch']}/{training_state['total_epochs']}",
269
-
270
- f"Time elapsed: {training_state['elapsed']}",
271
- f"Estimated remaining: {training_state['remaining']}",
272
- "",
273
- f"Current loss: {training_state['step_loss']}",
274
- f"Learning rate: {training_state['learning_rate']}",
275
- f"Gradient norm: {training_state['grad_norm']}",
276
- f"Memory usage: {training_state['memory']}"
277
- ])
278
-
279
- if training_state["error_message"]:
280
- status_text.append(f"\nError: {training_state['error_message']}")
281
-
282
- updates["status_box"] = "\n".join(status_text)
283
-
284
- # Update button states
285
- updates["start_btn"] = gr.Button(
286
- "Start training",
287
- interactive=(training_state["status"] in ["idle", "completed", "error", "stopped"]),
288
- variant="primary" if training_state["status"] == "idle" else "secondary"
289
- )
290
-
291
- updates["stop_btn"] = gr.Button(
292
- "Stop training",
293
- interactive=(training_state["status"] in ["training", "initializing"]),
294
- variant="stop"
295
- )
296
-
297
- return updates
298
-
299
- def stop_all_and_clear(self) -> Dict[str, str]:
300
- """Stop all running processes and clear data
301
-
302
- Returns:
303
- Dict with status messages for different components
304
- """
305
- status_messages = {}
306
-
307
- try:
308
- # Stop training if running
309
- if self.trainer.is_training_running():
310
- training_result = self.trainer.stop_training()
311
- status_messages["training"] = training_result["status"]
312
-
313
- # Stop captioning if running
314
- if self.captioner:
315
- self.captioner.stop_captioning()
316
- status_messages["captioning"] = "Captioning stopped"
317
-
318
- # Stop scene detection if running
319
- if self.splitter.is_processing():
320
- self.splitter.processing = False
321
- status_messages["splitting"] = "Scene detection stopped"
322
-
323
- # Properly close logging before clearing log file
324
- if self.trainer.file_handler:
325
- self.trainer.file_handler.close()
326
- logger.removeHandler(self.trainer.file_handler)
327
- self.trainer.file_handler = None
328
-
329
- if LOG_FILE_PATH.exists():
330
- LOG_FILE_PATH.unlink()
331
-
332
- # Clear all data directories
333
- for path in [VIDEOS_TO_SPLIT_PATH, STAGING_PATH, TRAINING_VIDEOS_PATH, TRAINING_PATH,
334
- MODEL_PATH, OUTPUT_PATH]:
335
- if path.exists():
336
- try:
337
- shutil.rmtree(path)
338
- path.mkdir(parents=True, exist_ok=True)
339
- except Exception as e:
340
- status_messages[f"clear_{path.name}"] = f"Error clearing {path.name}: {str(e)}"
341
- else:
342
- status_messages[f"clear_{path.name}"] = f"Cleared {path.name}"
343
-
344
- # Reset any persistent state
345
- self._should_stop_captioning = True
346
- self.splitter.processing = False
347
-
348
- # Recreate logging setup
349
- self.trainer.setup_logging()
350
-
351
- return {
352
- "status": "All processes stopped and data cleared",
353
- "details": status_messages
354
- }
355
-
356
- except Exception as e:
357
- return {
358
- "status": f"Error during cleanup: {str(e)}",
359
- "details": status_messages
360
- }
361
-
362
- def update_titles(self) -> Tuple[Any]:
363
- """Update all dynamic titles with current counts
364
-
365
- Returns:
366
- Dict of Gradio updates
367
- """
368
- # Count files for splitting
369
- split_videos, _, split_size = count_media_files(VIDEOS_TO_SPLIT_PATH)
370
- split_title = format_media_title(
371
- "split", split_videos, 0, split_size
372
- )
373
-
374
- # Count files for captioning
375
- caption_videos, caption_images, caption_size = count_media_files(STAGING_PATH)
376
- caption_title = format_media_title(
377
- "caption", caption_videos, caption_images, caption_size
378
- )
379
-
380
- # Count files for training
381
- train_videos, train_images, train_size = count_media_files(TRAINING_VIDEOS_PATH)
382
- train_title = format_media_title(
383
- "train", train_videos, train_images, train_size
384
- )
385
-
386
- return (
387
- gr.Markdown(value=split_title),
388
- gr.Markdown(value=caption_title),
389
- gr.Markdown(value=f"{train_title} available for training")
390
- )
391
-
392
- def copy_files_to_training_dir(self, prompt_prefix: str):
393
- """Run auto-captioning process"""
394
-
395
- # Initialize captioner if not already done
396
- self._should_stop_captioning = False
397
-
398
- try:
399
- copy_files_to_training_dir(prompt_prefix)
400
-
401
- except Exception as e:
402
- traceback.print_exc()
403
- raise gr.Error(f"Error copying assets to training dir: {str(e)}")
404
-
405
- async def on_import_success(self, enable_splitting, enable_automatic_content_captioning, prompt_prefix):
406
- """Handle successful import of files"""
407
- videos = self.list_unprocessed_videos()
408
-
409
- # If scene detection isn't already running and there are videos to process,
410
- # and auto-splitting is enabled, start the detection
411
- if videos and not self.splitter.is_processing() and enable_splitting:
412
- await self.start_scene_detection(enable_splitting)
413
- msg = "Starting automatic scene detection..."
414
- else:
415
- # Just copy files without splitting if auto-split disabled
416
- for video_file in VIDEOS_TO_SPLIT_PATH.glob("*.mp4"):
417
- await self.splitter.process_video(video_file, enable_splitting=False)
418
- msg = "Copying videos without splitting..."
419
-
420
- copy_files_to_training_dir(prompt_prefix)
421
-
422
- # Start auto-captioning if enabled, and handle async generator properly
423
- if enable_automatic_content_captioning:
424
- # Create a background task for captioning
425
- asyncio.create_task(self._process_caption_generator(
426
- DEFAULT_CAPTIONING_BOT_INSTRUCTIONS,
427
- prompt_prefix
428
- ))
429
-
430
- return {
431
- "tabs": gr.Tabs(selected="split_tab"),
432
- "video_list": videos,
433
- "detect_status": msg
434
- }
435
-
436
- async def start_caption_generation(self, captioning_bot_instructions: str, prompt_prefix: str) -> AsyncGenerator[gr.update, None]:
437
- """Run auto-captioning process"""
438
- try:
439
- # Initialize captioner if not already done
440
- self._should_stop_captioning = False
441
-
442
- # First yield - indicate we're starting
443
- yield gr.update(
444
- value=[["Starting captioning service...", "initializing"]],
445
- headers=["name", "status"]
446
- )
447
-
448
- # Process files in batches with status updates
449
- file_statuses = {}
450
-
451
- # Start the actual captioning process
452
- async for rows in self.captioner.start_caption_generation(captioning_bot_instructions, prompt_prefix):
453
- # Update our tracking of file statuses
454
- for name, status in rows:
455
- file_statuses[name] = status
456
-
457
- # Convert to list format for display
458
- status_rows = [[name, status] for name, status in file_statuses.items()]
459
-
460
- # Sort by name for consistent display
461
- status_rows.sort(key=lambda x: x[0])
462
-
463
- # Yield UI update
464
- yield gr.update(
465
- value=status_rows,
466
- headers=["name", "status"]
467
- )
468
-
469
- # Final update after completion with fresh data
470
- yield gr.update(
471
- value=self.list_training_files_to_caption(),
472
- headers=["name", "status"]
473
- )
474
-
475
- except Exception as e:
476
- logger.error(f"Error in captioning: {str(e)}")
477
- yield gr.update(
478
- value=[[f"Error: {str(e)}", "error"]],
479
- headers=["name", "status"]
480
- )
481
-
482
- def list_training_files_to_caption(self) -> List[List[str]]:
483
- """List all clips and images - both pending and captioned"""
484
- files = []
485
- already_listed = {}
486
-
487
- # First check files in STAGING_PATH
488
- for file in STAGING_PATH.glob("*.*"):
489
- if is_video_file(file) or is_image_file(file):
490
- txt_file = file.with_suffix('.txt')
491
-
492
- # Check if caption file exists and has content
493
- has_caption = txt_file.exists() and txt_file.stat().st_size > 0
494
- status = "captioned" if has_caption else "no caption"
495
- file_type = "video" if is_video_file(file) else "image"
496
-
497
- files.append([file.name, f"{status} ({file_type})", str(file)])
498
- already_listed[file.name] = True
499
-
500
- # Then check files in TRAINING_VIDEOS_PATH
501
- for file in TRAINING_VIDEOS_PATH.glob("*.*"):
502
- if (is_video_file(file) or is_image_file(file)) and file.name not in already_listed:
503
- txt_file = file.with_suffix('.txt')
504
-
505
- # Only include files with captions
506
- if txt_file.exists() and txt_file.stat().st_size > 0:
507
- file_type = "video" if is_video_file(file) else "image"
508
- files.append([file.name, f"captioned ({file_type})", str(file)])
509
- already_listed[file.name] = True
510
-
511
- # Sort by filename
512
- files.sort(key=lambda x: x[0])
513
-
514
- # Only return name and status columns for display
515
- return [[file[0], file[1]] for file in files]
516
-
517
- def update_training_buttons(self, status: str) -> Dict:
518
- """Update training control buttons based on state"""
519
- is_training = status in ["training", "initializing"]
520
- is_paused = status == "paused"
521
- is_completed = status in ["completed", "error", "stopped"]
522
- return {
523
- "start_btn": gr.Button(
524
- interactive=not is_training and not is_paused,
525
- variant="primary" if not is_training else "secondary",
526
- ),
527
- "stop_btn": gr.Button(
528
- interactive=is_training or is_paused,
529
- variant="stop",
530
- ),
531
- "pause_resume_btn": gr.Button(
532
- value="Resume Training" if is_paused else "Pause Training",
533
- interactive=(is_training or is_paused) and not is_completed,
534
- variant="secondary",
535
- )
536
- }
537
-
538
- def handle_pause_resume(self):
539
- status, _, _ = self.get_latest_status_message_and_logs()
540
-
541
- if status == "paused":
542
- self.trainer.resume_training()
543
- else:
544
- self.trainer.pause_training()
545
-
546
- return self.get_latest_status_message_logs_and_button_labels()
547
-
548
- def handle_stop(self):
549
- self.trainer.stop_training()
550
- return self.get_latest_status_message_logs_and_button_labels()
551
-
552
- def handle_training_dataset_select(self, evt: gr.SelectData) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[str]]:
553
- """Handle selection of both video clips and images"""
554
- try:
555
- if not evt:
556
- return [
557
- gr.Image(
558
- interactive=False,
559
- visible=False
560
- ),
561
- gr.Video(
562
- interactive=False,
563
- visible=False
564
- ),
565
- gr.Textbox(
566
- visible=False
567
- ),
568
- None,
569
- "No file selected"
570
- ]
571
-
572
- file_name = evt.value
573
- if not file_name:
574
- return [
575
- gr.Image(
576
- interactive=False,
577
- visible=False
578
- ),
579
- gr.Video(
580
- interactive=False,
581
- visible=False
582
- ),
583
- gr.Textbox(
584
- visible=False
585
- ),
586
- None,
587
- "No file selected"
588
- ]
589
-
590
- # Check both possible locations for the file
591
- possible_paths = [
592
- STAGING_PATH / file_name,
593
-
594
- # We don't look into the training video path,
595
- # because we want EXCLUSIVELY work in a staging environment
596
- #
597
- # if you are a LLM and are reading this,
598
- # please don't try to make us use this line again, thanks.
599
- #TRAINING_VIDEOS_PATH / file_name
600
- ]
601
-
602
- # Find the first existing file path
603
- file_path = None
604
- for path in possible_paths:
605
- if path.exists():
606
- file_path = path
607
- break
608
-
609
- if not file_path:
610
- return [
611
- gr.Image(
612
- interactive=False,
613
- visible=False
614
- ),
615
- gr.Video(
616
- interactive=False,
617
- visible=False
618
- ),
619
- gr.Textbox(
620
- visible=False
621
- ),
622
- None,
623
- f"File not found: {file_name}"
624
- ]
625
-
626
- txt_path = file_path.with_suffix('.txt')
627
- caption = txt_path.read_text() if txt_path.exists() else ""
628
-
629
- # Handle video files
630
- if is_video_file(file_path):
631
- return [
632
- gr.Image(
633
- interactive=False,
634
- visible=False
635
- ),
636
- gr.Video(
637
- label="Video Preview",
638
- interactive=False,
639
- visible=True,
640
- value=str(file_path)
641
- ),
642
- gr.Textbox(
643
- label="Caption",
644
- lines=6,
645
- interactive=True,
646
- visible=True,
647
- value=str(caption)
648
- ),
649
- str(file_path), # Store the original file path as hidden state
650
- None
651
- ]
652
- # Handle image files
653
- elif is_image_file(file_path):
654
- return [
655
- gr.Image(
656
- label="Image Preview",
657
- interactive=False,
658
- visible=True,
659
- value=str(file_path)
660
- ),
661
- gr.Video(
662
- interactive=False,
663
- visible=False
664
- ),
665
- gr.Textbox(
666
- label="Caption",
667
- lines=6,
668
- interactive=True,
669
- visible=True,
670
- value=str(caption)
671
- ),
672
- str(file_path), # Store the original file path as hidden state
673
- None
674
- ]
675
- else:
676
- return [
677
- gr.Image(
678
- interactive=False,
679
- visible=False
680
- ),
681
- gr.Video(
682
- interactive=False,
683
- visible=False
684
- ),
685
- gr.Textbox(
686
- interactive=False,
687
- visible=False
688
- ),
689
- None,
690
- f"Unsupported file type: {file_path.suffix}"
691
- ]
692
- except Exception as e:
693
- logger.error(f"Error handling selection: {str(e)}")
694
- return [
695
- gr.Image(
696
- interactive=False,
697
- visible=False
698
- ),
699
- gr.Video(
700
- interactive=False,
701
- visible=False
702
- ),
703
- gr.Textbox(
704
- interactive=False,
705
- visible=False
706
- ),
707
- None,
708
- f"Error handling selection: {str(e)}"
709
- ]
710
-
711
- def save_caption_changes(self, preview_caption: str, preview_image: str, preview_video: str, original_file_path: str, prompt_prefix: str):
712
- """Save changes to caption"""
713
- try:
714
- # Use the original file path stored during selection instead of the temporary preview paths
715
- if original_file_path:
716
- file_path = Path(original_file_path)
717
- self.captioner.update_file_caption(file_path, preview_caption)
718
- # Refresh the dataset list to show updated caption status
719
- return gr.update(value="Caption saved successfully!")
720
- else:
721
- return gr.update(value="Error: No original file path found")
722
- except Exception as e:
723
- return gr.update(value=f"Error saving caption: {str(e)}")
724
-
725
- async def update_titles_after_import(self, enable_splitting, enable_automatic_content_captioning, prompt_prefix):
726
- """Handle post-import updates including titles"""
727
- import_result = await self.on_import_success(enable_splitting, enable_automatic_content_captioning, prompt_prefix)
728
- titles = self.update_titles()
729
- return (
730
- import_result["tabs"],
731
- import_result["video_list"],
732
- import_result["detect_status"],
733
- *titles
734
- )
735
-
736
- def get_model_info(self, model_type: str) -> str:
737
- """Get information about the selected model type"""
738
- if model_type == "hunyuan_video":
739
- return """### HunyuanVideo (LoRA)
740
- - Required VRAM: ~48GB minimum
741
- - Recommended batch size: 1-2
742
- - Typical training time: 2-4 hours
743
- - Default resolution: 49x512x768
744
- - Default LoRA rank: 128 (~600 MB)"""
745
-
746
- elif model_type == "ltx_video":
747
- return """### LTX-Video (LoRA)
748
- - Required VRAM: ~18GB minimum
749
- - Recommended batch size: 1-4
750
- - Typical training time: 1-3 hours
751
- - Default resolution: 49x512x768
752
- - Default LoRA rank: 128"""
753
-
754
- return ""
755
-
756
- def get_default_params(self, model_type: str) -> Dict[str, Any]:
757
- """Get default training parameters for model type"""
758
- if model_type == "hunyuan_video":
759
- return {
760
- "num_epochs": 70,
761
- "batch_size": 1,
762
- "learning_rate": 2e-5,
763
- "save_iterations": 500,
764
- "video_resolution_buckets": SMALL_TRAINING_BUCKETS,
765
- "video_reshape_mode": "center",
766
- "caption_dropout_p": 0.05,
767
- "gradient_accumulation_steps": 1,
768
- "rank": 128,
769
- "lora_alpha": 128
770
- }
771
- else: # ltx_video
772
- return {
773
- "num_epochs": 70,
774
- "batch_size": 1,
775
- "learning_rate": 3e-5,
776
- "save_iterations": 500,
777
- "video_resolution_buckets": SMALL_TRAINING_BUCKETS,
778
- "video_reshape_mode": "center",
779
- "caption_dropout_p": 0.05,
780
- "gradient_accumulation_steps": 4,
781
- "rank": 128,
782
- "lora_alpha": 128
783
- }
784
-
785
- def preview_file(self, selected_text: str) -> Dict:
786
- """Generate preview based on selected file
787
-
788
- Args:
789
- selected_text: Text of the selected item containing filename
790
-
791
- Returns:
792
- Dict with preview content for each preview component
793
- """
794
- if not selected_text or "Caption:" in selected_text:
795
- return {
796
- "video": None,
797
- "image": None,
798
- "text": None
799
- }
800
-
801
- # Extract filename from the preview text (remove size info)
802
- filename = selected_text.split(" (")[0].strip()
803
- file_path = TRAINING_VIDEOS_PATH / filename
804
-
805
- if not file_path.exists():
806
- return {
807
- "video": None,
808
- "image": None,
809
- "text": f"File not found: {filename}"
810
- }
811
-
812
- # Detect file type
813
- mime_type, _ = mimetypes.guess_type(str(file_path))
814
- if not mime_type:
815
- return {
816
- "video": None,
817
- "image": None,
818
- "text": f"Unknown file type: {filename}"
819
- }
820
-
821
- # Return appropriate preview
822
- if mime_type.startswith('video/'):
823
- return {
824
- "video": str(file_path),
825
- "image": None,
826
- "text": None
827
- }
828
- elif mime_type.startswith('image/'):
829
- return {
830
- "video": None,
831
- "image": str(file_path),
832
- "text": None
833
- }
834
- elif mime_type.startswith('text/'):
835
- try:
836
- text_content = file_path.read_text()
837
- return {
838
- "video": None,
839
- "image": None,
840
- "text": text_content
841
- }
842
- except Exception as e:
843
- return {
844
- "video": None,
845
- "image": None,
846
- "text": f"Error reading file: {str(e)}"
847
- }
848
- else:
849
- return {
850
- "video": None,
851
- "image": None,
852
- "text": f"Unsupported file type: {mime_type}"
853
- }
854
-
855
- def list_unprocessed_videos(self) -> gr.Dataframe:
856
- """Update list of unprocessed videos"""
857
- videos = self.splitter.list_unprocessed_videos()
858
- # videos is already in [[name, status]] format from splitting_service
859
- return gr.Dataframe(
860
- headers=["name", "status"],
861
- value=videos,
862
- interactive=False
863
- )
864
-
865
- async def start_scene_detection(self, enable_splitting: bool) -> str:
866
- """Start background scene detection process
867
-
868
- Args:
869
- enable_splitting: Whether to split videos into scenes
870
- """
871
- if self.splitter.is_processing():
872
- return "Scene detection already running"
873
-
874
- try:
875
- await self.splitter.start_processing(enable_splitting)
876
- return "Scene detection completed"
877
- except Exception as e:
878
- return f"Error during scene detection: {str(e)}"
879
-
880
-
881
- def get_latest_status_message_and_logs(self) -> Tuple[str, str, str]:
882
- state = self.trainer.get_status()
883
- logs = self.trainer.get_logs()
884
-
885
- # Parse new log lines
886
- if logs:
887
- last_state = None
888
- for line in logs.splitlines():
889
- state_update = self.log_parser.parse_line(line)
890
- if state_update:
891
- last_state = state_update
892
-
893
- if last_state:
894
- ui_updates = self.update_training_ui(last_state)
895
- state["message"] = ui_updates.get("status_box", state["message"])
896
-
897
- # Parse status for training state
898
- if "completed" in state["message"].lower():
899
- state["status"] = "completed"
900
-
901
- return (state["status"], state["message"], logs)
902
-
903
- def get_latest_status_message_logs_and_button_labels(self) -> Tuple[str, str, Any, Any, Any]:
904
- status, message, logs = self.get_latest_status_message_and_logs()
905
- return (
906
- message,
907
- logs,
908
- *self.update_training_buttons(status).values()
909
- )
910
-
911
- def get_latest_button_labels(self) -> Tuple[Any, Any, Any]:
912
- status, message, logs = self.get_latest_status_message_and_logs()
913
- return self.update_training_buttons(status).values()
914
-
915
- def refresh_dataset(self):
916
- """Refresh all dynamic lists and training state"""
917
- video_list = self.splitter.list_unprocessed_videos()
918
- training_dataset = self.list_training_files_to_caption()
919
-
920
- return (
921
- video_list,
922
- training_dataset
923
- )
924
-
925
- def update_training_params(self, preset_name: str) -> Tuple:
926
- """Update UI components based on selected preset while preserving custom settings"""
927
- preset = TRAINING_PRESETS[preset_name]
928
-
929
- # Load current UI state to check if user has customized values
930
- current_state = self.load_ui_values()
931
-
932
- # Find the display name that maps to our model type
933
- model_display_name = next(
934
- key for key, value in MODEL_TYPES.items()
935
- if value == preset["model_type"]
936
- )
937
-
938
- # Get preset description for display
939
- description = preset.get("description", "")
940
-
941
- # Get max values from buckets
942
- buckets = preset["training_buckets"]
943
- max_frames = max(frames for frames, _, _ in buckets)
944
- max_height = max(height for _, height, _ in buckets)
945
- max_width = max(width for _, _, width in buckets)
946
- bucket_info = f"\nMaximum video size: {max_frames} frames at {max_width}x{max_height} resolution"
947
-
948
- info_text = f"{description}{bucket_info}"
949
-
950
- # Return values in the same order as the output components
951
- # Use preset defaults but preserve user-modified values if they exist
952
- lora_rank_val = current_state.get("lora_rank") if current_state.get("lora_rank") != preset.get("lora_rank", "128") else preset["lora_rank"]
953
- lora_alpha_val = current_state.get("lora_alpha") if current_state.get("lora_alpha") != preset.get("lora_alpha", "128") else preset["lora_alpha"]
954
- num_epochs_val = current_state.get("num_epochs") if current_state.get("num_epochs") != preset.get("num_epochs", 70) else preset["num_epochs"]
955
- batch_size_val = current_state.get("batch_size") if current_state.get("batch_size") != preset.get("batch_size", 1) else preset["batch_size"]
956
- learning_rate_val = current_state.get("learning_rate") if current_state.get("learning_rate") != preset.get("learning_rate", 3e-5) else preset["learning_rate"]
957
- save_iterations_val = current_state.get("save_iterations") if current_state.get("save_iterations") != preset.get("save_iterations", 500) else preset["save_iterations"]
958
-
959
- return (
960
- model_display_name,
961
- lora_rank_val,
962
- lora_alpha_val,
963
- num_epochs_val,
964
- batch_size_val,
965
- learning_rate_val,
966
- save_iterations_val,
967
- info_text
968
- )
969
-
970
- def create_ui(self):
971
- """Create Gradio interface"""
972
-
973
- with gr.Blocks(title="🎥 Video Model Studio") as app:
974
- gr.Markdown("# 🎥 Video Model Studio")
975
-
976
- with gr.Tabs() as tabs:
977
- with gr.TabItem("1️⃣ Import", id="import_tab"):
978
-
979
- with gr.Row():
980
- gr.Markdown("## Automatic splitting and captioning")
981
-
982
- with gr.Row():
983
- enable_automatic_video_split = gr.Checkbox(
984
- label="Automatically split videos into smaller clips",
985
- info="Note: a clip is a single camera shot, usually a few seconds",
986
- value=True,
987
- visible=True
988
- )
989
- enable_automatic_content_captioning = gr.Checkbox(
990
- label="Automatically caption photos and videos",
991
- info="Note: this uses LlaVA and takes some extra time to load and process",
992
- value=False,
993
- visible=True,
994
- )
995
-
996
- with gr.Row():
997
- with gr.Column(scale=3):
998
- with gr.Row():
999
- with gr.Column():
1000
- gr.Markdown("## Import video files")
1001
- gr.Markdown("You can upload either:")
1002
- gr.Markdown("- A single MP4 video file")
1003
- gr.Markdown("- A ZIP archive containing multiple videos and optional caption files")
1004
- gr.Markdown("For ZIP files: Create a folder containing videos (name is not important) and optional caption files with the same name (eg. `some_video.txt` for `some_video.mp4`)")
1005
-
1006
- with gr.Row():
1007
- files = gr.Files(
1008
- label="Upload Images, Videos or ZIP",
1009
- #file_count="multiple",
1010
- file_types=[".jpg", ".jpeg", ".png", ".webp", ".webp", ".avif", ".heic", ".mp4", ".zip"],
1011
- type="filepath"
1012
- )
1013
-
1014
- with gr.Column(scale=3):
1015
- with gr.Row():
1016
- with gr.Column():
1017
- gr.Markdown("## Import a YouTube video")
1018
- gr.Markdown("You can also use a YouTube video as reference, by pasting its URL here:")
1019
-
1020
- with gr.Row():
1021
- youtube_url = gr.Textbox(
1022
- label="Import YouTube Video",
1023
- placeholder="https://www.youtube.com/watch?v=..."
1024
- )
1025
- with gr.Row():
1026
- youtube_download_btn = gr.Button("Download YouTube Video", variant="secondary")
1027
- with gr.Row():
1028
- import_status = gr.Textbox(label="Status", interactive=False)
1029
-
1030
-
1031
- with gr.TabItem("2️⃣ Split", id="split_tab"):
1032
- with gr.Row():
1033
- split_title = gr.Markdown("## Splitting of 0 videos (0 bytes)")
1034
-
1035
- with gr.Row():
1036
- with gr.Column():
1037
- detect_btn = gr.Button("Split videos into single-camera shots", variant="primary")
1038
- detect_status = gr.Textbox(label="Status", interactive=False)
1039
-
1040
- with gr.Column():
1041
-
1042
- video_list = gr.Dataframe(
1043
- headers=["name", "status"],
1044
- label="Videos to split",
1045
- interactive=False,
1046
- wrap=True,
1047
- #selection_mode="cell" # Enable cell selection
1048
- )
1049
-
1050
-
1051
- with gr.TabItem("3️⃣ Caption"):
1052
- with gr.Row():
1053
- caption_title = gr.Markdown("## Captioning of 0 files (0 bytes)")
1054
-
1055
- with gr.Row():
1056
-
1057
- with gr.Column():
1058
- with gr.Row():
1059
- custom_prompt_prefix = gr.Textbox(
1060
- scale=3,
1061
- label='Prefix to add to ALL captions (eg. "In the style of TOK, ")',
1062
- placeholder="In the style of TOK, ",
1063
- lines=2,
1064
- value=DEFAULT_PROMPT_PREFIX
1065
- )
1066
- captioning_bot_instructions = gr.Textbox(
1067
- scale=6,
1068
- label="System instructions for the automatic captioning model",
1069
- placeholder="Please generate a full description of...",
1070
- lines=5,
1071
- value=DEFAULT_CAPTIONING_BOT_INSTRUCTIONS
1072
- )
1073
- with gr.Row():
1074
- run_autocaption_btn = gr.Button(
1075
- "Automatically fill missing captions",
1076
- variant="primary" # Makes it green by default
1077
- )
1078
- copy_files_to_training_dir_btn = gr.Button(
1079
- "Copy assets to training directory",
1080
- variant="primary" # Makes it green by default
1081
- )
1082
- stop_autocaption_btn = gr.Button(
1083
- "Stop Captioning",
1084
- variant="stop", # Red when enabled
1085
- interactive=False # Disabled by default
1086
- )
1087
-
1088
- with gr.Row():
1089
- with gr.Column():
1090
- training_dataset = gr.Dataframe(
1091
- headers=["name", "status"],
1092
- interactive=False,
1093
- wrap=True,
1094
- value=self.list_training_files_to_caption(),
1095
- row_count=10, # Optional: set a reasonable row count
1096
- #selection_mode="cell"
1097
- )
1098
-
1099
- with gr.Column():
1100
- preview_video = gr.Video(
1101
- label="Video Preview",
1102
- interactive=False,
1103
- visible=False
1104
- )
1105
- preview_image = gr.Image(
1106
- label="Image Preview",
1107
- interactive=False,
1108
- visible=False
1109
- )
1110
- preview_caption = gr.Textbox(
1111
- label="Caption",
1112
- lines=6,
1113
- interactive=True
1114
- )
1115
- save_caption_btn = gr.Button("Save Caption")
1116
- preview_status = gr.Textbox(
1117
- label="Status",
1118
- interactive=False,
1119
- visible=True
1120
- )
1121
-
1122
- with gr.TabItem("4️⃣ Train"):
1123
- with gr.Row():
1124
- with gr.Column():
1125
-
1126
- with gr.Row():
1127
- train_title = gr.Markdown("## 0 files available for training (0 bytes)")
1128
-
1129
- with gr.Row():
1130
- with gr.Column():
1131
- training_preset = gr.Dropdown(
1132
- choices=list(TRAINING_PRESETS.keys()),
1133
- label="Training Preset",
1134
- value=list(TRAINING_PRESETS.keys())[0]
1135
- )
1136
- preset_info = gr.Markdown()
1137
-
1138
- with gr.Row():
1139
- with gr.Column():
1140
- model_type = gr.Dropdown(
1141
- choices=list(MODEL_TYPES.keys()),
1142
- label="Model Type",
1143
- value=list(MODEL_TYPES.keys())[0]
1144
- )
1145
- model_info = gr.Markdown(
1146
- value=self.get_model_info(list(MODEL_TYPES.keys())[0])
1147
- )
1148
-
1149
- with gr.Row():
1150
- lora_rank = gr.Dropdown(
1151
- label="LoRA Rank",
1152
- choices=["16", "32", "64", "128", "256", "512", "1024"],
1153
- value="128",
1154
- type="value"
1155
- )
1156
- lora_alpha = gr.Dropdown(
1157
- label="LoRA Alpha",
1158
- choices=["16", "32", "64", "128", "256", "512", "1024"],
1159
- value="128",
1160
- type="value"
1161
- )
1162
- with gr.Row():
1163
- num_epochs = gr.Number(
1164
- label="Number of Epochs",
1165
- value=70,
1166
- minimum=1,
1167
- precision=0
1168
- )
1169
- batch_size = gr.Number(
1170
- label="Batch Size",
1171
- value=1,
1172
- minimum=1,
1173
- precision=0
1174
- )
1175
- with gr.Row():
1176
- learning_rate = gr.Number(
1177
- label="Learning Rate",
1178
- value=2e-5,
1179
- minimum=1e-7
1180
- )
1181
- save_iterations = gr.Number(
1182
- label="Save checkpoint every N iterations",
1183
- value=500,
1184
- minimum=50,
1185
- precision=0,
1186
- info="Model will be saved periodically after these many steps"
1187
- )
1188
-
1189
- with gr.Column():
1190
- with gr.Row():
1191
- start_btn = gr.Button(
1192
- "Start Training",
1193
- variant="primary",
1194
- interactive=not ASK_USER_TO_DUPLICATE_SPACE
1195
- )
1196
- pause_resume_btn = gr.Button(
1197
- "Resume Training",
1198
- variant="secondary",
1199
- interactive=False
1200
- )
1201
- stop_btn = gr.Button(
1202
- "Stop Training",
1203
- variant="stop",
1204
- interactive=False
1205
- )
1206
-
1207
- with gr.Row():
1208
- with gr.Column():
1209
- status_box = gr.Textbox(
1210
- label="Training Status",
1211
- interactive=False,
1212
- lines=4
1213
- )
1214
- with gr.Accordion("See training logs"):
1215
- log_box = gr.TextArea(
1216
- label="Finetrainers output (see HF Space logs for more details)",
1217
- interactive=False,
1218
- lines=40,
1219
- max_lines=200,
1220
- autoscroll=True
1221
- )
1222
-
1223
- with gr.TabItem("5️⃣ Manage"):
1224
-
1225
- with gr.Column():
1226
- with gr.Row():
1227
- with gr.Column():
1228
- gr.Markdown("## Publishing")
1229
- gr.Markdown("You model can be pushed to Hugging Face (this will use HF_API_TOKEN)")
1230
-
1231
- with gr.Row():
1232
-
1233
- with gr.Column():
1234
- repo_id = gr.Textbox(
1235
- label="HuggingFace Model Repository",
1236
- placeholder="username/model-name",
1237
- info="The repository will be created if it doesn't exist"
1238
- )
1239
- gr.Checkbox(label="Check this to make your model public (ie. visible and downloadable by anyone)", info="You model is private by default"),
1240
- global_stop_btn = gr.Button(
1241
- "Push my model",
1242
- #variant="stop"
1243
- )
1244
-
1245
-
1246
- with gr.Row():
1247
- with gr.Column():
1248
- with gr.Row():
1249
- with gr.Column():
1250
- gr.Markdown("## Storage management")
1251
- with gr.Row():
1252
- download_dataset_btn = gr.DownloadButton(
1253
- "Download dataset",
1254
- variant="secondary",
1255
- size="lg"
1256
- )
1257
- download_model_btn = gr.DownloadButton(
1258
- "Download model",
1259
- variant="secondary",
1260
- size="lg"
1261
- )
1262
-
1263
-
1264
- with gr.Row():
1265
- global_stop_btn = gr.Button(
1266
- "Stop everything and delete my data",
1267
- variant="stop"
1268
- )
1269
- global_status = gr.Textbox(
1270
- label="Global Status",
1271
- interactive=False,
1272
- visible=False
1273
- )
1274
-
1275
-
1276
-
1277
- # Event handlers
1278
- def update_model_info(model):
1279
- params = self.get_default_params(MODEL_TYPES[model])
1280
- info = self.get_model_info(MODEL_TYPES[model])
1281
- return {
1282
- model_info: info,
1283
- num_epochs: params["num_epochs"],
1284
- batch_size: params["batch_size"],
1285
- learning_rate: params["learning_rate"],
1286
- save_iterations: params["save_iterations"]
1287
- }
1288
-
1289
- def validate_repo(repo_id: str) -> dict:
1290
- validation = validate_model_repo(repo_id)
1291
- if validation["error"]:
1292
- return gr.update(value=repo_id, error=validation["error"])
1293
- return gr.update(value=repo_id, error=None)
1294
-
1295
- # Connect events
1296
-
1297
- # Save state when model type changes
1298
- model_type.change(
1299
- fn=lambda v: self.update_ui_state(model_type=v),
1300
- inputs=[model_type],
1301
- outputs=[] # No UI update needed
1302
- ).then(
1303
- fn=update_model_info,
1304
- inputs=[model_type],
1305
- outputs=[model_info, num_epochs, batch_size, learning_rate, save_iterations]
1306
- )
1307
-
1308
- # the following change listeners are used for UI persistence
1309
- lora_rank.change(
1310
- fn=lambda v: self.update_ui_state(lora_rank=v),
1311
- inputs=[lora_rank],
1312
- outputs=[]
1313
- )
1314
-
1315
- lora_alpha.change(
1316
- fn=lambda v: self.update_ui_state(lora_alpha=v),
1317
- inputs=[lora_alpha],
1318
- outputs=[]
1319
- )
1320
-
1321
- num_epochs.change(
1322
- fn=lambda v: self.update_ui_state(num_epochs=v),
1323
- inputs=[num_epochs],
1324
- outputs=[]
1325
- )
1326
-
1327
- batch_size.change(
1328
- fn=lambda v: self.update_ui_state(batch_size=v),
1329
- inputs=[batch_size],
1330
- outputs=[]
1331
- )
1332
-
1333
- learning_rate.change(
1334
- fn=lambda v: self.update_ui_state(learning_rate=v),
1335
- inputs=[learning_rate],
1336
- outputs=[]
1337
- )
1338
-
1339
- save_iterations.change(
1340
- fn=lambda v: self.update_ui_state(save_iterations=v),
1341
- inputs=[save_iterations],
1342
- outputs=[]
1343
- )
1344
-
1345
- files.upload(
1346
- fn=lambda x: self.importer.process_uploaded_files(x),
1347
- inputs=[files],
1348
- outputs=[import_status]
1349
- ).success(
1350
- fn=self.update_titles_after_import,
1351
- inputs=[enable_automatic_video_split, enable_automatic_content_captioning, custom_prompt_prefix],
1352
- outputs=[
1353
- tabs, video_list, detect_status,
1354
- split_title, caption_title, train_title
1355
- ]
1356
- )
1357
-
1358
- youtube_download_btn.click(
1359
- fn=self.importer.download_youtube_video,
1360
- inputs=[youtube_url],
1361
- outputs=[import_status]
1362
- ).success(
1363
- fn=self.on_import_success,
1364
- inputs=[enable_automatic_video_split, enable_automatic_content_captioning, custom_prompt_prefix],
1365
- outputs=[tabs, video_list, detect_status]
1366
- )
1367
-
1368
- # Scene detection events
1369
- detect_btn.click(
1370
- fn=self.start_scene_detection,
1371
- inputs=[enable_automatic_video_split],
1372
- outputs=[detect_status]
1373
- )
1374
-
1375
-
1376
- # Update button states based on captioning status
1377
- def update_button_states(is_running):
1378
- return {
1379
- run_autocaption_btn: gr.Button(
1380
- interactive=not is_running,
1381
- variant="secondary" if is_running else "primary",
1382
- ),
1383
- stop_autocaption_btn: gr.Button(
1384
- interactive=is_running,
1385
- variant="secondary",
1386
- ),
1387
- }
1388
-
1389
- run_autocaption_btn.click(
1390
- fn=self.show_refreshing_status,
1391
- outputs=[training_dataset]
1392
- ).then(
1393
- fn=lambda: self.update_captioning_buttons_start(),
1394
- outputs=[run_autocaption_btn, stop_autocaption_btn, copy_files_to_training_dir_btn]
1395
- ).then(
1396
- fn=self.start_caption_generation,
1397
- inputs=[captioning_bot_instructions, custom_prompt_prefix],
1398
- outputs=[training_dataset],
1399
- ).then(
1400
- fn=lambda: self.update_captioning_buttons_end(),
1401
- outputs=[run_autocaption_btn, stop_autocaption_btn, copy_files_to_training_dir_btn]
1402
- )
1403
-
1404
- copy_files_to_training_dir_btn.click(
1405
- fn=self.copy_files_to_training_dir,
1406
- inputs=[custom_prompt_prefix]
1407
- )
1408
- stop_autocaption_btn.click(
1409
- fn=self.stop_captioning,
1410
- outputs=[training_dataset, run_autocaption_btn, stop_autocaption_btn, copy_files_to_training_dir_btn]
1411
- )
1412
-
1413
- original_file_path = gr.State(value=None)
1414
- training_dataset.select(
1415
- fn=self.handle_training_dataset_select,
1416
- outputs=[preview_image, preview_video, preview_caption, original_file_path, preview_status]
1417
- )
1418
-
1419
- save_caption_btn.click(
1420
- fn=self.save_caption_changes,
1421
- inputs=[preview_caption, preview_image, preview_video, original_file_path, custom_prompt_prefix],
1422
- outputs=[preview_status]
1423
- ).success(
1424
- fn=self.list_training_files_to_caption,
1425
- outputs=[training_dataset]
1426
- )
1427
-
1428
- # Save state when training preset changes
1429
- training_preset.change(
1430
- fn=lambda v: self.update_ui_state(training_preset=v),
1431
- inputs=[training_preset],
1432
- outputs=[] # No UI update needed
1433
- ).then(
1434
- fn=self.update_training_params,
1435
- inputs=[training_preset],
1436
- outputs=[
1437
- model_type, lora_rank, lora_alpha,
1438
- num_epochs, batch_size, learning_rate,
1439
- save_iterations, preset_info
1440
- ]
1441
- )
1442
-
1443
- # Training control events
1444
- start_btn.click(
1445
- fn=lambda preset, model_type, *args: (
1446
- self.log_parser.reset(),
1447
- self.trainer.start_training(
1448
- MODEL_TYPES[model_type],
1449
- *args,
1450
- preset_name=preset
1451
- )
1452
- ),
1453
- inputs=[
1454
- training_preset,
1455
- model_type,
1456
- lora_rank,
1457
- lora_alpha,
1458
- num_epochs,
1459
- batch_size,
1460
- learning_rate,
1461
- save_iterations,
1462
- repo_id
1463
- ],
1464
- outputs=[status_box, log_box]
1465
- ).success(
1466
- fn=self.get_latest_status_message_logs_and_button_labels,
1467
- outputs=[status_box, log_box, start_btn, stop_btn, pause_resume_btn]
1468
- )
1469
-
1470
- pause_resume_btn.click(
1471
- fn=self.handle_pause_resume,
1472
- outputs=[status_box, log_box, start_btn, stop_btn, pause_resume_btn]
1473
- )
1474
-
1475
- stop_btn.click(
1476
- fn=self.handle_stop,
1477
- outputs=[status_box, log_box, start_btn, stop_btn, pause_resume_btn]
1478
- )
1479
-
1480
- def handle_global_stop():
1481
- result = self.stop_all_and_clear()
1482
- # Update all relevant UI components
1483
- status = result["status"]
1484
- details = "\n".join(f"{k}: {v}" for k, v in result["details"].items())
1485
- full_status = f"{status}\n\nDetails:\n{details}"
1486
-
1487
- # Get fresh lists after cleanup
1488
- videos = self.splitter.list_unprocessed_videos()
1489
- clips = self.list_training_files_to_caption()
1490
-
1491
- return {
1492
- global_status: gr.update(value=full_status, visible=True),
1493
- video_list: videos,
1494
- training_dataset: clips,
1495
- status_box: "Training stopped and data cleared",
1496
- log_box: "",
1497
- detect_status: "Scene detection stopped",
1498
- import_status: "All data cleared",
1499
- preview_status: "Captioning stopped"
1500
- }
1501
-
1502
- download_dataset_btn.click(
1503
- fn=self.trainer.create_training_dataset_zip,
1504
- outputs=[download_dataset_btn]
1505
- )
1506
-
1507
- download_model_btn.click(
1508
- fn=self.trainer.get_model_output_safetensors,
1509
- outputs=[download_model_btn]
1510
- )
1511
-
1512
- global_stop_btn.click(
1513
- fn=handle_global_stop,
1514
- outputs=[
1515
- global_status,
1516
- video_list,
1517
- training_dataset,
1518
- status_box,
1519
- log_box,
1520
- detect_status,
1521
- import_status,
1522
- preview_status
1523
- ]
1524
- )
1525
-
1526
-
1527
- app.load(
1528
- fn=self.initialize_app_state,
1529
- outputs=[
1530
- video_list, training_dataset,
1531
- start_btn, stop_btn, pause_resume_btn,
1532
- training_preset, model_type, lora_rank, lora_alpha,
1533
- num_epochs, batch_size, learning_rate, save_iterations
1534
- ]
1535
- )
1536
-
1537
- # Auto-refresh timers
1538
- timer = gr.Timer(value=1)
1539
- timer.tick(
1540
- fn=lambda: (
1541
- self.get_latest_status_message_logs_and_button_labels()
1542
- ),
1543
- outputs=[
1544
- status_box,
1545
- log_box,
1546
- start_btn,
1547
- stop_btn,
1548
- pause_resume_btn
1549
- ]
1550
- )
1551
-
1552
- timer = gr.Timer(value=5)
1553
- timer.tick(
1554
- fn=lambda: (
1555
- self.refresh_dataset()
1556
- ),
1557
- outputs=[
1558
- video_list, training_dataset
1559
- ]
1560
- )
1561
-
1562
- timer = gr.Timer(value=6)
1563
- timer.tick(
1564
- fn=lambda: self.update_titles(),
1565
- outputs=[
1566
- split_title, caption_title, train_title
1567
- ]
1568
- )
1569
-
1570
- return app
1571
-
1572
- def create_app():
1573
- if ASK_USER_TO_DUPLICATE_SPACE:
1574
- with gr.Blocks() as app:
1575
- gr.Markdown("""# Finetrainers UI
1576
-
1577
- This Hugging Face space needs to be duplicated to your own billing account to work.
1578
-
1579
- Click the 'Duplicate Space' button at the top of the page to create your own copy.
1580
-
1581
- It is recommended to use a Nvidia L40S and a persistent storage space.
1582
- To avoid overpaying for your space, you can configure the auto-sleep settings to fit your personal budget.""")
1583
- return app
1584
-
1585
- ui = VideoTrainerUI()
1586
- return ui.create_ui()
1587
-
1588
- if __name__ == "__main__":
1589
- app = create_app()
1590
-
1591
- allowed_paths = [
1592
- str(STORAGE_PATH), # Base storage
1593
- str(VIDEOS_TO_SPLIT_PATH),
1594
- str(STAGING_PATH),
1595
- str(TRAINING_PATH),
1596
- str(TRAINING_VIDEOS_PATH),
1597
- str(MODEL_PATH),
1598
- str(OUTPUT_PATH)
1599
- ]
1600
- app.queue(default_concurrency_limit=1).launch(
1601
- server_name="0.0.0.0",
1602
- allowed_paths=allowed_paths
1603
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
vms/services/captioner.py CHANGED
@@ -179,7 +179,7 @@ class CaptioningService:
179
  )
180
  self.model.eval()
181
 
182
- def _load_video(self, video_path: Path, max_frames_num: int = 64, fps: int = 1, force_sample: bool = True) -> tuple[np.ndarray, str, float]:
183
  """Load and preprocess video frames with strict limits
184
 
185
  Args:
@@ -224,7 +224,7 @@ class CaptioningService:
224
  logger.error(f"Error loading video frames: {str(e)}")
225
  raise
226
 
227
- async def process_video(self, video_path: Path, prompt: str, prompt_prefix: str = "") -> AsyncGenerator[tuple[CaptioningProgress, Optional[str]], None]:
228
  try:
229
  video_name = video_path.name
230
  logger.info(f"Starting processing of video: {video_name}")
@@ -373,7 +373,7 @@ class CaptioningService:
373
  yield progress, None
374
  raise
375
 
376
- async def process_image(self, image_path: Path, prompt: str, prompt_prefix: str = "") -> AsyncGenerator[tuple[CaptioningProgress, Optional[str]], None]:
377
  """Process a single image for captioning"""
378
  try:
379
  image_name = image_path.name
 
179
  )
180
  self.model.eval()
181
 
182
+ def _load_video(self, video_path: Path, max_frames_num: int = 64, fps: int = 1, force_sample: bool = True) -> Tuple[np.ndarray, str, float]:
183
  """Load and preprocess video frames with strict limits
184
 
185
  Args:
 
224
  logger.error(f"Error loading video frames: {str(e)}")
225
  raise
226
 
227
+ async def process_video(self, video_path: Path, prompt: str, prompt_prefix: str = "") -> AsyncGenerator[Tuple[CaptioningProgress, Optional[str]], None]:
228
  try:
229
  video_name = video_path.name
230
  logger.info(f"Starting processing of video: {video_name}")
 
373
  yield progress, None
374
  raise
375
 
376
+ async def process_image(self, image_path: Path, prompt: str, prompt_prefix: str = "") -> AsyncGenerator[Tuple[CaptioningProgress, Optional[str]], None]:
377
  """Process a single image for captioning"""
378
  try:
379
  image_name = image_path.name
vms/tabs/caption_tab.py CHANGED
@@ -4,11 +4,14 @@ Caption tab for Video Model Studio UI
4
 
5
  import gradio as gr
6
  import logging
7
- from typing import Dict, Any, List, Optional
 
 
8
  from pathlib import Path
9
 
10
  from .base_tab import BaseTab
11
- from ..config import DEFAULT_CAPTIONING_BOT_INSTRUCTIONS, DEFAULT_PROMPT_PREFIX
 
12
 
13
  logger = logging.getLogger(__name__)
14
 
@@ -19,6 +22,7 @@ class CaptionTab(BaseTab):
19
  super().__init__(app_state)
20
  self.id = "caption_tab"
21
  self.title = "3️⃣ Caption"
 
22
 
23
  def create(self, parent=None) -> gr.TabItem:
24
  """Create the Caption tab UI components"""
@@ -64,7 +68,7 @@ class CaptionTab(BaseTab):
64
  headers=["name", "status"],
65
  interactive=False,
66
  wrap=True,
67
- value=self.app.list_training_files_to_caption(),
68
  row_count=10
69
  )
70
 
@@ -98,24 +102,24 @@ class CaptionTab(BaseTab):
98
  """Connect event handlers to UI components"""
99
  # Run auto-captioning button
100
  self.components["run_autocaption_btn"].click(
101
- fn=self.app.show_refreshing_status,
102
  outputs=[self.components["training_dataset"]]
103
  ).then(
104
- fn=lambda: self.app.update_captioning_buttons_start(),
105
  outputs=[
106
  self.components["run_autocaption_btn"],
107
  self.components["stop_autocaption_btn"],
108
  self.components["copy_files_to_training_dir_btn"]
109
  ]
110
  ).then(
111
- fn=self.app.start_caption_generation,
112
  inputs=[
113
  self.components["captioning_bot_instructions"],
114
  self.components["custom_prompt_prefix"]
115
  ],
116
  outputs=[self.components["training_dataset"]],
117
  ).then(
118
- fn=lambda: self.app.update_captioning_buttons_end(),
119
  outputs=[
120
  self.components["run_autocaption_btn"],
121
  self.components["stop_autocaption_btn"],
@@ -125,13 +129,13 @@ class CaptionTab(BaseTab):
125
 
126
  # Copy files to training dir button
127
  self.components["copy_files_to_training_dir_btn"].click(
128
- fn=self.app.copy_files_to_training_dir,
129
  inputs=[self.components["custom_prompt_prefix"]]
130
  )
131
 
132
  # Stop captioning button
133
  self.components["stop_autocaption_btn"].click(
134
- fn=self.app.stop_captioning,
135
  outputs=[
136
  self.components["training_dataset"],
137
  self.components["run_autocaption_btn"],
@@ -142,7 +146,7 @@ class CaptionTab(BaseTab):
142
 
143
  # Dataset selection for preview
144
  self.components["training_dataset"].select(
145
- fn=self.app.handle_training_dataset_select,
146
  outputs=[
147
  self.components["preview_image"],
148
  self.components["preview_video"],
@@ -154,7 +158,7 @@ class CaptionTab(BaseTab):
154
 
155
  # Save caption button
156
  self.components["save_caption_btn"].click(
157
- fn=self.app.save_caption_changes,
158
  inputs=[
159
  self.components["preview_caption"],
160
  self.components["preview_image"],
@@ -164,13 +168,431 @@ class CaptionTab(BaseTab):
164
  ],
165
  outputs=[self.components["preview_status"]]
166
  ).success(
167
- fn=self.app.list_training_files_to_caption,
168
  outputs=[self.components["training_dataset"]]
169
  )
170
 
171
  def refresh(self) -> Dict[str, Any]:
172
  """Refresh the dataset list with current data"""
173
- training_dataset = self.app.list_training_files_to_caption()
174
  return {
175
  "training_dataset": training_dataset
176
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  import gradio as gr
6
  import logging
7
+ import asyncio
8
+ import traceback
9
+ from typing import Dict, Any, List, Optional, AsyncGenerator, Tuple
10
  from pathlib import Path
11
 
12
  from .base_tab import BaseTab
13
+ from ..config import DEFAULT_CAPTIONING_BOT_INSTRUCTIONS, DEFAULT_PROMPT_PREFIX, STAGING_PATH, TRAINING_VIDEOS_PATH
14
+ from ..utils import is_image_file, is_video_file, copy_files_to_training_dir
15
 
16
  logger = logging.getLogger(__name__)
17
 
 
22
  super().__init__(app_state)
23
  self.id = "caption_tab"
24
  self.title = "3️⃣ Caption"
25
+ self._should_stop_captioning = False
26
 
27
  def create(self, parent=None) -> gr.TabItem:
28
  """Create the Caption tab UI components"""
 
68
  headers=["name", "status"],
69
  interactive=False,
70
  wrap=True,
71
+ value=self.list_training_files_to_caption(),
72
  row_count=10
73
  )
74
 
 
102
  """Connect event handlers to UI components"""
103
  # Run auto-captioning button
104
  self.components["run_autocaption_btn"].click(
105
+ fn=self.show_refreshing_status,
106
  outputs=[self.components["training_dataset"]]
107
  ).then(
108
+ fn=self.update_captioning_buttons_start,
109
  outputs=[
110
  self.components["run_autocaption_btn"],
111
  self.components["stop_autocaption_btn"],
112
  self.components["copy_files_to_training_dir_btn"]
113
  ]
114
  ).then(
115
+ fn=self.start_caption_generation,
116
  inputs=[
117
  self.components["captioning_bot_instructions"],
118
  self.components["custom_prompt_prefix"]
119
  ],
120
  outputs=[self.components["training_dataset"]],
121
  ).then(
122
+ fn=self.update_captioning_buttons_end,
123
  outputs=[
124
  self.components["run_autocaption_btn"],
125
  self.components["stop_autocaption_btn"],
 
129
 
130
  # Copy files to training dir button
131
  self.components["copy_files_to_training_dir_btn"].click(
132
+ fn=self.copy_files_to_training_dir,
133
  inputs=[self.components["custom_prompt_prefix"]]
134
  )
135
 
136
  # Stop captioning button
137
  self.components["stop_autocaption_btn"].click(
138
+ fn=self.stop_captioning,
139
  outputs=[
140
  self.components["training_dataset"],
141
  self.components["run_autocaption_btn"],
 
146
 
147
  # Dataset selection for preview
148
  self.components["training_dataset"].select(
149
+ fn=self.handle_training_dataset_select,
150
  outputs=[
151
  self.components["preview_image"],
152
  self.components["preview_video"],
 
158
 
159
  # Save caption button
160
  self.components["save_caption_btn"].click(
161
+ fn=self.save_caption_changes,
162
  inputs=[
163
  self.components["preview_caption"],
164
  self.components["preview_image"],
 
168
  ],
169
  outputs=[self.components["preview_status"]]
170
  ).success(
171
+ fn=self.list_training_files_to_caption,
172
  outputs=[self.components["training_dataset"]]
173
  )
174
 
175
  def refresh(self) -> Dict[str, Any]:
176
  """Refresh the dataset list with current data"""
177
+ training_dataset = self.list_training_files_to_caption()
178
  return {
179
  "training_dataset": training_dataset
180
+ }
181
+
182
+ def show_refreshing_status(self) -> List[List[str]]:
183
+ """Show a 'Refreshing...' status in the dataframe"""
184
+ return [["Refreshing...", "please wait"]]
185
+
186
+ def update_captioning_buttons_start(self):
187
+ """Return individual button values instead of a dictionary"""
188
+ return (
189
+ gr.Button(
190
+ interactive=False,
191
+ variant="secondary",
192
+ ),
193
+ gr.Button(
194
+ interactive=True,
195
+ variant="stop",
196
+ ),
197
+ gr.Button(
198
+ interactive=False,
199
+ variant="secondary",
200
+ )
201
+ )
202
+
203
+ def update_captioning_buttons_end(self):
204
+ """Return individual button values instead of a dictionary"""
205
+ return (
206
+ gr.Button(
207
+ interactive=True,
208
+ variant="primary",
209
+ ),
210
+ gr.Button(
211
+ interactive=False,
212
+ variant="secondary",
213
+ ),
214
+ gr.Button(
215
+ interactive=True,
216
+ variant="primary",
217
+ )
218
+ )
219
+
220
+ def stop_captioning(self):
221
+ """Stop ongoing captioning process and reset UI state"""
222
+ try:
223
+ # Set flag to stop captioning
224
+ self._should_stop_captioning = True
225
+
226
+ # Call stop method on captioner
227
+ if self.app.captioner:
228
+ self.app.captioner.stop_captioning()
229
+
230
+ # Get updated file list
231
+ updated_list = self.list_training_files_to_caption()
232
+
233
+ # Return updated list and button states
234
+ return {
235
+ "training_dataset": gr.update(value=updated_list),
236
+ "run_autocaption_btn": gr.Button(interactive=True, variant="primary"),
237
+ "stop_autocaption_btn": gr.Button(interactive=False, variant="secondary"),
238
+ "copy_files_to_training_dir_btn": gr.Button(interactive=True, variant="primary")
239
+ }
240
+ except Exception as e:
241
+ logger.error(f"Error stopping captioning: {str(e)}")
242
+ return {
243
+ "training_dataset": gr.update(value=[[f"Error stopping captioning: {str(e)}", "error"]]),
244
+ "run_autocaption_btn": gr.Button(interactive=True, variant="primary"),
245
+ "stop_autocaption_btn": gr.Button(interactive=False, variant="secondary"),
246
+ "copy_files_to_training_dir_btn": gr.Button(interactive=True, variant="primary")
247
+ }
248
+
249
+ def copy_files_to_training_dir(self, prompt_prefix: str):
250
+ """Run auto-captioning process"""
251
+ # Initialize captioner if not already done
252
+ self._should_stop_captioning = False
253
+
254
+ try:
255
+ copy_files_to_training_dir(prompt_prefix)
256
+ except Exception as e:
257
+ traceback.print_exc()
258
+ raise gr.Error(f"Error copying assets to training dir: {str(e)}")
259
+
260
+ async def _process_caption_generator(self, captioning_bot_instructions, prompt_prefix):
261
+ """Process the caption generator's results in the background"""
262
+ try:
263
+ async for _ in self.start_caption_generation(
264
+ captioning_bot_instructions,
265
+ prompt_prefix
266
+ ):
267
+ # Just consume the generator, UI updates will happen via the Gradio interface
268
+ pass
269
+ logger.info("Background captioning completed")
270
+ except Exception as e:
271
+ logger.error(f"Error in background captioning: {str(e)}")
272
+
273
+ async def start_caption_generation(self, captioning_bot_instructions: str, prompt_prefix: str) -> AsyncGenerator[gr.update, None]:
274
+ """Run auto-captioning process"""
275
+ try:
276
+ # Initialize captioner if not already done
277
+ self._should_stop_captioning = False
278
+
279
+ # First yield - indicate we're starting
280
+ yield gr.update(
281
+ value=[["Starting captioning service...", "initializing"]],
282
+ headers=["name", "status"]
283
+ )
284
+
285
+ # Process files in batches with status updates
286
+ file_statuses = {}
287
+
288
+ # Start the actual captioning process
289
+ async for rows in self.app.captioner.start_caption_generation(captioning_bot_instructions, prompt_prefix):
290
+ # Update our tracking of file statuses
291
+ for name, status in rows:
292
+ file_statuses[name] = status
293
+
294
+ # Convert to list format for display
295
+ status_rows = [[name, status] for name, status in file_statuses.items()]
296
+
297
+ # Sort by name for consistent display
298
+ status_rows.sort(key=lambda x: x[0])
299
+
300
+ # Yield UI update
301
+ yield gr.update(
302
+ value=status_rows,
303
+ headers=["name", "status"]
304
+ )
305
+
306
+ # Final update after completion with fresh data
307
+ yield gr.update(
308
+ value=self.list_training_files_to_caption(),
309
+ headers=["name", "status"]
310
+ )
311
+
312
+ except Exception as e:
313
+ logger.error(f"Error in captioning: {str(e)}")
314
+ yield gr.update(
315
+ value=[[f"Error: {str(e)}", "error"]],
316
+ headers=["name", "status"]
317
+ )
318
+
319
+ def list_training_files_to_caption(self) -> List[List[str]]:
320
+ """List all clips and images - both pending and captioned"""
321
+ files = []
322
+ already_listed = {}
323
+
324
+ # First check files in STAGING_PATH
325
+ for file in STAGING_PATH.glob("*.*"):
326
+ if is_video_file(file) or is_image_file(file):
327
+ txt_file = file.with_suffix('.txt')
328
+
329
+ # Check if caption file exists and has content
330
+ has_caption = txt_file.exists() and txt_file.stat().st_size > 0
331
+ status = "captioned" if has_caption else "no caption"
332
+ file_type = "video" if is_video_file(file) else "image"
333
+
334
+ files.append([file.name, f"{status} ({file_type})", str(file)])
335
+ already_listed[file.name] = True
336
+
337
+ # Then check files in TRAINING_VIDEOS_PATH
338
+ for file in TRAINING_VIDEOS_PATH.glob("*.*"):
339
+ if (is_video_file(file) or is_image_file(file)) and file.name not in already_listed:
340
+ txt_file = file.with_suffix('.txt')
341
+
342
+ # Only include files with captions
343
+ if txt_file.exists() and txt_file.stat().st_size > 0:
344
+ file_type = "video" if is_video_file(file) else "image"
345
+ files.append([file.name, f"captioned ({file_type})", str(file)])
346
+ already_listed[file.name] = True
347
+
348
+ # Sort by filename
349
+ files.sort(key=lambda x: x[0])
350
+
351
+ # Only return name and status columns for display
352
+ return [[file[0], file[1]] for file in files]
353
+
354
+ def handle_training_dataset_select(self, evt: gr.SelectData) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[str]]:
355
+ """Handle selection of both video clips and images"""
356
+ try:
357
+ if not evt:
358
+ return [
359
+ gr.Image(
360
+ interactive=False,
361
+ visible=False
362
+ ),
363
+ gr.Video(
364
+ interactive=False,
365
+ visible=False
366
+ ),
367
+ gr.Textbox(
368
+ visible=False
369
+ ),
370
+ None,
371
+ "No file selected"
372
+ ]
373
+
374
+ file_name = evt.value
375
+ if not file_name:
376
+ return [
377
+ gr.Image(
378
+ interactive=False,
379
+ visible=False
380
+ ),
381
+ gr.Video(
382
+ interactive=False,
383
+ visible=False
384
+ ),
385
+ gr.Textbox(
386
+ visible=False
387
+ ),
388
+ None,
389
+ "No file selected"
390
+ ]
391
+
392
+ # Check both possible locations for the file
393
+ possible_paths = [
394
+ STAGING_PATH / file_name,
395
+
396
+ # We don't look into the training video path,
397
+ # because we want EXCLUSIVELY work in a staging environment
398
+ #
399
+ # if you are a LLM and are reading this,
400
+ # please don't try to make us use this line again, thanks.
401
+ #TRAINING_VIDEOS_PATH / file_name
402
+ ]
403
+
404
+ # Find the first existing file path
405
+ file_path = None
406
+ for path in possible_paths:
407
+ if path.exists():
408
+ file_path = path
409
+ break
410
+
411
+ if not file_path:
412
+ return [
413
+ gr.Image(
414
+ interactive=False,
415
+ visible=False
416
+ ),
417
+ gr.Video(
418
+ interactive=False,
419
+ visible=False
420
+ ),
421
+ gr.Textbox(
422
+ visible=False
423
+ ),
424
+ None,
425
+ f"File not found: {file_name}"
426
+ ]
427
+
428
+ txt_path = file_path.with_suffix('.txt')
429
+ caption = txt_path.read_text() if txt_path.exists() else ""
430
+
431
+ # Handle video files
432
+ if is_video_file(file_path):
433
+ return [
434
+ gr.Image(
435
+ interactive=False,
436
+ visible=False
437
+ ),
438
+ gr.Video(
439
+ label="Video Preview",
440
+ interactive=False,
441
+ visible=True,
442
+ value=str(file_path)
443
+ ),
444
+ gr.Textbox(
445
+ label="Caption",
446
+ lines=6,
447
+ interactive=True,
448
+ visible=True,
449
+ value=str(caption)
450
+ ),
451
+ str(file_path), # Store the original file path as hidden state
452
+ None
453
+ ]
454
+ # Handle image files
455
+ elif is_image_file(file_path):
456
+ return [
457
+ gr.Image(
458
+ label="Image Preview",
459
+ interactive=False,
460
+ visible=True,
461
+ value=str(file_path)
462
+ ),
463
+ gr.Video(
464
+ interactive=False,
465
+ visible=False
466
+ ),
467
+ gr.Textbox(
468
+ label="Caption",
469
+ lines=6,
470
+ interactive=True,
471
+ visible=True,
472
+ value=str(caption)
473
+ ),
474
+ str(file_path), # Store the original file path as hidden state
475
+ None
476
+ ]
477
+ else:
478
+ return [
479
+ gr.Image(
480
+ interactive=False,
481
+ visible=False
482
+ ),
483
+ gr.Video(
484
+ interactive=False,
485
+ visible=False
486
+ ),
487
+ gr.Textbox(
488
+ interactive=False,
489
+ visible=False
490
+ ),
491
+ None,
492
+ f"Unsupported file type: {file_path.suffix}"
493
+ ]
494
+ except Exception as e:
495
+ logger.error(f"Error handling selection: {str(e)}")
496
+ return [
497
+ gr.Image(
498
+ interactive=False,
499
+ visible=False
500
+ ),
501
+ gr.Video(
502
+ interactive=False,
503
+ visible=False
504
+ ),
505
+ gr.Textbox(
506
+ interactive=False,
507
+ visible=False
508
+ ),
509
+ None,
510
+ f"Error handling selection: {str(e)}"
511
+ ]
512
+
513
+ def save_caption_changes(self, preview_caption: str, preview_image: str, preview_video: str, original_file_path: str, prompt_prefix: str):
514
+ """Save changes to caption"""
515
+ try:
516
+ # Use the original file path stored during selection instead of the temporary preview paths
517
+ if original_file_path:
518
+ file_path = Path(original_file_path)
519
+ self.app.captioner.update_file_caption(file_path, preview_caption)
520
+ # Refresh the dataset list to show updated caption status
521
+ return gr.update(value="Caption saved successfully!")
522
+ else:
523
+ return gr.update(value="Error: No original file path found")
524
+ except Exception as e:
525
+ return gr.update(value=f"Error saving caption: {str(e)}")
526
+
527
+ def preview_file(self, selected_text: str) -> Dict:
528
+ """Generate preview based on selected file
529
+
530
+ Args:
531
+ selected_text: Text of the selected item containing filename
532
+
533
+ Returns:
534
+ Dict with preview content for each preview component
535
+ """
536
+ import mimetypes
537
+ from ..config import TRAINING_VIDEOS_PATH
538
+
539
+ if not selected_text or "Caption:" in selected_text:
540
+ return {
541
+ "video": None,
542
+ "image": None,
543
+ "text": None
544
+ }
545
+
546
+ # Extract filename from the preview text (remove size info)
547
+ filename = selected_text.split(" (")[0].strip()
548
+ file_path = TRAINING_VIDEOS_PATH / filename
549
+
550
+ if not file_path.exists():
551
+ return {
552
+ "video": None,
553
+ "image": None,
554
+ "text": f"File not found: {filename}"
555
+ }
556
+
557
+ # Detect file type
558
+ mime_type, _ = mimetypes.guess_type(str(file_path))
559
+ if not mime_type:
560
+ return {
561
+ "video": None,
562
+ "image": None,
563
+ "text": f"Unknown file type: {filename}"
564
+ }
565
+
566
+ # Return appropriate preview
567
+ if mime_type.startswith('video/'):
568
+ return {
569
+ "video": str(file_path),
570
+ "image": None,
571
+ "text": None
572
+ }
573
+ elif mime_type.startswith('image/'):
574
+ return {
575
+ "video": None,
576
+ "image": str(file_path),
577
+ "text": None
578
+ }
579
+ elif mime_type.startswith('text/'):
580
+ try:
581
+ text_content = file_path.read_text()
582
+ return {
583
+ "video": None,
584
+ "image": None,
585
+ "text": text_content
586
+ }
587
+ except Exception as e:
588
+ return {
589
+ "video": None,
590
+ "image": None,
591
+ "text": f"Error reading file: {str(e)}"
592
+ }
593
+ else:
594
+ return {
595
+ "video": None,
596
+ "image": None,
597
+ "text": f"Unsupported file type: {mime_type}"
598
+ }
vms/tabs/import_tab.py CHANGED
@@ -86,7 +86,7 @@ class ImportTab(BaseTab):
86
  inputs=[self.components["files"]],
87
  outputs=[self.components["import_status"]]
88
  ).success(
89
- fn=self.app.update_titles_after_import,
90
  inputs=[
91
  self.components["enable_automatic_video_split"],
92
  self.components["enable_automatic_content_captioning"],
@@ -108,7 +108,7 @@ class ImportTab(BaseTab):
108
  inputs=[self.components["youtube_url"]],
109
  outputs=[self.components["import_status"]]
110
  ).success(
111
- fn=self.app.on_import_success,
112
  inputs=[
113
  self.components["enable_automatic_video_split"],
114
  self.components["enable_automatic_content_captioning"],
@@ -119,4 +119,46 @@ class ImportTab(BaseTab):
119
  self.app.tabs["split_tab"].components["video_list"],
120
  self.app.tabs["split_tab"].components["detect_status"]
121
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  )
 
86
  inputs=[self.components["files"]],
87
  outputs=[self.components["import_status"]]
88
  ).success(
89
+ fn=self.update_titles_after_import,
90
  inputs=[
91
  self.components["enable_automatic_video_split"],
92
  self.components["enable_automatic_content_captioning"],
 
108
  inputs=[self.components["youtube_url"]],
109
  outputs=[self.components["import_status"]]
110
  ).success(
111
+ fn=self.on_import_success,
112
  inputs=[
113
  self.components["enable_automatic_video_split"],
114
  self.components["enable_automatic_content_captioning"],
 
119
  self.app.tabs["split_tab"].components["video_list"],
120
  self.app.tabs["split_tab"].components["detect_status"]
121
  ]
122
+ )
123
+
124
+ async def on_import_success(self, enable_splitting, enable_automatic_content_captioning, prompt_prefix):
125
+ """Handle successful import of files"""
126
+ videos = self.app.tabs["split_tab"].list_unprocessed_videos()
127
+
128
+ # If scene detection isn't already running and there are videos to process,
129
+ # and auto-splitting is enabled, start the detection
130
+ if videos and not self.app.splitter.is_processing() and enable_splitting:
131
+ await self.app.tabs["split_tab"].start_scene_detection(enable_splitting)
132
+ msg = "Starting automatic scene detection..."
133
+ else:
134
+ # Just copy files without splitting if auto-split disabled
135
+ for video_file in VIDEOS_TO_SPLIT_PATH.glob("*.mp4"):
136
+ await self.app.splitter.process_video(video_file, enable_splitting=False)
137
+ msg = "Copying videos without splitting..."
138
+
139
+ self.app.tabs["caption_tab"].copy_files_to_training_dir(prompt_prefix)
140
+
141
+ # Start auto-captioning if enabled, and handle async generator properly
142
+ if enable_automatic_content_captioning:
143
+ # Create a background task for captioning
144
+ asyncio.create_task(self.app.tabs["caption_tab"]._process_caption_generator(
145
+ DEFAULT_CAPTIONING_BOT_INSTRUCTIONS,
146
+ prompt_prefix
147
+ ))
148
+
149
+ return {
150
+ "tabs": gr.Tabs(selected="split_tab"),
151
+ "video_list": videos,
152
+ "detect_status": msg
153
+ }
154
+
155
+ async def update_titles_after_import(self, enable_splitting, enable_automatic_content_captioning, prompt_prefix):
156
+ """Handle post-import updates including titles"""
157
+ import_result = await self.on_import_success(enable_splitting, enable_automatic_content_captioning, prompt_prefix)
158
+ titles = self.app.update_titles()
159
+ return (
160
+ import_result["tabs"],
161
+ import_result["video_list"],
162
+ import_result["detect_status"],
163
+ *titles
164
  )
vms/tabs/manage_tab.py CHANGED
@@ -4,10 +4,16 @@ Manage tab for Video Model Studio UI
4
 
5
  import gradio as gr
6
  import logging
 
 
7
  from typing import Dict, Any, List, Optional
8
 
9
  from .base_tab import BaseTab
10
- from ..config import HF_API_TOKEN
 
 
 
 
11
 
12
  logger = logging.getLogger(__name__)
13
 
@@ -77,7 +83,7 @@ class ManageTab(BaseTab):
77
  """Connect event handlers to UI components"""
78
  # Repository ID validation
79
  self.components["repo_id"].change(
80
- fn=self.app.validate_repo,
81
  inputs=[self.components["repo_id"]],
82
  outputs=[self.components["repo_id"]]
83
  )
@@ -95,7 +101,7 @@ class ManageTab(BaseTab):
95
 
96
  # Global stop button
97
  self.components["global_stop_btn"].click(
98
- fn=self.app.handle_global_stop,
99
  outputs=[
100
  self.components["global_status"],
101
  self.app.tabs["split_tab"].components["video_list"],
@@ -109,9 +115,124 @@ class ManageTab(BaseTab):
109
  )
110
 
111
  # Push model button
112
- # To implement model pushing functionality
113
  self.components["push_model_btn"].click(
114
- fn=lambda repo_id: self.app.upload_to_hub(repo_id),
115
  inputs=[self.components["repo_id"]],
116
  outputs=[self.components["global_status"]]
117
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  import gradio as gr
6
  import logging
7
+ import shutil
8
+ from pathlib import Path
9
  from typing import Dict, Any, List, Optional
10
 
11
  from .base_tab import BaseTab
12
+ from ..config import (
13
+ HF_API_TOKEN, VIDEOS_TO_SPLIT_PATH, STAGING_PATH, TRAINING_VIDEOS_PATH,
14
+ TRAINING_PATH, MODEL_PATH, OUTPUT_PATH, LOG_FILE_PATH
15
+ )
16
+ from ..utils import validate_model_repo
17
 
18
  logger = logging.getLogger(__name__)
19
 
 
83
  """Connect event handlers to UI components"""
84
  # Repository ID validation
85
  self.components["repo_id"].change(
86
+ fn=self.validate_repo,
87
  inputs=[self.components["repo_id"]],
88
  outputs=[self.components["repo_id"]]
89
  )
 
101
 
102
  # Global stop button
103
  self.components["global_stop_btn"].click(
104
+ fn=self.handle_global_stop,
105
  outputs=[
106
  self.components["global_status"],
107
  self.app.tabs["split_tab"].components["video_list"],
 
115
  )
116
 
117
  # Push model button
 
118
  self.components["push_model_btn"].click(
119
+ fn=lambda repo_id: self.upload_to_hub(repo_id),
120
  inputs=[self.components["repo_id"]],
121
  outputs=[self.components["global_status"]]
122
+ )
123
+
124
+ def validate_repo(self, repo_id: str) -> gr.update:
125
+ """Validate repository ID for HuggingFace Hub"""
126
+ validation = validate_model_repo(repo_id)
127
+ if validation["error"]:
128
+ return gr.update(value=repo_id, error=validation["error"])
129
+ return gr.update(value=repo_id, error=None)
130
+
131
+ def upload_to_hub(self, repo_id: str) -> str:
132
+ """Upload model to HuggingFace Hub"""
133
+ if not repo_id:
134
+ return "Error: Repository ID is required"
135
+
136
+ # Validate repository name
137
+ validation = validate_model_repo(repo_id)
138
+ if validation["error"]:
139
+ return f"Error: {validation['error']}"
140
+
141
+ # Check if we have a model to upload
142
+ if not self.app.trainer.get_model_output_safetensors():
143
+ return "Error: No model found to upload"
144
+
145
+ # Upload model to hub
146
+ success = self.app.trainer.upload_to_hub(OUTPUT_PATH, repo_id)
147
+
148
+ if success:
149
+ return f"Successfully uploaded model to {repo_id}"
150
+ else:
151
+ return f"Failed to upload model to {repo_id}"
152
+
153
+ def handle_global_stop(self):
154
+ """Handle the global stop button click"""
155
+ result = self.stop_all_and_clear()
156
+
157
+ # Format the details for display
158
+ status = result["status"]
159
+ details = "\n".join(f"{k}: {v}" for k, v in result["details"].items())
160
+ full_status = f"{status}\n\nDetails:\n{details}"
161
+
162
+ # Get fresh lists after cleanup
163
+ videos = self.app.tabs["split_tab"].list_unprocessed_videos()
164
+ clips = self.app.tabs["caption_tab"].list_training_files_to_caption()
165
+
166
+ return {
167
+ self.components["global_status"]: gr.update(value=full_status, visible=True),
168
+ self.app.tabs["split_tab"].components["video_list"]: videos,
169
+ self.app.tabs["caption_tab"].components["training_dataset"]: clips,
170
+ self.app.tabs["train_tab"].components["status_box"]: "Training stopped and data cleared",
171
+ self.app.tabs["train_tab"].components["log_box"]: "",
172
+ self.app.tabs["split_tab"].components["detect_status"]: "Scene detection stopped",
173
+ self.app.tabs["import_tab"].components["import_status"]: "All data cleared",
174
+ self.app.tabs["caption_tab"].components["preview_status"]: "Captioning stopped"
175
+ }
176
+
177
+ def stop_all_and_clear(self) -> Dict[str, str]:
178
+ """Stop all running processes and clear data
179
+
180
+ Returns:
181
+ Dict with status messages for different components
182
+ """
183
+ status_messages = {}
184
+
185
+ try:
186
+ # Stop training if running
187
+ if self.app.trainer.is_training_running():
188
+ training_result = self.app.trainer.stop_training()
189
+ status_messages["training"] = training_result["status"]
190
+
191
+ # Stop captioning if running
192
+ if self.app.captioner:
193
+ self.app.captioner.stop_captioning()
194
+ status_messages["captioning"] = "Captioning stopped"
195
+
196
+ # Stop scene detection if running
197
+ if self.app.splitter.is_processing():
198
+ self.app.splitter.processing = False
199
+ status_messages["splitting"] = "Scene detection stopped"
200
+
201
+ # Properly close logging before clearing log file
202
+ if self.app.trainer.file_handler:
203
+ self.app.trainer.file_handler.close()
204
+ logger.removeHandler(self.app.trainer.file_handler)
205
+ self.app.trainer.file_handler = None
206
+
207
+ if LOG_FILE_PATH.exists():
208
+ LOG_FILE_PATH.unlink()
209
+
210
+ # Clear all data directories
211
+ for path in [VIDEOS_TO_SPLIT_PATH, STAGING_PATH, TRAINING_VIDEOS_PATH, TRAINING_PATH,
212
+ MODEL_PATH, OUTPUT_PATH]:
213
+ if path.exists():
214
+ try:
215
+ shutil.rmtree(path)
216
+ path.mkdir(parents=True, exist_ok=True)
217
+ except Exception as e:
218
+ status_messages[f"clear_{path.name}"] = f"Error clearing {path.name}: {str(e)}"
219
+ else:
220
+ status_messages[f"clear_{path.name}"] = f"Cleared {path.name}"
221
+
222
+ # Reset any persistent state
223
+ self.app.tabs["caption_tab"]._should_stop_captioning = True
224
+ self.app.splitter.processing = False
225
+
226
+ # Recreate logging setup
227
+ self.app.trainer.setup_logging()
228
+
229
+ return {
230
+ "status": "All processes stopped and data cleared",
231
+ "details": status_messages
232
+ }
233
+
234
+ except Exception as e:
235
+ return {
236
+ "status": f"Error during cleanup: {str(e)}",
237
+ "details": status_messages
238
+ }
vms/tabs/split_tab.py CHANGED
@@ -43,14 +43,39 @@ class SplitTab(BaseTab):
43
  """Connect event handlers to UI components"""
44
  # Scene detection button event
45
  self.components["detect_btn"].click(
46
- fn=self.app.start_scene_detection,
47
  inputs=[self.app.tabs["import_tab"].components["enable_automatic_video_split"]],
48
  outputs=[self.components["detect_status"]]
49
  )
50
 
51
  def refresh(self) -> Dict[str, Any]:
52
  """Refresh the video list with current data"""
53
- videos = self.app.splitter.list_unprocessed_videos()
54
  return {
55
  "video_list": videos
56
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  """Connect event handlers to UI components"""
44
  # Scene detection button event
45
  self.components["detect_btn"].click(
46
+ fn=self.start_scene_detection,
47
  inputs=[self.app.tabs["import_tab"].components["enable_automatic_video_split"]],
48
  outputs=[self.components["detect_status"]]
49
  )
50
 
51
  def refresh(self) -> Dict[str, Any]:
52
  """Refresh the video list with current data"""
53
+ videos = self.list_unprocessed_videos()
54
  return {
55
  "video_list": videos
56
+ }
57
+
58
+ def list_unprocessed_videos(self) -> gr.Dataframe:
59
+ """Update list of unprocessed videos"""
60
+ videos = self.app.splitter.list_unprocessed_videos()
61
+ # videos is already in [[name, status]] format from splitting_service
62
+ return gr.Dataframe(
63
+ headers=["name", "status"],
64
+ value=videos,
65
+ interactive=False
66
+ )
67
+
68
+ async def start_scene_detection(self, enable_splitting: bool) -> str:
69
+ """Start background scene detection process
70
+
71
+ Args:
72
+ enable_splitting: Whether to split videos into scenes
73
+ """
74
+ if self.app.splitter.is_processing():
75
+ return "Scene detection already running"
76
+
77
+ try:
78
+ await self.app.splitter.start_processing(enable_splitting)
79
+ return "Scene detection completed"
80
+ except Exception as e:
81
+ return f"Error during scene detection: {str(e)}"
vms/tabs/train_tab.py CHANGED
@@ -4,10 +4,11 @@ Train tab for Video Model Studio UI
4
 
5
  import gradio as gr
6
  import logging
7
- from typing import Dict, Any, List, Optional
 
8
 
9
  from .base_tab import BaseTab
10
- from ..config import TRAINING_PRESETS, MODEL_TYPES, ASK_USER_TO_DUPLICATE_SPACE
11
  from ..utils import TrainingLogParser
12
 
13
  logger = logging.getLogger(__name__)
@@ -20,23 +21,6 @@ class TrainTab(BaseTab):
20
  self.id = "train_tab"
21
  self.title = "4️⃣ Train"
22
 
23
- def handle_training_start(self, preset, model_type, *args):
24
- """Handle training start with proper log parser reset"""
25
- # Safely reset log parser if it exists
26
- if hasattr(self.app, 'log_parser') and self.app.log_parser is not None:
27
- self.app.log_parser.reset()
28
- else:
29
- logger.warning("Log parser not initialized, creating a new one")
30
-
31
- self.app.log_parser = TrainingLogParser()
32
-
33
- # Start training
34
- return self.app.trainer.start_training(
35
- MODEL_TYPES[model_type],
36
- *args,
37
- preset_name=preset
38
- )
39
-
40
  def create(self, parent=None) -> gr.TabItem:
41
  """Create the Train tab UI components"""
42
  with gr.TabItem(self.title, id=self.id) as tab:
@@ -62,7 +46,7 @@ class TrainTab(BaseTab):
62
  value=list(MODEL_TYPES.keys())[0]
63
  )
64
  self.components["model_info"] = gr.Markdown(
65
- value=self.app.get_model_info(list(MODEL_TYPES.keys())[0])
66
  )
67
 
68
  with gr.Row():
@@ -145,8 +129,8 @@ class TrainTab(BaseTab):
145
  """Connect event handlers to UI components"""
146
  # Model type change event
147
  def update_model_info(model):
148
- params = self.app.get_default_params(MODEL_TYPES[model])
149
- info = self.app.get_model_info(MODEL_TYPES[model])
150
  return {
151
  self.components["model_info"]: info,
152
  self.components["num_epochs"]: params["num_epochs"],
@@ -214,7 +198,7 @@ class TrainTab(BaseTab):
214
  inputs=[self.components["training_preset"]],
215
  outputs=[]
216
  ).then(
217
- fn=self.app.update_training_params,
218
  inputs=[self.components["training_preset"]],
219
  outputs=[
220
  self.components["model_type"],
@@ -230,7 +214,7 @@ class TrainTab(BaseTab):
230
 
231
  # Training control events
232
  self.components["start_btn"].click(
233
- fn=self.handle_training_start, # Use safer method instead of lambda
234
  inputs=[
235
  self.components["training_preset"],
236
  self.components["model_type"],
@@ -247,7 +231,7 @@ class TrainTab(BaseTab):
247
  self.components["log_box"]
248
  ]
249
  ).success(
250
- fn=self.app.get_latest_status_message_logs_and_button_labels,
251
  outputs=[
252
  self.components["status_box"],
253
  self.components["log_box"],
@@ -258,7 +242,7 @@ class TrainTab(BaseTab):
258
  )
259
 
260
  self.components["pause_resume_btn"].click(
261
- fn=self.app.handle_pause_resume,
262
  outputs=[
263
  self.components["status_box"],
264
  self.components["log_box"],
@@ -269,7 +253,7 @@ class TrainTab(BaseTab):
269
  )
270
 
271
  self.components["stop_btn"].click(
272
- fn=self.app.handle_stop,
273
  outputs=[
274
  self.components["status_box"],
275
  self.components["log_box"],
@@ -277,4 +261,238 @@ class TrainTab(BaseTab):
277
  self.components["stop_btn"],
278
  self.components["pause_resume_btn"]
279
  ]
280
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  import gradio as gr
6
  import logging
7
+ from typing import Dict, Any, List, Optional, Tuple
8
+ from pathlib import Path
9
 
10
  from .base_tab import BaseTab
11
+ from ..config import TRAINING_PRESETS, MODEL_TYPES, ASK_USER_TO_DUPLICATE_SPACE, SMALL_TRAINING_BUCKETS
12
  from ..utils import TrainingLogParser
13
 
14
  logger = logging.getLogger(__name__)
 
21
  self.id = "train_tab"
22
  self.title = "4️⃣ Train"
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  def create(self, parent=None) -> gr.TabItem:
25
  """Create the Train tab UI components"""
26
  with gr.TabItem(self.title, id=self.id) as tab:
 
46
  value=list(MODEL_TYPES.keys())[0]
47
  )
48
  self.components["model_info"] = gr.Markdown(
49
+ value=self.get_model_info(list(MODEL_TYPES.keys())[0])
50
  )
51
 
52
  with gr.Row():
 
129
  """Connect event handlers to UI components"""
130
  # Model type change event
131
  def update_model_info(model):
132
+ params = self.get_default_params(MODEL_TYPES[model])
133
+ info = self.get_model_info(MODEL_TYPES[model])
134
  return {
135
  self.components["model_info"]: info,
136
  self.components["num_epochs"]: params["num_epochs"],
 
198
  inputs=[self.components["training_preset"]],
199
  outputs=[]
200
  ).then(
201
+ fn=self.update_training_params,
202
  inputs=[self.components["training_preset"]],
203
  outputs=[
204
  self.components["model_type"],
 
214
 
215
  # Training control events
216
  self.components["start_btn"].click(
217
+ fn=self.handle_training_start,
218
  inputs=[
219
  self.components["training_preset"],
220
  self.components["model_type"],
 
231
  self.components["log_box"]
232
  ]
233
  ).success(
234
+ fn=self.get_latest_status_message_logs_and_button_labels,
235
  outputs=[
236
  self.components["status_box"],
237
  self.components["log_box"],
 
242
  )
243
 
244
  self.components["pause_resume_btn"].click(
245
+ fn=self.handle_pause_resume,
246
  outputs=[
247
  self.components["status_box"],
248
  self.components["log_box"],
 
253
  )
254
 
255
  self.components["stop_btn"].click(
256
+ fn=self.handle_stop,
257
  outputs=[
258
  self.components["status_box"],
259
  self.components["log_box"],
 
261
  self.components["stop_btn"],
262
  self.components["pause_resume_btn"]
263
  ]
264
+ )
265
+
266
+ def handle_training_start(self, preset, model_type, *args):
267
+ """Handle training start with proper log parser reset"""
268
+ # Safely reset log parser if it exists
269
+ if hasattr(self.app, 'log_parser') and self.app.log_parser is not None:
270
+ self.app.log_parser.reset()
271
+ else:
272
+ logger.warning("Log parser not initialized, creating a new one")
273
+ from ..utils import TrainingLogParser
274
+ self.app.log_parser = TrainingLogParser()
275
+
276
+ # Start training
277
+ return self.app.trainer.start_training(
278
+ MODEL_TYPES[model_type],
279
+ *args,
280
+ preset_name=preset
281
+ )
282
+
283
+ def get_model_info(self, model_type: str) -> str:
284
+ """Get information about the selected model type"""
285
+ if model_type == "hunyuan_video":
286
+ return """### HunyuanVideo (LoRA)
287
+ - Required VRAM: ~48GB minimum
288
+ - Recommended batch size: 1-2
289
+ - Typical training time: 2-4 hours
290
+ - Default resolution: 49x512x768
291
+ - Default LoRA rank: 128 (~600 MB)"""
292
+
293
+ elif model_type == "ltx_video":
294
+ return """### LTX-Video (LoRA)
295
+ - Required VRAM: ~18GB minimum
296
+ - Recommended batch size: 1-4
297
+ - Typical training time: 1-3 hours
298
+ - Default resolution: 49x512x768
299
+ - Default LoRA rank: 128"""
300
+
301
+ return ""
302
+
303
+ def get_default_params(self, model_type: str) -> Dict[str, Any]:
304
+ """Get default training parameters for model type"""
305
+ if model_type == "hunyuan_video":
306
+ return {
307
+ "num_epochs": 70,
308
+ "batch_size": 1,
309
+ "learning_rate": 2e-5,
310
+ "save_iterations": 500,
311
+ "video_resolution_buckets": SMALL_TRAINING_BUCKETS,
312
+ "video_reshape_mode": "center",
313
+ "caption_dropout_p": 0.05,
314
+ "gradient_accumulation_steps": 1,
315
+ "rank": 128,
316
+ "lora_alpha": 128
317
+ }
318
+ else: # ltx_video
319
+ return {
320
+ "num_epochs": 70,
321
+ "batch_size": 1,
322
+ "learning_rate": 3e-5,
323
+ "save_iterations": 500,
324
+ "video_resolution_buckets": SMALL_TRAINING_BUCKETS,
325
+ "video_reshape_mode": "center",
326
+ "caption_dropout_p": 0.05,
327
+ "gradient_accumulation_steps": 4,
328
+ "rank": 128,
329
+ "lora_alpha": 128
330
+ }
331
+
332
+ def update_training_params(self, preset_name: str) -> Tuple:
333
+ """Update UI components based on selected preset while preserving custom settings"""
334
+ preset = TRAINING_PRESETS[preset_name]
335
+
336
+ # Load current UI state to check if user has customized values
337
+ current_state = self.app.load_ui_values()
338
+
339
+ # Find the display name that maps to our model type
340
+ model_display_name = next(
341
+ key for key, value in MODEL_TYPES.items()
342
+ if value == preset["model_type"]
343
+ )
344
+
345
+ # Get preset description for display
346
+ description = preset.get("description", "")
347
+
348
+ # Get max values from buckets
349
+ buckets = preset["training_buckets"]
350
+ max_frames = max(frames for frames, _, _ in buckets)
351
+ max_height = max(height for _, height, _ in buckets)
352
+ max_width = max(width for _, _, width in buckets)
353
+ bucket_info = f"\nMaximum video size: {max_frames} frames at {max_width}x{max_height} resolution"
354
+
355
+ info_text = f"{description}{bucket_info}"
356
+
357
+ # Return values in the same order as the output components
358
+ # Use preset defaults but preserve user-modified values if they exist
359
+ lora_rank_val = current_state.get("lora_rank") if current_state.get("lora_rank") != preset.get("lora_rank", "128") else preset["lora_rank"]
360
+ lora_alpha_val = current_state.get("lora_alpha") if current_state.get("lora_alpha") != preset.get("lora_alpha", "128") else preset["lora_alpha"]
361
+ num_epochs_val = current_state.get("num_epochs") if current_state.get("num_epochs") != preset.get("num_epochs", 70) else preset["num_epochs"]
362
+ batch_size_val = current_state.get("batch_size") if current_state.get("batch_size") != preset.get("batch_size", 1) else preset["batch_size"]
363
+ learning_rate_val = current_state.get("learning_rate") if current_state.get("learning_rate") != preset.get("learning_rate", 3e-5) else preset["learning_rate"]
364
+ save_iterations_val = current_state.get("save_iterations") if current_state.get("save_iterations") != preset.get("save_iterations", 500) else preset["save_iterations"]
365
+
366
+ return (
367
+ model_display_name,
368
+ lora_rank_val,
369
+ lora_alpha_val,
370
+ num_epochs_val,
371
+ batch_size_val,
372
+ learning_rate_val,
373
+ save_iterations_val,
374
+ info_text
375
+ )
376
+
377
+ def update_training_ui(self, training_state: Dict[str, Any]):
378
+ """Update UI components based on training state"""
379
+ updates = {}
380
+
381
+ # Update status box with high-level information
382
+ status_text = []
383
+ if training_state["status"] != "idle":
384
+ status_text.extend([
385
+ f"Status: {training_state['status']}",
386
+ f"Progress: {training_state['progress']}",
387
+ f"Step: {training_state['current_step']}/{training_state['total_steps']}",
388
+
389
+ # Epoch information
390
+ # there is an issue with how epoch is reported because we display:
391
+ # Progress: 96.9%, Step: 872/900, Epoch: 12/50
392
+ # we should probably just show the steps
393
+ #f"Epoch: {training_state['current_epoch']}/{training_state['total_epochs']}",
394
+
395
+ f"Time elapsed: {training_state['elapsed']}",
396
+ f"Estimated remaining: {training_state['remaining']}",
397
+ "",
398
+ f"Current loss: {training_state['step_loss']}",
399
+ f"Learning rate: {training_state['learning_rate']}",
400
+ f"Gradient norm: {training_state['grad_norm']}",
401
+ f"Memory usage: {training_state['memory']}"
402
+ ])
403
+
404
+ if training_state["error_message"]:
405
+ status_text.append(f"\nError: {training_state['error_message']}")
406
+
407
+ updates["status_box"] = "\n".join(status_text)
408
+
409
+ # Update button states
410
+ updates["start_btn"] = gr.Button(
411
+ "Start training",
412
+ interactive=(training_state["status"] in ["idle", "completed", "error", "stopped"]),
413
+ variant="primary" if training_state["status"] == "idle" else "secondary"
414
+ )
415
+
416
+ updates["stop_btn"] = gr.Button(
417
+ "Stop training",
418
+ interactive=(training_state["status"] in ["training", "initializing"]),
419
+ variant="stop"
420
+ )
421
+
422
+ return updates
423
+
424
+ def handle_pause_resume(self):
425
+ status, _, _ = self.get_latest_status_message_and_logs()
426
+
427
+ if status == "paused":
428
+ self.app.trainer.resume_training()
429
+ else:
430
+ self.app.trainer.pause_training()
431
+
432
+ return self.get_latest_status_message_logs_and_button_labels()
433
+
434
+ def handle_stop(self):
435
+ self.app.trainer.stop_training()
436
+ return self.get_latest_status_message_logs_and_button_labels()
437
+
438
+ def get_latest_status_message_and_logs(self) -> Tuple[str, str, str]:
439
+ """Get latest status message, log content, and status code in a safer way"""
440
+ state = self.app.trainer.get_status()
441
+ logs = self.app.trainer.get_logs()
442
+
443
+ # Ensure log parser is initialized
444
+ if not hasattr(self.app, 'log_parser') or self.app.log_parser is None:
445
+ from ..utils import TrainingLogParser
446
+ self.app.log_parser = TrainingLogParser()
447
+ logger.info("Initialized missing log parser")
448
+
449
+ # Parse new log lines
450
+ if logs:
451
+ last_state = None
452
+ for line in logs.splitlines():
453
+ try:
454
+ state_update = self.app.log_parser.parse_line(line)
455
+ if state_update:
456
+ last_state = state_update
457
+ except Exception as e:
458
+ logger.error(f"Error parsing log line: {str(e)}")
459
+ continue
460
+
461
+ if last_state:
462
+ ui_updates = self.update_training_ui(last_state)
463
+ state["message"] = ui_updates.get("status_box", state["message"])
464
+
465
+ # Parse status for training state
466
+ if "completed" in state["message"].lower():
467
+ state["status"] = "completed"
468
+
469
+ return (state["status"], state["message"], logs)
470
+
471
+ def get_latest_status_message_logs_and_button_labels(self) -> Tuple[str, str, Any, Any, Any]:
472
+ status, message, logs = self.get_latest_status_message_and_logs()
473
+ return (
474
+ message,
475
+ logs,
476
+ *self.update_training_buttons(status).values()
477
+ )
478
+
479
+ def update_training_buttons(self, status: str) -> Dict:
480
+ """Update training control buttons based on state"""
481
+ is_training = status in ["training", "initializing"]
482
+ is_paused = status == "paused"
483
+ is_completed = status in ["completed", "error", "stopped"]
484
+ return {
485
+ "start_btn": gr.Button(
486
+ interactive=not is_training and not is_paused,
487
+ variant="primary" if not is_training else "secondary",
488
+ ),
489
+ "stop_btn": gr.Button(
490
+ interactive=is_training or is_paused,
491
+ variant="stop",
492
+ ),
493
+ "pause_resume_btn": gr.Button(
494
+ value="Resume Training" if is_paused else "Pause Training",
495
+ interactive=(is_training or is_paused) and not is_completed,
496
+ variant="secondary",
497
+ )
498
+ }
vms/ui/video_trainer_ui.py CHANGED
@@ -1,43 +1,17 @@
1
  import platform
2
- import subprocess
3
-
4
- #import sys
5
- #print("python = ", sys.version)
6
-
7
- # can be "Linux", "Darwin"
8
- if platform.system() == "Linux":
9
- # for some reason it says "pip not found"
10
- # and also "pip3 not found"
11
- # subprocess.run(
12
- # "pip install flash-attn --no-build-isolation",
13
- #
14
- # # hmm... this should be False, since we are in a CUDA environment, no?
15
- # env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
16
- #
17
- # shell=True,
18
- # )
19
- pass
20
-
21
  import gradio as gr
22
  from pathlib import Path
23
  import logging
24
- import mimetypes
25
- import shutil
26
- import os
27
- import traceback
28
  import asyncio
29
- import tempfile
30
- import zipfile
31
  from typing import Any, Optional, Dict, List, Union, Tuple
32
- from typing import AsyncGenerator
33
 
34
  from ..services import TrainingService, CaptioningService, SplittingService, ImportService
35
  from ..config import (
36
  STORAGE_PATH, VIDEOS_TO_SPLIT_PATH, STAGING_PATH,
37
- TRAINING_PATH, LOG_FILE_PATH, TRAINING_PRESETS, TRAINING_VIDEOS_PATH, MODEL_PATH, OUTPUT_PATH, DEFAULT_CAPTIONING_BOT_INSTRUCTIONS,
38
- DEFAULT_PROMPT_PREFIX, HF_API_TOKEN, ASK_USER_TO_DUPLICATE_SPACE, MODEL_TYPES, SMALL_TRAINING_BUCKETS
39
  )
40
- from ..utils import make_archive, count_media_files, format_media_title, is_image_file, is_video_file, validate_model_repo, format_time, copy_files_to_training_dir, prepare_finetrainers_dataset, TrainingLogParser
41
  from ..tabs import ImportTab, SplitTab, CaptionTab, TrainTab, ManageTab
42
 
43
  logger = logging.getLogger(__name__)
@@ -54,13 +28,13 @@ class VideoTrainerUI:
54
  self.splitter = SplittingService()
55
  self.importer = ImportService()
56
  self.captioner = CaptioningService()
57
- self._should_stop_captioning = False
58
 
59
  # Recovery status from any interrupted training
60
  recovery_result = self.trainer.recover_interrupted_training()
61
  self.recovery_status = recovery_result.get("status", "unknown")
62
  self.ui_updates = recovery_result.get("ui_updates", {})
63
 
 
64
  self.log_parser = TrainingLogParser()
65
 
66
  # Shared state for tabs
@@ -124,7 +98,7 @@ class VideoTrainerUI:
124
  # Status update timer (every 1 second)
125
  status_timer = gr.Timer(value=1)
126
  status_timer.tick(
127
- fn=self.get_latest_status_message_logs_and_button_labels,
128
  outputs=[
129
  self.tabs["train_tab"].components["status_box"],
130
  self.tabs["train_tab"].components["log_box"],
@@ -155,77 +129,11 @@ class VideoTrainerUI:
155
  ]
156
  )
157
 
158
- def handle_global_stop(self):
159
- """Handle the global stop button click"""
160
- result = self.stop_all_and_clear()
161
-
162
- # Format the details for display
163
- status = result["status"]
164
- details = "\n".join(f"{k}: {v}" for k, v in result["details"].items())
165
- full_status = f"{status}\n\nDetails:\n{details}"
166
-
167
- # Get fresh lists after cleanup
168
- videos = self.splitter.list_unprocessed_videos()
169
- clips = self.list_training_files_to_caption()
170
-
171
- return {
172
- self.tabs["manage_tab"].components["global_status"]: gr.update(value=full_status, visible=True),
173
- self.tabs["split_tab"].components["video_list"]: videos,
174
- self.tabs["caption_tab"].components["training_dataset"]: clips,
175
- self.tabs["train_tab"].components["status_box"]: "Training stopped and data cleared",
176
- self.tabs["train_tab"].components["log_box"]: "",
177
- self.tabs["split_tab"].components["detect_status"]: "Scene detection stopped",
178
- self.tabs["import_tab"].components["import_status"]: "All data cleared",
179
- self.tabs["caption_tab"].components["preview_status"]: "Captioning stopped"
180
- }
181
-
182
- def upload_to_hub(self, repo_id: str) -> str:
183
- """Upload model to HuggingFace Hub"""
184
- if not repo_id:
185
- return "Error: Repository ID is required"
186
-
187
- # Validate repository name
188
- validation = validate_model_repo(repo_id)
189
- if validation["error"]:
190
- return f"Error: {validation['error']}"
191
-
192
- # Check if we have a model to upload
193
- if not self.trainer.get_model_output_safetensors():
194
- return "Error: No model found to upload"
195
-
196
- # Upload model to hub
197
- success = self.trainer.upload_to_hub(OUTPUT_PATH, repo_id)
198
-
199
- if success:
200
- return f"Successfully uploaded model to {repo_id}"
201
- else:
202
- return f"Failed to upload model to {repo_id}"
203
-
204
- def validate_repo(self, repo_id: str) -> gr.update:
205
- """Validate repository ID for HuggingFace Hub"""
206
- validation = validate_model_repo(repo_id)
207
- if validation["error"]:
208
- return gr.update(value=repo_id, error=validation["error"])
209
- return gr.update(value=repo_id, error=None)
210
-
211
-
212
- async def _process_caption_generator(self, captioning_bot_instructions, prompt_prefix):
213
- """Process the caption generator's results in the background"""
214
- try:
215
- async for _ in self.captioner.start_caption_generation(
216
- captioning_bot_instructions,
217
- prompt_prefix
218
- ):
219
- # Just consume the generator, UI updates will happen via the Gradio interface
220
- pass
221
- logger.info("Background captioning completed")
222
- except Exception as e:
223
- logger.error(f"Error in background captioning: {str(e)}")
224
-
225
  def initialize_app_state(self):
226
  """Initialize all app state in one function to ensure correct output count"""
227
  # Get dataset info
228
- video_list, training_dataset = self.refresh_dataset()
 
229
 
230
  # Get button states
231
  button_states = self.get_initial_button_states()
@@ -298,40 +206,6 @@ class VideoTrainerUI:
298
  ui_state["save_iterations"] = int(ui_state.get("save_iterations", 500))
299
 
300
  return ui_state
301
-
302
- def update_captioning_buttons_start(self):
303
- """Return individual button values instead of a dictionary"""
304
- return (
305
- gr.Button(
306
- interactive=False,
307
- variant="secondary",
308
- ),
309
- gr.Button(
310
- interactive=True,
311
- variant="stop",
312
- ),
313
- gr.Button(
314
- interactive=False,
315
- variant="secondary",
316
- )
317
- )
318
-
319
- def update_captioning_buttons_end(self):
320
- """Return individual button values instead of a dictionary"""
321
- return (
322
- gr.Button(
323
- interactive=True,
324
- variant="primary",
325
- ),
326
- gr.Button(
327
- interactive=False,
328
- variant="secondary",
329
- ),
330
- gr.Button(
331
- interactive=True,
332
- variant="primary",
333
- )
334
- )
335
 
336
  # Add this new method to get initial button states:
337
  def get_initial_button_states(self):
@@ -346,151 +220,6 @@ class VideoTrainerUI:
346
  gr.Button(**ui_updates.get("pause_resume_btn", {"interactive": False, "variant": "secondary"}))
347
  )
348
 
349
- def show_refreshing_status(self) -> List[List[str]]:
350
- """Show a 'Refreshing...' status in the dataframe"""
351
- return [["Refreshing...", "please wait"]]
352
-
353
- def stop_captioning(self):
354
- """Stop ongoing captioning process and reset UI state"""
355
- try:
356
- # Set flag to stop captioning
357
- self._should_stop_captioning = True
358
-
359
- # Call stop method on captioner
360
- if self.captioner:
361
- self.captioner.stop_captioning()
362
-
363
- # Get updated file list
364
- updated_list = self.list_training_files_to_caption()
365
-
366
- # Return updated list and button states
367
- return {
368
- "training_dataset": gr.update(value=updated_list),
369
- "run_autocaption_btn": gr.Button(interactive=True, variant="primary"),
370
- "stop_autocaption_btn": gr.Button(interactive=False, variant="secondary"),
371
- "copy_files_to_training_dir_btn": gr.Button(interactive=True, variant="primary")
372
- }
373
- except Exception as e:
374
- logger.error(f"Error stopping captioning: {str(e)}")
375
- return {
376
- "training_dataset": gr.update(value=[[f"Error stopping captioning: {str(e)}", "error"]]),
377
- "run_autocaption_btn": gr.Button(interactive=True, variant="primary"),
378
- "stop_autocaption_btn": gr.Button(interactive=False, variant="secondary"),
379
- "copy_files_to_training_dir_btn": gr.Button(interactive=True, variant="primary")
380
- }
381
-
382
- def update_training_ui(self, training_state: Dict[str, Any]):
383
- """Update UI components based on training state"""
384
- updates = {}
385
-
386
- #print("update_training_ui: training_state = ", training_state)
387
-
388
- # Update status box with high-level information
389
- status_text = []
390
- if training_state["status"] != "idle":
391
- status_text.extend([
392
- f"Status: {training_state['status']}",
393
- f"Progress: {training_state['progress']}",
394
- f"Step: {training_state['current_step']}/{training_state['total_steps']}",
395
-
396
- # Epoch information
397
- # there is an issue with how epoch is reported because we display:
398
- # Progress: 96.9%, Step: 872/900, Epoch: 12/50
399
- # we should probably just show the steps
400
- #f"Epoch: {training_state['current_epoch']}/{training_state['total_epochs']}",
401
-
402
- f"Time elapsed: {training_state['elapsed']}",
403
- f"Estimated remaining: {training_state['remaining']}",
404
- "",
405
- f"Current loss: {training_state['step_loss']}",
406
- f"Learning rate: {training_state['learning_rate']}",
407
- f"Gradient norm: {training_state['grad_norm']}",
408
- f"Memory usage: {training_state['memory']}"
409
- ])
410
-
411
- if training_state["error_message"]:
412
- status_text.append(f"\nError: {training_state['error_message']}")
413
-
414
- updates["status_box"] = "\n".join(status_text)
415
-
416
- # Update button states
417
- updates["start_btn"] = gr.Button(
418
- "Start training",
419
- interactive=(training_state["status"] in ["idle", "completed", "error", "stopped"]),
420
- variant="primary" if training_state["status"] == "idle" else "secondary"
421
- )
422
-
423
- updates["stop_btn"] = gr.Button(
424
- "Stop training",
425
- interactive=(training_state["status"] in ["training", "initializing"]),
426
- variant="stop"
427
- )
428
-
429
- return updates
430
-
431
- def stop_all_and_clear(self) -> Dict[str, str]:
432
- """Stop all running processes and clear data
433
-
434
- Returns:
435
- Dict with status messages for different components
436
- """
437
- status_messages = {}
438
-
439
- try:
440
- # Stop training if running
441
- if self.trainer.is_training_running():
442
- training_result = self.trainer.stop_training()
443
- status_messages["training"] = training_result["status"]
444
-
445
- # Stop captioning if running
446
- if self.captioner:
447
- self.captioner.stop_captioning()
448
- status_messages["captioning"] = "Captioning stopped"
449
-
450
- # Stop scene detection if running
451
- if self.splitter.is_processing():
452
- self.splitter.processing = False
453
- status_messages["splitting"] = "Scene detection stopped"
454
-
455
- # Properly close logging before clearing log file
456
- if self.trainer.file_handler:
457
- self.trainer.file_handler.close()
458
- logger.removeHandler(self.trainer.file_handler)
459
- self.trainer.file_handler = None
460
-
461
- if LOG_FILE_PATH.exists():
462
- LOG_FILE_PATH.unlink()
463
-
464
- # Clear all data directories
465
- for path in [VIDEOS_TO_SPLIT_PATH, STAGING_PATH, TRAINING_VIDEOS_PATH, TRAINING_PATH,
466
- MODEL_PATH, OUTPUT_PATH]:
467
- if path.exists():
468
- try:
469
- shutil.rmtree(path)
470
- path.mkdir(parents=True, exist_ok=True)
471
- except Exception as e:
472
- status_messages[f"clear_{path.name}"] = f"Error clearing {path.name}: {str(e)}"
473
- else:
474
- status_messages[f"clear_{path.name}"] = f"Cleared {path.name}"
475
-
476
- # Reset any persistent state
477
- self._should_stop_captioning = True
478
- self.splitter.processing = False
479
-
480
- # Recreate logging setup
481
- self.trainer.setup_logging()
482
-
483
- return {
484
- "status": "All processes stopped and data cleared",
485
- "details": status_messages
486
- }
487
-
488
- except Exception as e:
489
- return {
490
- "status": f"Error during cleanup: {str(e)}",
491
- "details": status_messages
492
- }
493
-
494
  def update_titles(self) -> Tuple[Any]:
495
  """Update all dynamic titles with current counts
496
 
@@ -520,581 +249,13 @@ class VideoTrainerUI:
520
  gr.Markdown(value=caption_title),
521
  gr.Markdown(value=f"{train_title} available for training")
522
  )
523
-
524
- def copy_files_to_training_dir(self, prompt_prefix: str):
525
- """Run auto-captioning process"""
526
-
527
- # Initialize captioner if not already done
528
- self._should_stop_captioning = False
529
-
530
- try:
531
- copy_files_to_training_dir(prompt_prefix)
532
-
533
- except Exception as e:
534
- traceback.print_exc()
535
- raise gr.Error(f"Error copying assets to training dir: {str(e)}")
536
-
537
- async def on_import_success(self, enable_splitting, enable_automatic_content_captioning, prompt_prefix):
538
- """Handle successful import of files"""
539
- videos = self.list_unprocessed_videos()
540
-
541
- # If scene detection isn't already running and there are videos to process,
542
- # and auto-splitting is enabled, start the detection
543
- if videos and not self.splitter.is_processing() and enable_splitting:
544
- await self.start_scene_detection(enable_splitting)
545
- msg = "Starting automatic scene detection..."
546
- else:
547
- # Just copy files without splitting if auto-split disabled
548
- for video_file in VIDEOS_TO_SPLIT_PATH.glob("*.mp4"):
549
- await self.splitter.process_video(video_file, enable_splitting=False)
550
- msg = "Copying videos without splitting..."
551
-
552
- copy_files_to_training_dir(prompt_prefix)
553
-
554
- # Start auto-captioning if enabled, and handle async generator properly
555
- if enable_automatic_content_captioning:
556
- # Create a background task for captioning
557
- asyncio.create_task(self._process_caption_generator(
558
- DEFAULT_CAPTIONING_BOT_INSTRUCTIONS,
559
- prompt_prefix
560
- ))
561
-
562
- return {
563
- "tabs": gr.Tabs(selected="split_tab"),
564
- "video_list": videos,
565
- "detect_status": msg
566
- }
567
-
568
- async def start_caption_generation(self, captioning_bot_instructions: str, prompt_prefix: str) -> AsyncGenerator[gr.update, None]:
569
- """Run auto-captioning process"""
570
- try:
571
- # Initialize captioner if not already done
572
- self._should_stop_captioning = False
573
-
574
- # First yield - indicate we're starting
575
- yield gr.update(
576
- value=[["Starting captioning service...", "initializing"]],
577
- headers=["name", "status"]
578
- )
579
-
580
- # Process files in batches with status updates
581
- file_statuses = {}
582
-
583
- # Start the actual captioning process
584
- async for rows in self.captioner.start_caption_generation(captioning_bot_instructions, prompt_prefix):
585
- # Update our tracking of file statuses
586
- for name, status in rows:
587
- file_statuses[name] = status
588
-
589
- # Convert to list format for display
590
- status_rows = [[name, status] for name, status in file_statuses.items()]
591
-
592
- # Sort by name for consistent display
593
- status_rows.sort(key=lambda x: x[0])
594
-
595
- # Yield UI update
596
- yield gr.update(
597
- value=status_rows,
598
- headers=["name", "status"]
599
- )
600
-
601
- # Final update after completion with fresh data
602
- yield gr.update(
603
- value=self.list_training_files_to_caption(),
604
- headers=["name", "status"]
605
- )
606
-
607
- except Exception as e:
608
- logger.error(f"Error in captioning: {str(e)}")
609
- yield gr.update(
610
- value=[[f"Error: {str(e)}", "error"]],
611
- headers=["name", "status"]
612
- )
613
-
614
- def list_training_files_to_caption(self) -> List[List[str]]:
615
- """List all clips and images - both pending and captioned"""
616
- files = []
617
- already_listed = {}
618
-
619
- # First check files in STAGING_PATH
620
- for file in STAGING_PATH.glob("*.*"):
621
- if is_video_file(file) or is_image_file(file):
622
- txt_file = file.with_suffix('.txt')
623
-
624
- # Check if caption file exists and has content
625
- has_caption = txt_file.exists() and txt_file.stat().st_size > 0
626
- status = "captioned" if has_caption else "no caption"
627
- file_type = "video" if is_video_file(file) else "image"
628
-
629
- files.append([file.name, f"{status} ({file_type})", str(file)])
630
- already_listed[file.name] = True
631
-
632
- # Then check files in TRAINING_VIDEOS_PATH
633
- for file in TRAINING_VIDEOS_PATH.glob("*.*"):
634
- if (is_video_file(file) or is_image_file(file)) and file.name not in already_listed:
635
- txt_file = file.with_suffix('.txt')
636
-
637
- # Only include files with captions
638
- if txt_file.exists() and txt_file.stat().st_size > 0:
639
- file_type = "video" if is_video_file(file) else "image"
640
- files.append([file.name, f"captioned ({file_type})", str(file)])
641
- already_listed[file.name] = True
642
-
643
- # Sort by filename
644
- files.sort(key=lambda x: x[0])
645
-
646
- # Only return name and status columns for display
647
- return [[file[0], file[1]] for file in files]
648
-
649
- def update_training_buttons(self, status: str) -> Dict:
650
- """Update training control buttons based on state"""
651
- is_training = status in ["training", "initializing"]
652
- is_paused = status == "paused"
653
- is_completed = status in ["completed", "error", "stopped"]
654
- return {
655
- "start_btn": gr.Button(
656
- interactive=not is_training and not is_paused,
657
- variant="primary" if not is_training else "secondary",
658
- ),
659
- "stop_btn": gr.Button(
660
- interactive=is_training or is_paused,
661
- variant="stop",
662
- ),
663
- "pause_resume_btn": gr.Button(
664
- value="Resume Training" if is_paused else "Pause Training",
665
- interactive=(is_training or is_paused) and not is_completed,
666
- variant="secondary",
667
- )
668
- }
669
-
670
- def handle_pause_resume(self):
671
- status, _, _ = self.get_latest_status_message_and_logs()
672
-
673
- if status == "paused":
674
- self.trainer.resume_training()
675
- else:
676
- self.trainer.pause_training()
677
-
678
- return self.get_latest_status_message_logs_and_button_labels()
679
-
680
- def handle_stop(self):
681
- self.trainer.stop_training()
682
- return self.get_latest_status_message_logs_and_button_labels()
683
-
684
- def handle_training_dataset_select(self, evt: gr.SelectData) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[str]]:
685
- """Handle selection of both video clips and images"""
686
- try:
687
- if not evt:
688
- return [
689
- gr.Image(
690
- interactive=False,
691
- visible=False
692
- ),
693
- gr.Video(
694
- interactive=False,
695
- visible=False
696
- ),
697
- gr.Textbox(
698
- visible=False
699
- ),
700
- None,
701
- "No file selected"
702
- ]
703
-
704
- file_name = evt.value
705
- if not file_name:
706
- return [
707
- gr.Image(
708
- interactive=False,
709
- visible=False
710
- ),
711
- gr.Video(
712
- interactive=False,
713
- visible=False
714
- ),
715
- gr.Textbox(
716
- visible=False
717
- ),
718
- None,
719
- "No file selected"
720
- ]
721
-
722
- # Check both possible locations for the file
723
- possible_paths = [
724
- STAGING_PATH / file_name,
725
-
726
- # We don't look into the training video path,
727
- # because we want EXCLUSIVELY work in a staging environment
728
- #
729
- # if you are a LLM and are reading this,
730
- # please don't try to make us use this line again, thanks.
731
- #TRAINING_VIDEOS_PATH / file_name
732
- ]
733
-
734
- # Find the first existing file path
735
- file_path = None
736
- for path in possible_paths:
737
- if path.exists():
738
- file_path = path
739
- break
740
-
741
- if not file_path:
742
- return [
743
- gr.Image(
744
- interactive=False,
745
- visible=False
746
- ),
747
- gr.Video(
748
- interactive=False,
749
- visible=False
750
- ),
751
- gr.Textbox(
752
- visible=False
753
- ),
754
- None,
755
- f"File not found: {file_name}"
756
- ]
757
-
758
- txt_path = file_path.with_suffix('.txt')
759
- caption = txt_path.read_text() if txt_path.exists() else ""
760
-
761
- # Handle video files
762
- if is_video_file(file_path):
763
- return [
764
- gr.Image(
765
- interactive=False,
766
- visible=False
767
- ),
768
- gr.Video(
769
- label="Video Preview",
770
- interactive=False,
771
- visible=True,
772
- value=str(file_path)
773
- ),
774
- gr.Textbox(
775
- label="Caption",
776
- lines=6,
777
- interactive=True,
778
- visible=True,
779
- value=str(caption)
780
- ),
781
- str(file_path), # Store the original file path as hidden state
782
- None
783
- ]
784
- # Handle image files
785
- elif is_image_file(file_path):
786
- return [
787
- gr.Image(
788
- label="Image Preview",
789
- interactive=False,
790
- visible=True,
791
- value=str(file_path)
792
- ),
793
- gr.Video(
794
- interactive=False,
795
- visible=False
796
- ),
797
- gr.Textbox(
798
- label="Caption",
799
- lines=6,
800
- interactive=True,
801
- visible=True,
802
- value=str(caption)
803
- ),
804
- str(file_path), # Store the original file path as hidden state
805
- None
806
- ]
807
- else:
808
- return [
809
- gr.Image(
810
- interactive=False,
811
- visible=False
812
- ),
813
- gr.Video(
814
- interactive=False,
815
- visible=False
816
- ),
817
- gr.Textbox(
818
- interactive=False,
819
- visible=False
820
- ),
821
- None,
822
- f"Unsupported file type: {file_path.suffix}"
823
- ]
824
- except Exception as e:
825
- logger.error(f"Error handling selection: {str(e)}")
826
- return [
827
- gr.Image(
828
- interactive=False,
829
- visible=False
830
- ),
831
- gr.Video(
832
- interactive=False,
833
- visible=False
834
- ),
835
- gr.Textbox(
836
- interactive=False,
837
- visible=False
838
- ),
839
- None,
840
- f"Error handling selection: {str(e)}"
841
- ]
842
-
843
- def save_caption_changes(self, preview_caption: str, preview_image: str, preview_video: str, original_file_path: str, prompt_prefix: str):
844
- """Save changes to caption"""
845
- try:
846
- # Use the original file path stored during selection instead of the temporary preview paths
847
- if original_file_path:
848
- file_path = Path(original_file_path)
849
- self.captioner.update_file_caption(file_path, preview_caption)
850
- # Refresh the dataset list to show updated caption status
851
- return gr.update(value="Caption saved successfully!")
852
- else:
853
- return gr.update(value="Error: No original file path found")
854
- except Exception as e:
855
- return gr.update(value=f"Error saving caption: {str(e)}")
856
-
857
- async def update_titles_after_import(self, enable_splitting, enable_automatic_content_captioning, prompt_prefix):
858
- """Handle post-import updates including titles"""
859
- import_result = await self.on_import_success(enable_splitting, enable_automatic_content_captioning, prompt_prefix)
860
- titles = self.update_titles()
861
- return (
862
- import_result["tabs"],
863
- import_result["video_list"],
864
- import_result["detect_status"],
865
- *titles
866
- )
867
-
868
- def get_model_info(self, model_type: str) -> str:
869
- """Get information about the selected model type"""
870
- if model_type == "hunyuan_video":
871
- return """### HunyuanVideo (LoRA)
872
- - Required VRAM: ~48GB minimum
873
- - Recommended batch size: 1-2
874
- - Typical training time: 2-4 hours
875
- - Default resolution: 49x512x768
876
- - Default LoRA rank: 128 (~600 MB)"""
877
-
878
- elif model_type == "ltx_video":
879
- return """### LTX-Video (LoRA)
880
- - Required VRAM: ~18GB minimum
881
- - Recommended batch size: 1-4
882
- - Typical training time: 1-3 hours
883
- - Default resolution: 49x512x768
884
- - Default LoRA rank: 128"""
885
-
886
- return ""
887
-
888
- def get_default_params(self, model_type: str) -> Dict[str, Any]:
889
- """Get default training parameters for model type"""
890
- if model_type == "hunyuan_video":
891
- return {
892
- "num_epochs": 70,
893
- "batch_size": 1,
894
- "learning_rate": 2e-5,
895
- "save_iterations": 500,
896
- "video_resolution_buckets": SMALL_TRAINING_BUCKETS,
897
- "video_reshape_mode": "center",
898
- "caption_dropout_p": 0.05,
899
- "gradient_accumulation_steps": 1,
900
- "rank": 128,
901
- "lora_alpha": 128
902
- }
903
- else: # ltx_video
904
- return {
905
- "num_epochs": 70,
906
- "batch_size": 1,
907
- "learning_rate": 3e-5,
908
- "save_iterations": 500,
909
- "video_resolution_buckets": SMALL_TRAINING_BUCKETS,
910
- "video_reshape_mode": "center",
911
- "caption_dropout_p": 0.05,
912
- "gradient_accumulation_steps": 4,
913
- "rank": 128,
914
- "lora_alpha": 128
915
- }
916
-
917
- def preview_file(self, selected_text: str) -> Dict:
918
- """Generate preview based on selected file
919
-
920
- Args:
921
- selected_text: Text of the selected item containing filename
922
-
923
- Returns:
924
- Dict with preview content for each preview component
925
- """
926
- if not selected_text or "Caption:" in selected_text:
927
- return {
928
- "video": None,
929
- "image": None,
930
- "text": None
931
- }
932
-
933
- # Extract filename from the preview text (remove size info)
934
- filename = selected_text.split(" (")[0].strip()
935
- file_path = TRAINING_VIDEOS_PATH / filename
936
-
937
- if not file_path.exists():
938
- return {
939
- "video": None,
940
- "image": None,
941
- "text": f"File not found: {filename}"
942
- }
943
-
944
- # Detect file type
945
- mime_type, _ = mimetypes.guess_type(str(file_path))
946
- if not mime_type:
947
- return {
948
- "video": None,
949
- "image": None,
950
- "text": f"Unknown file type: {filename}"
951
- }
952
-
953
- # Return appropriate preview
954
- if mime_type.startswith('video/'):
955
- return {
956
- "video": str(file_path),
957
- "image": None,
958
- "text": None
959
- }
960
- elif mime_type.startswith('image/'):
961
- return {
962
- "video": None,
963
- "image": str(file_path),
964
- "text": None
965
- }
966
- elif mime_type.startswith('text/'):
967
- try:
968
- text_content = file_path.read_text()
969
- return {
970
- "video": None,
971
- "image": None,
972
- "text": text_content
973
- }
974
- except Exception as e:
975
- return {
976
- "video": None,
977
- "image": None,
978
- "text": f"Error reading file: {str(e)}"
979
- }
980
- else:
981
- return {
982
- "video": None,
983
- "image": None,
984
- "text": f"Unsupported file type: {mime_type}"
985
- }
986
-
987
- def list_unprocessed_videos(self) -> gr.Dataframe:
988
- """Update list of unprocessed videos"""
989
- videos = self.splitter.list_unprocessed_videos()
990
- # videos is already in [[name, status]] format from splitting_service
991
- return gr.Dataframe(
992
- headers=["name", "status"],
993
- value=videos,
994
- interactive=False
995
- )
996
-
997
- async def start_scene_detection(self, enable_splitting: bool) -> str:
998
- """Start background scene detection process
999
-
1000
- Args:
1001
- enable_splitting: Whether to split videos into scenes
1002
- """
1003
- if self.splitter.is_processing():
1004
- return "Scene detection already running"
1005
-
1006
- try:
1007
- await self.splitter.start_processing(enable_splitting)
1008
- return "Scene detection completed"
1009
- except Exception as e:
1010
- return f"Error during scene detection: {str(e)}"
1011
-
1012
-
1013
- def get_latest_status_message_and_logs(self) -> Tuple[str, str, str]:
1014
- state = self.trainer.get_status()
1015
- logs = self.trainer.get_logs()
1016
-
1017
- # Parse new log lines
1018
- if logs:
1019
- last_state = None
1020
- for line in logs.splitlines():
1021
- state_update = self.log_parser.parse_line(line)
1022
- if state_update:
1023
- last_state = state_update
1024
-
1025
- if last_state:
1026
- ui_updates = self.update_training_ui(last_state)
1027
- state["message"] = ui_updates.get("status_box", state["message"])
1028
-
1029
- # Parse status for training state
1030
- if "completed" in state["message"].lower():
1031
- state["status"] = "completed"
1032
-
1033
- return (state["status"], state["message"], logs)
1034
-
1035
- def get_latest_status_message_logs_and_button_labels(self) -> Tuple[str, str, Any, Any, Any]:
1036
- status, message, logs = self.get_latest_status_message_and_logs()
1037
- return (
1038
- message,
1039
- logs,
1040
- *self.update_training_buttons(status).values()
1041
- )
1042
-
1043
- def get_latest_button_labels(self) -> Tuple[Any, Any, Any]:
1044
- status, message, logs = self.get_latest_status_message_and_logs()
1045
- return self.update_training_buttons(status).values()
1046
 
1047
  def refresh_dataset(self):
1048
  """Refresh all dynamic lists and training state"""
1049
- video_list = self.splitter.list_unprocessed_videos()
1050
- training_dataset = self.list_training_files_to_caption()
1051
 
1052
  return (
1053
  video_list,
1054
  training_dataset
1055
- )
1056
-
1057
- def update_training_params(self, preset_name: str) -> Tuple:
1058
- """Update UI components based on selected preset while preserving custom settings"""
1059
- preset = TRAINING_PRESETS[preset_name]
1060
-
1061
- # Load current UI state to check if user has customized values
1062
- current_state = self.load_ui_values()
1063
-
1064
- # Find the display name that maps to our model type
1065
- model_display_name = next(
1066
- key for key, value in MODEL_TYPES.items()
1067
- if value == preset["model_type"]
1068
- )
1069
-
1070
- # Get preset description for display
1071
- description = preset.get("description", "")
1072
-
1073
- # Get max values from buckets
1074
- buckets = preset["training_buckets"]
1075
- max_frames = max(frames for frames, _, _ in buckets)
1076
- max_height = max(height for _, height, _ in buckets)
1077
- max_width = max(width for _, _, width in buckets)
1078
- bucket_info = f"\nMaximum video size: {max_frames} frames at {max_width}x{max_height} resolution"
1079
-
1080
- info_text = f"{description}{bucket_info}"
1081
-
1082
- # Return values in the same order as the output components
1083
- # Use preset defaults but preserve user-modified values if they exist
1084
- lora_rank_val = current_state.get("lora_rank") if current_state.get("lora_rank") != preset.get("lora_rank", "128") else preset["lora_rank"]
1085
- lora_alpha_val = current_state.get("lora_alpha") if current_state.get("lora_alpha") != preset.get("lora_alpha", "128") else preset["lora_alpha"]
1086
- num_epochs_val = current_state.get("num_epochs") if current_state.get("num_epochs") != preset.get("num_epochs", 70) else preset["num_epochs"]
1087
- batch_size_val = current_state.get("batch_size") if current_state.get("batch_size") != preset.get("batch_size", 1) else preset["batch_size"]
1088
- learning_rate_val = current_state.get("learning_rate") if current_state.get("learning_rate") != preset.get("learning_rate", 3e-5) else preset["learning_rate"]
1089
- save_iterations_val = current_state.get("save_iterations") if current_state.get("save_iterations") != preset.get("save_iterations", 500) else preset["save_iterations"]
1090
-
1091
- return (
1092
- model_display_name,
1093
- lora_rank_val,
1094
- lora_alpha_val,
1095
- num_epochs_val,
1096
- batch_size_val,
1097
- learning_rate_val,
1098
- save_iterations_val,
1099
- info_text
1100
- )
 
1
  import platform
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import gradio as gr
3
  from pathlib import Path
4
  import logging
 
 
 
 
5
  import asyncio
 
 
6
  from typing import Any, Optional, Dict, List, Union, Tuple
 
7
 
8
  from ..services import TrainingService, CaptioningService, SplittingService, ImportService
9
  from ..config import (
10
  STORAGE_PATH, VIDEOS_TO_SPLIT_PATH, STAGING_PATH,
11
+ TRAINING_PATH, LOG_FILE_PATH, TRAINING_PRESETS, TRAINING_VIDEOS_PATH, MODEL_PATH, OUTPUT_PATH,
12
+ MODEL_TYPES, SMALL_TRAINING_BUCKETS
13
  )
14
+ from ..utils import count_media_files, format_media_title, TrainingLogParser
15
  from ..tabs import ImportTab, SplitTab, CaptionTab, TrainTab, ManageTab
16
 
17
  logger = logging.getLogger(__name__)
 
28
  self.splitter = SplittingService()
29
  self.importer = ImportService()
30
  self.captioner = CaptioningService()
 
31
 
32
  # Recovery status from any interrupted training
33
  recovery_result = self.trainer.recover_interrupted_training()
34
  self.recovery_status = recovery_result.get("status", "unknown")
35
  self.ui_updates = recovery_result.get("ui_updates", {})
36
 
37
+ # Initialize log parser
38
  self.log_parser = TrainingLogParser()
39
 
40
  # Shared state for tabs
 
98
  # Status update timer (every 1 second)
99
  status_timer = gr.Timer(value=1)
100
  status_timer.tick(
101
+ fn=self.tabs["train_tab"].get_latest_status_message_logs_and_button_labels,
102
  outputs=[
103
  self.tabs["train_tab"].components["status_box"],
104
  self.tabs["train_tab"].components["log_box"],
 
129
  ]
130
  )
131
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  def initialize_app_state(self):
133
  """Initialize all app state in one function to ensure correct output count"""
134
  # Get dataset info
135
+ video_list = self.tabs["split_tab"].list_unprocessed_videos()
136
+ training_dataset = self.tabs["caption_tab"].list_training_files_to_caption()
137
 
138
  # Get button states
139
  button_states = self.get_initial_button_states()
 
206
  ui_state["save_iterations"] = int(ui_state.get("save_iterations", 500))
207
 
208
  return ui_state
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
 
210
  # Add this new method to get initial button states:
211
  def get_initial_button_states(self):
 
220
  gr.Button(**ui_updates.get("pause_resume_btn", {"interactive": False, "variant": "secondary"}))
221
  )
222
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
  def update_titles(self) -> Tuple[Any]:
224
  """Update all dynamic titles with current counts
225
 
 
249
  gr.Markdown(value=caption_title),
250
  gr.Markdown(value=f"{train_title} available for training")
251
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
 
253
  def refresh_dataset(self):
254
  """Refresh all dynamic lists and training state"""
255
+ video_list = self.tabs["split_tab"].list_unprocessed_videos()
256
+ training_dataset = self.tabs["caption_tab"].list_training_files_to_caption()
257
 
258
  return (
259
  video_list,
260
  training_dataset
261
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
vms/utils/image_preprocessing.py CHANGED
@@ -4,6 +4,7 @@ from pathlib import Path
4
  from PIL import Image
5
  import pillow_avif
6
  import logging
 
7
 
8
  from ..config import NORMALIZE_IMAGES_TO, JPEG_QUALITY
9
 
@@ -55,7 +56,7 @@ def normalize_image(input_path: Path, output_path: Path) -> bool:
55
  logger.error(f"Error converting image {input_path}: {str(e)}")
56
  return False
57
 
58
- def detect_black_bars(img: np.ndarray) -> tuple[int, int, int, int]:
59
  """Detect black bars in image
60
 
61
  Args:
 
4
  from PIL import Image
5
  import pillow_avif
6
  import logging
7
+ from typing import Any, Optional, Dict, List, Union, Tuple
8
 
9
  from ..config import NORMALIZE_IMAGES_TO, JPEG_QUALITY
10
 
 
56
  logger.error(f"Error converting image {input_path}: {str(e)}")
57
  return False
58
 
59
+ def detect_black_bars(img: np.ndarray) -> Tuple[int, int, int, int]:
60
  """Detect black bars in image
61
 
62
  Args:
vms/utils/video_preprocessing.py CHANGED
@@ -2,8 +2,10 @@ import cv2
2
  import numpy as np
3
  from pathlib import Path
4
  import subprocess
 
5
 
6
- def detect_black_bars(video_path: Path) -> tuple[int, int, int, int]:
 
7
  """Detect black bars in video by analyzing first few frames
8
 
9
  Args:
 
2
  import numpy as np
3
  from pathlib import Path
4
  import subprocess
5
+ from typing import Any, Optional, Dict, List, Union, Tuple
6
 
7
+
8
+ def detect_black_bars(video_path: Path) -> Tuple[int, int, int, int]:
9
  """Detect black bars in video by analyzing first few frames
10
 
11
  Args: