import spaces from huggingface_hub import snapshot_download, hf_hub_download import os import subprocess import importlib, site from PIL import Image import uuid import shutil # Re-discover all .pth/.egg-link files for sitedir in site.getsitepackages(): site.addsitedir(sitedir) # Clear caches so importlib will pick up new modules importlib.invalidate_caches() def sh(cmd): subprocess.check_call(cmd, shell=True) flash_attention_installed = False # FlashAttention 설치 시도 - 실패해도 계속 진행 try: print("Attempting to download and install FlashAttention wheel...") flash_attention_wheel = hf_hub_download( repo_id="alexnasa/flash-attn-3", repo_type="model", filename="128/flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl", ) sh(f"pip install {flash_attention_wheel}") # tell Python to re-scan site-packages now that the egg-link exists import importlib, site; site.addsitedir(site.getsitepackages()[0]); importlib.invalidate_caches() flash_attention_installed = True print("FlashAttention installed successfully.") except Exception as e: print(f"⚠️ Could not install FlashAttention: {e}") print("Continuing without FlashAttention...") # ===== CRITICAL FIX: attention.py 파일 패치 ===== attention_file = "/home/user/app/ovi/modules/attention.py" if os.path.exists(attention_file): try: with open(attention_file, 'r') as f: content = f.read() # FLASH_ATTN_3_AVAILABLE 변수가 정의되지 않은 경우를 위한 패치 if 'FLASH_ATTN_3_AVAILABLE' not in content.split('try:')[0]: # 파일 시작 부분에 변수 초기화 추가 patched_content = f"FLASH_ATTN_3_AVAILABLE = False\n\n{content}" with open(attention_file, 'w') as f: f.write(patched_content) print("✓ Successfully patched attention.py") except Exception as e: print(f"⚠️ Could not patch attention.py: {e}") # ===== END FIX ===== import torch print(f"Torch version: {torch.__version__}") print(f"FlashAttention available: {flash_attention_installed}") os.environ["PROCESSED_RESULTS"] = f"{os.getcwd()}/processed_results" import gradio as gr import argparse from ovi.ovi_fusion_engine import OviFusionEngine, DEFAULT_CONFIG from diffusers import FluxPipeline import tempfile from ovi.utils.io_utils import save_video from ovi.utils.processing_utils import clean_text, scale_hw_to_area_divisible # ---------------------------- # Parse CLI Args # ---------------------------- parser = argparse.ArgumentParser(description="Ovi Joint Video + Audio Gradio Demo") parser.add_argument( "--cpu_offload", action="store_true", help="Enable CPU offload for both OviFusionEngine and FluxPipeline" ) args = parser.parse_args() ckpt_dir = "./ckpts" # Wan2.2 wan_dir = os.path.join(ckpt_dir, "Wan2.2-TI2V-5B") snapshot_download( repo_id="Wan-AI/Wan2.2-TI2V-5B", local_dir=wan_dir, allow_patterns=[ "google/*", "models_t5_umt5-xxl-enc-bf16.pth", "Wan2.2_VAE.pth" ] ) # MMAudio mm_audio_dir = os.path.join(ckpt_dir, "MMAudio") snapshot_download( repo_id="hkchengrex/MMAudio", local_dir=mm_audio_dir, allow_patterns=[ "ext_weights/best_netG.pt", "ext_weights/v1-16.pth" ] ) ovi_dir = os.path.join(ckpt_dir, "Ovi") snapshot_download( repo_id="chetwinlow1/Ovi", local_dir=ovi_dir, allow_patterns=[ "model.safetensors" ] ) # Initialize OviFusionEngine enable_cpu_offload = args.cpu_offload print(f"loading model...") DEFAULT_CONFIG['cpu_offload'] = enable_cpu_offload DEFAULT_CONFIG['mode'] = "t2v" ovi_engine = OviFusionEngine() print("loaded model") def resize_for_model(image_path): img = Image.open(image_path) w, h = img.size aspect_ratio = w / h if aspect_ratio > 1.5: target_size = (992, 512) elif aspect_ratio < 0.66: target_size = (512, 992) else: target_size = (512, 512) img.thumbnail(target_size, Image.Resampling.LANCZOS) new_img = Image.new("RGB", target_size, (0, 0, 0)) new_img.paste( img, ((target_size[0] - img.size[0]) // 2, (target_size[1] - img.size[1]) // 2) ) return new_img, target_size def generate_scene( text_prompt, image, sample_steps = 50, session_id = None, video_seed = 100, solver_name = "unipc", shift = 5, video_guidance_scale = 4, audio_guidance_scale = 3, slg_layer = 11, video_negative_prompt = "", audio_negative_prompt = "", progress=gr.Progress(track_tqdm=True) ): text_prompt_processed = (text_prompt or "").strip() if not image: raise gr.Error("Please provide an image") if not text_prompt_processed: raise gr.Error("Please enter a prompt.") return generate_video(text_prompt, image, sample_steps, session_id, video_seed, solver_name, shift, video_guidance_scale, audio_guidance_scale, slg_layer, video_negative_prompt, audio_negative_prompt, progress) def get_duration( text_prompt, image, sample_steps, session_id, video_seed, solver_name, shif, video_guidance_scale, audio_guidance_scale, slg_layer, video_negative_prompt, audio_negative_prompt, progress, ): warmup = 20 return int(sample_steps * 3 + warmup) @spaces.GPU(duration=get_duration) def generate_video( text_prompt, image, sample_steps = 50, session_id = None, video_seed = 100, solver_name = "unipc", shift = 5, video_guidance_scale = 4, audio_guidance_scale = 3, slg_layer = 11, video_negative_prompt = "", audio_negative_prompt = "", progress=gr.Progress(track_tqdm=True) ): try: image_path = None if image is not None: image_path = image if session_id is None: session_id = uuid.uuid4().hex output_dir = os.path.join(os.environ["PROCESSED_RESULTS"], session_id) os.makedirs(output_dir, exist_ok=True) output_path = os.path.join(output_dir, f"generated_video.mp4") _, target_size = resize_for_model(image_path) video_frame_width = target_size[0] video_frame_height = target_size[1] generated_video, generated_audio, _ = ovi_engine.generate( text_prompt=text_prompt, image_path=image_path, video_frame_height_width=[video_frame_height, video_frame_width], seed=video_seed, solver_name=solver_name, sample_steps=sample_steps, shift=shift, video_guidance_scale=video_guidance_scale, audio_guidance_scale=audio_guidance_scale, slg_layer=slg_layer, video_negative_prompt=video_negative_prompt, audio_negative_prompt=audio_negative_prompt, ) save_video(output_path, generated_video, generated_audio, fps=24, sample_rate=16000) return output_path except Exception as e: print(f"Error during video generation: {e}") return None def cleanup(request: gr.Request): sid = request.session_hash if sid: d1 = os.path.join(os.environ["PROCESSED_RESULTS"], sid) shutil.rmtree(d1, ignore_errors=True) def start_session(request: gr.Request): return request.session_hash css = """ #col-container { margin: 0 auto; max-width: 1024px; } """ theme = gr.themes.Ocean() with gr.Blocks(css=css, theme=theme) as demo: session_state = gr.State() demo.load(start_session, outputs=[session_state]) with gr.Column(elem_id="col-container"): gr.HTML( """
🎥 Ovi – Twin Backbone Cross-Modal Fusion for Audio-Video Generation
[model]