|
import gradio as gr |
|
from pathlib import Path |
|
import logging |
|
import shutil |
|
from typing import Any, Optional, Dict, List, Union, Tuple |
|
|
|
from ..config import ( |
|
STORAGE_PATH, TRAINING_PATH, STAGING_PATH, TRAINING_VIDEOS_PATH, MODEL_PATH, OUTPUT_PATH, HF_API_TOKEN, MODEL_TYPES, |
|
DEFAULT_VALIDATION_NB_STEPS, |
|
DEFAULT_VALIDATION_HEIGHT, |
|
DEFAULT_VALIDATION_WIDTH, |
|
DEFAULT_VALIDATION_NB_FRAMES, |
|
DEFAULT_VALIDATION_FRAMERATE |
|
) |
|
from .utils import get_video_fps, extract_scene_info, make_archive, is_image_file, is_video_file |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
def prepare_finetrainers_dataset() -> Tuple[Path, Path]: |
|
"""Prepare a Finetrainers-compatible dataset structure |
|
|
|
Creates: |
|
training/ |
|
├── prompt.txt # All captions, one per line |
|
├── videos.txt # All video paths, one per line |
|
└── videos/ # Directory containing all mp4 files |
|
├── 00000.mp4 |
|
├── 00001.mp4 |
|
└── ... |
|
Returns: |
|
Tuple of (videos_file_path, prompts_file_path) |
|
""" |
|
|
|
|
|
TRAINING_VIDEOS_PATH.mkdir(exist_ok=True) |
|
|
|
|
|
for f in TRAINING_PATH.glob("*"): |
|
if f.is_file(): |
|
if f.name in ["videos.txt", "prompts.txt", "prompt.txt"]: |
|
f.unlink() |
|
|
|
videos_file = TRAINING_PATH / "videos.txt" |
|
prompts_file = TRAINING_PATH / "prompts.txt" |
|
|
|
media_files = [] |
|
captions = [] |
|
|
|
|
|
for idx, file in enumerate(sorted(TRAINING_VIDEOS_PATH.glob("*.mp4"))): |
|
caption_file = file.with_suffix('.txt') |
|
if caption_file.exists(): |
|
|
|
caption = caption_file.read_text().strip() |
|
caption = ' '.join(caption.split()) |
|
|
|
|
|
relative_path = f"videos/{file.name}" |
|
media_files.append(relative_path) |
|
captions.append(caption) |
|
|
|
|
|
if media_files and captions: |
|
videos_file.write_text('\n'.join(media_files)) |
|
prompts_file.write_text('\n'.join(captions)) |
|
logger.info(f"Created dataset with {len(media_files)} video/caption pairs") |
|
else: |
|
logger.warning("No valid video/caption pairs found in training directory") |
|
return None, None |
|
|
|
|
|
with open(videos_file) as vf: |
|
video_lines = [l.strip() for l in vf.readlines() if l.strip()] |
|
with open(prompts_file) as pf: |
|
prompt_lines = [l.strip() for l in pf.readlines() if l.strip()] |
|
|
|
if len(video_lines) != len(prompt_lines): |
|
logger.error(f"Mismatch in generated files: {len(video_lines)} videos vs {len(prompt_lines)} prompts") |
|
return None, None |
|
|
|
return videos_file, prompts_file |
|
|
|
def copy_files_to_training_dir(prompt_prefix: str) -> int: |
|
"""Just copy files over, with no destruction""" |
|
|
|
gr.Info("Copying assets to the training dataset..") |
|
|
|
|
|
video_files = list(STAGING_PATH.glob("*.mp4")) |
|
image_files = [f for f in STAGING_PATH.glob("*") if is_image_file(f)] |
|
all_files = video_files + image_files |
|
|
|
nb_copied_pairs = 0 |
|
|
|
for file_path in all_files: |
|
|
|
caption = "" |
|
file_caption_path = file_path.with_suffix('.txt') |
|
if file_caption_path.exists(): |
|
logger.debug(f"Found caption file: {file_caption_path}") |
|
caption = file_caption_path.read_text() |
|
|
|
|
|
parent_caption = "" |
|
if "___" in file_path.stem: |
|
parent_name, _ = extract_scene_info(file_path.stem) |
|
|
|
parent_caption_path = STAGING_PATH / f"{parent_name}.txt" |
|
if parent_caption_path.exists(): |
|
logger.debug(f"Found parent caption file: {parent_caption_path}") |
|
parent_caption = parent_caption_path.read_text().strip() |
|
|
|
target_file_path = TRAINING_VIDEOS_PATH / file_path.name |
|
|
|
target_caption_path = target_file_path.with_suffix('.txt') |
|
|
|
if parent_caption and not caption.endswith(parent_caption): |
|
caption = f"{caption}\n{parent_caption}" |
|
|
|
|
|
if is_video_file(file_path) and caption: |
|
|
|
if not any(f"FPS, " in line for line in caption.split('\n')): |
|
fps_info = get_video_fps(file_path) |
|
if fps_info: |
|
caption = f"{fps_info}{caption}" |
|
|
|
if prompt_prefix and not caption.startswith(prompt_prefix): |
|
caption = f"{prompt_prefix}{caption}" |
|
|
|
|
|
if caption: |
|
try: |
|
target_caption_path.write_text(caption) |
|
shutil.copy2(file_path, target_file_path) |
|
nb_copied_pairs += 1 |
|
except Exception as e: |
|
print(f"failed to copy one of the pairs: {e}") |
|
pass |
|
|
|
prepare_finetrainers_dataset() |
|
|
|
gr.Info(f"Successfully generated the training dataset ({nb_copied_pairs} pairs)") |
|
|
|
return nb_copied_pairs |
|
|
|
|
|
|
|
def create_validation_config() -> Optional[Path]: |
|
"""Create a validation configuration JSON file for Finetrainers |
|
|
|
Creates a validation dataset file with a subset of the training data |
|
|
|
Returns: |
|
Path to the validation JSON file, or None if no training files exist |
|
""" |
|
|
|
if not TRAINING_VIDEOS_PATH.exists() or not any(TRAINING_VIDEOS_PATH.glob("*.mp4")): |
|
logger.warning("No training videos found for validation") |
|
return None |
|
|
|
|
|
training_videos = list(TRAINING_VIDEOS_PATH.glob("*.mp4")) |
|
validation_videos = training_videos[:min(4, len(training_videos))] |
|
|
|
if not validation_videos: |
|
logger.warning("No validation videos selected") |
|
return None |
|
|
|
|
|
validation_data = {"data": []} |
|
|
|
for video_path in validation_videos: |
|
|
|
caption_path = video_path.with_suffix('.txt') |
|
if not caption_path.exists(): |
|
logger.warning(f"Missing caption for {video_path}, skipping for validation") |
|
continue |
|
|
|
caption = caption_path.read_text().strip() |
|
|
|
|
|
try: |
|
|
|
data_entry = { |
|
"caption": caption, |
|
"image_path": "", |
|
"video_path": str(video_path), |
|
"num_inference_steps": DEFAULT_VALIDATION_NB_STEPS, |
|
"height": DEFAULT_VALIDATION_HEIGHT, |
|
"width": DEFAULT_VALIDATION_WIDTH, |
|
"num_frames": DEFAULT_VALIDATION_NB_FRAMES, |
|
"frame_rate": DEFAULT_VALIDATION_FRAMERATE |
|
} |
|
validation_data["data"].append(data_entry) |
|
except Exception as e: |
|
logger.warning(f"Error adding validation entry for {video_path}: {e}") |
|
|
|
if not validation_data["data"]: |
|
logger.warning("No valid validation entries created") |
|
return None |
|
|
|
|
|
validation_file = OUTPUT_PATH / "validation_config.json" |
|
with open(validation_file, 'w') as f: |
|
json.dump(validation_data, f, indent=2) |
|
|
|
logger.info(f"Created validation config with {len(validation_data['data'])} entries") |
|
return validation_file |
|
|