jbilcke-hf HF staff commited on
Commit
d2662cc
·
1 Parent(s): 4a3f789
vms/config.py CHANGED
@@ -61,7 +61,7 @@ JPEG_QUALITY = int(os.environ.get('JPEG_QUALITY', '97'))
61
  MODEL_TYPES = {
62
  "HunyuanVideo": "hunyuan_video",
63
  "LTX-Video": "ltx_video",
64
- "Wan-2.1-T2V": "wan"
65
  }
66
 
67
  # Training types
@@ -70,8 +70,8 @@ TRAINING_TYPES = {
70
  "Full Finetune": "full-finetune"
71
  }
72
 
73
- # Model variants for each model type
74
- MODEL_VARIANTS = {
75
  "wan": {
76
  "Wan-AI/Wan2.1-T2V-1.3B-Diffusers": {
77
  "name": "Wan 2.1 T2V 1.3B (text-only, smaller)",
@@ -342,7 +342,7 @@ class TrainingConfig:
342
 
343
  # Optional arguments follow
344
  revision: Optional[str] = None
345
- variant: Optional[str] = None
346
  cache_dir: Optional[str] = None
347
 
348
  # Dataset arguments
@@ -415,7 +415,7 @@ class TrainingConfig:
415
  train_steps=DEFAULT_NB_TRAINING_STEPS,
416
  lr=2e-5,
417
  gradient_checkpointing=True,
418
- id_token="afkx",
419
  gradient_accumulation_steps=1,
420
  lora_rank=DEFAULT_LORA_RANK,
421
  lora_alpha=DEFAULT_LORA_ALPHA,
@@ -437,7 +437,7 @@ class TrainingConfig:
437
  train_steps=DEFAULT_NB_TRAINING_STEPS,
438
  lr=DEFAULT_LEARNING_RATE,
439
  gradient_checkpointing=True,
440
- id_token="BW_STYLE",
441
  gradient_accumulation_steps=4,
442
  lora_rank=DEFAULT_LORA_RANK,
443
  lora_alpha=DEFAULT_LORA_ALPHA,
@@ -459,7 +459,7 @@ class TrainingConfig:
459
  train_steps=DEFAULT_NB_TRAINING_STEPS,
460
  lr=1e-5,
461
  gradient_checkpointing=True,
462
- id_token="BW_STYLE",
463
  gradient_accumulation_steps=1,
464
  video_resolution_buckets=buckets or SMALL_TRAINING_BUCKETS,
465
  caption_dropout_p=DEFAULT_CAPTION_DROPOUT_P,
@@ -479,7 +479,7 @@ class TrainingConfig:
479
  train_steps=DEFAULT_NB_TRAINING_STEPS,
480
  lr=5e-5,
481
  gradient_checkpointing=True,
482
- id_token=None, # Default is no ID token for Wan
483
  gradient_accumulation_steps=1,
484
  lora_rank=32,
485
  lora_alpha=32,
@@ -502,8 +502,8 @@ class TrainingConfig:
502
  args.extend(["--pretrained_model_name_or_path", self.pretrained_model_name_or_path])
503
  if self.revision:
504
  args.extend(["--revision", self.revision])
505
- if self.variant:
506
- args.extend(["--variant", self.variant])
507
  if self.cache_dir:
508
  args.extend(["--cache_dir", self.cache_dir])
509
 
 
61
  MODEL_TYPES = {
62
  "HunyuanVideo": "hunyuan_video",
63
  "LTX-Video": "ltx_video",
64
+ "Wan": "wan"
65
  }
66
 
67
  # Training types
 
70
  "Full Finetune": "full-finetune"
71
  }
72
 
73
+ # Model versions for each model type
74
+ MODEL_VERSIONS = {
75
  "wan": {
76
  "Wan-AI/Wan2.1-T2V-1.3B-Diffusers": {
77
  "name": "Wan 2.1 T2V 1.3B (text-only, smaller)",
 
342
 
343
  # Optional arguments follow
344
  revision: Optional[str] = None
345
+ version: Optional[str] = None
346
  cache_dir: Optional[str] = None
347
 
348
  # Dataset arguments
 
415
  train_steps=DEFAULT_NB_TRAINING_STEPS,
416
  lr=2e-5,
417
  gradient_checkpointing=True,
418
+ id_token=None,
419
  gradient_accumulation_steps=1,
420
  lora_rank=DEFAULT_LORA_RANK,
421
  lora_alpha=DEFAULT_LORA_ALPHA,
 
437
  train_steps=DEFAULT_NB_TRAINING_STEPS,
438
  lr=DEFAULT_LEARNING_RATE,
439
  gradient_checkpointing=True,
440
+ id_token=None,
441
  gradient_accumulation_steps=4,
442
  lora_rank=DEFAULT_LORA_RANK,
443
  lora_alpha=DEFAULT_LORA_ALPHA,
 
459
  train_steps=DEFAULT_NB_TRAINING_STEPS,
460
  lr=1e-5,
461
  gradient_checkpointing=True,
462
+ id_token=None,
463
  gradient_accumulation_steps=1,
464
  video_resolution_buckets=buckets or SMALL_TRAINING_BUCKETS,
465
  caption_dropout_p=DEFAULT_CAPTION_DROPOUT_P,
 
479
  train_steps=DEFAULT_NB_TRAINING_STEPS,
480
  lr=5e-5,
481
  gradient_checkpointing=True,
482
+ id_token=None,
483
  gradient_accumulation_steps=1,
484
  lora_rank=32,
485
  lora_alpha=32,
 
502
  args.extend(["--pretrained_model_name_or_path", self.pretrained_model_name_or_path])
503
  if self.revision:
504
  args.extend(["--revision", self.revision])
505
+ if self.version:
506
+ args.extend(["--variant", self.version])
507
  if self.cache_dir:
508
  args.extend(["--cache_dir", self.cache_dir])
509
 
vms/ui/app_ui.py CHANGED
@@ -8,7 +8,7 @@ from typing import Any, Optional, Dict, List, Union, Tuple
8
  from vms.config import (
9
  STORAGE_PATH, VIDEOS_TO_SPLIT_PATH, STAGING_PATH, OUTPUT_PATH,
10
  TRAINING_PATH, LOG_FILE_PATH, TRAINING_PRESETS, TRAINING_VIDEOS_PATH, MODEL_PATH, OUTPUT_PATH,
11
- MODEL_TYPES, SMALL_TRAINING_BUCKETS, TRAINING_TYPES,
12
  DEFAULT_NB_TRAINING_STEPS, DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
13
  DEFAULT_BATCH_SIZE, DEFAULT_CAPTION_DROPOUT_P,
14
  DEFAULT_LEARNING_RATE,
@@ -220,6 +220,7 @@ class AppUI:
220
  self.project_tabs["train_tab"].components["pause_resume_btn"],
221
  self.project_tabs["train_tab"].components["training_preset"],
222
  self.project_tabs["train_tab"].components["model_type"],
 
223
  self.project_tabs["train_tab"].components["training_type"],
224
  self.project_tabs["train_tab"].components["lora_rank"],
225
  self.project_tabs["train_tab"].components["lora_alpha"],
@@ -378,6 +379,20 @@ class AppUI:
378
  model_type_val = list(MODEL_TYPES.keys())[0]
379
  logger.warning(f"Invalid model type '{model_type_val}', using default: {model_type_val}")
380
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
381
  # Ensure training_type is a valid display name
382
  training_type_val = ui_state.get("training_type", list(TRAINING_TYPES.keys())[0])
383
  if training_type_val not in TRAINING_TYPES:
@@ -436,6 +451,7 @@ class AppUI:
436
  delete_checkpoints_btn,
437
  training_preset,
438
  model_type_val,
 
439
  training_type_val,
440
  lora_rank_val,
441
  lora_alpha_val,
@@ -453,10 +469,22 @@ class AppUI:
453
  """Initialize UI components from saved state"""
454
  ui_state = self.load_ui_values()
455
 
 
 
 
 
 
 
 
 
 
 
 
456
  # Return values in order matching the outputs in app.load
457
  return (
458
  ui_state.get("training_preset", list(TRAINING_PRESETS.keys())[0]),
459
- ui_state.get("model_type", list(MODEL_TYPES.keys())[0]),
 
460
  ui_state.get("training_type", list(TRAINING_TYPES.keys())[0]),
461
  ui_state.get("lora_rank", DEFAULT_LORA_RANK_STR),
462
  ui_state.get("lora_alpha", DEFAULT_LORA_ALPHA_STR),
 
8
  from vms.config import (
9
  STORAGE_PATH, VIDEOS_TO_SPLIT_PATH, STAGING_PATH, OUTPUT_PATH,
10
  TRAINING_PATH, LOG_FILE_PATH, TRAINING_PRESETS, TRAINING_VIDEOS_PATH, MODEL_PATH, OUTPUT_PATH,
11
+ MODEL_TYPES, SMALL_TRAINING_BUCKETS, TRAINING_TYPES, MODEL_VERSIONS,
12
  DEFAULT_NB_TRAINING_STEPS, DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
13
  DEFAULT_BATCH_SIZE, DEFAULT_CAPTION_DROPOUT_P,
14
  DEFAULT_LEARNING_RATE,
 
220
  self.project_tabs["train_tab"].components["pause_resume_btn"],
221
  self.project_tabs["train_tab"].components["training_preset"],
222
  self.project_tabs["train_tab"].components["model_type"],
223
+ self.project_tabs["train_tab"].components["model_version"],
224
  self.project_tabs["train_tab"].components["training_type"],
225
  self.project_tabs["train_tab"].components["lora_rank"],
226
  self.project_tabs["train_tab"].components["lora_alpha"],
 
379
  model_type_val = list(MODEL_TYPES.keys())[0]
380
  logger.warning(f"Invalid model type '{model_type_val}', using default: {model_type_val}")
381
 
382
+ # Get model_version value
383
+ model_version_val = ""
384
+ # First get the internal model type for the currently selected model
385
+ model_internal_type = MODEL_TYPES.get(model_type_val)
386
+ if model_internal_type and model_internal_type in MODEL_VERSIONS:
387
+ # If there's a saved model_version and it's valid for this model type
388
+ if "model_version" in ui_state and ui_state["model_version"] in MODEL_VERSIONS.get(model_internal_type, {}):
389
+ model_version_val = ui_state["model_version"]
390
+ else:
391
+ # Otherwise use the first available version
392
+ versions = list(MODEL_VERSIONS.get(model_internal_type, {}).keys())
393
+ if versions:
394
+ model_version_val = versions[0]
395
+
396
  # Ensure training_type is a valid display name
397
  training_type_val = ui_state.get("training_type", list(TRAINING_TYPES.keys())[0])
398
  if training_type_val not in TRAINING_TYPES:
 
451
  delete_checkpoints_btn,
452
  training_preset,
453
  model_type_val,
454
+ model_version_val,
455
  training_type_val,
456
  lora_rank_val,
457
  lora_alpha_val,
 
469
  """Initialize UI components from saved state"""
470
  ui_state = self.load_ui_values()
471
 
472
+ # Get model type and determine the default model version if not specified
473
+ model_type = ui_state.get("model_type", list(MODEL_TYPES.keys())[0])
474
+ model_internal_type = MODEL_TYPES.get(model_type)
475
+
476
+ # Get model_version, defaulting to first available version if not set
477
+ model_version = ui_state.get("model_version", "")
478
+ if not model_version and model_internal_type and model_internal_type in MODEL_VERSIONS:
479
+ versions = list(MODEL_VERSIONS.get(model_internal_type, {}).keys())
480
+ if versions:
481
+ model_version = versions[0]
482
+
483
  # Return values in order matching the outputs in app.load
484
  return (
485
  ui_state.get("training_preset", list(TRAINING_PRESETS.keys())[0]),
486
+ model_type,
487
+ model_version,
488
  ui_state.get("training_type", list(TRAINING_TYPES.keys())[0]),
489
  ui_state.get("lora_rank", DEFAULT_LORA_RANK_STR),
490
  ui_state.get("lora_alpha", DEFAULT_LORA_ALPHA_STR),
vms/ui/monitoring/services/monitoring.py CHANGED
@@ -22,6 +22,7 @@ import matplotlib.pyplot as plt
22
  import numpy as np
23
 
24
  logger = logging.getLogger(__name__)
 
25
 
26
  class MonitoringService:
27
  """Service for monitoring system resources and performance"""
 
22
  import numpy as np
23
 
24
  logger = logging.getLogger(__name__)
25
+ logger.setLevel(logging.INFO)
26
 
27
  class MonitoringService:
28
  """Service for monitoring system resources and performance"""
vms/ui/monitoring/tabs/general_tab.py CHANGED
@@ -17,7 +17,7 @@ from vms.config import STORAGE_PATH
17
  from vms.ui.monitoring.utils import get_folder_size, human_readable_size
18
 
19
  logger = logging.getLogger(__name__)
20
-
21
 
22
  class GeneralTab(BaseTab):
23
  """Monitor tab for general system resource monitoring"""
 
17
  from vms.ui.monitoring.utils import get_folder_size, human_readable_size
18
 
19
  logger = logging.getLogger(__name__)
20
+ logger.setLevel(logging.INFO)
21
 
22
  class GeneralTab(BaseTab):
23
  """Monitor tab for general system resource monitoring"""
vms/ui/project/services/captioning.py CHANGED
@@ -21,6 +21,7 @@ from vms.config import TRAINING_VIDEOS_PATH, STAGING_PATH, PRELOAD_CAPTIONING_MO
21
  from vms.utils import extract_scene_info, is_image_file, is_video_file, copy_files_to_training_dir, prepare_finetrainers_dataset
22
 
23
  logger = logging.getLogger(__name__)
 
24
 
25
  @dataclass
26
  class CaptioningProgress:
 
21
  from vms.utils import extract_scene_info, is_image_file, is_video_file, copy_files_to_training_dir, prepare_finetrainers_dataset
22
 
23
  logger = logging.getLogger(__name__)
24
+ logger.setLevel(logging.INFO)
25
 
26
  @dataclass
27
  class CaptioningProgress:
vms/ui/project/services/previewing.py CHANGED
@@ -6,17 +6,20 @@ Handles the video generation logic and model integration
6
 
7
  import logging
8
  import tempfile
 
 
9
  from pathlib import Path
10
  from typing import Dict, Any, List, Optional, Tuple, Callable
11
  import time
12
 
13
  from vms.config import (
14
  OUTPUT_PATH, STORAGE_PATH, MODEL_TYPES, TRAINING_PATH,
15
- DEFAULT_PROMPT_PREFIX, MODEL_VARIANTS
16
  )
17
  from vms.utils import format_time
18
 
19
  logger = logging.getLogger(__name__)
 
20
 
21
  class PreviewingService:
22
  """Handles the video generation logic and model integration"""
@@ -48,14 +51,14 @@ class PreviewingService:
48
  logger.error(f"Error finding LoRA weights: {e}")
49
  return None
50
 
51
- def get_model_variants(self, model_type: str) -> Dict[str, Dict[str, str]]:
52
- """Get available model variants for the given model type"""
53
- return MODEL_VARIANTS.get(model_type, {})
54
 
55
  def generate_video(
56
  self,
57
  model_type: str,
58
- model_variant: str,
59
  prompt: str,
60
  negative_prompt: str,
61
  prompt_prefix: str,
@@ -66,13 +69,15 @@ class PreviewingService:
66
  flow_shift: float,
67
  lora_weight: float,
68
  inference_steps: int,
69
- enable_cpu_offload: bool,
70
- fps: int,
 
71
  conditioning_image: Optional[str] = None
72
  ) -> Tuple[Optional[str], str, str]:
73
  """Generate a video using the trained model"""
74
  try:
75
  log_messages = []
 
76
 
77
  def log(msg: str):
78
  log_messages.append(msg)
@@ -102,33 +107,46 @@ class PreviewingService:
102
  if not internal_model_type:
103
  return None, f"Error: Invalid model type {model_type}", log(f"Error: Invalid model type {model_type}")
104
 
105
- # Check if model variant is valid for this model type
106
- variants = self.get_model_variants(internal_model_type)
107
- if model_variant not in variants:
108
- # Use default variant if specified one is invalid
109
- if len(variants) > 0:
110
- model_variant = next(iter(variants.keys()))
111
- log(f"Warning: Invalid model variant, using default: {model_variant}")
 
 
 
 
112
  else:
113
- # Fall back to default IDs if no variants defined
 
 
 
 
 
 
 
 
114
  if internal_model_type == "wan":
115
- model_variant = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
116
  elif internal_model_type == "ltx_video":
117
- model_variant = "Lightricks/LTX-Video"
118
  elif internal_model_type == "hunyuan_video":
119
- model_variant = "hunyuanvideo-community/HunyuanVideo"
120
- log(f"Warning: No variants defined for model type, using default: {model_variant}")
121
 
122
  # Check if this is an image-to-video model but no image was provided
123
- variant_info = variants.get(model_variant, {})
124
- if variant_info.get("type") == "image-to-video" and not conditioning_image:
125
- return None, "Error: This model requires a conditioning image", log("Error: This model variant requires a conditioning image but none was provided")
126
 
127
  log(f"Generating video with model type: {internal_model_type}")
128
- log(f"Using model variant: {model_variant}")
129
  log(f"Using LoRA weights from: {lora_path}")
130
  log(f"Resolution: {width}x{height}, Frames: {num_frames}, FPS: {fps}")
131
  log(f"Guidance Scale: {guidance_scale}, Flow Shift: {flow_shift}, LoRA Weight: {lora_weight}")
 
132
  log(f"Prompt: {full_prompt}")
133
  log(f"Negative Prompt: {negative_prompt}")
134
 
@@ -137,22 +155,22 @@ class PreviewingService:
137
  return self.generate_wan_video(
138
  full_prompt, negative_prompt, width, height, num_frames,
139
  guidance_scale, flow_shift, lora_path, lora_weight,
140
- inference_steps, enable_cpu_offload, fps, log,
141
- model_variant, conditioning_image
142
  )
143
  elif internal_model_type == "ltx_video":
144
  return self.generate_ltx_video(
145
  full_prompt, negative_prompt, width, height, num_frames,
146
  guidance_scale, flow_shift, lora_path, lora_weight,
147
- inference_steps, enable_cpu_offload, fps, log,
148
- model_variant, conditioning_image
149
  )
150
  elif internal_model_type == "hunyuan_video":
151
  return self.generate_hunyuan_video(
152
  full_prompt, negative_prompt, width, height, num_frames,
153
  guidance_scale, flow_shift, lora_path, lora_weight,
154
- inference_steps, enable_cpu_offload, fps, log,
155
- model_variant, conditioning_image
156
  )
157
  else:
158
  return None, f"Error: Unsupported model type {internal_model_type}", log(f"Error: Unsupported model type {internal_model_type}")
@@ -173,16 +191,18 @@ class PreviewingService:
173
  lora_path: str,
174
  lora_weight: float,
175
  inference_steps: int,
176
- enable_cpu_offload: bool,
177
- fps: int,
178
- log_fn: Callable,
179
- model_variant: str = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
 
180
  conditioning_image: Optional[str] = None
181
  ) -> Tuple[Optional[str], str, str]:
182
  """Generate video using Wan model"""
183
 
184
  try:
185
  import torch
 
186
  from diffusers import AutoencoderKLWan, WanPipeline
187
  from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
188
  from diffusers.utils import export_to_video
@@ -192,14 +212,26 @@ class PreviewingService:
192
  start_time = torch.cuda.Event(enable_timing=True)
193
  end_time = torch.cuda.Event(enable_timing=True)
194
 
195
-
196
  log_fn("Importing Wan model components...")
197
 
198
- log_fn(f"Loading VAE from {model_variant}...")
199
- vae = AutoencoderKLWan.from_pretrained(model_variant, subfolder="vae", torch_dtype=torch.float32)
 
 
 
 
 
 
 
 
 
200
 
201
- log_fn(f"Loading transformer from {model_variant}...")
202
- pipe = WanPipeline.from_pretrained(model_variant, vae=vae, torch_dtype=torch.bfloat16)
 
 
 
203
 
204
  log_fn(f"Configuring scheduler with flow_shift={flow_shift}...")
205
  pipe.scheduler = UniPCMultistepScheduler.from_config(
@@ -213,11 +245,13 @@ class PreviewingService:
213
  if enable_cpu_offload:
214
  log_fn("Enabling model CPU offload...")
215
  pipe.enable_model_cpu_offload()
216
-
217
  log_fn(f"Loading LoRA weights from {lora_path} with weight {lora_weight}...")
218
  pipe.load_lora_weights(lora_path)
219
- pipe.fuse_lora(lora_weight)
220
 
 
 
 
221
  # Create temporary file for the output
222
  with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as temp_file:
223
  output_path = temp_file.name
@@ -226,7 +260,7 @@ class PreviewingService:
226
  start_time.record()
227
 
228
  # Check if this is an image-to-video model
229
- is_i2v = "I2V" in model_variant
230
 
231
  if is_i2v and conditioning_image:
232
  log_fn(f"Loading conditioning image from {conditioning_image}...")
@@ -243,6 +277,7 @@ class PreviewingService:
243
  num_frames=num_frames,
244
  guidance_scale=guidance_scale,
245
  num_inference_steps=inference_steps,
 
246
  ).frames[0]
247
  else:
248
  log_fn("Generating video with text-only conditioning...")
@@ -254,6 +289,7 @@ class PreviewingService:
254
  num_frames=num_frames,
255
  guidance_scale=guidance_scale,
256
  num_inference_steps=inference_steps,
 
257
  ).frames[0]
258
 
259
  end_time.record()
@@ -274,11 +310,12 @@ class PreviewingService:
274
  return output_path, "Video generated successfully!", log_fn(f"Generation completed in {format_time(generation_time)}")
275
 
276
  except Exception as e:
 
277
  log_fn(f"Error generating video with Wan: {str(e)}")
278
  # Clean up CUDA memory
279
  torch.cuda.empty_cache()
280
  return None, f"Error: {str(e)}", log_fn(f"Exception occurred: {str(e)}")
281
-
282
  def generate_ltx_video(
283
  self,
284
  prompt: str,
@@ -291,27 +328,41 @@ class PreviewingService:
291
  lora_path: str,
292
  lora_weight: float,
293
  inference_steps: int,
294
- enable_cpu_offload: bool,
295
- fps: int,
296
- log_fn: Callable,
297
- model_variant: str = "Lightricks/LTX-Video",
 
298
  conditioning_image: Optional[str] = None
299
  ) -> Tuple[Optional[str], str, str]:
300
  """Generate video using LTX model"""
301
 
302
  try:
303
  import torch
 
304
  from diffusers import LTXPipeline
305
  from diffusers.utils import export_to_video
306
  from PIL import Image
307
 
308
  start_time = torch.cuda.Event(enable_timing=True)
309
  end_time = torch.cuda.Event(enable_timing=True)
 
 
 
 
 
 
 
 
 
 
 
 
310
 
311
  log_fn("Importing LTX model components...")
312
 
313
- log_fn(f"Loading pipeline from {model_variant}...")
314
- pipe = LTXPipeline.from_pretrained(model_variant, torch_dtype=torch.bfloat16)
315
 
316
  log_fn("Moving pipeline to CUDA device...")
317
  pipe.to("cuda")
@@ -342,6 +393,7 @@ class PreviewingService:
342
  decode_timestep=0.03,
343
  decode_noise_scale=0.025,
344
  num_inference_steps=inference_steps,
 
345
  ).frames[0]
346
 
347
  end_time.record()
@@ -379,10 +431,11 @@ class PreviewingService:
379
  lora_path: str,
380
  lora_weight: float,
381
  inference_steps: int,
382
- enable_cpu_offload: bool,
383
- fps: int,
384
- log_fn: Callable,
385
- model_variant: str = "hunyuanvideo-community/HunyuanVideo",
 
386
  conditioning_image: Optional[str] = None
387
  ) -> Tuple[Optional[str], str, str]:
388
  """Generate video using HunyuanVideo model"""
@@ -390,24 +443,37 @@ class PreviewingService:
390
 
391
  try:
392
  import torch
 
393
  from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel, AutoencoderKLHunyuanVideo
394
  from diffusers.utils import export_to_video
395
 
396
  start_time = torch.cuda.Event(enable_timing=True)
397
  end_time = torch.cuda.Event(enable_timing=True)
 
 
 
 
 
 
 
 
 
 
 
 
398
 
399
  log_fn("Importing HunyuanVideo model components...")
400
 
401
- log_fn(f"Loading transformer from {model_variant}...")
402
  transformer = HunyuanVideoTransformer3DModel.from_pretrained(
403
- model_variant,
404
  subfolder="transformer",
405
  torch_dtype=torch.bfloat16
406
  )
407
 
408
- log_fn(f"Loading pipeline from {model_variant}...")
409
  pipe = HunyuanVideoPipeline.from_pretrained(
410
- model_variant,
411
  transformer=transformer,
412
  torch_dtype=torch.float16
413
  )
@@ -446,6 +512,7 @@ class PreviewingService:
446
  guidance_scale=guidance_scale,
447
  true_cfg_scale=1.0,
448
  num_inference_steps=inference_steps,
 
449
  ).frames[0]
450
 
451
  end_time.record()
 
6
 
7
  import logging
8
  import tempfile
9
+ import traceback
10
+ import random
11
  from pathlib import Path
12
  from typing import Dict, Any, List, Optional, Tuple, Callable
13
  import time
14
 
15
  from vms.config import (
16
  OUTPUT_PATH, STORAGE_PATH, MODEL_TYPES, TRAINING_PATH,
17
+ DEFAULT_PROMPT_PREFIX, MODEL_VERSIONS
18
  )
19
  from vms.utils import format_time
20
 
21
  logger = logging.getLogger(__name__)
22
+ logger.setLevel(logging.INFO)
23
 
24
  class PreviewingService:
25
  """Handles the video generation logic and model integration"""
 
51
  logger.error(f"Error finding LoRA weights: {e}")
52
  return None
53
 
54
+ def get_model_versions(self, model_type: str) -> Dict[str, Dict[str, str]]:
55
+ """Get available model versions for the given model type"""
56
+ return MODEL_VERSIONS.get(model_type, {})
57
 
58
  def generate_video(
59
  self,
60
  model_type: str,
61
+ model_version: str,
62
  prompt: str,
63
  negative_prompt: str,
64
  prompt_prefix: str,
 
69
  flow_shift: float,
70
  lora_weight: float,
71
  inference_steps: int,
72
+ seed: int = -1,
73
+ enable_cpu_offload: bool = True,
74
+ fps: int = 16,
75
  conditioning_image: Optional[str] = None
76
  ) -> Tuple[Optional[str], str, str]:
77
  """Generate a video using the trained model"""
78
  try:
79
  log_messages = []
80
+ print("generate_video")
81
 
82
  def log(msg: str):
83
  log_messages.append(msg)
 
107
  if not internal_model_type:
108
  return None, f"Error: Invalid model type {model_type}", log(f"Error: Invalid model type {model_type}")
109
 
110
+ # Check if model version is valid
111
+ # This section uses model_version directly from parameter
112
+ if model_version:
113
+ # Verify that the specified model_version exists in our versions
114
+ versions = self.get_model_versions(internal_model_type)
115
+ if model_version not in versions:
116
+ log(f"Warning: Specified model version '{model_version}' is not recognized")
117
+ # Fall back to default version for this model
118
+ if len(versions) > 0:
119
+ model_version = next(iter(versions.keys()))
120
+ log(f"Using default model version instead: {model_version}")
121
  else:
122
+ log(f"Using specified model version: {model_version}")
123
+ else:
124
+ # No model version specified, use default
125
+ versions = self.get_model_versions(internal_model_type)
126
+ if len(versions) > 0:
127
+ model_version = next(iter(versions.keys()))
128
+ log(f"No model version specified, using default: {model_version}")
129
+ else:
130
+ # Fall back to hardcoded defaults if no versions defined
131
  if internal_model_type == "wan":
132
+ model_version = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
133
  elif internal_model_type == "ltx_video":
134
+ model_version = "Lightricks/LTX-Video"
135
  elif internal_model_type == "hunyuan_video":
136
+ model_version = "hunyuanvideo-community/HunyuanVideo"
137
+ log(f"No versions defined for model type, using default: {model_version}")
138
 
139
  # Check if this is an image-to-video model but no image was provided
140
+ model_version_info = versions.get(model_version, {})
141
+ if model_version_info.get("type") == "image-to-video" and not conditioning_image:
142
+ return None, "Error: This model requires a conditioning image", log("Error: This model version requires a conditioning image but none was provided")
143
 
144
  log(f"Generating video with model type: {internal_model_type}")
145
+ log(f"Using model version: {model_version}")
146
  log(f"Using LoRA weights from: {lora_path}")
147
  log(f"Resolution: {width}x{height}, Frames: {num_frames}, FPS: {fps}")
148
  log(f"Guidance Scale: {guidance_scale}, Flow Shift: {flow_shift}, LoRA Weight: {lora_weight}")
149
+ log(f"Generation Seed: {seed}")
150
  log(f"Prompt: {full_prompt}")
151
  log(f"Negative Prompt: {negative_prompt}")
152
 
 
155
  return self.generate_wan_video(
156
  full_prompt, negative_prompt, width, height, num_frames,
157
  guidance_scale, flow_shift, lora_path, lora_weight,
158
+ inference_steps, seed, enable_cpu_offload, fps, log,
159
+ model_version, conditioning_image
160
  )
161
  elif internal_model_type == "ltx_video":
162
  return self.generate_ltx_video(
163
  full_prompt, negative_prompt, width, height, num_frames,
164
  guidance_scale, flow_shift, lora_path, lora_weight,
165
+ inference_steps, seed, enable_cpu_offload, fps, log,
166
+ model_version, conditioning_image
167
  )
168
  elif internal_model_type == "hunyuan_video":
169
  return self.generate_hunyuan_video(
170
  full_prompt, negative_prompt, width, height, num_frames,
171
  guidance_scale, flow_shift, lora_path, lora_weight,
172
+ inference_steps, seed, enable_cpu_offload, fps, log,
173
+ model_version, conditioning_image
174
  )
175
  else:
176
  return None, f"Error: Unsupported model type {internal_model_type}", log(f"Error: Unsupported model type {internal_model_type}")
 
191
  lora_path: str,
192
  lora_weight: float,
193
  inference_steps: int,
194
+ seed: int = -1,
195
+ enable_cpu_offload: bool = True,
196
+ fps: int = 16,
197
+ log_fn: Callable = print,
198
+ model_version: str = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
199
  conditioning_image: Optional[str] = None
200
  ) -> Tuple[Optional[str], str, str]:
201
  """Generate video using Wan model"""
202
 
203
  try:
204
  import torch
205
+ import numpy as np
206
  from diffusers import AutoencoderKLWan, WanPipeline
207
  from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
208
  from diffusers.utils import export_to_video
 
212
  start_time = torch.cuda.Event(enable_timing=True)
213
  end_time = torch.cuda.Event(enable_timing=True)
214
 
215
+ print("Initializing wan generation..")
216
  log_fn("Importing Wan model components...")
217
 
218
+ # Set up random seed
219
+ if seed == -1:
220
+ seed = random.randint(0, 2**32 - 1)
221
+ log_fn(f"Using randomly generated seed: {seed}")
222
+
223
+ # Set random seeds for reproducibility
224
+ random.seed(seed)
225
+ np.random.seed(seed)
226
+ torch.manual_seed(seed)
227
+ generator = torch.Generator(device="cuda")
228
+ generator = generator.manual_seed(seed)
229
 
230
+ log_fn(f"Loading VAE from {model_version}...")
231
+ vae = AutoencoderKLWan.from_pretrained(model_version, subfolder="vae", torch_dtype=torch.float32)
232
+
233
+ log_fn(f"Loading transformer from {model_version}...")
234
+ pipe = WanPipeline.from_pretrained(model_version, vae=vae, torch_dtype=torch.bfloat16)
235
 
236
  log_fn(f"Configuring scheduler with flow_shift={flow_shift}...")
237
  pipe.scheduler = UniPCMultistepScheduler.from_config(
 
245
  if enable_cpu_offload:
246
  log_fn("Enabling model CPU offload...")
247
  pipe.enable_model_cpu_offload()
248
+
249
  log_fn(f"Loading LoRA weights from {lora_path} with weight {lora_weight}...")
250
  pipe.load_lora_weights(lora_path)
 
251
 
252
+ # TODO: Set the lora scale directly instead of using fuse_lora
253
+ #pipe._lora_scale = lora_weight
254
+
255
  # Create temporary file for the output
256
  with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as temp_file:
257
  output_path = temp_file.name
 
260
  start_time.record()
261
 
262
  # Check if this is an image-to-video model
263
+ is_i2v = "I2V" in model_version
264
 
265
  if is_i2v and conditioning_image:
266
  log_fn(f"Loading conditioning image from {conditioning_image}...")
 
277
  num_frames=num_frames,
278
  guidance_scale=guidance_scale,
279
  num_inference_steps=inference_steps,
280
+ generator=generator,
281
  ).frames[0]
282
  else:
283
  log_fn("Generating video with text-only conditioning...")
 
289
  num_frames=num_frames,
290
  guidance_scale=guidance_scale,
291
  num_inference_steps=inference_steps,
292
+ generator=generator,
293
  ).frames[0]
294
 
295
  end_time.record()
 
310
  return output_path, "Video generated successfully!", log_fn(f"Generation completed in {format_time(generation_time)}")
311
 
312
  except Exception as e:
313
+ traceback.print_exc()
314
  log_fn(f"Error generating video with Wan: {str(e)}")
315
  # Clean up CUDA memory
316
  torch.cuda.empty_cache()
317
  return None, f"Error: {str(e)}", log_fn(f"Exception occurred: {str(e)}")
318
+
319
  def generate_ltx_video(
320
  self,
321
  prompt: str,
 
328
  lora_path: str,
329
  lora_weight: float,
330
  inference_steps: int,
331
+ seed: int = -1,
332
+ enable_cpu_offload: bool = True,
333
+ fps: int = 16,
334
+ log_fn: Callable = print,
335
+ model_version: str = "Lightricks/LTX-Video",
336
  conditioning_image: Optional[str] = None
337
  ) -> Tuple[Optional[str], str, str]:
338
  """Generate video using LTX model"""
339
 
340
  try:
341
  import torch
342
+ import numpy as np
343
  from diffusers import LTXPipeline
344
  from diffusers.utils import export_to_video
345
  from PIL import Image
346
 
347
  start_time = torch.cuda.Event(enable_timing=True)
348
  end_time = torch.cuda.Event(enable_timing=True)
349
+
350
+ # Set up random seed
351
+ if seed == -1:
352
+ seed = random.randint(0, 2**32 - 1)
353
+ log_fn(f"Using randomly generated seed: {seed}")
354
+
355
+ # Set random seeds for reproducibility
356
+ random.seed(seed)
357
+ np.random.seed(seed)
358
+ torch.manual_seed(seed)
359
+ generator = torch.Generator(device="cuda")
360
+ generator = generator.manual_seed(seed)
361
 
362
  log_fn("Importing LTX model components...")
363
 
364
+ log_fn(f"Loading pipeline from {model_version}...")
365
+ pipe = LTXPipeline.from_pretrained(model_version, torch_dtype=torch.bfloat16)
366
 
367
  log_fn("Moving pipeline to CUDA device...")
368
  pipe.to("cuda")
 
393
  decode_timestep=0.03,
394
  decode_noise_scale=0.025,
395
  num_inference_steps=inference_steps,
396
+ generator=generator,
397
  ).frames[0]
398
 
399
  end_time.record()
 
431
  lora_path: str,
432
  lora_weight: float,
433
  inference_steps: int,
434
+ seed: int = -1,
435
+ enable_cpu_offload: bool = True,
436
+ fps: int = 16,
437
+ log_fn: Callable = print,
438
+ model_version: str = "hunyuanvideo-community/HunyuanVideo",
439
  conditioning_image: Optional[str] = None
440
  ) -> Tuple[Optional[str], str, str]:
441
  """Generate video using HunyuanVideo model"""
 
443
 
444
  try:
445
  import torch
446
+ import numpy as np
447
  from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel, AutoencoderKLHunyuanVideo
448
  from diffusers.utils import export_to_video
449
 
450
  start_time = torch.cuda.Event(enable_timing=True)
451
  end_time = torch.cuda.Event(enable_timing=True)
452
+
453
+ # Set up random seed
454
+ if seed == -1:
455
+ seed = random.randint(0, 2**32 - 1)
456
+ log_fn(f"Using randomly generated seed: {seed}")
457
+
458
+ # Set random seeds for reproducibility
459
+ random.seed(seed)
460
+ np.random.seed(seed)
461
+ torch.manual_seed(seed)
462
+ generator = torch.Generator(device="cuda")
463
+ generator = generator.manual_seed(seed)
464
 
465
  log_fn("Importing HunyuanVideo model components...")
466
 
467
+ log_fn(f"Loading transformer from {model_version}...")
468
  transformer = HunyuanVideoTransformer3DModel.from_pretrained(
469
+ model_version,
470
  subfolder="transformer",
471
  torch_dtype=torch.bfloat16
472
  )
473
 
474
+ log_fn(f"Loading pipeline from {model_version}...")
475
  pipe = HunyuanVideoPipeline.from_pretrained(
476
+ model_version,
477
  transformer=transformer,
478
  torch_dtype=torch.float16
479
  )
 
512
  guidance_scale=guidance_scale,
513
  true_cfg_scale=1.0,
514
  num_inference_steps=inference_steps,
515
+ generator=generator,
516
  ).frames[0]
517
 
518
  end_time.record()
vms/ui/project/services/splitting.py CHANGED
@@ -16,6 +16,7 @@ from vms.config import TRAINING_PATH, STORAGE_PATH, TRAINING_VIDEOS_PATH, VIDEOS
16
  from vms.utils import remove_black_bars, extract_scene_info, is_video_file, is_image_file, add_prefix_to_caption
17
 
18
  logger = logging.getLogger(__name__)
 
19
 
20
  class SplittingService:
21
  def __init__(self):
 
16
  from vms.utils import remove_black_bars, extract_scene_info, is_video_file, is_image_file, add_prefix_to_caption
17
 
18
  logger = logging.getLogger(__name__)
19
+ logger.setLevel(logging.INFO)
20
 
21
  class SplittingService:
22
  def __init__(self):
vms/ui/project/services/training.py CHANGED
@@ -23,7 +23,7 @@ from huggingface_hub import upload_folder, create_repo
23
  from vms.config import (
24
  TrainingConfig, TRAINING_PRESETS, LOG_FILE_PATH, TRAINING_VIDEOS_PATH,
25
  STORAGE_PATH, TRAINING_PATH, MODEL_PATH, OUTPUT_PATH, HF_API_TOKEN,
26
- MODEL_TYPES, TRAINING_TYPES,
27
  DEFAULT_NB_TRAINING_STEPS, DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
28
  DEFAULT_BATCH_SIZE, DEFAULT_CAPTION_DROPOUT_P,
29
  DEFAULT_LEARNING_RATE,
@@ -50,6 +50,7 @@ from vms.utils import (
50
  )
51
 
52
  logger = logging.getLogger(__name__)
 
53
 
54
  class TrainingService:
55
  def __init__(self, app=None):
@@ -134,6 +135,7 @@ class TrainingService:
134
  validated_values = {}
135
  default_state = {
136
  "model_type": list(MODEL_TYPES.keys())[0],
 
137
  "training_type": list(TRAINING_TYPES.keys())[0],
138
  "lora_rank": DEFAULT_LORA_RANK_STR,
139
  "lora_alpha": DEFAULT_LORA_ALPHA_STR,
@@ -213,6 +215,7 @@ class TrainingService:
213
  ui_state_file = OUTPUT_PATH / "ui_state.json"
214
  default_state = {
215
  "model_type": list(MODEL_TYPES.keys())[0],
 
216
  "training_type": list(TRAINING_TYPES.keys())[0],
217
  "lora_rank": DEFAULT_LORA_RANK_STR,
218
  "lora_alpha": DEFAULT_LORA_ALPHA_STR,
@@ -255,7 +258,7 @@ class TrainingService:
255
  if "model_type" in saved_state and " (LoRA)" in saved_state["model_type"]:
256
  saved_state["model_type"] = saved_state["model_type"].replace(" (LoRA)", "")
257
  logger.info(f"Removed (LoRA) suffix from saved model type: {saved_state['model_type']}")
258
-
259
  # Convert numeric values to appropriate types
260
  if "train_steps" in saved_state:
261
  try:
@@ -302,6 +305,18 @@ class TrainingService:
302
  if not model_found:
303
  merged_state["model_type"] = default_state["model_type"]
304
  logger.warning(f"Invalid model type in saved state, using default")
 
 
 
 
 
 
 
 
 
 
 
 
305
 
306
  # Validate training_type is in available choices
307
  if merged_state["training_type"] not in TRAINING_TYPES:
@@ -545,6 +560,7 @@ class TrainingService:
545
  repo_id: str,
546
  preset_name: str,
547
  training_type: str = DEFAULT_TRAINING_TYPE,
 
548
  resume_from_checkpoint: Optional[str] = None,
549
  num_gpus: int = DEFAULT_NUM_GPUS,
550
  precomputation_items: int = DEFAULT_PRECOMPUTATION_ITEMS,
@@ -869,6 +885,7 @@ class TrainingService:
869
  # Save session info including repo_id for later hub upload
870
  self.save_session({
871
  "model_type": model_type,
 
872
  "training_type": training_type,
873
  "lora_rank": lora_rank,
874
  "lora_alpha": lora_alpha,
@@ -1039,6 +1056,7 @@ class TrainingService:
1039
  last_session = {
1040
  "params": {
1041
  "model_type": MODEL_TYPES.get(ui_state.get("model_type", list(MODEL_TYPES.keys())[0])),
 
1042
  "training_type": TRAINING_TYPES.get(ui_state.get("training_type", list(TRAINING_TYPES.keys())[0])),
1043
  "lora_rank": ui_state.get("lora_rank", DEFAULT_LORA_RANK_STR),
1044
  "lora_alpha": ui_state.get("lora_alpha", DEFAULT_LORA_ALPHA_STR),
@@ -1102,8 +1120,9 @@ class TrainingService:
1102
  # Add UI updates to restore the training parameters in the UI
1103
  # This shows the user what values are being used for the resumed training
1104
  ui_updates.update({
1105
- "model_type": model_type_display, # Use the display name for the UI dropdown
1106
- "training_type": training_type_display, # Use the display name for training type
 
1107
  "lora_rank": params.get('lora_rank', DEFAULT_LORA_RANK_STR),
1108
  "lora_alpha": params.get('lora_alpha', DEFAULT_LORA_ALPHA_STR),
1109
  "train_steps": params.get('train_steps', DEFAULT_NB_TRAINING_STEPS),
@@ -1122,19 +1141,19 @@ class TrainingService:
1122
  # Use the internal model_type for the actual training
1123
  # But keep model_type_display for the UI
1124
  result = self.start_training(
1125
- model_type=model_type_internal,
1126
  lora_rank=params.get('lora_rank', DEFAULT_LORA_RANK_STR),
1127
  lora_alpha=params.get('lora_alpha', DEFAULT_LORA_ALPHA_STR),
1128
  train_size=params.get('train_steps', DEFAULT_NB_TRAINING_STEPS),
1129
  batch_size=params.get('batch_size', DEFAULT_BATCH_SIZE),
1130
  learning_rate=params.get('learning_rate', DEFAULT_LEARNING_RATE),
1131
  save_iterations=params.get('save_iterations', DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS),
 
1132
  repo_id=params.get('repo_id', ''),
1133
  preset_name=params.get('preset_name', list(TRAINING_PRESETS.keys())[0]),
1134
  training_type=training_type_internal,
1135
  resume_from_checkpoint=str(latest_checkpoint)
1136
  )
1137
-
1138
  # Set buttons for active training
1139
  ui_updates.update({
1140
  "start_btn": {"interactive": False, "variant": "secondary", "value": "Continue Training"},
@@ -1142,7 +1161,7 @@ class TrainingService:
1142
  "delete_checkpoints_btn": {"interactive": False, "variant": "stop", "value": "Delete All Checkpoints"},
1143
  "pause_resume_btn": {"interactive": False, "variant": "secondary", "visible": False}
1144
  })
1145
-
1146
  return {
1147
  "status": "recovered",
1148
  "message": f"Training resumed from checkpoint {checkpoint_step}",
 
23
  from vms.config import (
24
  TrainingConfig, TRAINING_PRESETS, LOG_FILE_PATH, TRAINING_VIDEOS_PATH,
25
  STORAGE_PATH, TRAINING_PATH, MODEL_PATH, OUTPUT_PATH, HF_API_TOKEN,
26
+ MODEL_TYPES, TRAINING_TYPES, MODEL_VERSIONS,
27
  DEFAULT_NB_TRAINING_STEPS, DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
28
  DEFAULT_BATCH_SIZE, DEFAULT_CAPTION_DROPOUT_P,
29
  DEFAULT_LEARNING_RATE,
 
50
  )
51
 
52
  logger = logging.getLogger(__name__)
53
+ logger.setLevel(logging.INFO)
54
 
55
  class TrainingService:
56
  def __init__(self, app=None):
 
135
  validated_values = {}
136
  default_state = {
137
  "model_type": list(MODEL_TYPES.keys())[0],
138
+ "model_version": "",
139
  "training_type": list(TRAINING_TYPES.keys())[0],
140
  "lora_rank": DEFAULT_LORA_RANK_STR,
141
  "lora_alpha": DEFAULT_LORA_ALPHA_STR,
 
215
  ui_state_file = OUTPUT_PATH / "ui_state.json"
216
  default_state = {
217
  "model_type": list(MODEL_TYPES.keys())[0],
218
+ "model_version": "",
219
  "training_type": list(TRAINING_TYPES.keys())[0],
220
  "lora_rank": DEFAULT_LORA_RANK_STR,
221
  "lora_alpha": DEFAULT_LORA_ALPHA_STR,
 
258
  if "model_type" in saved_state and " (LoRA)" in saved_state["model_type"]:
259
  saved_state["model_type"] = saved_state["model_type"].replace(" (LoRA)", "")
260
  logger.info(f"Removed (LoRA) suffix from saved model type: {saved_state['model_type']}")
261
+
262
  # Convert numeric values to appropriate types
263
  if "train_steps" in saved_state:
264
  try:
 
305
  if not model_found:
306
  merged_state["model_type"] = default_state["model_type"]
307
  logger.warning(f"Invalid model type in saved state, using default")
308
+
309
+ # Validate model_version is appropriate for model_type
310
+ if "model_type" in merged_state and "model_version" in merged_state:
311
+ model_internal_type = MODEL_TYPES.get(merged_state["model_type"])
312
+ if model_internal_type:
313
+ valid_versions = MODEL_VERSIONS.get(model_internal_type, {}).keys()
314
+ if merged_state["model_version"] not in valid_versions:
315
+ # Set to default for this model type
316
+ from vms.ui.project.tabs.train_tab import TrainTab
317
+ train_tab = TrainTab(None) # Temporary instance just for the helper method
318
+ merged_state["model_version"] = train_tab.get_default_model_version(saved_state["model_type"])
319
+ logger.warning(f"Invalid model version for {merged_state['model_type']}, using default")
320
 
321
  # Validate training_type is in available choices
322
  if merged_state["training_type"] not in TRAINING_TYPES:
 
560
  repo_id: str,
561
  preset_name: str,
562
  training_type: str = DEFAULT_TRAINING_TYPE,
563
+ model_version: str = "",
564
  resume_from_checkpoint: Optional[str] = None,
565
  num_gpus: int = DEFAULT_NUM_GPUS,
566
  precomputation_items: int = DEFAULT_PRECOMPUTATION_ITEMS,
 
885
  # Save session info including repo_id for later hub upload
886
  self.save_session({
887
  "model_type": model_type,
888
+ "model_version": model_version,
889
  "training_type": training_type,
890
  "lora_rank": lora_rank,
891
  "lora_alpha": lora_alpha,
 
1056
  last_session = {
1057
  "params": {
1058
  "model_type": MODEL_TYPES.get(ui_state.get("model_type", list(MODEL_TYPES.keys())[0])),
1059
+ "model_version": ui_state.get("model_version", ""),
1060
  "training_type": TRAINING_TYPES.get(ui_state.get("training_type", list(TRAINING_TYPES.keys())[0])),
1061
  "lora_rank": ui_state.get("lora_rank", DEFAULT_LORA_RANK_STR),
1062
  "lora_alpha": ui_state.get("lora_alpha", DEFAULT_LORA_ALPHA_STR),
 
1120
  # Add UI updates to restore the training parameters in the UI
1121
  # This shows the user what values are being used for the resumed training
1122
  ui_updates.update({
1123
+ "model_type": model_type_display,
1124
+ "model_version": params.get('model_version', ''),
1125
+ "training_type": training_type_display,
1126
  "lora_rank": params.get('lora_rank', DEFAULT_LORA_RANK_STR),
1127
  "lora_alpha": params.get('lora_alpha', DEFAULT_LORA_ALPHA_STR),
1128
  "train_steps": params.get('train_steps', DEFAULT_NB_TRAINING_STEPS),
 
1141
  # Use the internal model_type for the actual training
1142
  # But keep model_type_display for the UI
1143
  result = self.start_training(
1144
+ model_type=model_internal_type,
1145
  lora_rank=params.get('lora_rank', DEFAULT_LORA_RANK_STR),
1146
  lora_alpha=params.get('lora_alpha', DEFAULT_LORA_ALPHA_STR),
1147
  train_size=params.get('train_steps', DEFAULT_NB_TRAINING_STEPS),
1148
  batch_size=params.get('batch_size', DEFAULT_BATCH_SIZE),
1149
  learning_rate=params.get('learning_rate', DEFAULT_LEARNING_RATE),
1150
  save_iterations=params.get('save_iterations', DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS),
1151
+ model_version=params.get('model_version', ''),
1152
  repo_id=params.get('repo_id', ''),
1153
  preset_name=params.get('preset_name', list(TRAINING_PRESETS.keys())[0]),
1154
  training_type=training_type_internal,
1155
  resume_from_checkpoint=str(latest_checkpoint)
1156
  )
 
1157
  # Set buttons for active training
1158
  ui_updates.update({
1159
  "start_btn": {"interactive": False, "variant": "secondary", "value": "Continue Training"},
 
1161
  "delete_checkpoints_btn": {"interactive": False, "variant": "stop", "value": "Delete All Checkpoints"},
1162
  "pause_resume_btn": {"interactive": False, "variant": "secondary", "visible": False}
1163
  })
1164
+
1165
  return {
1166
  "status": "recovered",
1167
  "message": f"Training resumed from checkpoint {checkpoint_step}",
vms/ui/project/tabs/preview_tab.py CHANGED
@@ -4,16 +4,18 @@ Preview tab for Video Model Studio UI
4
 
5
  import gradio as gr
6
  import logging
 
7
  from pathlib import Path
8
  from typing import Dict, Any, List, Optional, Tuple
9
  import time
10
 
11
  from vms.utils import BaseTab
12
  from vms.config import (
13
- MODEL_TYPES, DEFAULT_PROMPT_PREFIX
14
  )
15
 
16
  logger = logging.getLogger(__name__)
 
17
 
18
  class PreviewTab(BaseTab):
19
  """Preview tab for testing trained models"""
@@ -49,25 +51,35 @@ class PreviewTab(BaseTab):
49
  placeholder="Prefix to add to all prompts",
50
  value=DEFAULT_PROMPT_PREFIX
51
  )
 
 
 
 
 
 
 
 
 
52
 
53
  with gr.Row():
54
  # Get the currently selected model type from training tab if possible
55
  default_model = self.get_default_model_type()
56
 
57
- # Make model_type read-only (disabled), as it must match what was trained
58
- self.components["model_type"] = gr.Dropdown(
59
- choices=list(MODEL_TYPES.keys()),
60
- label="Model Type (from training)",
61
- value=default_model,
62
- interactive=False
63
- )
64
-
65
- # Add model variant selection based on model type
66
- self.components["model_variant"] = gr.Dropdown(
67
- label="Model Variant",
68
- choices=self.get_variant_choices(default_model),
69
- value=self.get_default_variant(default_model)
70
- )
 
71
 
72
  # Add image input for image-to-video models
73
  self.components["conditioning_image"] = gr.Image(
@@ -177,36 +189,55 @@ class PreviewTab(BaseTab):
177
 
178
  return tab
179
 
180
- def get_variant_choices(self, model_type: str) -> List[str]:
181
- """Get model variant choices based on model type"""
182
  # Convert UI display name to internal name
183
  internal_type = MODEL_TYPES.get(model_type)
184
  if not internal_type:
185
  return []
186
 
187
- # Get variants from preview service
188
- variants = self.app.previewing.get_model_variants(internal_type)
189
- if not variants:
190
  return []
191
 
192
  # Format choices with display name and description
193
  choices = []
194
- for model_id, info in variants.items():
195
  choices.append(f"{model_id} - {info.get('name', '')}")
196
 
197
  return choices
198
 
199
- def get_default_variant(self, model_type: str) -> str:
200
- """Get default model variant for the model type"""
201
- choices = self.get_variant_choices(model_type)
202
  if choices:
203
  return choices[0]
204
  return ""
205
-
206
  def get_default_model_type(self) -> str:
207
- """Get the currently selected model type from training tab"""
208
  try:
209
- # Try to get the model type from UI state
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
  ui_state = self.app.training.load_ui_state()
211
  model_type = ui_state.get("model_type")
212
 
@@ -214,7 +245,7 @@ class PreviewTab(BaseTab):
214
  if model_type in MODEL_TYPES:
215
  return model_type
216
 
217
- # If we couldn't get a valid model type, try to get it from the training tab directly
218
  if hasattr(self.app, 'tabs') and 'train_tab' in self.app.tabs:
219
  train_tab = self.app.tabs['train_tab']
220
  if hasattr(train_tab, 'components') and 'model_type' in train_tab.components:
@@ -225,31 +256,31 @@ class PreviewTab(BaseTab):
225
  # Fallback to first model type
226
  return list(MODEL_TYPES.keys())[0]
227
  except Exception as e:
228
- logger.warning(f"Failed to get default model type: {e}")
229
  return list(MODEL_TYPES.keys())[0]
230
 
231
- def extract_model_id(self, variant_choice: str) -> str:
232
- """Extract model ID from variant choice string"""
233
- if " - " in variant_choice:
234
- return variant_choice.split(" - ")[0].strip()
235
- return variant_choice
236
 
237
- def get_variant_type(self, model_type: str, model_variant: str) -> str:
238
- """Get the variant type (text-to-video or image-to-video)"""
239
  # Convert UI display name to internal name
240
  internal_type = MODEL_TYPES.get(model_type)
241
  if not internal_type:
242
  return "text-to-video"
243
 
244
- # Extract model_id from variant choice
245
- model_id = self.extract_model_id(model_variant)
246
 
247
- # Get variants from preview service
248
- variants = self.app.previewing.get_model_variants(internal_type)
249
- variant_info = variants.get(model_id, {})
250
 
251
- # Return the variant type or default to text-to-video
252
- return variant_info.get("type", "text-to-video")
253
 
254
  def connect_events(self) -> None:
255
  """Connect event handlers to UI components"""
@@ -264,23 +295,23 @@ class PreviewTab(BaseTab):
264
  ]
265
  )
266
 
267
- # Update model_variant choices when model_type changes or tab is selected
268
  if hasattr(self.app, 'tabs_component') and self.app.tabs_component is not None:
269
  self.app.tabs_component.select(
270
- fn=self.sync_model_type_and_variants,
271
  inputs=[],
272
  outputs=[
273
  self.components["model_type"],
274
- self.components["model_variant"]
275
  ]
276
  )
277
 
278
- # Update variant-specific UI elements when variant changes
279
- self.components["model_variant"].change(
280
- fn=self.update_variant_ui,
281
  inputs=[
282
  self.components["model_type"],
283
- self.components["model_variant"]
284
  ],
285
  outputs=[
286
  self.components["conditioning_image"]
@@ -305,13 +336,13 @@ class PreviewTab(BaseTab):
305
  self.components["lora_weight"],
306
  self.components["inference_steps"],
307
  self.components["enable_cpu_offload"],
308
- self.components["model_variant"]
309
  ]
310
  )
311
 
312
  # Save preview UI state when values change
313
  for component_name in [
314
- "prompt", "negative_prompt", "prompt_prefix", "model_variant", "resolution_preset",
315
  "width", "height", "num_frames", "fps", "guidance_scale", "flow_shift",
316
  "lora_weight", "inference_steps", "enable_cpu_offload"
317
  ]:
@@ -327,7 +358,7 @@ class PreviewTab(BaseTab):
327
  fn=self.generate_video,
328
  inputs=[
329
  self.components["model_type"],
330
- self.components["model_variant"],
331
  self.components["prompt"],
332
  self.components["negative_prompt"],
333
  self.components["prompt_prefix"],
@@ -349,22 +380,41 @@ class PreviewTab(BaseTab):
349
  ]
350
  )
351
 
352
- def update_variant_ui(self, model_type: str, model_variant: str) -> Dict[str, Any]:
353
- """Update UI based on the selected model variant"""
354
- variant_type = self.get_variant_type(model_type, model_variant)
355
 
356
  # Show conditioning image input only for image-to-video models
357
- show_conditioning_image = variant_type == "image-to-video"
358
 
359
  return {
360
  self.components["conditioning_image"]: gr.Image(visible=show_conditioning_image)
361
  }
362
 
363
- def sync_model_type_and_variants(self) -> Tuple[str, str]:
364
- """Sync model type with training tab when preview tab is selected and update variant choices"""
365
  model_type = self.get_default_model_type()
366
- model_variant = self.get_default_variant(model_type)
367
- return model_type, model_variant
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
368
 
369
  def update_resolution(self, preset: str) -> Tuple[int, int, float]:
370
  """Update resolution and flow shift based on preset"""
@@ -385,11 +435,11 @@ class PreviewTab(BaseTab):
385
  # Get model type (can't be changed in UI)
386
  model_type = self.get_default_model_type()
387
 
388
- # If model_variant not in choices for current model_type, use default
389
- model_variant = preview_state.get("model_variant", "")
390
- variant_choices = self.get_variant_choices(model_type)
391
- if model_variant not in variant_choices and variant_choices:
392
- model_variant = variant_choices[0]
393
 
394
  return (
395
  preview_state.get("prompt", ""),
@@ -404,7 +454,7 @@ class PreviewTab(BaseTab):
404
  preview_state.get("lora_weight", 0.7),
405
  preview_state.get("inference_steps", 30),
406
  preview_state.get("enable_cpu_offload", True),
407
- model_variant
408
  )
409
  except Exception as e:
410
  logger.error(f"Error loading preview state: {e}")
@@ -414,7 +464,7 @@ class PreviewTab(BaseTab):
414
  "worst quality, low quality, blurry, jittery, distorted, ugly, deformed, disfigured, messy background",
415
  DEFAULT_PROMPT_PREFIX,
416
  832, 480, 49, 16, 5.0, 3.0, 0.7, 30, True,
417
- self.get_default_variant(self.get_default_model_type())
418
  )
419
 
420
  def save_preview_state_value(self, value: Any) -> None:
@@ -456,7 +506,7 @@ class PreviewTab(BaseTab):
456
  def generate_video(
457
  self,
458
  model_type: str,
459
- model_variant: str,
460
  prompt: str,
461
  negative_prompt: str,
462
  prompt_prefix: str,
@@ -473,13 +523,14 @@ class PreviewTab(BaseTab):
473
  ) -> Tuple[Optional[str], str, str]:
474
  """Handler for generate button click, delegates to preview service"""
475
  # Save all the parameters to preview state before generating
 
476
  try:
477
  state = self.app.training.load_ui_state()
478
  if "preview" not in state:
479
  state["preview"] = {}
480
 
481
- # Extract model ID from variant choice
482
- model_variant_id = self.extract_model_id(model_variant)
483
 
484
  # Update all values
485
  preview_state = {
@@ -487,7 +538,7 @@ class PreviewTab(BaseTab):
487
  "negative_prompt": negative_prompt,
488
  "prompt_prefix": prompt_prefix,
489
  "model_type": model_type,
490
- "model_variant": model_variant,
491
  "width": width,
492
  "height": height,
493
  "num_frames": num_frames,
@@ -504,40 +555,30 @@ class PreviewTab(BaseTab):
504
  except Exception as e:
505
  logger.error(f"Error saving preview state before generation: {e}")
506
 
507
- # Clear the log display at the start to make room for new logs
508
- # Yield and sleep briefly to allow UI update
509
- yield None, "Starting generation...", ""
510
- time.sleep(0.1)
511
 
512
- # Extract model ID from variant choice string
513
- model_variant_id = self.extract_model_id(model_variant)
514
 
515
- # Use streaming updates to provide real-time feedback during generation
516
- def generate_with_updates():
517
- # Initial UI update
518
- yield None, "Initializing generation...", "Starting video generation process..."
519
-
520
- # Start actual generation
521
- result = self.app.previewing.generate_video(
522
- model_type=model_type,
523
- model_variant=model_variant_id,
524
- prompt=prompt,
525
- negative_prompt=negative_prompt,
526
- prompt_prefix=prompt_prefix,
527
- width=width,
528
- height=height,
529
- num_frames=num_frames,
530
- guidance_scale=guidance_scale,
531
- flow_shift=flow_shift,
532
- lora_weight=lora_weight,
533
- inference_steps=inference_steps,
534
- enable_cpu_offload=enable_cpu_offload,
535
- fps=fps,
536
- conditioning_image=conditioning_image
537
- )
538
-
539
- # Return final result
540
- return result
541
 
542
- # Return the generator for streaming updates
543
- return generate_with_updates()
 
4
 
5
  import gradio as gr
6
  import logging
7
+ import json
8
  from pathlib import Path
9
  from typing import Dict, Any, List, Optional, Tuple
10
  import time
11
 
12
  from vms.utils import BaseTab
13
  from vms.config import (
14
+ OUTPUT_PATH, MODEL_TYPES, DEFAULT_PROMPT_PREFIX, MODEL_VERSIONS
15
  )
16
 
17
  logger = logging.getLogger(__name__)
18
+ logger.setLevel(logging.INFO)
19
 
20
  class PreviewTab(BaseTab):
21
  """Preview tab for testing trained models"""
 
51
  placeholder="Prefix to add to all prompts",
52
  value=DEFAULT_PROMPT_PREFIX
53
  )
54
+
55
+ self.components["seed"] = gr.Slider(
56
+ label="Generation Seed (-1 for random)",
57
+ minimum=-1,
58
+ maximum=2147483647, # 2^31 - 1
59
+ step=1,
60
+ value=-1,
61
+ info="Set to -1 for random seed or specific value for reproducible results"
62
+ )
63
 
64
  with gr.Row():
65
  # Get the currently selected model type from training tab if possible
66
  default_model = self.get_default_model_type()
67
 
68
+ with gr.Column():
69
+ # Make model_type read-only (disabled), as it must match what was trained
70
+ self.components["model_type"] = gr.Dropdown(
71
+ choices=list(MODEL_TYPES.keys()),
72
+ label="Model Type (from training)",
73
+ value=default_model,
74
+ interactive=False
75
+ )
76
+
77
+ # Add model version selection based on model type
78
+ self.components["model_version"] = gr.Dropdown(
79
+ label="Model Version",
80
+ choices=self.get_model_version_choices(default_model),
81
+ value=self.get_default_model_version(default_model)
82
+ )
83
 
84
  # Add image input for image-to-video models
85
  self.components["conditioning_image"] = gr.Image(
 
189
 
190
  return tab
191
 
192
+ def get_model_version_choices(self, model_type: str) -> List[str]:
193
+ """Get model version choices based on model type"""
194
  # Convert UI display name to internal name
195
  internal_type = MODEL_TYPES.get(model_type)
196
  if not internal_type:
197
  return []
198
 
199
+ # Get versions from preview service
200
+ versions = self.app.previewing.get_model_versions(internal_type)
201
+ if not versions:
202
  return []
203
 
204
  # Format choices with display name and description
205
  choices = []
206
+ for model_id, info in versions.items():
207
  choices.append(f"{model_id} - {info.get('name', '')}")
208
 
209
  return choices
210
 
211
+ def get_default_model_version(self, model_type: str) -> str:
212
+ """Get default model version for the model type"""
213
+ choices = self.get_model_version_choices(model_type)
214
  if choices:
215
  return choices[0]
216
  return ""
217
+
218
  def get_default_model_type(self) -> str:
219
+ """Get the model type from the latest training session"""
220
  try:
221
+ # First check the session.json which contains the actual training data
222
+ session_file = OUTPUT_PATH / "session.json"
223
+ if session_file.exists():
224
+ with open(session_file, 'r') as f:
225
+ session_data = json.load(f)
226
+
227
+ # Get the internal model type from the session parameters
228
+ if "params" in session_data and "model_type" in session_data["params"]:
229
+ internal_model_type = session_data["params"]["model_type"]
230
+
231
+ # Convert internal model type to display name
232
+ for display_name, internal_name in MODEL_TYPES.items():
233
+ if internal_name == internal_model_type:
234
+ logger.info(f"Using model type '{display_name}' from session file")
235
+ return display_name
236
+
237
+ # If we couldn't map it, log a warning
238
+ logger.warning(f"Could not map internal model type '{internal_model_type}' to a display name")
239
+
240
+ # If we couldn't get it from session.json, try to get it from UI state
241
  ui_state = self.app.training.load_ui_state()
242
  model_type = ui_state.get("model_type")
243
 
 
245
  if model_type in MODEL_TYPES:
246
  return model_type
247
 
248
+ # If we still couldn't get a valid model type, try to get it from the training tab
249
  if hasattr(self.app, 'tabs') and 'train_tab' in self.app.tabs:
250
  train_tab = self.app.tabs['train_tab']
251
  if hasattr(train_tab, 'components') and 'model_type' in train_tab.components:
 
256
  # Fallback to first model type
257
  return list(MODEL_TYPES.keys())[0]
258
  except Exception as e:
259
+ logger.warning(f"Failed to get default model type from session: {e}")
260
  return list(MODEL_TYPES.keys())[0]
261
 
262
+ def extract_model_id(self, model_version_choice: str) -> str:
263
+ """Extract model ID from model version choice string"""
264
+ if " - " in model_version_choice:
265
+ return model_version_choice.split(" - ")[0].strip()
266
+ return model_version_choice
267
 
268
+ def get_model_version_type(self, model_type: str, model_version: str) -> str:
269
+ """Get the model version type (text-to-video or image-to-video)"""
270
  # Convert UI display name to internal name
271
  internal_type = MODEL_TYPES.get(model_type)
272
  if not internal_type:
273
  return "text-to-video"
274
 
275
+ # Extract model_id from model version choice
276
+ model_id = self.extract_model_id(model_version)
277
 
278
+ # Get versions from preview service
279
+ versions = self.app.previewing.get_model_versions(internal_type)
280
+ model_version_info = versions.get(model_id, {})
281
 
282
+ # Return the model version type or default to text-to-video
283
+ return model_version_info.get("type", "text-to-video")
284
 
285
  def connect_events(self) -> None:
286
  """Connect event handlers to UI components"""
 
295
  ]
296
  )
297
 
298
+ # Update model_version choices when model_type changes or tab is selected
299
  if hasattr(self.app, 'tabs_component') and self.app.tabs_component is not None:
300
  self.app.tabs_component.select(
301
+ fn=self.sync_model_type_and_verions,
302
  inputs=[],
303
  outputs=[
304
  self.components["model_type"],
305
+ self.components["model_version"]
306
  ]
307
  )
308
 
309
+ # Update model version-specific UI elements when version changes
310
+ self.components["model_version"].change(
311
+ fn=self.update_model_version_ui,
312
  inputs=[
313
  self.components["model_type"],
314
+ self.components["model_version"]
315
  ],
316
  outputs=[
317
  self.components["conditioning_image"]
 
336
  self.components["lora_weight"],
337
  self.components["inference_steps"],
338
  self.components["enable_cpu_offload"],
339
+ self.components["model_version"]
340
  ]
341
  )
342
 
343
  # Save preview UI state when values change
344
  for component_name in [
345
+ "prompt", "negative_prompt", "prompt_prefix", "model_version", "resolution_preset",
346
  "width", "height", "num_frames", "fps", "guidance_scale", "flow_shift",
347
  "lora_weight", "inference_steps", "enable_cpu_offload"
348
  ]:
 
358
  fn=self.generate_video,
359
  inputs=[
360
  self.components["model_type"],
361
+ self.components["model_version"],
362
  self.components["prompt"],
363
  self.components["negative_prompt"],
364
  self.components["prompt_prefix"],
 
380
  ]
381
  )
382
 
383
+ def update_model_version_ui(self, model_type: str, model_version: str) -> Dict[str, Any]:
384
+ """Update UI based on the selected model version"""
385
+ model_version_type = self.get_model_version_type(model_type, model_version)
386
 
387
  # Show conditioning image input only for image-to-video models
388
+ show_conditioning_image = model_version_type == "image-to-video"
389
 
390
  return {
391
  self.components["conditioning_image"]: gr.Image(visible=show_conditioning_image)
392
  }
393
 
394
+ def sync_model_type_and_verions(self) -> Tuple[str, str]:
395
+ """Sync model type with training tab when preview tab is selected and update model version choices"""
396
  model_type = self.get_default_model_type()
397
+ model_version = ""
398
+
399
+ # Try to get model_version from session or UI state
400
+ ui_state = self.app.training.load_ui_state()
401
+ preview_state = ui_state.get("preview", {})
402
+ model_version = preview_state.get("model_version", "")
403
+
404
+ if not model_version:
405
+ # Format it as a display choice
406
+ internal_type = MODEL_TYPES.get(model_type)
407
+ if internal_type and internal_type in MODEL_VERSIONS:
408
+ first_version = next(iter(MODEL_VERSIONS[internal_type].keys()), "")
409
+ if first_version:
410
+ model_version_info = MODEL_VERSIONS[internal_type][first_version]
411
+ model_version = f"{first_version} - {model_version_info.get('name', '')}"
412
+
413
+ # If we couldn't get it, use default
414
+ if not model_version:
415
+ model_version = self.get_default_model_version(model_type)
416
+
417
+ return model_type, model_version
418
 
419
  def update_resolution(self, preset: str) -> Tuple[int, int, float]:
420
  """Update resolution and flow shift based on preset"""
 
435
  # Get model type (can't be changed in UI)
436
  model_type = self.get_default_model_type()
437
 
438
+ # If model_version not in choices for current model_type, use default
439
+ model_version = preview_state.get("model_version", "")
440
+ model_version_choices = self.get_model_version_choices(model_type)
441
+ if model_version not in model_version_choices and model_version_choices:
442
+ model_version = model_version_choices[0]
443
 
444
  return (
445
  preview_state.get("prompt", ""),
 
454
  preview_state.get("lora_weight", 0.7),
455
  preview_state.get("inference_steps", 30),
456
  preview_state.get("enable_cpu_offload", True),
457
+ model_version
458
  )
459
  except Exception as e:
460
  logger.error(f"Error loading preview state: {e}")
 
464
  "worst quality, low quality, blurry, jittery, distorted, ugly, deformed, disfigured, messy background",
465
  DEFAULT_PROMPT_PREFIX,
466
  832, 480, 49, 16, 5.0, 3.0, 0.7, 30, True,
467
+ self.get_default_model_version(self.get_default_model_type())
468
  )
469
 
470
  def save_preview_state_value(self, value: Any) -> None:
 
506
  def generate_video(
507
  self,
508
  model_type: str,
509
+ model_version: str,
510
  prompt: str,
511
  negative_prompt: str,
512
  prompt_prefix: str,
 
523
  ) -> Tuple[Optional[str], str, str]:
524
  """Handler for generate button click, delegates to preview service"""
525
  # Save all the parameters to preview state before generating
526
+ print("preview_tab: generate_video() has been called")
527
  try:
528
  state = self.app.training.load_ui_state()
529
  if "preview" not in state:
530
  state["preview"] = {}
531
 
532
+ # Extract model ID from model version choice
533
+ model_version_id = self.extract_model_id(model_version)
534
 
535
  # Update all values
536
  preview_state = {
 
538
  "negative_prompt": negative_prompt,
539
  "prompt_prefix": prompt_prefix,
540
  "model_type": model_type,
541
+ "model_version": model_version,
542
  "width": width,
543
  "height": height,
544
  "num_frames": num_frames,
 
555
  except Exception as e:
556
  logger.error(f"Error saving preview state before generation: {e}")
557
 
558
+ # Extract model ID from model version choice string
559
+ model_version_id = self.extract_model_id(model_version)
 
 
560
 
561
+ # Initial UI update
562
+ video_path, status, log = None, "Initializing generation...", "Starting video generation process..."
563
 
564
+ # Start actual generation
565
+ result = self.app.previewing.generate_video(
566
+ model_type=model_type,
567
+ model_version=model_version_id,
568
+ prompt=prompt,
569
+ negative_prompt=negative_prompt,
570
+ prompt_prefix=prompt_prefix,
571
+ width=width,
572
+ height=height,
573
+ num_frames=num_frames,
574
+ guidance_scale=guidance_scale,
575
+ flow_shift=flow_shift,
576
+ lora_weight=lora_weight,
577
+ inference_steps=inference_steps,
578
+ enable_cpu_offload=enable_cpu_offload,
579
+ fps=fps,
580
+ conditioning_image=conditioning_image
581
+ )
 
 
 
 
 
 
 
 
582
 
583
+ # Return final result
584
+ return result
vms/ui/project/tabs/train_tab.py CHANGED
@@ -5,12 +5,15 @@ Train tab for Video Model Studio UI with improved task progress display
5
  import gradio as gr
6
  import logging
7
  import os
 
8
  from typing import Dict, Any, List, Optional, Tuple
9
  from pathlib import Path
10
 
11
  from vms.utils import BaseTab
12
  from vms.config import (
13
- TRAINING_PRESETS, OUTPUT_PATH, MODEL_TYPES, ASK_USER_TO_DUPLICATE_SPACE, SMALL_TRAINING_BUCKETS, TRAINING_TYPES,
 
 
14
  DEFAULT_NB_TRAINING_STEPS, DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
15
  DEFAULT_BATCH_SIZE, DEFAULT_CAPTION_DROPOUT_P,
16
  DEFAULT_LEARNING_RATE,
@@ -53,12 +56,27 @@ class TrainTab(BaseTab):
53
 
54
  with gr.Row():
55
  with gr.Column():
 
 
 
56
  self.components["model_type"] = gr.Dropdown(
57
  choices=list(MODEL_TYPES.keys()),
58
  label="Model Type",
59
- value=list(MODEL_TYPES.keys())[0]
 
60
  )
61
- with gr.Column():
 
 
 
 
 
 
 
 
 
 
 
62
  self.components["training_type"] = gr.Dropdown(
63
  choices=list(TRAINING_TYPES.keys()),
64
  label="Training Type",
@@ -198,45 +216,36 @@ class TrainTab(BaseTab):
198
 
199
  def connect_events(self) -> None:
200
  """Connect event handlers to UI components"""
201
- # Model type change event
202
- def update_model_info(model, training_type):
203
- params = self.get_default_params(MODEL_TYPES[model], TRAINING_TYPES[training_type])
204
- info = self.get_model_info(model, training_type)
205
- show_lora_params = training_type == list(TRAINING_TYPES.keys())[0] # Show if LoRA Finetune
206
-
207
- return {
208
- self.components["model_info"]: info,
209
- self.components["train_steps"]: params["train_steps"],
210
- self.components["batch_size"]: params["batch_size"],
211
- self.components["learning_rate"]: params["learning_rate"],
212
- self.components["save_iterations"]: params["save_iterations"],
213
- self.components["lora_params_row"]: gr.Row(visible=show_lora_params)
214
- }
215
-
216
  self.components["model_type"].change(
 
 
 
 
217
  fn=lambda v: self.app.update_ui_state(model_type=v),
218
  inputs=[self.components["model_type"]],
219
  outputs=[]
220
  ).then(
221
- fn=update_model_info,
 
222
  inputs=[self.components["model_type"], self.components["training_type"]],
223
- outputs=[
224
- self.components["model_info"],
225
- self.components["train_steps"],
226
- self.components["batch_size"],
227
- self.components["learning_rate"],
228
- self.components["save_iterations"],
229
- self.components["lora_params_row"]
230
- ]
231
  )
232
 
 
 
 
 
 
 
 
233
  # Training type change event
234
  self.components["training_type"].change(
235
  fn=lambda v: self.app.update_ui_state(training_type=v),
236
  inputs=[self.components["training_type"]],
237
  outputs=[]
238
  ).then(
239
- fn=update_model_info,
240
  inputs=[self.components["model_type"], self.components["training_type"]],
241
  outputs=[
242
  self.components["model_info"],
@@ -248,7 +257,6 @@ class TrainTab(BaseTab):
248
  ]
249
  )
250
 
251
-
252
  # Add in the connect_events() method:
253
  self.components["num_gpus"].change(
254
  fn=lambda v: self.app.update_ui_state(num_gpus=v),
@@ -326,7 +334,9 @@ class TrainTab(BaseTab):
326
  self.components["lora_params_row"],
327
  self.components["num_gpus"],
328
  self.components["precomputation_items"],
329
- self.components["lr_warmup_steps"]
 
 
330
  ]
331
  )
332
 
@@ -336,6 +346,7 @@ class TrainTab(BaseTab):
336
  inputs=[
337
  self.components["training_preset"],
338
  self.components["model_type"],
 
339
  self.components["training_type"],
340
  self.components["lora_rank"],
341
  self.components["lora_alpha"],
@@ -383,9 +394,19 @@ class TrainTab(BaseTab):
383
  fn=lambda: self.app.training.delete_all_checkpoints(),
384
  outputs=[self.components["status_box"]]
385
  )
 
 
 
 
 
 
 
 
386
 
387
  def handle_training_start(
388
- self, preset, model_type, training_type, lora_rank, lora_alpha, train_steps, batch_size, learning_rate, save_iterations, repo_id, progress=gr.Progress()
 
 
389
  ):
390
  """Handle training start with proper log parser reset and checkpoint detection"""
391
  # Safely reset log parser if it exists
@@ -396,9 +417,6 @@ class TrainTab(BaseTab):
396
  from ..utils import TrainingLogParser
397
  self.app.log_parser = TrainingLogParser()
398
 
399
- # Initialize progress
400
- #progress(0, desc="Initializing training")
401
-
402
  # Check for latest checkpoint
403
  checkpoints = list(OUTPUT_PATH.glob("checkpoint-*"))
404
  resume_from = None
@@ -408,10 +426,6 @@ class TrainTab(BaseTab):
408
  latest_checkpoint = max(checkpoints, key=os.path.getmtime)
409
  resume_from = str(latest_checkpoint)
410
  logger.info(f"Found checkpoint at {resume_from}, will resume training")
411
- #progress(0.05, desc=f"Resuming from checkpoint {Path(resume_from).name}")
412
- else:
413
- #progress(0.05, desc="Starting new training run")
414
- pass
415
 
416
  # Convert model_type display name to internal name
417
  model_internal_type = MODEL_TYPES.get(model_type)
@@ -432,9 +446,6 @@ class TrainTab(BaseTab):
432
  precomputation_items = int(self.components["precomputation_items"].value)
433
  lr_warmup_steps = int(self.components["lr_warmup_steps"].value)
434
 
435
- # Progress update
436
- #progress(0.1, desc="Preparing dataset")
437
-
438
  # Start training (it will automatically use the checkpoint if provided)
439
  try:
440
  return self.app.training.start_training(
@@ -448,6 +459,7 @@ class TrainTab(BaseTab):
448
  repo_id,
449
  preset_name=preset,
450
  training_type=training_internal_type,
 
451
  resume_from_checkpoint=resume_from,
452
  num_gpus=num_gpus,
453
  precomputation_items=precomputation_items,
@@ -458,6 +470,52 @@ class TrainTab(BaseTab):
458
  logger.exception("Error starting training")
459
  return f"Error starting training: {str(e)}", f"Exception: {str(e)}\n\nCheck the logs for more details."
460
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
461
  def get_model_info(self, model_type: str, training_type: str) -> str:
462
  """Get information about the selected model type and training method"""
463
  if model_type == "HunyuanVideo":
@@ -483,14 +541,14 @@ class TrainTab(BaseTab):
483
  else:
484
  return base_info + "\n- Required VRAM: ~21GB minimum\n- Full model size: ~8GB"
485
 
486
- elif model_type == "Wan-2.1-T2V":
487
- base_info = """### Wan-2.1-T2V
488
- - Recommended batch size: ?
489
- - Typical training time: ? hours
490
  - Default resolution: 49x512x768"""
491
 
492
  if training_type == "LoRA Finetune":
493
- return base_info + "\n- Required VRAM: ?GB minimum\n- Default LoRA rank: 32 (~120 MB)"
494
  else:
495
  return base_info + "\n- **Full finetune not recommended due to VRAM requirements**"
496
 
@@ -601,6 +659,10 @@ class TrainTab(BaseTab):
601
  precomputation_items_val = current_state.get("precomputation_items") if current_state.get("precomputation_items") != preset.get("precomputation_items", DEFAULT_PRECOMPUTATION_ITEMS) else preset.get("precomputation_items", DEFAULT_PRECOMPUTATION_ITEMS)
602
  lr_warmup_steps_val = current_state.get("lr_warmup_steps") if current_state.get("lr_warmup_steps") != preset.get("lr_warmup_steps", DEFAULT_NB_LR_WARMUP_STEPS) else preset.get("lr_warmup_steps", DEFAULT_NB_LR_WARMUP_STEPS)
603
 
 
 
 
 
604
  # Return values in the same order as the output components
605
  return (
606
  model_display_name,
@@ -615,9 +677,11 @@ class TrainTab(BaseTab):
615
  gr.Row(visible=show_lora_params),
616
  num_gpus_val,
617
  precomputation_items_val,
618
- lr_warmup_steps_val
 
619
  )
620
-
 
621
  def get_latest_status_message_and_logs(self) -> Tuple[str, str, str]:
622
  """Get latest status message, log content, and status code in a safer way"""
623
  state = self.app.training.get_status()
 
5
  import gradio as gr
6
  import logging
7
  import os
8
+ import json
9
  from typing import Dict, Any, List, Optional, Tuple
10
  from pathlib import Path
11
 
12
  from vms.utils import BaseTab
13
  from vms.config import (
14
+ OUTPUT_PATH, ASK_USER_TO_DUPLICATE_SPACE,
15
+ SMALL_TRAINING_BUCKETS,
16
+ TRAINING_PRESETS, TRAINING_TYPES, MODEL_TYPES, MODEL_VERSIONS,
17
  DEFAULT_NB_TRAINING_STEPS, DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
18
  DEFAULT_BATCH_SIZE, DEFAULT_CAPTION_DROPOUT_P,
19
  DEFAULT_LEARNING_RATE,
 
56
 
57
  with gr.Row():
58
  with gr.Column():
59
+ # Get the default model type from the first preset
60
+ default_model_type = list(MODEL_TYPES.keys())[0]
61
+
62
  self.components["model_type"] = gr.Dropdown(
63
  choices=list(MODEL_TYPES.keys()),
64
  label="Model Type",
65
+ value=default_model_type,
66
+ interactive=True
67
  )
68
+
69
+ # Get model versions for the default model type
70
+ default_model_versions = self.get_model_version_choices(default_model_type)
71
+ default_model_version = self.get_default_model_version(default_model_type)
72
+
73
+ self.components["model_version"] = gr.Dropdown(
74
+ choices=default_model_versions,
75
+ label="Model Version",
76
+ value=default_model_version,
77
+ interactive=True
78
+ )
79
+
80
  self.components["training_type"] = gr.Dropdown(
81
  choices=list(TRAINING_TYPES.keys()),
82
  label="Training Type",
 
216
 
217
  def connect_events(self) -> None:
218
  """Connect event handlers to UI components"""
219
+ # Model type change event - Update model version dropdown choices
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
  self.components["model_type"].change(
221
+ fn=self.update_model_versions,
222
+ inputs=[self.components["model_type"]],
223
+ outputs=[self.components["model_version"]]
224
+ ).then(
225
  fn=lambda v: self.app.update_ui_state(model_type=v),
226
  inputs=[self.components["model_type"]],
227
  outputs=[]
228
  ).then(
229
+ # Use get_model_info instead of update_model_info
230
+ fn=self.get_model_info,
231
  inputs=[self.components["model_type"], self.components["training_type"]],
232
+ outputs=[self.components["model_info"]]
 
 
 
 
 
 
 
233
  )
234
 
235
+ # Model version change event
236
+ self.components["model_version"].change(
237
+ fn=lambda v: self.app.update_ui_state(model_version=v),
238
+ inputs=[self.components["model_version"]],
239
+ outputs=[]
240
+ )
241
+
242
  # Training type change event
243
  self.components["training_type"].change(
244
  fn=lambda v: self.app.update_ui_state(training_type=v),
245
  inputs=[self.components["training_type"]],
246
  outputs=[]
247
  ).then(
248
+ fn=self.update_model_info,
249
  inputs=[self.components["model_type"], self.components["training_type"]],
250
  outputs=[
251
  self.components["model_info"],
 
257
  ]
258
  )
259
 
 
260
  # Add in the connect_events() method:
261
  self.components["num_gpus"].change(
262
  fn=lambda v: self.app.update_ui_state(num_gpus=v),
 
334
  self.components["lora_params_row"],
335
  self.components["num_gpus"],
336
  self.components["precomputation_items"],
337
+ self.components["lr_warmup_steps"],
338
+ # Add model_version to the outputs
339
+ self.components["model_version"]
340
  ]
341
  )
342
 
 
346
  inputs=[
347
  self.components["training_preset"],
348
  self.components["model_type"],
349
+ self.components["model_version"], # Add model_version to the inputs
350
  self.components["training_type"],
351
  self.components["lora_rank"],
352
  self.components["lora_alpha"],
 
394
  fn=lambda: self.app.training.delete_all_checkpoints(),
395
  outputs=[self.components["status_box"]]
396
  )
397
+
398
+ def update_model_versions(self, model_type: str) -> Dict:
399
+ """Update model version choices based on selected model type"""
400
+ model_versions = self.get_model_version_choices(model_type)
401
+ default_version = self.get_default_model_version(model_type)
402
+
403
+ # Update the model_version dropdown with new choices and default value
404
+ return gr.Dropdown(choices=model_versions, value=default_version)
405
 
406
  def handle_training_start(
407
+ self, preset, model_type, model_version, training_type,
408
+ lora_rank, lora_alpha, train_steps, batch_size, learning_rate,
409
+ save_iterations, repo_id, progress=gr.Progress()
410
  ):
411
  """Handle training start with proper log parser reset and checkpoint detection"""
412
  # Safely reset log parser if it exists
 
417
  from ..utils import TrainingLogParser
418
  self.app.log_parser = TrainingLogParser()
419
 
 
 
 
420
  # Check for latest checkpoint
421
  checkpoints = list(OUTPUT_PATH.glob("checkpoint-*"))
422
  resume_from = None
 
426
  latest_checkpoint = max(checkpoints, key=os.path.getmtime)
427
  resume_from = str(latest_checkpoint)
428
  logger.info(f"Found checkpoint at {resume_from}, will resume training")
 
 
 
 
429
 
430
  # Convert model_type display name to internal name
431
  model_internal_type = MODEL_TYPES.get(model_type)
 
446
  precomputation_items = int(self.components["precomputation_items"].value)
447
  lr_warmup_steps = int(self.components["lr_warmup_steps"].value)
448
 
 
 
 
449
  # Start training (it will automatically use the checkpoint if provided)
450
  try:
451
  return self.app.training.start_training(
 
459
  repo_id,
460
  preset_name=preset,
461
  training_type=training_internal_type,
462
+ model_version=model_version, # Pass the model version from dropdown
463
  resume_from_checkpoint=resume_from,
464
  num_gpus=num_gpus,
465
  precomputation_items=precomputation_items,
 
470
  logger.exception("Error starting training")
471
  return f"Error starting training: {str(e)}", f"Exception: {str(e)}\n\nCheck the logs for more details."
472
 
473
+ def get_model_version_choices(self, model_type: str) -> List[str]:
474
+ """Get model version choices based on model type"""
475
+ # Convert UI display name to internal name
476
+ internal_type = MODEL_TYPES.get(model_type)
477
+ if not internal_type or internal_type not in MODEL_VERSIONS:
478
+ return []
479
+
480
+ # Get versions and return them as choices
481
+ versions = MODEL_VERSIONS.get(internal_type, {})
482
+ return list(versions.keys())
483
+
484
+ def get_default_model_version(self, model_type: str) -> str:
485
+ """Get default model version for the given model type"""
486
+ # Convert UI display name to internal name
487
+ internal_type = MODEL_TYPES.get(model_type)
488
+ if not internal_type or internal_type not in MODEL_VERSIONS:
489
+ return ""
490
+
491
+ # Get the first version available for this model type
492
+ versions = MODEL_VERSIONS.get(internal_type, {})
493
+ if versions:
494
+ return next(iter(versions.keys()))
495
+
496
+ return ""
497
+
498
+ def update_model_info(self, model_type: str, training_type: str) -> Dict:
499
+ """Update model info and related UI components based on model type and training type"""
500
+ # Get model info text
501
+ model_info = self.get_model_info(model_type, training_type)
502
+
503
+ # Get default parameters for this model type and training type
504
+ params = self.get_default_params(MODEL_TYPES.get(model_type), TRAINING_TYPES.get(training_type))
505
+
506
+ # Check if LoRA params should be visible
507
+ show_lora_params = training_type == "LoRA Finetune"
508
+
509
+ # Return updates for UI components
510
+ return {
511
+ self.components["model_info"]: model_info,
512
+ self.components["train_steps"]: params["train_steps"],
513
+ self.components["batch_size"]: params["batch_size"],
514
+ self.components["learning_rate"]: params["learning_rate"],
515
+ self.components["save_iterations"]: params["save_iterations"],
516
+ self.components["lora_params_row"]: gr.Row(visible=show_lora_params)
517
+ }
518
+
519
  def get_model_info(self, model_type: str, training_type: str) -> str:
520
  """Get information about the selected model type and training method"""
521
  if model_type == "HunyuanVideo":
 
541
  else:
542
  return base_info + "\n- Required VRAM: ~21GB minimum\n- Full model size: ~8GB"
543
 
544
+ elif model_type == "Wan":
545
+ base_info = """### Wan
546
+ - Recommended batch size: 1-4
547
+ - Typical training time: 1-3 hours
548
  - Default resolution: 49x512x768"""
549
 
550
  if training_type == "LoRA Finetune":
551
+ return base_info + "\n- Required VRAM: ~16GB minimum\n- Default LoRA rank: 32 (~120 MB)"
552
  else:
553
  return base_info + "\n- **Full finetune not recommended due to VRAM requirements**"
554
 
 
659
  precomputation_items_val = current_state.get("precomputation_items") if current_state.get("precomputation_items") != preset.get("precomputation_items", DEFAULT_PRECOMPUTATION_ITEMS) else preset.get("precomputation_items", DEFAULT_PRECOMPUTATION_ITEMS)
660
  lr_warmup_steps_val = current_state.get("lr_warmup_steps") if current_state.get("lr_warmup_steps") != preset.get("lr_warmup_steps", DEFAULT_NB_LR_WARMUP_STEPS) else preset.get("lr_warmup_steps", DEFAULT_NB_LR_WARMUP_STEPS)
661
 
662
+ # Get the appropriate model version for the selected model type
663
+ model_versions = self.get_model_version_choices(model_display_name)
664
+ default_model_version = self.get_default_model_version(model_display_name)
665
+
666
  # Return values in the same order as the output components
667
  return (
668
  model_display_name,
 
677
  gr.Row(visible=show_lora_params),
678
  num_gpus_val,
679
  precomputation_items_val,
680
+ lr_warmup_steps_val,
681
+ gr.Dropdown(choices=model_versions, value=default_model_version)
682
  )
683
+
684
+
685
  def get_latest_status_message_and_logs(self) -> Tuple[str, str, str]:
686
  """Get latest status message, log content, and status code in a safer way"""
687
  state = self.app.training.get_status()