Spaces:
Running
Running
import platform | |
import gradio as gr | |
from pathlib import Path | |
import logging | |
import asyncio | |
from typing import Any, Optional, Dict, List, Union, Tuple | |
from vms.config import ( | |
STORAGE_PATH, VIDEOS_TO_SPLIT_PATH, STAGING_PATH, OUTPUT_PATH, | |
TRAINING_PATH, LOG_FILE_PATH, TRAINING_PRESETS, TRAINING_VIDEOS_PATH, MODEL_PATH, OUTPUT_PATH, | |
MODEL_TYPES, SMALL_TRAINING_BUCKETS, TRAINING_TYPES, MODEL_VERSIONS, | |
DEFAULT_NB_TRAINING_STEPS, DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS, | |
DEFAULT_BATCH_SIZE, DEFAULT_CAPTION_DROPOUT_P, | |
DEFAULT_LEARNING_RATE, | |
DEFAULT_LORA_RANK, DEFAULT_LORA_ALPHA, | |
DEFAULT_LORA_RANK_STR, DEFAULT_LORA_ALPHA_STR, | |
DEFAULT_SEED, | |
DEFAULT_NUM_GPUS, | |
DEFAULT_MAX_GPUS, | |
DEFAULT_PRECOMPUTATION_ITEMS, | |
DEFAULT_NB_TRAINING_STEPS, | |
DEFAULT_NB_LR_WARMUP_STEPS, | |
DEFAULT_AUTO_RESUME | |
) | |
from vms.utils import ( | |
get_recommended_precomputation_items, | |
count_media_files, | |
format_media_title, | |
TrainingLogParser | |
) | |
from vms.ui.project.services import ( | |
TrainingService, CaptioningService, SplittingService, ImportingService, PreviewingService | |
) | |
from vms.ui.project.tabs import ( | |
ImportTab, CaptionTab, TrainTab, PreviewTab, ManageTab | |
) | |
from vms.ui.monitoring.services import ( | |
MonitoringService | |
) | |
from vms.ui.monitoring.tabs import ( | |
GeneralTab, GPUTab | |
) | |
logger = logging.getLogger(__name__) | |
logger.setLevel(logging.INFO) | |
httpx_logger = logging.getLogger('httpx') | |
httpx_logger.setLevel(logging.WARN) | |
class AppUI: | |
def __init__(self): | |
"""Initialize services and tabs""" | |
# Project view | |
self.training = TrainingService(self) | |
self.splitting = SplittingService() | |
self.importing = ImportingService() | |
self.captioning = CaptioningService() | |
self.previewing = PreviewingService() | |
# Monitoring view | |
self.monitoring = MonitoringService() | |
self.monitoring.start_monitoring() | |
# Recovery status from any interrupted training | |
recovery_result = self.training.recover_interrupted_training() | |
# Add null check for recovery_result | |
if recovery_result is None: | |
recovery_result = {"status": "unknown", "ui_updates": {}} | |
self.recovery_status = recovery_result.get("status", "unknown") | |
self.ui_updates = recovery_result.get("ui_updates", {}) | |
# Initialize log parser | |
self.log_parser = TrainingLogParser() | |
# Shared state for tabs | |
self.state = { | |
"recovery_result": recovery_result | |
} | |
# Initialize tabs dictionary | |
self.tabs = {} | |
self.project_tabs = {} | |
self.monitor_tabs = {} | |
self.main_tabs = None # Main tabbed interface | |
self.project_tabs_component = None # Project sub-tabs | |
self.monitor_tabs_component = None # Monitor sub-tabs | |
# Log recovery status | |
logger.info(f"Initialization complete. Recovery status: {self.recovery_status}") | |
def add_periodic_callback(self, callback_fn, interval=1.0): | |
"""Add a periodic callback function to the UI | |
Args: | |
callback_fn: Function to call periodically | |
interval: Time in seconds between calls (default: 1.0) | |
""" | |
try: | |
# Store a reference to the callback function | |
if not hasattr(self, "_periodic_callbacks"): | |
self._periodic_callbacks = [] | |
self._periodic_callbacks.append(callback_fn) | |
# Add the callback to the Gradio app | |
self.app.add_callback( | |
interval, # Interval in seconds | |
callback_fn, # Function to call | |
inputs=None, # No inputs needed | |
outputs=list(self.components.values()) # All components as possible outputs | |
) | |
logger.info(f"Added periodic callback {callback_fn.__name__} with interval {interval}s") | |
except Exception as e: | |
logger.error(f"Error adding periodic callback: {e}", exc_info=True) | |
def switch_to_tab(self, tab_index: int): | |
"""Switch to the specified tab index | |
Args: | |
tab_index: Index of the tab to select (0 for Project, 1 for Monitor) | |
Returns: | |
Tab selection dictionary for Gradio | |
""" | |
return gr.Tabs(selected=tab_index) | |
def create_ui(self): | |
self.components = {} | |
"""Create the main Gradio UI with tabbed navigation""" | |
with gr.Blocks( | |
title="ποΈ Video Model Studio", | |
# Let's hack Gradio! | |
css="#main-tabs > .tab-wrapper{ display: none; }") as app: | |
self.app = app | |
# Main container with sidebar and tab area | |
with gr.Row(): | |
# Sidebar for navigation | |
with gr.Sidebar(position="left", open=True): | |
gr.Markdown("# ποΈ Video Model Studio") | |
self.components["current_project_btn"] = gr.Button("π New Project", variant="primary") | |
self.components["system_monitoring_btn"] = gr.Button("π‘οΈ System Monitoring") | |
# Main content area with tabs | |
with gr.Column(): | |
# Main tabbed interface for switching between Project and Monitor views | |
with gr.Tabs(elem_id="main-tabs") as main_tabs: | |
self.main_tabs = main_tabs | |
# Project View Tab | |
with gr.Tab("π New Project", id=0) as project_view: | |
# Create project tabs | |
with gr.Tabs() as project_tabs: | |
# Store reference to project tabs component | |
self.project_tabs_component = project_tabs | |
# Initialize project tab objects | |
self.project_tabs["import_tab"] = ImportTab(self) | |
self.project_tabs["caption_tab"] = CaptionTab(self) | |
self.project_tabs["train_tab"] = TrainTab(self) | |
self.project_tabs["preview_tab"] = PreviewTab(self) | |
self.project_tabs["manage_tab"] = ManageTab(self) | |
# Create tab UI components for project | |
for tab_id, tab_obj in self.project_tabs.items(): | |
tab_obj.create(project_tabs) | |
# Monitoring View Tab | |
with gr.Tab("π‘οΈ System Monitoring", id=1) as monitoring_view: | |
# Create monitoring tabs | |
with gr.Tabs() as monitoring_tabs: | |
# Store reference to monitoring tabs component | |
self.monitor_tabs_component = monitoring_tabs | |
# Initialize monitoring tab objects | |
self.monitor_tabs["general_tab"] = GeneralTab(self) | |
self.monitor_tabs["gpu_tab"] = GPUTab(self) | |
# Create tab UI components for monitoring | |
for tab_id, tab_obj in self.monitor_tabs.items(): | |
tab_obj.create(monitoring_tabs) | |
# Combine all tabs into a single dictionary for event handling | |
self.tabs = {**self.project_tabs, **self.monitor_tabs} | |
# Connect event handlers for all tabs - this must happen AFTER all tabs are created | |
for tab_id, tab_obj in self.tabs.items(): | |
tab_obj.connect_events() | |
# app-level timers for auto-refresh functionality | |
self._add_timers() | |
# Connect navigation events using tab switching | |
self.components["current_project_btn"].click( | |
fn=lambda: self.switch_to_tab(0), | |
outputs=[self.main_tabs], | |
) | |
self.components["system_monitoring_btn"].click( | |
fn=lambda: self.switch_to_tab(1), | |
outputs=[self.main_tabs], | |
) | |
# Initialize app state on load | |
app.load( | |
fn=self.initialize_app_state, | |
outputs=[ | |
self.project_tabs["caption_tab"].components["training_dataset"], | |
self.project_tabs["train_tab"].components["start_btn"], | |
self.project_tabs["train_tab"].components["resume_btn"], | |
self.project_tabs["train_tab"].components["stop_btn"], | |
self.project_tabs["train_tab"].components["delete_checkpoints_btn"], | |
self.project_tabs["train_tab"].components["training_preset"], | |
self.project_tabs["train_tab"].components["model_type"], | |
self.project_tabs["train_tab"].components["model_version"], | |
self.project_tabs["train_tab"].components["training_type"], | |
self.project_tabs["train_tab"].components["lora_rank"], | |
self.project_tabs["train_tab"].components["lora_alpha"], | |
self.project_tabs["train_tab"].components["train_steps"], | |
self.project_tabs["train_tab"].components["batch_size"], | |
self.project_tabs["train_tab"].components["learning_rate"], | |
self.project_tabs["train_tab"].components["save_iterations"], | |
self.project_tabs["train_tab"].components["current_task_box"], | |
self.project_tabs["train_tab"].components["num_gpus"], | |
self.project_tabs["train_tab"].components["precomputation_items"], | |
self.project_tabs["train_tab"].components["lr_warmup_steps"], | |
self.project_tabs["train_tab"].components["auto_resume"] | |
] | |
) | |
return app | |
def _add_timers(self): | |
"""Add auto-refresh timers to the UI""" | |
# Status update timer for text components (every 1 second) | |
status_timer = gr.Timer(value=1) | |
status_timer.tick( | |
fn=self.project_tabs["train_tab"].get_status_updates, | |
outputs=[ | |
self.project_tabs["train_tab"].components["status_box"], | |
self.project_tabs["train_tab"].components["log_box"], | |
self.project_tabs["train_tab"].components["current_task_box"] if "current_task_box" in self.project_tabs["train_tab"].components else None | |
] | |
) | |
# Button update timer for button components (every 1 second) | |
button_timer = gr.Timer(value=1) | |
button_outputs = [ | |
self.project_tabs["train_tab"].components["start_btn"], | |
self.project_tabs["train_tab"].components["resume_btn"], | |
self.project_tabs["train_tab"].components["stop_btn"], | |
self.project_tabs["train_tab"].components["delete_checkpoints_btn"] | |
] | |
button_timer.tick( | |
fn=self.project_tabs["train_tab"].get_button_updates, | |
outputs=button_outputs | |
) | |
# Dataset refresh timer (every 5 seconds) | |
dataset_timer = gr.Timer(value=5) | |
dataset_timer.tick( | |
fn=self.refresh_dataset, | |
outputs=[ | |
self.project_tabs["caption_tab"].components["training_dataset"] | |
] | |
) | |
# Titles update timer (every 6 seconds) | |
titles_timer = gr.Timer(value=6) | |
titles_timer.tick( | |
fn=self.update_titles, | |
outputs=[ | |
self.project_tabs["caption_tab"].components["caption_title"], | |
self.project_tabs["train_tab"].components["train_title"] | |
] | |
) | |
def initialize_app_state(self): | |
"""Initialize all app state in one function to ensure correct output count""" | |
# Get dataset info | |
training_dataset = self.project_tabs["caption_tab"].list_training_files_to_caption() | |
# Get button states based on recovery status | |
button_states = self.get_initial_button_states() | |
start_btn = button_states[0] | |
resume_btn = button_states[1] | |
stop_btn = button_states[2] | |
delete_checkpoints_btn = button_states[3] | |
# Get UI form values - possibly from the recovery | |
if self.recovery_status in ["recovered", "ready_to_recover", "running"] and "ui_updates" in self.state["recovery_result"]: | |
recovery_ui = self.state["recovery_result"]["ui_updates"] | |
# If we recovered training parameters from the original session | |
ui_state = {} | |
# Handle model_type specifically - could be internal or display name | |
if "model_type" in recovery_ui: | |
model_type_value = recovery_ui["model_type"] | |
# Remove " (LoRA)" suffix if present | |
if " (LoRA)" in model_type_value: | |
model_type_value = model_type_value.replace(" (LoRA)", "") | |
logger.info(f"Removed (LoRA) suffix from model type: {model_type_value}") | |
# If it's an internal name, convert to display name | |
if model_type_value not in MODEL_TYPES: | |
# Find the display name for this internal model type | |
for display_name, internal_name in MODEL_TYPES.items(): | |
if internal_name == model_type_value: | |
model_type_value = display_name | |
logger.info(f"Converted internal model type '{recovery_ui['model_type']}' to display name '{model_type_value}'") | |
break | |
ui_state["model_type"] = model_type_value | |
# Handle training_type | |
if "training_type" in recovery_ui: | |
training_type_value = recovery_ui["training_type"] | |
# If it's an internal name, convert to display name | |
if training_type_value not in TRAINING_TYPES: | |
for display_name, internal_name in TRAINING_TYPES.items(): | |
if internal_name == training_type_value: | |
training_type_value = display_name | |
logger.info(f"Converted internal training type '{recovery_ui['training_type']}' to display name '{training_type_value}'") | |
break | |
ui_state["training_type"] = training_type_value | |
# Copy other parameters | |
for param in ["lora_rank", "lora_alpha", "train_steps", | |
"batch_size", "learning_rate", "save_iterations", "training_preset"]: | |
if param in recovery_ui: | |
ui_state[param] = recovery_ui[param] | |
# Merge with existing UI state if needed | |
if ui_state: | |
current_state = self.load_ui_values() | |
current_state.update(ui_state) | |
self.training.save_ui_state(current_state) | |
logger.info(f"Updated UI state from recovery: {ui_state}") | |
# Load values (potentially with recovery updates applied) | |
ui_state = self.load_ui_values() | |
# Ensure model_type is a valid display name | |
model_type_val = ui_state.get("model_type", list(MODEL_TYPES.keys())[0]) | |
# Remove " (LoRA)" suffix if present | |
if " (LoRA)" in model_type_val: | |
model_type_val = model_type_val.replace(" (LoRA)", "") | |
logger.info(f"Removed (LoRA) suffix from model type: {model_type_val}") | |
# Ensure it's a valid model type in the dropdown | |
if model_type_val not in MODEL_TYPES: | |
# Convert from internal to display name or use default | |
model_type_found = False | |
for display_name, internal_name in MODEL_TYPES.items(): | |
if internal_name == model_type_val: | |
model_type_val = display_name | |
model_type_found = True | |
break | |
# If still not found, use the first model type | |
if not model_type_found: | |
model_type_val = list(MODEL_TYPES.keys())[0] | |
logger.warning(f"Invalid model type '{model_type_val}', using default: {model_type_val}") | |
# Get model_version value | |
model_version_val = "" | |
auto_resume_val = ui_state.get("auto_resume", DEFAULT_AUTO_RESUME) | |
# First get the internal model type for the currently selected model | |
model_internal_type = MODEL_TYPES.get(model_type_val) | |
logger.info(f"Initializing model version for model_type: {model_type_val} (internal: {model_internal_type})") | |
if model_internal_type and model_internal_type in MODEL_VERSIONS: | |
# Get available versions for this model type as simple strings | |
available_model_versions = list(MODEL_VERSIONS.get(model_internal_type, {}).keys()) | |
# Log for debugging | |
logger.info(f"Available versions: {available_model_versions}") | |
# Set model_version_val to saved value if valid, otherwise first available | |
if "model_version" in ui_state and ui_state["model_version"] in available_model_versions: | |
model_version_val = ui_state["model_version"] | |
logger.info(f"Using saved model version: {model_version_val}") | |
elif available_model_versions: | |
model_version_val = available_model_versions[0] | |
logger.info(f"Using first available model version: {model_version_val}") | |
# IMPORTANT: Create a new list of simple strings for the dropdown choices | |
# This ensures each choice is a single string, not a tuple or other structure | |
simple_choices = [str(version) for version in available_model_versions] | |
# Update the dropdown choices directly in the UI component | |
try: | |
self.project_tabs["train_tab"].components["model_version"].choices = simple_choices | |
logger.info(f"Updated model_version dropdown choices: {len(simple_choices)} options") | |
except Exception as e: | |
logger.error(f"Error updating model_version dropdown: {str(e)}") | |
else: | |
logger.warning(f"No versions available for model type: {model_type_val}") | |
# Set empty choices to avoid errors | |
try: | |
self.project_tabs["train_tab"].components["model_version"].choices = [] | |
except Exception as e: | |
logger.error(f"Error setting empty model_version choices: {str(e)}") | |
# Ensure training_type is a valid display name | |
training_type_val = ui_state.get("training_type", list(TRAINING_TYPES.keys())[0]) | |
if training_type_val not in TRAINING_TYPES: | |
# Convert from internal to display name or use default | |
training_type_found = False | |
for display_name, internal_name in TRAINING_TYPES.items(): | |
if internal_name == training_type_val: | |
training_type_val = display_name | |
training_type_found = True | |
break | |
# If still not found, use the first training type | |
if not training_type_found: | |
training_type_val = list(TRAINING_TYPES.keys())[0] | |
logger.warning(f"Invalid training type '{training_type_val}', using default: {training_type_val}") | |
# Validate training preset | |
training_preset = ui_state.get("training_preset", list(TRAINING_PRESETS.keys())[0]) | |
if training_preset not in TRAINING_PRESETS: | |
training_preset = list(TRAINING_PRESETS.keys())[0] | |
logger.warning(f"Invalid training preset '{training_preset}', using default: {training_preset}") | |
lora_rank_val = ui_state.get("lora_rank", DEFAULT_LORA_RANK_STR) | |
lora_alpha_val = ui_state.get("lora_alpha", DEFAULT_LORA_ALPHA_STR) | |
batch_size_val = int(ui_state.get("batch_size", DEFAULT_BATCH_SIZE)) | |
learning_rate_val = float(ui_state.get("learning_rate", DEFAULT_LEARNING_RATE)) | |
save_iterations_val = int(ui_state.get("save_iterations", DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS)) | |
num_gpus_val = int(ui_state.get("num_gpus", DEFAULT_NUM_GPUS)) | |
# Calculate recommended precomputation items based on video count | |
video_count = len(list(TRAINING_VIDEOS_PATH.glob('*.mp4'))) | |
recommended_precomputation = get_recommended_precomputation_items(video_count, num_gpus_val) | |
precomputation_items_val = int(ui_state.get("precomputation_items", recommended_precomputation)) | |
# Ensure warmup steps are not more than training steps | |
train_steps_val = int(ui_state.get("train_steps", DEFAULT_NB_TRAINING_STEPS)) | |
default_warmup = min(DEFAULT_NB_LR_WARMUP_STEPS, int(train_steps_val * 0.2)) | |
lr_warmup_steps_val = int(ui_state.get("lr_warmup_steps", default_warmup)) | |
# Ensure warmup steps <= training steps | |
lr_warmup_steps_val = min(lr_warmup_steps_val, train_steps_val) | |
# Initial current task value | |
current_task_val = "" | |
if hasattr(self, 'log_parser') and self.log_parser: | |
current_task_val = self.log_parser.get_current_task_display() | |
# Return all values in the exact order expected by outputs | |
return ( | |
training_dataset, | |
start_btn, | |
resume_btn, | |
stop_btn, | |
delete_checkpoints_btn, | |
training_preset, | |
model_type_val, | |
model_version_val, | |
training_type_val, | |
lora_rank_val, | |
lora_alpha_val, | |
train_steps_val, | |
batch_size_val, | |
learning_rate_val, | |
save_iterations_val, | |
current_task_val, | |
num_gpus_val, | |
precomputation_items_val, | |
lr_warmup_steps_val, | |
auto_resume_val | |
) | |
def initialize_ui_from_state(self): | |
"""Initialize UI components from saved state""" | |
ui_state = self.load_ui_values() | |
# Get model type and determine the default model version if not specified | |
model_type = ui_state.get("model_type", list(MODEL_TYPES.keys())[0]) | |
model_internal_type = MODEL_TYPES.get(model_type) | |
# Get model_version, defaulting to first available version if not set | |
model_version = ui_state.get("model_version", "") | |
if not model_version and model_internal_type and model_internal_type in MODEL_VERSIONS: | |
versions = list(MODEL_VERSIONS.get(model_internal_type, {}).keys()) | |
if versions: | |
model_version = versions[0] | |
# Return values in order matching the outputs in app.load | |
return ( | |
ui_state.get("training_preset", list(TRAINING_PRESETS.keys())[0]), | |
model_type, | |
model_version, | |
ui_state.get("training_type", list(TRAINING_TYPES.keys())[0]), | |
ui_state.get("lora_rank", DEFAULT_LORA_RANK_STR), | |
ui_state.get("lora_alpha", DEFAULT_LORA_ALPHA_STR), | |
ui_state.get("train_steps", DEFAULT_NB_TRAINING_STEPS), | |
ui_state.get("batch_size", DEFAULT_BATCH_SIZE), | |
ui_state.get("learning_rate", DEFAULT_LEARNING_RATE), | |
ui_state.get("save_iterations", DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS) | |
) | |
def update_ui_state(self, **kwargs): | |
"""Update UI state with new values""" | |
current_state = self.training.load_ui_state() | |
current_state.update(kwargs) | |
self.training.save_ui_state(current_state) | |
# Don't return anything to avoid Gradio warnings | |
return None | |
def load_ui_values(self): | |
"""Load UI state values for initializing form fields""" | |
ui_state = self.training.load_ui_state() | |
# Ensure proper type conversion for numeric values | |
ui_state["lora_rank"] = ui_state.get("lora_rank", DEFAULT_LORA_RANK_STR) | |
ui_state["lora_alpha"] = ui_state.get("lora_alpha", DEFAULT_LORA_ALPHA_STR) | |
ui_state["train_steps"] = int(ui_state.get("train_steps", DEFAULT_NB_TRAINING_STEPS)) | |
ui_state["batch_size"] = int(ui_state.get("batch_size", DEFAULT_BATCH_SIZE)) | |
ui_state["learning_rate"] = float(ui_state.get("learning_rate", DEFAULT_LEARNING_RATE)) | |
ui_state["save_iterations"] = int(ui_state.get("save_iterations", DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS)) | |
return ui_state | |
# Add this new method to get initial button states: | |
def get_initial_button_states(self): | |
"""Get the initial states for training buttons based on recovery status""" | |
recovery_result = self.state.get("recovery_result") or self.training.recover_interrupted_training() | |
ui_updates = recovery_result.get("ui_updates", {}) | |
# Check for checkpoints to determine start button text | |
checkpoints = list(OUTPUT_PATH.glob("finetrainers_step_*")) | |
has_checkpoints = len(checkpoints) > 0 | |
# Default button states if recovery didn't provide any | |
if not ui_updates or not ui_updates.get("start_btn"): | |
is_training = self.training.is_training_running() | |
if is_training: | |
# Active training detected | |
start_btn_props = {"interactive": False, "variant": "secondary", "value": "π Start new training"} | |
resume_btn_props = {"interactive": False, "variant": "secondary", "value": "πΈ Start from latest checkpoint"} | |
stop_btn_props = {"interactive": True, "variant": "primary", "value": "Stop at Last Checkpoint"} | |
delete_btn_props = {"interactive": False, "variant": "stop", "value": "Delete All Checkpoints"} | |
else: | |
# No active training | |
start_btn_props = {"interactive": True, "variant": "primary", "value": "π Start new training"} | |
resume_btn_props = {"interactive": has_checkpoints, "variant": "primary", "value": "πΈ Start from latest checkpoint"} | |
stop_btn_props = {"interactive": False, "variant": "secondary", "value": "Stop at Last Checkpoint"} | |
delete_btn_props = {"interactive": has_checkpoints, "variant": "stop", "value": "Delete All Checkpoints"} | |
else: | |
# Use button states from recovery, adding the new resume button | |
start_btn_props = ui_updates.get("start_btn", {"interactive": True, "variant": "primary", "value": "π Start new training"}) | |
resume_btn_props = {"interactive": has_checkpoints and not self.training.is_training_running(), | |
"variant": "primary", "value": "πΈ Start from latest checkpoint"} | |
stop_btn_props = ui_updates.get("stop_btn", {"interactive": False, "variant": "secondary", "value": "Stop at Last Checkpoint"}) | |
delete_btn_props = ui_updates.get("delete_checkpoints_btn", {"interactive": has_checkpoints, "variant": "stop", "value": "Delete All Checkpoints"}) | |
# Return button states in the correct order | |
return ( | |
gr.Button(**start_btn_props), | |
gr.Button(**resume_btn_props), # Add the new resume button | |
gr.Button(**stop_btn_props), | |
gr.Button(**delete_btn_props) | |
) | |
def update_titles(self) -> Tuple[Any]: | |
"""Update all dynamic titles with current counts | |
Returns: | |
Dict of Gradio updates | |
""" | |
# Count files for captioning | |
caption_videos, caption_images, caption_size = count_media_files(STAGING_PATH) | |
caption_title = format_media_title( | |
"caption", caption_videos, caption_images, caption_size | |
) | |
# Count files for training | |
train_videos, train_images, train_size = count_media_files(TRAINING_VIDEOS_PATH) | |
train_title = format_media_title( | |
"train", train_videos, train_images, train_size | |
) | |
return ( | |
gr.Markdown(value=caption_title), | |
gr.Markdown(value=f"{train_title}") | |
) | |
def refresh_dataset(self): | |
"""Refresh all dynamic lists and training state""" | |
training_dataset = self.project_tabs["caption_tab"].list_training_files_to_caption() | |
return ( | |
training_dataset | |
) |