jbilcke-hf HF Staff commited on
Commit
37a6639
Β·
1 Parent(s): dd2d897

big refactoring

Browse files
Files changed (1) hide show
  1. app.py +455 -117
app.py CHANGED
@@ -1,21 +1,53 @@
1
  import gradio as gr
2
- import subprocess
3
  import os
 
4
  import tempfile
5
  import shutil
6
  from pathlib import Path
7
  import torch
8
  import logging
9
  from huggingface_hub import snapshot_download
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  # Configure logging
12
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
13
  logger = logging.getLogger(__name__)
14
 
15
  # Constants
 
16
  DEFAULT_CONFIG_PATH = "configs/inference_1.3B.yaml"
17
- DEFAULT_INPUT_FILE = "examples/infer_samples.txt"
18
- MODELS_DIR = Path("pretrained_models")
 
 
 
 
 
19
 
20
  def download_models():
21
  """Download required models if they don't exist"""
@@ -61,17 +93,344 @@ def download_models():
61
  logger.error(f"Failed to download {model['name']}: {str(e)}")
62
  raise gr.Error(f"Failed to download {model['name']}: {str(e)}")
63
 
64
- # Initialize models on module import (for Hugging Face Spaces)
65
- logger.info("Initializing OmniAvatar...")
66
- logger.info("Checking and downloading required models...")
67
- download_models()
68
- logger.info("Model initialization complete")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
  def generate_avatar_video(
71
  reference_image,
72
  audio_file,
73
  text_prompt,
74
- seed=42,
 
75
  num_steps=15,
76
  guidance_scale=4.5,
77
  audio_scale=None,
@@ -81,144 +440,116 @@ def generate_avatar_video(
81
  resolution="720p",
82
  progress=gr.Progress()
83
  ):
84
- """Generate an avatar video using OmniAvatar
85
-
86
- Args:
87
- reference_image: Path to reference avatar image
88
- audio_file: Path to audio file for lip sync
89
- text_prompt: Text description of the video to generate
90
- seed: Random seed for generation
91
- num_steps: Number of inference steps
92
- guidance_scale: Classifier-free guidance scale
93
- audio_scale: Audio guidance scale (uses guidance_scale if None)
94
- overlap_frames: Number of overlapping frames between chunks
95
- fps: Frames per second
96
- silence_duration: Duration of silence to add before/after audio
97
- resolution: Output resolution ("480p" or "720p")
98
- progress: Gradio progress callback
99
-
100
- Returns:
101
- str: Path to generated video file
102
- """
103
 
104
  try:
105
- progress(0.1, desc="Preparing inputs")
 
 
 
 
 
 
 
 
106
 
107
- # Create temporary directory for this generation
108
  with tempfile.TemporaryDirectory() as temp_dir:
109
  temp_path = Path(temp_dir)
110
 
 
 
111
  # Copy input files to temp directory
112
  temp_image = temp_path / "input_image.jpeg"
113
  temp_audio = temp_path / "input_audio.mp3"
114
  shutil.copy(reference_image, temp_image)
115
  shutil.copy(audio_file, temp_audio)
116
 
117
- # Create input file for inference script
118
- input_file = temp_path / "input.txt"
119
- # Format: prompt@@image_path@@audio_path
120
- with open(input_file, 'w') as f:
121
- f.write(f"{text_prompt}@@{temp_image}@@{temp_audio}\n")
122
-
123
- progress(0.2, desc="Configuring generation parameters")
124
 
125
- # Determine max_hw based on resolution
126
- max_hw = 720 if resolution == "480p" else 1280
127
 
128
- # Build command to run inference script
129
- cmd = [
130
- "torchrun",
131
- "--nproc_per_node=1",
132
- "scripts/inference.py",
133
- "--config", DEFAULT_CONFIG_PATH,
134
- "--input_file", str(input_file),
135
- "-hp", f"seed={seed},num_steps={num_steps},guidance_scale={guidance_scale},"
136
- f"overlap_frame={overlap_frames},fps={fps},silence_duration_s={silence_duration},"
137
- f"max_hw={max_hw},use_audio=True,i2v=True"
138
- ]
139
 
140
- # Add audio scale if specified
141
- if audio_scale is not None:
142
- cmd[-1] += f",audio_scale={audio_scale}"
143
 
144
- progress(0.3, desc="Running OmniAvatar generation")
145
- logger.info(f"Running command: {' '.join(cmd)}")
146
-
147
- # Run the inference script
148
- env = os.environ.copy()
149
- env['CUDA_VISIBLE_DEVICES'] = '0' # Use first GPU
150
-
151
- process = subprocess.Popen(
152
- cmd,
153
- stdout=subprocess.PIPE,
154
- stderr=subprocess.PIPE,
155
- text=True,
156
- env=env
157
  )
158
 
159
- # Monitor progress (simplified - in reality you'd parse the output)
160
- stdout_lines = []
161
- stderr_lines = []
162
 
163
- while True:
164
- output = process.stdout.readline()
165
- if output:
166
- stdout_lines.append(output.strip())
167
- logger.info(output.strip())
168
-
169
- # Update progress based on output
170
- if "Starting video generation" in output:
171
- progress(0.5, desc="Generating video frames")
172
- elif "[1/" in output: # First chunk
173
- progress(0.6, desc="Processing video chunks")
174
- elif "Saving video" in output:
175
- progress(0.9, desc="Finalizing video")
176
-
177
- if process.poll() is not None:
178
- break
179
 
180
- # Get any remaining output
181
- remaining_stdout, remaining_stderr = process.communicate()
182
- if remaining_stdout:
183
- stdout_lines.extend(remaining_stdout.strip().split('\n'))
184
- if remaining_stderr:
185
- stderr_lines.extend(remaining_stderr.strip().split('\n'))
186
 
187
- if process.returncode != 0:
188
- error_msg = '\n'.join(stderr_lines)
189
- logger.error(f"Inference failed with return code {process.returncode}")
190
- logger.error(f"Error output: {error_msg}")
191
- raise gr.Error(f"Video generation failed: {error_msg}")
 
 
 
 
192
 
193
- progress(0.95, desc="Retrieving generated video")
194
 
195
  # Find the generated video file
196
- # The inference script saves to demo_out/{exp_name}/res_{input_file_name}_...
197
- # We need to find the most recent video file
198
- generated_videos = list(Path("demo_out").rglob("result_000.mp4"))
 
 
199
  if not generated_videos:
200
  raise gr.Error("No video file was generated")
201
 
202
- # Get the most recent video
203
- latest_video = max(generated_videos, key=lambda p: p.stat().st_mtime)
204
 
205
- # Create a temporary file for the output video
206
- # This file will persist beyond the context manager since we're using delete=False
207
  with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp_output:
208
  output_path = tmp_output.name
209
 
210
- # Copy the generated video to the temporary file
211
  shutil.copy(latest_video, output_path)
212
 
213
  progress(1.0, desc="Generation complete")
214
- logger.info(f"Video saved to temporary path: {output_path}")
215
 
216
- return output_path
217
 
218
  except Exception as e:
219
- logger.error(f"Error generating video: {str(e)}")
220
  raise gr.Error(f"Error generating video: {str(e)}")
221
 
 
 
 
 
 
 
222
  # Create the Gradio interface
223
  with gr.Blocks(title="OmniAvatar - Lipsynced Avatar Video Generation") as app:
224
  gr.Markdown("""
@@ -252,12 +583,17 @@ with gr.Blocks(title="OmniAvatar - Lipsynced Avatar Video Generation") as app:
252
 
253
  with gr.Accordion("Advanced Settings", open=False):
254
  with gr.Row():
 
 
 
 
 
255
  seed = gr.Slider(
256
- label="Seed",
257
- minimum=-1,
258
  maximum=2147483647,
259
  step=1,
260
- value=-1
261
  )
262
 
263
  resolution = gr.Radio(
@@ -322,6 +658,12 @@ with gr.Blocks(title="OmniAvatar - Lipsynced Avatar Video Generation") as app:
322
  "🎬 Generate Avatar Video",
323
  variant="primary"
324
  )
 
 
 
 
 
 
325
 
326
  with gr.Column(scale=1):
327
  # Output component
@@ -351,6 +693,7 @@ with gr.Blocks(title="OmniAvatar - Lipsynced Avatar Video Generation") as app:
351
  audio_file,
352
  text_prompt,
353
  seed,
 
354
  num_steps,
355
  guidance_scale,
356
  audio_scale,
@@ -359,7 +702,7 @@ with gr.Blocks(title="OmniAvatar - Lipsynced Avatar Video Generation") as app:
359
  silence_duration,
360
  resolution
361
  ],
362
- outputs=output_video
363
  )
364
 
365
  gr.Markdown("""
@@ -372,9 +715,4 @@ with gr.Blocks(title="OmniAvatar - Lipsynced Avatar Video Generation") as app:
372
 
373
  # Launch the app
374
  if __name__ == "__main__":
375
- # Download models on startup
376
- logger.info("Checking and downloading required models...")
377
- download_models()
378
- logger.info("Model download complete, launching app...")
379
-
380
  app.launch(share=True)
 
1
  import gradio as gr
 
2
  import os
3
+ import sys
4
  import tempfile
5
  import shutil
6
  from pathlib import Path
7
  import torch
8
  import logging
9
  from huggingface_hub import snapshot_download
10
+ import math
11
+ import random
12
+ import librosa
13
+ import numpy as np
14
+ import torch.nn as nn
15
+ from tqdm import tqdm
16
+ from functools import partial
17
+ from datetime import datetime
18
+ import torchvision.transforms as TT
19
+ from transformers import Wav2Vec2FeatureExtractor
20
+ import torchvision.transforms as transforms
21
+ import torch.nn.functional as F
22
+ from glob import glob
23
+
24
+ # Add parent directory to path for imports
25
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
26
+
27
+ from OmniAvatar.utils.args_config import parse_args
28
+ from OmniAvatar.utils.io_utils import load_state_dict
29
+ from peft import LoraConfig, inject_adapter_in_model
30
+ from OmniAvatar.models.model_manager import ModelManager
31
+ from OmniAvatar.wan_video import WanVideoPipeline
32
+ from OmniAvatar.utils.io_utils import save_video_as_grid_and_mp4
33
+ import torch.distributed as dist
34
+ from OmniAvatar.utils.audio_preprocess import add_silence_to_audio_ffmpeg
35
+ from OmniAvatar.distributed.fsdp import shard_model
36
 
37
  # Configure logging
38
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
39
  logger = logging.getLogger(__name__)
40
 
41
  # Constants
42
+ MODELS_DIR = Path(os.environ.get('MODELS_DIR', 'pretrained_models'))
43
  DEFAULT_CONFIG_PATH = "configs/inference_1.3B.yaml"
44
+
45
+ def set_seed(seed: int = 42):
46
+ random.seed(seed)
47
+ np.random.seed(seed)
48
+ torch.manual_seed(seed)
49
+ torch.cuda.manual_seed(seed)
50
+ torch.cuda.manual_seed_all(seed)
51
 
52
  def download_models():
53
  """Download required models if they don't exist"""
 
93
  logger.error(f"Failed to download {model['name']}: {str(e)}")
94
  raise gr.Error(f"Failed to download {model['name']}: {str(e)}")
95
 
96
+ # Utility functions from inference.py
97
+ def match_size(image_size, h, w):
98
+ ratio_ = 9999
99
+ size_ = 9999
100
+ select_size = None
101
+ for image_s in image_size:
102
+ ratio_tmp = abs(image_s[0] / image_s[1] - h / w)
103
+ size_tmp = abs(max(image_s) - max(w, h))
104
+ if ratio_tmp < ratio_:
105
+ ratio_ = ratio_tmp
106
+ size_ = size_tmp
107
+ select_size = image_s
108
+ if ratio_ == ratio_tmp:
109
+ if size_ == size_tmp:
110
+ select_size = image_s
111
+ return select_size
112
+
113
+ def resize_pad(image, ori_size, tgt_size):
114
+ h, w = ori_size
115
+ scale_ratio = max(tgt_size[0] / h, tgt_size[1] / w)
116
+ scale_h = int(h * scale_ratio)
117
+ scale_w = int(w * scale_ratio)
118
+
119
+ image = transforms.Resize(size=[scale_h, scale_w])(image)
120
+
121
+ padding_h = tgt_size[0] - scale_h
122
+ padding_w = tgt_size[1] - scale_w
123
+ pad_top = padding_h // 2
124
+ pad_bottom = padding_h - pad_top
125
+ pad_left = padding_w // 2
126
+ pad_right = padding_w - pad_left
127
+
128
+ image = F.pad(image, (pad_left, pad_right, pad_top, pad_bottom), mode='constant', value=0)
129
+ return image
130
+
131
+ class WanInferencePipeline(nn.Module):
132
+ def __init__(self, args):
133
+ super().__init__()
134
+ self.args = args
135
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
136
+ if args.dtype=='bf16':
137
+ self.dtype = torch.bfloat16
138
+ elif args.dtype=='fp16':
139
+ self.dtype = torch.float16
140
+ else:
141
+ self.dtype = torch.float32
142
+ self.pipe = self.load_model()
143
+ if args.i2v:
144
+ chained_trainsforms = []
145
+ chained_trainsforms.append(TT.ToTensor())
146
+ self.transform = TT.Compose(chained_trainsforms)
147
+ if args.use_audio:
148
+ from OmniAvatar.models.wav2vec import Wav2VecModel
149
+ self.wav_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
150
+ str(MODELS_DIR / "wav2vec2-base-960h")
151
+ )
152
+ self.audio_encoder = Wav2VecModel.from_pretrained(str(MODELS_DIR / "wav2vec2-base-960h"), local_files_only=True).to(device=self.device)
153
+ self.audio_encoder.feature_extractor._freeze_parameters()
154
+
155
+ def load_model(self):
156
+ # Initialize for single GPU
157
+ os.environ['MASTER_ADDR'] = 'localhost'
158
+ os.environ['MASTER_PORT'] = '12355'
159
+ os.environ['RANK'] = '0'
160
+ os.environ['WORLD_SIZE'] = '1'
161
+
162
+ dist.init_process_group(backend="nccl", init_method="env://")
163
+
164
+ from xfuser.core.distributed import (initialize_model_parallel,
165
+ init_distributed_environment)
166
+ init_distributed_environment(rank=0, world_size=1)
167
+ initialize_model_parallel(
168
+ sequence_parallel_degree=self.args.sp_size,
169
+ ring_degree=1,
170
+ ulysses_degree=self.args.sp_size,
171
+ )
172
+ torch.cuda.set_device(0)
173
+
174
+ ckpt_path = f'{self.args.exp_path}/pytorch_model.pt'
175
+ assert os.path.exists(ckpt_path), f"pytorch_model.pt not found in {self.args.exp_path}"
176
+ if self.args.train_architecture == 'lora':
177
+ self.args.pretrained_lora_path = pretrained_lora_path = ckpt_path
178
+ else:
179
+ resume_path = ckpt_path
180
+
181
+ self.step = 0
182
+
183
+ # Load models
184
+ model_manager = ModelManager(device="cpu", infer=True)
185
+ model_manager.load_models(
186
+ [
187
+ self.args.dit_path.split(","),
188
+ self.args.text_encoder_path,
189
+ self.args.vae_path
190
+ ],
191
+ torch_dtype=self.dtype,
192
+ device='cpu',
193
+ )
194
+
195
+ pipe = WanVideoPipeline.from_model_manager(model_manager,
196
+ torch_dtype=self.dtype,
197
+ device=str(self.device),
198
+ use_usp=True if self.args.sp_size > 1 else False,
199
+ infer=True)
200
+ if self.args.train_architecture == "lora":
201
+ logger.info(f'Use LoRA: lora rank: {self.args.lora_rank}, lora alpha: {self.args.lora_alpha}')
202
+ self.add_lora_to_model(
203
+ pipe.denoising_model(),
204
+ lora_rank=self.args.lora_rank,
205
+ lora_alpha=self.args.lora_alpha,
206
+ lora_target_modules=self.args.lora_target_modules,
207
+ init_lora_weights=self.args.init_lora_weights,
208
+ pretrained_lora_path=pretrained_lora_path,
209
+ )
210
+ else:
211
+ missing_keys, unexpected_keys = pipe.denoising_model().load_state_dict(load_state_dict(resume_path), strict=True)
212
+ logger.info(f"load from {resume_path}, {len(missing_keys)} missing keys, {len(unexpected_keys)} unexpected keys")
213
+ pipe.requires_grad_(False)
214
+ pipe.eval()
215
+ pipe.enable_vram_management(num_persistent_param_in_dit=self.args.num_persistent_param_in_dit)
216
+ if self.args.use_fsdp:
217
+ shard_fn = partial(shard_model, device_id=self.device)
218
+ pipe.dit = shard_fn(pipe.dit)
219
+ return pipe
220
+
221
+ def add_lora_to_model(self, model, lora_rank=4, lora_alpha=4, lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming", pretrained_lora_path=None, state_dict_converter=None):
222
+ self.lora_alpha = lora_alpha
223
+ if init_lora_weights == "kaiming":
224
+ init_lora_weights = True
225
+
226
+ lora_config = LoraConfig(
227
+ r=lora_rank,
228
+ lora_alpha=lora_alpha,
229
+ init_lora_weights=init_lora_weights,
230
+ target_modules=lora_target_modules.split(","),
231
+ )
232
+ model = inject_adapter_in_model(lora_config, model)
233
+
234
+ if pretrained_lora_path is not None:
235
+ state_dict = load_state_dict(pretrained_lora_path)
236
+ if state_dict_converter is not None:
237
+ state_dict = state_dict_converter(state_dict)
238
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
239
+ all_keys = [i for i, _ in model.named_parameters()]
240
+ num_updated_keys = len(all_keys) - len(missing_keys)
241
+ num_unexpected_keys = len(unexpected_keys)
242
+ logger.info(f"{num_updated_keys} parameters are loaded from {pretrained_lora_path}. {num_unexpected_keys} parameters are unexpected.")
243
+
244
+ def forward(self, prompt,
245
+ image_path=None,
246
+ audio_path=None,
247
+ seq_len=101,
248
+ height=720,
249
+ width=720,
250
+ overlap_frame=None,
251
+ num_steps=None,
252
+ negative_prompt=None,
253
+ guidance_scale=None,
254
+ audio_scale=None):
255
+ overlap_frame = overlap_frame if overlap_frame is not None else self.args.overlap_frame
256
+ num_steps = num_steps if num_steps is not None else self.args.num_steps
257
+ negative_prompt = negative_prompt if negative_prompt is not None else self.args.negative_prompt
258
+ guidance_scale = guidance_scale if guidance_scale is not None else self.args.guidance_scale
259
+ audio_scale = audio_scale if audio_scale is not None else self.args.audio_scale
260
+
261
+ if image_path is not None:
262
+ from PIL import Image
263
+ image = Image.open(image_path).convert("RGB")
264
+ image = self.transform(image).unsqueeze(0).to(self.device)
265
+ _, _, h, w = image.shape
266
+ select_size = match_size(getattr(self.args, f'image_sizes_{self.args.max_hw}'), h, w)
267
+ image = resize_pad(image, (h, w), select_size)
268
+ image = image * 2.0 - 1.0
269
+ image = image[:, :, None]
270
+ else:
271
+ image = None
272
+ select_size = [height, width]
273
+ L = int(self.args.max_tokens * 16 * 16 * 4 / select_size[0] / select_size[1])
274
+ L = L // 4 * 4 + 1 if L % 4 != 0 else L - 3 # video frames
275
+ T = (L + 3) // 4 # latent frames
276
+
277
+ if self.args.i2v:
278
+ if self.args.random_prefix_frames:
279
+ fixed_frame = overlap_frame
280
+ assert fixed_frame % 4 == 1
281
+ else:
282
+ fixed_frame = 1
283
+ prefix_lat_frame = (3 + fixed_frame) // 4
284
+ first_fixed_frame = 1
285
+ else:
286
+ fixed_frame = 0
287
+ prefix_lat_frame = 0
288
+ first_fixed_frame = 0
289
+
290
+ if audio_path is not None and self.args.use_audio:
291
+ audio, sr = librosa.load(audio_path, sr=self.args.sample_rate)
292
+ input_values = np.squeeze(
293
+ self.wav_feature_extractor(audio, sampling_rate=16000).input_values
294
+ )
295
+ input_values = torch.from_numpy(input_values).float().to(device=self.device)
296
+ ori_audio_len = audio_len = math.ceil(len(input_values) / self.args.sample_rate * self.args.fps)
297
+ input_values = input_values.unsqueeze(0)
298
+ # padding audio
299
+ if audio_len < L - first_fixed_frame:
300
+ audio_len = audio_len + ((L - first_fixed_frame) - audio_len % (L - first_fixed_frame))
301
+ elif (audio_len - (L - first_fixed_frame)) % (L - fixed_frame) != 0:
302
+ audio_len = audio_len + ((L - fixed_frame) - (audio_len - (L - first_fixed_frame)) % (L - fixed_frame))
303
+ input_values = F.pad(input_values, (0, audio_len * int(self.args.sample_rate / self.args.fps) - input_values.shape[1]), mode='constant', value=0)
304
+ with torch.no_grad():
305
+ hidden_states = self.audio_encoder(input_values, seq_len=audio_len, output_hidden_states=True)
306
+ audio_embeddings = hidden_states.last_hidden_state
307
+ for mid_hidden_states in hidden_states.hidden_states:
308
+ audio_embeddings = torch.cat((audio_embeddings, mid_hidden_states), -1)
309
+ seq_len = audio_len
310
+ audio_embeddings = audio_embeddings.squeeze(0)
311
+ audio_prefix = torch.zeros_like(audio_embeddings[:first_fixed_frame])
312
+ else:
313
+ audio_embeddings = None
314
+
315
+ # loop
316
+ times = (seq_len - L + first_fixed_frame) // (L-fixed_frame) + 1
317
+ if times * (L-fixed_frame) + fixed_frame < seq_len:
318
+ times += 1
319
+ video = []
320
+ image_emb = {}
321
+ img_lat = None
322
+ if self.args.i2v:
323
+ self.pipe.load_models_to_device(['vae'])
324
+ img_lat = self.pipe.encode_video(image.to(dtype=self.dtype)).to(self.device)
325
+
326
+ msk = torch.zeros_like(img_lat.repeat(1, 1, T, 1, 1)[:,:1])
327
+ image_cat = img_lat.repeat(1, 1, T, 1, 1)
328
+ msk[:, :, 1:] = 1
329
+ image_emb["y"] = torch.cat([image_cat, msk], dim=1)
330
+ for t in range(times):
331
+ logger.info(f"[{t+1}/{times}]")
332
+ audio_emb = {}
333
+ if t == 0:
334
+ overlap = first_fixed_frame
335
+ else:
336
+ overlap = fixed_frame
337
+ image_emb["y"][:, -1:, :prefix_lat_frame] = 0
338
+ prefix_overlap = (3 + overlap) // 4
339
+ if audio_embeddings is not None:
340
+ if t == 0:
341
+ audio_tensor = audio_embeddings[
342
+ :min(L - overlap, audio_embeddings.shape[0])
343
+ ]
344
+ else:
345
+ audio_start = L - first_fixed_frame + (t - 1) * (L - overlap)
346
+ audio_tensor = audio_embeddings[
347
+ audio_start: min(audio_start + L - overlap, audio_embeddings.shape[0])
348
+ ]
349
+
350
+ audio_tensor = torch.cat([audio_prefix, audio_tensor], dim=0)
351
+ audio_prefix = audio_tensor[-fixed_frame:]
352
+ audio_tensor = audio_tensor.unsqueeze(0).to(device=self.device, dtype=self.dtype)
353
+ audio_emb["audio_emb"] = audio_tensor
354
+ else:
355
+ audio_prefix = None
356
+ if image is not None and img_lat is None:
357
+ self.pipe.load_models_to_device(['vae'])
358
+ img_lat = self.pipe.encode_video(image.to(dtype=self.dtype)).to(self.device)
359
+ assert img_lat.shape[2] == prefix_overlap
360
+ img_lat = torch.cat([img_lat, torch.zeros_like(img_lat[:, :, :1].repeat(1, 1, T - prefix_overlap, 1, 1))], dim=2)
361
+ frames, _, latents = self.pipe.log_video(img_lat, prompt, prefix_overlap, image_emb, audio_emb,
362
+ negative_prompt, num_inference_steps=num_steps,
363
+ cfg_scale=guidance_scale, audio_cfg_scale=audio_scale if audio_scale is not None else guidance_scale,
364
+ return_latent=True,
365
+ tea_cache_l1_thresh=self.args.tea_cache_l1_thresh,tea_cache_model_id="Wan2.1-T2V-14B")
366
+ img_lat = None
367
+ image = (frames[:, -fixed_frame:].clip(0, 1) * 2 - 1).permute(0, 2, 1, 3, 4).contiguous()
368
+ if t == 0:
369
+ video.append(frames)
370
+ else:
371
+ video.append(frames[:, overlap:])
372
+ video = torch.cat(video, dim=1)
373
+ video = video[:, :ori_audio_len + 1]
374
+ return video
375
+
376
+ # Initialize the pipeline globally
377
+ inference_pipeline = None
378
+ args_global = None
379
+
380
+ def initialize_inference_pipeline():
381
+ """Initialize the inference pipeline with arguments"""
382
+ global inference_pipeline, args_global
383
+
384
+ if inference_pipeline is not None:
385
+ return inference_pipeline
386
+
387
+ # Create a minimal args object
388
+ class Args:
389
+ def __init__(self):
390
+ self.rank = 0
391
+ self.dtype = 'bf16'
392
+ self.exp_path = str(MODELS_DIR / "OmniAvatar-1.3B")
393
+ self.dit_path = str(MODELS_DIR / "Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors")
394
+ self.text_encoder_path = str(MODELS_DIR / "Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth")
395
+ self.vae_path = str(MODELS_DIR / "Wan2.1-T2V-1.3B/Wan2.1_VAE.pth")
396
+ self.wav2vec_path = str(MODELS_DIR / "wav2vec2-base-960h")
397
+ self.train_architecture = 'lora'
398
+ self.lora_rank = 128
399
+ self.lora_alpha = 64.0
400
+ self.lora_target_modules = 'q,k,v,o,ffn.0,ffn.2'
401
+ self.init_lora_weights = 'kaiming'
402
+ self.sp_size = 1
403
+ self.num_persistent_param_in_dit = None
404
+ self.use_fsdp = False
405
+ self.i2v = True
406
+ self.use_audio = True
407
+ self.random_prefix_frames = True
408
+ self.overlap_frame = 13
409
+ self.num_steps = 15
410
+ self.negative_prompt = 'Vivid color tones, background/camera moving quickly, screen switching, subtitles and special effects, mutation, overexposed, static, blurred details, subtitles, style, work, painting, image, still, overall grayish, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn face, deformed, disfigured, malformed limbs, fingers merging, motionless image, chaotic background, three legs, crowded background with many people, walking backward'
411
+ self.guidance_scale = 4.5
412
+ self.audio_scale = 0
413
+ self.max_tokens = 30000
414
+ self.sample_rate = 16000
415
+ self.fps = 25
416
+ self.max_hw = 720
417
+ self.tea_cache_l1_thresh = 0
418
+ self.image_sizes_720 = [[400, 720], [720, 720], [720, 400]]
419
+ self.image_sizes_1280 = [[720, 720], [528, 960], [960, 528], [720, 1280], [1280, 720]]
420
+ self.seq_len = 200
421
+
422
+ args_global = Args()
423
+ logger.info("Initializing inference pipeline...")
424
+ inference_pipeline = WanInferencePipeline(args_global)
425
+ logger.info("Inference pipeline initialized successfully")
426
+ return inference_pipeline
427
 
428
  def generate_avatar_video(
429
  reference_image,
430
  audio_file,
431
  text_prompt,
432
+ seed=None,
433
+ use_random_seed=True,
434
  num_steps=15,
435
  guidance_scale=4.5,
436
  audio_scale=None,
 
440
  resolution="720p",
441
  progress=gr.Progress()
442
  ):
443
+ """Generate an avatar video using OmniAvatar"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
444
 
445
  try:
446
+ progress(0.1, desc="Initializing")
447
+
448
+ if use_random_seed or seed is None or seed == -1:
449
+ seed = random.randint(0, 2147483647)
450
+
451
+ set_seed(seed)
452
+
453
+ # Initialize pipeline if needed
454
+ pipeline = initialize_inference_pipeline()
455
 
 
456
  with tempfile.TemporaryDirectory() as temp_dir:
457
  temp_path = Path(temp_dir)
458
 
459
+ progress(0.2, desc="Preparing inputs")
460
+
461
  # Copy input files to temp directory
462
  temp_image = temp_path / "input_image.jpeg"
463
  temp_audio = temp_path / "input_audio.mp3"
464
  shutil.copy(reference_image, temp_image)
465
  shutil.copy(audio_file, temp_audio)
466
 
467
+ # Add silence to audio
468
+ if silence_duration > 0:
469
+ audio_with_silence = temp_path / "audio_with_silence.wav"
470
+ add_silence_to_audio_ffmpeg(str(temp_audio), str(audio_with_silence), silence_duration)
471
+ input_audio_path = str(audio_with_silence)
472
+ else:
473
+ input_audio_path = str(temp_audio)
474
 
475
+ progress(0.3, desc="Configuring generation parameters")
 
476
 
477
+ # Update args for this generation
478
+ args_global.seed = seed
479
+ args_global.num_steps = num_steps
480
+ args_global.guidance_scale = guidance_scale
481
+ args_global.audio_scale = audio_scale if audio_scale is not None and audio_scale > 0 else 0
482
+ args_global.overlap_frame = overlap_frames
483
+ args_global.fps = fps
484
+ args_global.silence_duration_s = silence_duration
485
+ args_global.max_hw = 720 if resolution == "480p" else 1280
 
 
486
 
487
+ progress(0.4, desc="Running OmniAvatar generation")
 
 
488
 
489
+ # Generate video
490
+ video = pipeline(
491
+ prompt=text_prompt,
492
+ image_path=str(temp_image),
493
+ audio_path=input_audio_path,
494
+ seq_len=args_global.seq_len
 
 
 
 
 
 
 
495
  )
496
 
497
+ progress(0.8, desc="Saving video")
 
 
498
 
499
+ # Create output directory in temp folder
500
+ output_dir = temp_path / "output"
501
+ output_dir.mkdir(exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
502
 
503
+ # Add audio offset for final output
504
+ audio_with_offset = temp_path / "audio_with_offset.wav"
505
+ add_silence_to_audio_ffmpeg(str(temp_audio), str(audio_with_offset), 1.0 / fps + silence_duration)
 
 
 
506
 
507
+ # Save video
508
+ save_video_as_grid_and_mp4(
509
+ video,
510
+ str(output_dir),
511
+ fps,
512
+ prompt=text_prompt,
513
+ audio_path=str(audio_with_offset) if args_global.use_audio else None,
514
+ prefix=f'result_000'
515
+ )
516
 
517
+ progress(0.9, desc="Finalizing")
518
 
519
  # Find the generated video file
520
+ generated_videos = list(output_dir.glob("result_000_*.mp4"))
521
+ if not generated_videos:
522
+ # Also check for result_000.mp4 (without suffix)
523
+ generated_videos = list(output_dir.glob("result_000.mp4"))
524
+
525
  if not generated_videos:
526
  raise gr.Error("No video file was generated")
527
 
528
+ # Get the first (and should be only) video
529
+ latest_video = generated_videos[0]
530
 
531
+ # Create a persistent temporary file for Gradio
 
532
  with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp_output:
533
  output_path = tmp_output.name
534
 
535
+ # Copy the generated video to the persistent temp file
536
  shutil.copy(latest_video, output_path)
537
 
538
  progress(1.0, desc="Generation complete")
539
+ logger.info(f"Video saved to: {output_path}")
540
 
541
+ return output_path, seed
542
 
543
  except Exception as e:
544
+ logger.error(f"Error generating video: {str(e)}", exc_info=True)
545
  raise gr.Error(f"Error generating video: {str(e)}")
546
 
547
+ # Initialize models on module import (for Hugging Face Spaces)
548
+ logger.info("Initializing OmniAvatar...")
549
+ logger.info("Checking and downloading required models...")
550
+ download_models()
551
+ logger.info("Model initialization complete")
552
+
553
  # Create the Gradio interface
554
  with gr.Blocks(title="OmniAvatar - Lipsynced Avatar Video Generation") as app:
555
  gr.Markdown("""
 
583
 
584
  with gr.Accordion("Advanced Settings", open=False):
585
  with gr.Row():
586
+ use_random_seed = gr.Checkbox(
587
+ label="Use random seed",
588
+ value=True
589
+ )
590
+
591
  seed = gr.Slider(
592
+ label="Seed (ignored if random seed is checked)",
593
+ minimum=0,
594
  maximum=2147483647,
595
  step=1,
596
+ value=42
597
  )
598
 
599
  resolution = gr.Radio(
 
658
  "🎬 Generate Avatar Video",
659
  variant="primary"
660
  )
661
+
662
+ # Add seed output display
663
+ seed_output = gr.Number(
664
+ label="Seed used",
665
+ interactive=False
666
+ )
667
 
668
  with gr.Column(scale=1):
669
  # Output component
 
693
  audio_file,
694
  text_prompt,
695
  seed,
696
+ use_random_seed,
697
  num_steps,
698
  guidance_scale,
699
  audio_scale,
 
702
  silence_duration,
703
  resolution
704
  ],
705
+ outputs=[output_video, seed_output]
706
  )
707
 
708
  gr.Markdown("""
 
715
 
716
  # Launch the app
717
  if __name__ == "__main__":
 
 
 
 
 
718
  app.launch(share=True)