Delik commited on
Commit
5f3dff7
·
verified ·
1 Parent(s): 668dbc7

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +403 -0
app.py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import argparse
3
+ from datetime import datetime
4
+ from pathlib import Path
5
+ import numpy as np
6
+ import torch
7
+ from PIL import Image
8
+ import gradio as gr
9
+ import shutil
10
+ import librosa
11
+ import python_speech_features
12
+ import time
13
+ from LIA_Model import LIA_Model
14
+ import os
15
+ from tqdm import tqdm
16
+ import argparse
17
+ import numpy as np
18
+ from torchvision import transforms
19
+ from templates import *
20
+ import argparse
21
+ import shutil
22
+ from moviepy.editor import *
23
+ import librosa
24
+ import python_speech_features
25
+ import importlib.util
26
+ import time
27
+ import os
28
+ import time
29
+ import numpy as np
30
+
31
+
32
+
33
+ # Disable Gradio analytics to avoid network-related issues
34
+ gr.analytics_enabled = False
35
+
36
+
37
+ def check_package_installed(package_name):
38
+ package_spec = importlib.util.find_spec(package_name)
39
+ if package_spec is None:
40
+ print(f"{package_name} is not installed.")
41
+ return False
42
+ else:
43
+ print(f"{package_name} is installed.")
44
+ return True
45
+
46
+ def frames_to_video(input_path, audio_path, output_path, fps=25):
47
+ image_files = [os.path.join(input_path, img) for img in sorted(os.listdir(input_path))]
48
+ clips = [ImageClip(m).set_duration(1/fps) for m in image_files]
49
+ video = concatenate_videoclips(clips, method="compose")
50
+
51
+ audio = AudioFileClip(audio_path)
52
+ final_video = video.set_audio(audio)
53
+ final_video.write_videofile(output_path, fps=fps, codec='libx264', audio_codec='aac')
54
+
55
+ def load_image(filename, size):
56
+ img = Image.open(filename).convert('RGB')
57
+ img = img.resize((size, size))
58
+ img = np.asarray(img)
59
+ img = np.transpose(img, (2, 0, 1)) # 3 x 256 x 256
60
+ return img / 255.0
61
+
62
+ def img_preprocessing(img_path, size):
63
+ img = load_image(img_path, size) # [0, 1]
64
+ img = torch.from_numpy(img).unsqueeze(0).float() # [0, 1]
65
+ imgs_norm = (img - 0.5) * 2.0 # [-1, 1]
66
+ return imgs_norm
67
+
68
+ def saved_image(img_tensor, img_path):
69
+ toPIL = transforms.ToPILImage()
70
+ img = toPIL(img_tensor.detach().cpu().squeeze(0)) # 使用squeeze(0)来移除批次维度
71
+ img.save(img_path)
72
+
73
+ def main(args):
74
+ frames_result_saved_path = os.path.join(args.result_path, 'frames')
75
+ os.makedirs(frames_result_saved_path, exist_ok=True)
76
+ test_image_name = os.path.splitext(os.path.basename(args.test_image_path))[0]
77
+ audio_name = os.path.splitext(os.path.basename(args.test_audio_path))[0]
78
+ predicted_video_256_path = os.path.join(args.result_path, f'{test_image_name}-{audio_name}.mp4')
79
+ predicted_video_512_path = os.path.join(args.result_path, f'{test_image_name}-{audio_name}_SR.mp4')
80
+
81
+ #======Loading Stage 1 model=========
82
+ lia = LIA_Model(motion_dim=args.motion_dim, fusion_type='weighted_sum')
83
+ lia.load_lightning_model(args.stage1_checkpoint_path)
84
+ lia.to(args.device)
85
+ #============================
86
+
87
+ conf = ffhq256_autoenc()
88
+ conf.seed = args.seed
89
+ conf.decoder_layers = args.decoder_layers
90
+ conf.infer_type = args.infer_type
91
+ conf.motion_dim = args.motion_dim
92
+
93
+ if args.infer_type == 'mfcc_full_control':
94
+ conf.face_location=True
95
+ conf.face_scale=True
96
+ conf.mfcc = True
97
+ elif args.infer_type == 'mfcc_pose_only':
98
+ conf.face_location=False
99
+ conf.face_scale=False
100
+ conf.mfcc = True
101
+ elif args.infer_type == 'hubert_pose_only':
102
+ conf.face_location=False
103
+ conf.face_scale=False
104
+ conf.mfcc = False
105
+ elif args.infer_type == 'hubert_audio_only':
106
+ conf.face_location=False
107
+ conf.face_scale=False
108
+ conf.mfcc = False
109
+ elif args.infer_type == 'hubert_full_control':
110
+ conf.face_location=True
111
+ conf.face_scale=True
112
+ conf.mfcc = False
113
+ else:
114
+ print('Type NOT Found!')
115
+ exit(0)
116
+
117
+ if not os.path.exists(args.test_image_path):
118
+ print(f'{args.test_image_path} does not exist!')
119
+ exit(0)
120
+
121
+ if not os.path.exists(args.test_audio_path):
122
+ print(f'{args.test_audio_path} does not exist!')
123
+ exit(0)
124
+
125
+ img_source = img_preprocessing(args.test_image_path, args.image_size).to(args.device)
126
+ one_shot_lia_start, one_shot_lia_direction, feats = lia.get_start_direction_code(img_source, img_source, img_source, img_source)
127
+
128
+ #======Loading Stage 2 model=========
129
+ model = LitModel(conf)
130
+ state = torch.load(args.stage2_checkpoint_path, map_location='cpu')
131
+ model.load_state_dict(state, strict=True)
132
+ model.ema_model.eval()
133
+ model.ema_model.to(args.device)
134
+ #=================================
135
+
136
+ #======Audio Input=========
137
+ if conf.infer_type.startswith('mfcc'):
138
+ # MFCC features
139
+ wav, sr = librosa.load(args.test_audio_path, sr=16000)
140
+ input_values = python_speech_features.mfcc(signal=wav, samplerate=sr, numcep=13, winlen=0.025, winstep=0.01)
141
+ d_mfcc_feat = python_speech_features.base.delta(input_values, 1)
142
+ d_mfcc_feat2 = python_speech_features.base.delta(input_values, 2)
143
+ audio_driven_obj = np.hstack((input_values, d_mfcc_feat, d_mfcc_feat2))
144
+ frame_start, frame_end = 0, int(audio_driven_obj.shape[0]/4)
145
+ audio_start, audio_end = int(frame_start * 4), int(frame_end * 4) # The video frame is fixed to 25 hz and the audio is fixed to 100 hz
146
+
147
+ audio_driven = torch.Tensor(audio_driven_obj[audio_start:audio_end,:]).unsqueeze(0).float().to(args.device)
148
+
149
+ elif conf.infer_type.startswith('hubert'):
150
+ # Hubert features
151
+ if not os.path.exists(args.test_hubert_path):
152
+
153
+ if not check_package_installed('transformers'):
154
+ print('Please install transformers module first.')
155
+ exit(0)
156
+ hubert_model_path = './ckpts/chinese-hubert-large'
157
+ if not os.path.exists(hubert_model_path):
158
+ print('Please download the hubert weight into the ckpts path first.')
159
+ exit(0)
160
+ print('You did not extract the audio features in advance, extracting online now, which will increase processing delay')
161
+
162
+ start_time = time.time()
163
+
164
+ # load hubert model
165
+ from transformers import Wav2Vec2FeatureExtractor, HubertModel
166
+ audio_model = HubertModel.from_pretrained(hubert_model_path).to(args.device)
167
+ feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(hubert_model_path)
168
+ audio_model.feature_extractor._freeze_parameters()
169
+ audio_model.eval()
170
+
171
+ # hubert model forward pass
172
+ audio, sr = librosa.load(args.test_audio_path, sr=16000)
173
+ input_values = feature_extractor(audio, sampling_rate=16000, padding=True, do_normalize=True, return_tensors="pt").input_values
174
+ input_values = input_values.to(args.device)
175
+ ws_feats = []
176
+ with torch.no_grad():
177
+ outputs = audio_model(input_values, output_hidden_states=True)
178
+ for i in range(len(outputs.hidden_states)):
179
+ ws_feats.append(outputs.hidden_states[i].detach().cpu().numpy())
180
+ ws_feat_obj = np.array(ws_feats)
181
+ ws_feat_obj = np.squeeze(ws_feat_obj, 1)
182
+ ws_feat_obj = np.pad(ws_feat_obj, ((0, 0), (0, 1), (0, 0)), 'edge') # align the audio length with video frame
183
+
184
+ execution_time = time.time() - start_time
185
+ print(f"Extraction Audio Feature: {execution_time:.2f} Seconds")
186
+
187
+ audio_driven_obj = ws_feat_obj
188
+ else:
189
+ print(f'Using audio feature from path: {args.test_hubert_path}')
190
+ audio_driven_obj = np.load(args.test_hubert_path)
191
+
192
+ frame_start, frame_end = 0, int(audio_driven_obj.shape[1]/2)
193
+ audio_start, audio_end = int(frame_start * 2), int(frame_end * 2) # The video frame is fixed to 25 hz and the audio is fixed to 50 hz
194
+
195
+ audio_driven = torch.Tensor(audio_driven_obj[:,audio_start:audio_end,:]).unsqueeze(0).float().to(args.device)
196
+ #============================
197
+
198
+ # Diffusion Noise
199
+ noisyT = torch.randn((1,frame_end, args.motion_dim)).to(args.device)
200
+
201
+ #======Inputs for Attribute Control=========
202
+ if os.path.exists(args.pose_driven_path):
203
+ pose_obj = np.load(args.pose_driven_path)
204
+
205
+ if len(pose_obj.shape) != 2:
206
+ print('please check your pose information. The shape must be like (T, 3).')
207
+ exit(0)
208
+ if pose_obj.shape[1] != 3:
209
+ print('please check your pose information. The shape must be like (T, 3).')
210
+ exit(0)
211
+
212
+ if pose_obj.shape[0] >= frame_end:
213
+ pose_obj = pose_obj[:frame_end,:]
214
+ else:
215
+ padding = np.tile(pose_obj[-1, :], (frame_end - pose_obj.shape[0], 1))
216
+ pose_obj = np.vstack((pose_obj, padding))
217
+
218
+ pose_signal = torch.Tensor(pose_obj).unsqueeze(0).to(args.device) / 90 # 90 is for normalization here
219
+ else:
220
+ yaw_signal = torch.zeros(1, frame_end, 1).to(args.device) + args.pose_yaw
221
+ pitch_signal = torch.zeros(1, frame_end, 1).to(args.device) + args.pose_pitch
222
+ roll_signal = torch.zeros(1, frame_end, 1).to(args.device) + args.pose_roll
223
+ pose_signal = torch.cat((yaw_signal, pitch_signal, roll_signal), dim=-1)
224
+
225
+ pose_signal = torch.clamp(pose_signal, -1, 1)
226
+
227
+ face_location_signal = torch.zeros(1, frame_end, 1).to(args.device) + args.face_location
228
+ face_scae_signal = torch.zeros(1, frame_end, 1).to(args.device) + args.face_scale
229
+ #===========================================
230
+
231
+ start_time = time.time()
232
+
233
+ #======Diffusion Denosing Process=========
234
+ generated_directions = model.render(one_shot_lia_start, one_shot_lia_direction, audio_driven, face_location_signal, face_scae_signal, pose_signal, noisyT, args.step_T, control_flag=args.control_flag)
235
+ #=========================================
236
+
237
+ execution_time = time.time() - start_time
238
+ print(f"Motion Diffusion Model: {execution_time:.2f} Seconds")
239
+
240
+ generated_directions = generated_directions.detach().cpu().numpy()
241
+
242
+ start_time = time.time()
243
+ #======Rendering images frame-by-frame=========
244
+ for pred_index in tqdm(range(generated_directions.shape[1])):
245
+ ori_img_recon = lia.render(one_shot_lia_start, torch.Tensor(generated_directions[:,pred_index,:]).to(args.device), feats)
246
+ ori_img_recon = ori_img_recon.clamp(-1, 1)
247
+ wav_pred = (ori_img_recon.detach() + 1) / 2
248
+ saved_image(wav_pred, os.path.join(frames_result_saved_path, "%06d.png"%(pred_index)))
249
+ #==============================================
250
+
251
+ execution_time = time.time() - start_time
252
+ print(f"Renderer Model: {execution_time:.2f} Seconds")
253
+
254
+ frames_to_video(frames_result_saved_path, args.test_audio_path, predicted_video_256_path)
255
+
256
+ shutil.rmtree(frames_result_saved_path)
257
+
258
+ # Enhancer
259
+ if args.face_sr and check_package_installed('gfpgan'):
260
+ from face_sr.face_enhancer import enhancer_list
261
+ import imageio
262
+
263
+ # Super-resolution
264
+ imageio.mimsave(predicted_video_512_path+'.tmp.mp4', enhancer_list(predicted_video_256_path, method='gfpgan', bg_upsampler=None), fps=float(25))
265
+
266
+ # Merge audio and video
267
+ video_clip = VideoFileClip(predicted_video_512_path+'.tmp.mp4')
268
+ audio_clip = AudioFileClip(predicted_video_256_path)
269
+ final_clip = video_clip.set_audio(audio_clip)
270
+ final_clip.write_videofile(predicted_video_512_path, codec='libx264', audio_codec='aac')
271
+
272
+ os.remove(predicted_video_512_path+'.tmp.mp4')
273
+
274
+ if args.face_sr:
275
+ return predicted_video_256_path, predicted_video_512_path
276
+ else:
277
+ return predicted_video_256_path, predicted_video_256_path
278
+
279
+ def generate_video(uploaded_img, uploaded_audio, infer_type,
280
+ pose_yaw, pose_pitch, pose_roll, face_location, face_scale, step_T, device, face_sr, seed):
281
+ if uploaded_img is None or uploaded_audio is None:
282
+ return None, gr.Markdown("Error: Input image or audio file is empty. Please check and upload both files.")
283
+
284
+ model_mapping = {
285
+ "mfcc_pose_only": "./ckpts/stage2_pose_only_mfcc.ckpt",
286
+ "mfcc_full_control": "./ckpts/stage2_more_controllable_mfcc.ckpt",
287
+ "hubert_audio_only": "./ckpts/stage2_audio_only_hubert.ckpt",
288
+ "hubert_pose_only": "./ckpts/stage2_pose_only_hubert.ckpt",
289
+ "hubert_full_control": "./ckpts/stage2_full_control_hubert.ckpt",
290
+ }
291
+
292
+ # if face_crop:
293
+ # uploaded_img_path = Path(uploaded_img)
294
+ # cropped_img_path = uploaded_img_path.with_name(uploaded_img_path.stem + "_crop" + uploaded_img_path.suffix)
295
+ # crop_image(uploaded_img, cropped_img_path)
296
+ # uploaded_img = str(cropped_img_path)
297
+
298
+ # import pdb;pdb.set_trace()
299
+
300
+ stage2_checkpoint_path = model_mapping.get(infer_type, "default_checkpoint.ckpt")
301
+ try:
302
+ args = argparse.Namespace(
303
+ infer_type=infer_type,
304
+ test_image_path=uploaded_img,
305
+ test_audio_path=uploaded_audio,
306
+ test_hubert_path='',
307
+ result_path='./outputs/',
308
+ stage1_checkpoint_path='./ckpts/stage1.ckpt',
309
+ stage2_checkpoint_path=stage2_checkpoint_path,
310
+ seed=seed,
311
+ control_flag=True,
312
+ pose_yaw=pose_yaw,
313
+ pose_pitch=pose_pitch,
314
+ pose_roll=pose_roll,
315
+ face_location=face_location,
316
+ pose_driven_path='not_supported_in_this_mode',
317
+ face_scale=face_scale,
318
+ step_T=step_T,
319
+ image_size=256,
320
+ device=device,
321
+ motion_dim=20,
322
+ decoder_layers=2,
323
+ face_sr=face_sr
324
+ )
325
+
326
+ # Save the uploaded audio to the expected path
327
+ # shutil.copy(uploaded_audio, args.test_audio_path)
328
+
329
+ # Run the main function
330
+ output_256_video_path, output_512_video_path = main(args)
331
+
332
+ # Check if the output video file exists
333
+ if not os.path.exists(output_256_video_path):
334
+ return None, gr.Markdown("Error: Video generation failed. Please check your inputs and try again.")
335
+ if output_256_video_path == output_512_video_path:
336
+ return gr.Video(value=output_256_video_path), None, gr.Markdown("Video (256*256 only) generated successfully!")
337
+ return gr.Video(value=output_256_video_path), gr.Video(value=output_512_video_path), gr.Markdown("Video generated successfully!")
338
+
339
+ except Exception as e:
340
+ return None, None, gr.Markdown(f"Error: An unexpected error occurred - {str(e)}")
341
+
342
+ default_values = {
343
+ "pose_yaw": 0,
344
+ "pose_pitch": 0,
345
+ "pose_roll": 0,
346
+ "face_location": 0.5,
347
+ "face_scale": 0.5,
348
+ "step_T": 50,
349
+ "seed": 0,
350
+ "device": "cuda"
351
+ }
352
+
353
+ with gr.Blocks() as demo:
354
+ gr.Markdown('# AniTalker')
355
+ gr.Markdown('![]()')
356
+ with gr.Row():
357
+ with gr.Column():
358
+ uploaded_img = gr.Image(type="filepath", label="Reference Image")
359
+ uploaded_audio = gr.Audio(type="filepath", label="Input Audio")
360
+ with gr.Column():
361
+ output_video_256 = gr.Video(label="Generated Video (256)")
362
+ output_video_512 = gr.Video(label="Generated Video (512)")
363
+ output_message = gr.Markdown()
364
+
365
+
366
+
367
+ generate_button = gr.Button("Generate Video")
368
+
369
+ with gr.Accordion("Configuration", open=True):
370
+ infer_type = gr.Dropdown(
371
+ label="Inference Type",
372
+ choices=['mfcc_pose_only', 'mfcc_full_control', 'hubert_audio_only', 'hubert_pose_only'],
373
+ value='hubert_audio_only'
374
+ )
375
+ face_sr = gr.Checkbox(label="Enable Face Super-Resolution (512*512)", value=False)
376
+ # face_crop = gr.Checkbox(label="Face Crop (Dlib)", value=False)
377
+ # face_crop = False # TODO
378
+ seed = gr.Number(label="Seed", value=default_values["seed"])
379
+ pose_yaw = gr.Slider(label="pose_yaw", minimum=-1, maximum=1, value=default_values["pose_yaw"])
380
+ pose_pitch = gr.Slider(label="pose_pitch", minimum=-1, maximum=1, value=default_values["pose_pitch"])
381
+ pose_roll = gr.Slider(label="pose_roll", minimum=-1, maximum=1, value=default_values["pose_roll"])
382
+ face_location = gr.Slider(label="face_location", minimum=0, maximum=1, value=default_values["face_location"])
383
+ face_scale = gr.Slider(label="face_scale", minimum=0, maximum=1, value=default_values["face_scale"])
384
+ step_T = gr.Slider(label="step_T", minimum=1, maximum=100, step=1, value=default_values["step_T"])
385
+ device = gr.Radio(label="Device", choices=["cuda", "cpu"], value=default_values["device"])
386
+
387
+
388
+ generate_button.click(
389
+ generate_video,
390
+ inputs=[
391
+ uploaded_img, uploaded_audio, infer_type,
392
+ pose_yaw, pose_pitch, pose_roll, face_location, face_scale, step_T, device, face_sr, seed
393
+ ],
394
+ outputs=[output_video_256, output_video_512, output_message]
395
+ )
396
+
397
+ if __name__ == '__main__':
398
+ parser = argparse.ArgumentParser(description='EchoMimic')
399
+ parser.add_argument('--server_name', type=str, default='0.0.0.0', help='Server name')
400
+ parser.add_argument('--server_port', type=int, default=3001, help='Server port')
401
+ args = parser.parse_args()
402
+
403
+ demo.launch()