VideoModelStudio / vms /utils /finetrainers_utils.py
jbilcke-hf's picture
jbilcke-hf HF staff
making our code more robust
7c52128
import gradio as gr
from pathlib import Path
import logging
import shutil
from typing import Any, Optional, Dict, List, Union, Tuple
from ..config import (
STORAGE_PATH, TRAINING_PATH, STAGING_PATH, TRAINING_VIDEOS_PATH, MODEL_PATH, OUTPUT_PATH, HF_API_TOKEN, MODEL_TYPES,
DEFAULT_VALIDATION_NB_STEPS,
DEFAULT_VALIDATION_HEIGHT,
DEFAULT_VALIDATION_WIDTH,
DEFAULT_VALIDATION_NB_FRAMES,
DEFAULT_VALIDATION_FRAMERATE
)
from .utils import get_video_fps, extract_scene_info, make_archive, is_image_file, is_video_file
logger = logging.getLogger(__name__)
def prepare_finetrainers_dataset() -> Tuple[Path, Path]:
"""Prepare a Finetrainers-compatible dataset structure
Creates:
training/
β”œβ”€β”€ prompt.txt # All captions, one per line
β”œβ”€β”€ videos.txt # All video paths, one per line
└── videos/ # Directory containing all mp4 files
β”œβ”€β”€ 00000.mp4
β”œβ”€β”€ 00001.mp4
└── ...
Returns:
Tuple of (videos_file_path, prompts_file_path)
"""
# Verifies the videos subdirectory
TRAINING_VIDEOS_PATH.mkdir(exist_ok=True)
# Clear existing training lists
for f in TRAINING_PATH.glob("*"):
if f.is_file():
if f.name in ["videos.txt", "prompts.txt", "prompt.txt"]:
f.unlink()
videos_file = TRAINING_PATH / "videos.txt"
prompts_file = TRAINING_PATH / "prompts.txt" # Finetrainers can use either prompts.txt or prompt.txt
media_files = []
captions = []
# Process all video files from the videos subdirectory
for idx, file in enumerate(sorted(TRAINING_VIDEOS_PATH.glob("*.mp4"))):
caption_file = file.with_suffix('.txt')
if caption_file.exists():
# Normalize caption to single line
caption = caption_file.read_text().strip()
caption = ' '.join(caption.split())
# Use relative path from training root
relative_path = f"videos/{file.name}"
media_files.append(relative_path)
captions.append(caption)
# Write files if we have content
if media_files and captions:
videos_file.write_text('\n'.join(media_files))
prompts_file.write_text('\n'.join(captions))
logger.info(f"Created dataset with {len(media_files)} video/caption pairs")
else:
logger.warning("No valid video/caption pairs found in training directory")
return None, None
# Verify file contents
with open(videos_file) as vf:
video_lines = [l.strip() for l in vf.readlines() if l.strip()]
with open(prompts_file) as pf:
prompt_lines = [l.strip() for l in pf.readlines() if l.strip()]
if len(video_lines) != len(prompt_lines):
logger.error(f"Mismatch in generated files: {len(video_lines)} videos vs {len(prompt_lines)} prompts")
return None, None
return videos_file, prompts_file
def copy_files_to_training_dir(prompt_prefix: str) -> int:
"""Just copy files over, with no destruction"""
gr.Info("Copying assets to the training dataset..")
# Find files needing captions
video_files = list(STAGING_PATH.glob("*.mp4"))
image_files = [f for f in STAGING_PATH.glob("*") if is_image_file(f)]
all_files = video_files + image_files
nb_copied_pairs = 0
for file_path in all_files:
caption = ""
file_caption_path = file_path.with_suffix('.txt')
if file_caption_path.exists():
logger.debug(f"Found caption file: {file_caption_path}")
caption = file_caption_path.read_text()
# Get parent caption if this is a clip
parent_caption = ""
if "___" in file_path.stem:
parent_name, _ = extract_scene_info(file_path.stem)
#print(f"parent_name is {parent_name}")
parent_caption_path = STAGING_PATH / f"{parent_name}.txt"
if parent_caption_path.exists():
logger.debug(f"Found parent caption file: {parent_caption_path}")
parent_caption = parent_caption_path.read_text().strip()
target_file_path = TRAINING_VIDEOS_PATH / file_path.name
target_caption_path = target_file_path.with_suffix('.txt')
if parent_caption and not caption.endswith(parent_caption):
caption = f"{caption}\n{parent_caption}"
# Add FPS information for videos
if is_video_file(file_path) and caption:
# Only add FPS if not already present
if not any(f"FPS, " in line for line in caption.split('\n')):
fps_info = get_video_fps(file_path)
if fps_info:
caption = f"{fps_info}{caption}"
if prompt_prefix and not caption.startswith(prompt_prefix):
caption = f"{prompt_prefix}{caption}"
# make sure we only copy over VALID pairs
if caption:
try:
target_caption_path.write_text(caption)
shutil.copy2(file_path, target_file_path)
nb_copied_pairs += 1
except Exception as e:
print(f"failed to copy one of the pairs: {e}")
pass
prepare_finetrainers_dataset()
gr.Info(f"Successfully generated the training dataset ({nb_copied_pairs} pairs)")
return nb_copied_pairs
# Add this function to finetrainers_utils.py or a suitable place
def create_validation_config() -> Optional[Path]:
"""Create a validation configuration JSON file for Finetrainers
Creates a validation dataset file with a subset of the training data
Returns:
Path to the validation JSON file, or None if no training files exist
"""
# Ensure training dataset exists
if not TRAINING_VIDEOS_PATH.exists() or not any(TRAINING_VIDEOS_PATH.glob("*.mp4")):
logger.warning("No training videos found for validation")
return None
# Get a subset of the training videos (up to 4) for validation
training_videos = list(TRAINING_VIDEOS_PATH.glob("*.mp4"))
validation_videos = training_videos[:min(4, len(training_videos))]
if not validation_videos:
logger.warning("No validation videos selected")
return None
# Create validation data entries
validation_data = {"data": []}
for video_path in validation_videos:
# Get caption from matching text file
caption_path = video_path.with_suffix('.txt')
if not caption_path.exists():
logger.warning(f"Missing caption for {video_path}, skipping for validation")
continue
caption = caption_path.read_text().strip()
# Get video dimensions and properties
try:
# Use the most common default resolution and settings
data_entry = {
"caption": caption,
"image_path": "", # No input image for text-to-video
"video_path": str(video_path),
"num_inference_steps": DEFAULT_VALIDATION_NB_STEPS,
"height": DEFAULT_VALIDATION_HEIGHT,
"width": DEFAULT_VALIDATION_WIDTH,
"num_frames": DEFAULT_VALIDATION_NB_FRAMES,
"frame_rate": DEFAULT_VALIDATION_FRAMERATE
}
validation_data["data"].append(data_entry)
except Exception as e:
logger.warning(f"Error adding validation entry for {video_path}: {e}")
if not validation_data["data"]:
logger.warning("No valid validation entries created")
return None
# Write validation config to file
validation_file = OUTPUT_PATH / "validation_config.json"
with open(validation_file, 'w') as f:
json.dump(validation_data, f, indent=2)
logger.info(f"Created validation config with {len(validation_data['data'])} entries")
return validation_file