Spaces:
Running
Running
""" | |
Preview tab for Video Model Studio UI | |
""" | |
import gradio as gr | |
import logging | |
import json | |
from pathlib import Path | |
from typing import Dict, Any, List, Optional, Tuple | |
import time | |
from vms.utils import BaseTab | |
from vms.config import ( | |
OUTPUT_PATH, MODEL_TYPES, DEFAULT_PROMPT_PREFIX, MODEL_VERSIONS | |
) | |
logger = logging.getLogger(__name__) | |
logger.setLevel(logging.INFO) | |
class PreviewTab(BaseTab): | |
"""Preview tab for testing trained models""" | |
def __init__(self, app_state): | |
super().__init__(app_state) | |
self.id = "preview_tab" | |
self.title = "4️⃣ Preview" | |
def create(self, parent=None) -> gr.TabItem: | |
"""Create the Preview tab UI components""" | |
with gr.TabItem(self.title, id=self.id) as tab: | |
with gr.Row(): | |
gr.Markdown("## 🔬 Preview your model") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
# Add dropdown to choose between LoRA and original model | |
has_lora = self.check_lora_model_exists() | |
lora_choices = [] | |
default_lora_choice = "" | |
if has_lora: | |
lora_choices = ["Use LoRA model", "Use original model"] | |
default_lora_choice = "Use LoRA model" | |
else: | |
lora_choices = ["Cannot find LoRA model", "Use original model"] | |
default_lora_choice = "Use original model" | |
self.components["use_lora"] = gr.Dropdown( | |
choices=lora_choices, | |
label="Model Selection", | |
value=default_lora_choice | |
) | |
self.components["prompt"] = gr.Textbox( | |
label="Prompt", | |
placeholder="Enter your prompt here...", | |
lines=3 | |
) | |
self.components["negative_prompt"] = gr.Textbox( | |
label="Negative Prompt", | |
placeholder="Enter negative prompt here...", | |
lines=3, | |
value="worst quality, low quality, blurry, jittery, distorted, ugly, deformed, disfigured, messy background" | |
) | |
self.components["prompt_prefix"] = gr.Textbox( | |
label="Global Prompt Prefix", | |
placeholder="Prefix to add to all prompts", | |
value=DEFAULT_PROMPT_PREFIX | |
) | |
# Ensure seed is interactive with a slider | |
self.components["seed"] = gr.Slider( | |
label="Generation Seed (-1 for random)", | |
minimum=-1, | |
maximum=2147483647, # 2^31 - 1 | |
step=1, | |
value=-1, | |
info="Set to -1 for random seed or specific value for reproducible results", | |
interactive=True | |
) | |
with gr.Row(): | |
# Get the currently selected model type from training tab if possible | |
default_model = self.get_default_model_type() | |
with gr.Column(): | |
# Make model_type read-only (disabled), as it must match what was trained | |
self.components["model_type"] = gr.Dropdown( | |
choices=list(MODEL_TYPES.keys()), | |
label="Model Type (from training)", | |
value=default_model, | |
interactive=False | |
) | |
# Add model version selection based on model type | |
self.components["model_version"] = gr.Dropdown( | |
label="Model Version", | |
choices=self.get_model_version_choices(default_model), | |
value=self.get_default_model_version(default_model) | |
) | |
# Add image input for image-to-video models | |
self.components["conditioning_image"] = gr.Image( | |
label="Conditioning Image (for Image-to-Video models)", | |
type="filepath", | |
visible=False | |
) | |
with gr.Row(): | |
self.components["resolution_preset"] = gr.Dropdown( | |
choices=["480p", "720p"], | |
label="Resolution Preset", | |
value="480p" | |
) | |
with gr.Row(): | |
self.components["width"] = gr.Number( | |
label="Width", | |
value=832, | |
precision=0 | |
) | |
self.components["height"] = gr.Number( | |
label="Height", | |
value=480, | |
precision=0 | |
) | |
with gr.Row(): | |
self.components["num_frames"] = gr.Slider( | |
label="Number of Frames", | |
minimum=1, | |
maximum=257, | |
step=8, | |
value=49 | |
) | |
self.components["fps"] = gr.Slider( | |
label="FPS", | |
minimum=1, | |
maximum=60, | |
step=1, | |
value=16 | |
) | |
with gr.Row(): | |
self.components["guidance_scale"] = gr.Slider( | |
label="Guidance Scale", | |
minimum=1.0, | |
maximum=10.0, | |
step=0.1, | |
value=5.0 | |
) | |
self.components["flow_shift"] = gr.Slider( | |
label="Flow Shift", | |
minimum=0.0, | |
maximum=10.0, | |
step=0.1, | |
value=3.0 | |
) | |
with gr.Row(): | |
self.components["lora_scale"] = gr.Slider( | |
label="LoRA Scale", | |
minimum=0.0, | |
maximum=1.0, | |
step=0.01, | |
value=0.7, | |
visible=has_lora # Only visible if using LoRA | |
) | |
self.components["inference_steps"] = gr.Slider( | |
label="Inference Steps", | |
minimum=1, | |
maximum=100, | |
step=1, | |
value=20 | |
) | |
self.components["enable_cpu_offload"] = gr.Checkbox( | |
label="Enable Model CPU Offload (for low-VRAM GPUs)", | |
value=False # let's assume user is using a video model training rig with a good GPU | |
) | |
self.components["generate_btn"] = gr.Button( | |
"Generate Video", | |
variant="primary" | |
) | |
with gr.Column(scale=3): | |
self.components["preview_video"] = gr.Video( | |
label="Generated Video", | |
interactive=False | |
) | |
self.components["status"] = gr.Textbox( | |
label="Status", | |
interactive=False | |
) | |
with gr.Accordion("Log", open=False): | |
self.components["log"] = gr.TextArea( | |
label="Generation Log", | |
interactive=False, | |
lines=60 | |
) | |
return tab | |
def check_lora_model_exists(self) -> bool: | |
"""Check if any LoRA model files exist in the output directory""" | |
# Look for the standard LoRA weights file | |
lora_path = OUTPUT_PATH / "pytorch_lora_weights.safetensors" | |
if lora_path.exists(): | |
return True | |
# If not found in the expected location, try to find in checkpoints | |
checkpoints = list(OUTPUT_PATH.glob("finetrainers_step_*")) | |
has_checkpoints = len(checkpoints) > 0 | |
if not checkpoints: | |
return False | |
for checkpoint in checkpoints: | |
lora_path = checkpoint / "pytorch_lora_weights.safetensors" | |
if lora_path.exists(): | |
return True | |
return False | |
def update_lora_ui(self, use_lora_value: str) -> Dict[str, Any]: | |
"""Update UI based on LoRA selection""" | |
is_using_lora = "Use LoRA model" in use_lora_value | |
return { | |
self.components["lora_scale"]: gr.Slider(visible=is_using_lora) | |
} | |
def get_model_version_choices(self, model_type: str) -> List[str]: | |
"""Get model version choices based on model type""" | |
# Convert UI display name to internal name | |
internal_type = MODEL_TYPES.get(model_type) | |
if not internal_type or internal_type not in MODEL_VERSIONS: | |
logger.warning(f"No model versions found for {model_type} (internal type: {internal_type})") | |
return [] | |
# Return just the model IDs as a list of simple strings | |
version_ids = list(MODEL_VERSIONS.get(internal_type, {}).keys()) | |
logger.info(f"Found {len(version_ids)} versions for {model_type}: {version_ids}") | |
# Ensure they're all strings | |
return [str(version) for version in version_ids] | |
def get_default_model_version(self, model_type: str) -> str: | |
"""Get default model version for the given model type""" | |
# Convert UI display name to internal name | |
internal_type = MODEL_TYPES.get(model_type) | |
logger.debug(f"get_default_model_version({model_type}) = {internal_type}") | |
if not internal_type or internal_type not in MODEL_VERSIONS: | |
logger.warning(f"No valid model versions found for {model_type}") | |
return "" | |
# Get the first version available for this model type | |
versions = list(MODEL_VERSIONS.get(internal_type, {}).keys()) | |
if versions: | |
default_version = versions[0] | |
logger.debug(f"Default version for {model_type}: {default_version}") | |
return default_version | |
return "" | |
def get_default_model_type(self) -> str: | |
"""Get the model type from the latest training session""" | |
try: | |
# First check the session.json which contains the actual training data | |
session_file = OUTPUT_PATH / "session.json" | |
if session_file.exists(): | |
with open(session_file, 'r') as f: | |
session_data = json.load(f) | |
# Get the internal model type from the session parameters | |
if "params" in session_data and "model_type" in session_data["params"]: | |
internal_model_type = session_data["params"]["model_type"] | |
# Convert internal model type to display name | |
for display_name, internal_name in MODEL_TYPES.items(): | |
if internal_name == internal_model_type: | |
logger.info(f"Using model type '{display_name}' from session file") | |
return display_name | |
# If we couldn't map it, log a warning | |
logger.warning(f"Could not map internal model type '{internal_model_type}' to a display name") | |
# If we couldn't get it from session.json, try to get it from UI state | |
ui_state = self.app.training.load_ui_state() | |
model_type = ui_state.get("model_type") | |
# Make sure it's a valid model type | |
if model_type in MODEL_TYPES: | |
return model_type | |
# If we still couldn't get a valid model type, try to get it from the training tab | |
if hasattr(self.app, 'tabs') and 'train_tab' in self.app.tabs: | |
train_tab = self.app.tabs['train_tab'] | |
if hasattr(train_tab, 'components') and 'model_type' in train_tab.components: | |
train_model_type = train_tab.components['model_type'].value | |
if train_model_type in MODEL_TYPES: | |
return train_model_type | |
# Fallback to first model type | |
return list(MODEL_TYPES.keys())[0] | |
except Exception as e: | |
logger.warning(f"Failed to get default model type from session: {e}") | |
return list(MODEL_TYPES.keys())[0] | |
def extract_model_id(self, model_version_choice: str) -> str: | |
"""Extract model ID from model version choice string""" | |
if " - " in model_version_choice: | |
return model_version_choice.split(" - ")[0].strip() | |
return model_version_choice | |
def get_model_version_type(self, model_type: str, model_version: str) -> str: | |
"""Get the model version type (text-to-video or image-to-video)""" | |
# Convert UI display name to internal name | |
internal_type = MODEL_TYPES.get(model_type) | |
if not internal_type: | |
return "text-to-video" | |
# Extract model_id from model version choice | |
model_id = self.extract_model_id(model_version) | |
# Get versions from preview service | |
versions = self.app.previewing.get_model_versions(internal_type) | |
model_version_info = versions.get(model_id, {}) | |
# Return the model version type or default to text-to-video | |
return model_version_info.get("type", "text-to-video") | |
def connect_events(self) -> None: | |
"""Connect event handlers to UI components""" | |
# Update resolution when preset changes | |
self.components["resolution_preset"].change( | |
fn=self.update_resolution, | |
inputs=[self.components["resolution_preset"]], | |
outputs=[ | |
self.components["width"], | |
self.components["height"], | |
self.components["flow_shift"] | |
] | |
) | |
# Update model_version choices when model_type changes or tab is selected | |
if hasattr(self.app, 'tabs_component') and self.app.tabs_component is not None: | |
self.app.tabs_component.select( | |
fn=self.sync_model_type_and_versions, | |
inputs=[], | |
outputs=[ | |
self.components["model_type"], | |
self.components["model_version"] | |
] | |
) | |
# Update model version-specific UI elements when version changes | |
self.components["model_version"].change( | |
fn=self.update_model_version_ui, | |
inputs=[ | |
self.components["model_type"], | |
self.components["model_version"] | |
], | |
outputs=[ | |
self.components["conditioning_image"] | |
] | |
) | |
# Connect LoRA selection dropdown to update LoRA weight visibility | |
self.components["use_lora"].change( | |
fn=self.update_lora_ui, | |
inputs=[self.components["use_lora"]], | |
outputs=[self.components["lora_scale"]] | |
) | |
# Load preview UI state when the tab is selected | |
if hasattr(self.app, 'tabs_component') and self.app.tabs_component is not None: | |
self.app.tabs_component.select( | |
fn=self.load_preview_state, | |
inputs=[], | |
outputs=[ | |
self.components["prompt"], | |
self.components["negative_prompt"], | |
self.components["prompt_prefix"], | |
self.components["width"], | |
self.components["height"], | |
self.components["num_frames"], | |
self.components["fps"], | |
self.components["guidance_scale"], | |
self.components["flow_shift"], | |
self.components["lora_scale"], | |
self.components["inference_steps"], | |
self.components["enable_cpu_offload"], | |
self.components["model_version"], | |
self.components["seed"], | |
self.components["use_lora"] | |
] | |
) | |
# Save preview UI state when values change | |
for component_name in [ | |
"prompt", "negative_prompt", "prompt_prefix", "model_version", "resolution_preset", | |
"width", "height", "num_frames", "fps", "guidance_scale", "flow_shift", | |
"lora_scale", "inference_steps", "enable_cpu_offload", "seed", "use_lora" | |
]: | |
if component_name in self.components: | |
self.components[component_name].change( | |
fn=self.save_preview_state_value, | |
inputs=[self.components[component_name]], | |
outputs=[] | |
) | |
# Generate button click | |
self.components["generate_btn"].click( | |
fn=self.generate_video, | |
inputs=[ | |
self.components["model_type"], | |
self.components["model_version"], | |
self.components["prompt"], | |
self.components["negative_prompt"], | |
self.components["prompt_prefix"], | |
self.components["width"], | |
self.components["height"], | |
self.components["num_frames"], | |
self.components["guidance_scale"], | |
self.components["flow_shift"], | |
self.components["lora_scale"], | |
self.components["inference_steps"], | |
self.components["enable_cpu_offload"], | |
self.components["fps"], | |
self.components["conditioning_image"], | |
self.components["seed"], | |
self.components["use_lora"] | |
], | |
outputs=[ | |
self.components["preview_video"], | |
self.components["status"], | |
self.components["log"] | |
] | |
) | |
def update_model_version_ui(self, model_type: str, model_version: str) -> Dict[str, Any]: | |
"""Update UI based on the selected model version""" | |
model_version_type = self.get_model_version_type(model_type, model_version) | |
# Show conditioning image input only for image-to-video models | |
show_conditioning_image = model_version_type == "image-to-video" | |
return { | |
self.components["conditioning_image"]: gr.Image(visible=show_conditioning_image) | |
} | |
def sync_model_type_and_versions(self) -> Tuple[str, str]: | |
"""Sync model type with training tab when preview tab is selected and update model version choices""" | |
model_type = self.get_default_model_type() | |
model_version = "" | |
# Try to get model_version from session or UI state | |
ui_state = self.app.training.load_ui_state() | |
preview_state = ui_state.get("preview", {}) | |
model_version = preview_state.get("model_version", "") | |
# If no model version specified or invalid, use default | |
if not model_version: | |
# Get the internal model type | |
internal_type = MODEL_TYPES.get(model_type) | |
if internal_type and internal_type in MODEL_VERSIONS: | |
versions = list(MODEL_VERSIONS[internal_type].keys()) | |
if versions: | |
model_version = versions[0] | |
return model_type, model_version | |
def update_resolution(self, preset: str) -> Tuple[int, int, float]: | |
"""Update resolution and flow shift based on preset""" | |
if preset == "480p": | |
return 832, 480, 3.0 | |
elif preset == "720p": | |
return 1280, 720, 5.0 | |
else: | |
return 832, 480, 3.0 | |
def load_preview_state(self) -> Tuple: | |
"""Load saved preview UI state""" | |
# Try to get the saved state | |
try: | |
state = self.app.training.load_ui_state() | |
preview_state = state.get("preview", {}) | |
# Get model type (can't be changed in UI) | |
model_type = self.get_default_model_type() | |
# If model_version not in choices for current model_type, use default | |
model_version = preview_state.get("model_version", "") | |
model_version_choices = self.get_model_version_choices(model_type) | |
if model_version not in model_version_choices and model_version_choices: | |
model_version = model_version_choices[0] | |
# Check if LoRA exists and set appropriate dropdown options | |
has_lora = self.check_lora_model_exists() | |
use_lora = preview_state.get("use_lora", "") | |
# Validate use_lora value against current state | |
if has_lora: | |
valid_choices = ["Use LoRA model", "Use original model"] | |
if use_lora not in valid_choices: | |
use_lora = "Use LoRA model" # Default when LoRA exists | |
else: | |
valid_choices = ["Cannot find LoRA model", "Use original model"] | |
if use_lora not in valid_choices: | |
use_lora = "Use original model" # Default when no LoRA | |
# Update the dropdown choices in the UI | |
try: | |
self.components["use_lora"].choices = valid_choices | |
except Exception as e: | |
logger.error(f"Failed to update use_lora choices: {e}") | |
return ( | |
preview_state.get("prompt", ""), | |
preview_state.get("negative_prompt", "worst quality, low quality, blurry, jittery, distorted, ugly, deformed, disfigured, messy background"), | |
preview_state.get("prompt_prefix", DEFAULT_PROMPT_PREFIX), | |
preview_state.get("width", 832), | |
preview_state.get("height", 480), | |
preview_state.get("num_frames", 49), | |
preview_state.get("fps", 16), | |
preview_state.get("guidance_scale", 5.0), | |
preview_state.get("flow_shift", 3.0), | |
preview_state.get("lora_scale", 0.7), | |
preview_state.get("inference_steps", 30), | |
preview_state.get("enable_cpu_offload", True), | |
model_version, | |
preview_state.get("seed", -1), | |
use_lora | |
) | |
except Exception as e: | |
logger.error(f"Error loading preview state: {e}") | |
# Return defaults if loading fails | |
return ( | |
"", | |
"worst quality, low quality, blurry, jittery, distorted, ugly, deformed, disfigured, messy background", | |
DEFAULT_PROMPT_PREFIX, | |
832, 480, 49, 16, 5.0, 3.0, 0.7, 30, True, | |
self.get_default_model_version(self.get_default_model_type()), | |
-1, | |
"Use original model" if not self.check_lora_model_exists() else "Use LoRA model" | |
) | |
def save_preview_state_value(self, value: Any) -> None: | |
"""Save an individual preview state value""" | |
try: | |
# Get the component name from the event context | |
import inspect | |
frame = inspect.currentframe() | |
frame = inspect.getouterframes(frame)[1] | |
event_context = frame.frame.f_locals | |
component = event_context.get('component') | |
if component is None: | |
return | |
# Find the component name | |
component_name = None | |
for name, comp in self.components.items(): | |
if comp == component: | |
component_name = name | |
break | |
if component_name is None: | |
return | |
# Load current state | |
state = self.app.training.load_ui_state() | |
if "preview" not in state: | |
state["preview"] = {} | |
# Update the value | |
state["preview"][component_name] = value | |
# Save state | |
self.app.training.save_ui_state(state) | |
except Exception as e: | |
logger.error(f"Error saving preview state: {e}") | |
def generate_video( | |
self, | |
model_type: str, | |
model_version: str, | |
prompt: str, | |
negative_prompt: str, | |
prompt_prefix: str, | |
width: int, | |
height: int, | |
num_frames: int, | |
guidance_scale: float, | |
flow_shift: float, | |
lora_scale: float, | |
inference_steps: int, | |
enable_cpu_offload: bool, | |
fps: int, | |
conditioning_image: Optional[str] = None, | |
seed: int = -1, | |
use_lora: str = "Use LoRA model" | |
) -> Tuple[Optional[str], str, str]: | |
"""Handler for generate button click, delegates to preview service""" | |
# Save all the parameters to preview state before generating | |
print("preview_tab: generate_video() has been called") | |
try: | |
state = self.app.training.load_ui_state() | |
if "preview" not in state: | |
state["preview"] = {} | |
# Extract model ID from model version choice | |
model_version_id = self.extract_model_id(model_version) | |
# Update all values | |
preview_state = { | |
"prompt": prompt, | |
"negative_prompt": negative_prompt, | |
"prompt_prefix": prompt_prefix, | |
"model_type": model_type, | |
"model_version": model_version, | |
"width": width, | |
"height": height, | |
"num_frames": num_frames, | |
"fps": fps, | |
"guidance_scale": guidance_scale, | |
"flow_shift": flow_shift, | |
"lora_scale": lora_scale, | |
"inference_steps": inference_steps, | |
"enable_cpu_offload": enable_cpu_offload, | |
"seed": seed, | |
"use_lora": use_lora | |
} | |
state["preview"] = preview_state | |
self.app.training.save_ui_state(state) | |
except Exception as e: | |
logger.error(f"Error saving preview state before generation: {e}") | |
# Extract model ID from model version choice string | |
model_version_id = self.extract_model_id(model_version) | |
# Initial UI update | |
video_path, status, log = None, "Initializing generation...", "Starting video generation process..." | |
# Set lora_path to None if not using LoRA | |
use_lora_model = use_lora == "Use LoRA model" | |
# Start actual generation | |
# If not using LoRA, set lora_scale to 0 to disable it | |
effective_lora_scale = lora_scale if use_lora_model else 0.0 | |
result = self.app.previewing.generate_video( | |
model_type=model_type, | |
model_version=model_version_id, | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
prompt_prefix=prompt_prefix, | |
width=width, | |
height=height, | |
num_frames=num_frames, | |
guidance_scale=guidance_scale, | |
flow_shift=flow_shift, | |
lora_scale=effective_lora_scale, # Use 0.0 if not using LoRA | |
inference_steps=inference_steps, | |
enable_cpu_offload=enable_cpu_offload, | |
fps=fps, | |
conditioning_image=conditioning_image, | |
seed=seed | |
) | |
# Return final result | |
return result |