import os import sys import logging import torch from tqdm import tqdm from omegaconf import OmegaConf from ovi.utils.io_utils import save_video from ovi.utils.processing_utils import format_prompt_for_filename, validate_and_process_user_prompt from ovi.utils.utils import get_arguments from ovi.distributed_comms.util import get_world_size, get_local_rank, get_global_rank from ovi.distributed_comms.parallel_states import initialize_sequence_parallel_state, get_sequence_parallel_state, nccl_info from ovi.ovi_fusion_engine import OviFusionEngine def _init_logging(rank): # logging if rank == 0: # set format logging.basicConfig( level=logging.INFO, format="[%(asctime)s] %(levelname)s: %(message)s", handlers=[logging.StreamHandler(stream=sys.stdout)]) else: logging.basicConfig(level=logging.ERROR) def main(config, args): world_size = get_world_size() global_rank = get_global_rank() local_rank = get_local_rank() device = local_rank torch.cuda.set_device(local_rank) sp_size = config.get("sp_size", 1) assert sp_size <= world_size and world_size % sp_size == 0, "sp_size must be less than or equal to world_size and world_size must be divisible by sp_size." _init_logging(global_rank) if world_size > 1: torch.distributed.init_process_group( backend="nccl", init_method="env://", rank=global_rank, world_size=world_size) else: assert sp_size == 1, f"When world_size is 1, sp_size must also be 1, but got {sp_size}." ## TODO: assert not sharding t5 etc... initialize_sequence_parallel_state(sp_size) logging.info(f"Using SP: {get_sequence_parallel_state()}, SP_SIZE: {sp_size}") args.local_rank = local_rank args.device = device target_dtype = torch.bfloat16 # validate inputs before loading model to not waste time if input is not valid text_prompt = config.get("text_prompt") image_path = config.get("image_path", None) assert config.get("mode") in ["t2v", "i2v", "t2i2v"], f"Invalid mode {config.get('mode')}, must be one of ['t2v', 'i2v', 't2i2v']" text_prompts, image_paths = validate_and_process_user_prompt(text_prompt, image_path, mode=config.get("mode")) if config.get("mode") != "i2v": logging.info(f"mode: {config.get('mode')}, setting all image_paths to None") image_paths = [None] * len(text_prompts) else: assert all(p is not None and os.path.isfile(p) for p in image_paths), f"In i2v mode, all image paths must be provided.{image_paths}" logging.info("Loading OVI Fusion Engine...") ovi_engine = OviFusionEngine(config=config, device=device, target_dtype=target_dtype) logging.info("OVI Fusion Engine loaded!") output_dir = config.get("output_dir", "./outputs") os.makedirs(output_dir, exist_ok=True) # Load CSV data all_eval_data = list(zip(text_prompts, image_paths)) # Get SP configuration use_sp = get_sequence_parallel_state() if use_sp: sp_size = nccl_info.sp_size sp_rank = nccl_info.rank_within_group sp_group_id = global_rank // sp_size num_sp_groups = world_size // sp_size else: # No SP: treat each GPU as its own group sp_size = 1 sp_rank = 0 sp_group_id = global_rank num_sp_groups = world_size # Data distribution - by SP groups total_files = len(all_eval_data) require_sample_padding = False if total_files == 0: logging.error(f"ERROR: No evaluation files found") this_rank_eval_data = [] else: # Pad to match number of SP groups remainder = total_files % num_sp_groups if require_sample_padding and remainder != 0: pad_count = num_sp_groups - remainder all_eval_data += [all_eval_data[0]] * pad_count # Distribute across SP groups this_rank_eval_data = all_eval_data[sp_group_id :: num_sp_groups] for _, (text_prompt, image_path) in tqdm(enumerate(this_rank_eval_data)): video_frame_height_width = config.get("video_frame_height_width", None) seed = config.get("seed", 100) solver_name = config.get("solver_name", "unipc") sample_steps = config.get("sample_steps", 50) shift = config.get("shift", 5.0) video_guidance_scale = config.get("video_guidance_scale", 4.0) audio_guidance_scale = config.get("audio_guidance_scale", 3.0) slg_layer = config.get("slg_layer", 11) video_negative_prompt = config.get("video_negative_prompt", "") audio_negative_prompt = config.get("audio_negative_prompt", "") for idx in range(config.get("each_example_n_times", 1)): generated_video, generated_audio, generated_image = ovi_engine.generate(text_prompt=text_prompt, image_path=image_path, video_frame_height_width=video_frame_height_width, seed=seed+idx, solver_name=solver_name, sample_steps=sample_steps, shift=shift, video_guidance_scale=video_guidance_scale, audio_guidance_scale=audio_guidance_scale, slg_layer=slg_layer, video_negative_prompt=video_negative_prompt, audio_negative_prompt=audio_negative_prompt) if sp_rank == 0: formatted_prompt = format_prompt_for_filename(text_prompt) output_path = os.path.join(output_dir, f"{formatted_prompt}_{'x'.join(map(str, video_frame_height_width))}_{seed+idx}_{global_rank}.mp4") save_video(output_path, generated_video, generated_audio, fps=24, sample_rate=16000) if generated_image is not None: generated_image.save(output_path.replace('.mp4', '.png')) if __name__ == "__main__": args = get_arguments() config = OmegaConf.load(args.config_file) main(config=config,args=args)