zejunyang
commited on
Commit
·
e4de730
1
Parent(s):
fab87df
debug
Browse files- src/create_modules.py +40 -35
src/create_modules.py
CHANGED
|
@@ -33,14 +33,11 @@ from src.utils.crop_face_single import crop_face
|
|
| 33 |
|
| 34 |
class Processer():
|
| 35 |
def __init__(self):
|
| 36 |
-
self.create_models()
|
| 37 |
-
|
| 38 |
@spaces.GPU
|
| 39 |
def create_models(self):
|
| 40 |
|
| 41 |
-
self.lmk_extractor = LMKExtractor()
|
| 42 |
-
self.vis = FaceMeshVisualizer(forehead_edge=False)
|
| 43 |
-
|
| 44 |
config = OmegaConf.load('./configs/prompts/animation_audio.yaml')
|
| 45 |
|
| 46 |
if config.weight_dtype == "fp16":
|
|
@@ -50,64 +47,69 @@ class Processer():
|
|
| 50 |
|
| 51 |
audio_infer_config = OmegaConf.load(config.audio_inference_config)
|
| 52 |
# prepare model
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
|
| 57 |
-
|
| 58 |
config.pretrained_vae_path,
|
| 59 |
).to("cuda", dtype=weight_dtype)
|
| 60 |
|
| 61 |
-
|
| 62 |
config.pretrained_base_model_path,
|
| 63 |
subfolder="unet",
|
| 64 |
).to(dtype=weight_dtype, device="cuda")
|
| 65 |
|
| 66 |
inference_config_path = config.inference_config
|
| 67 |
infer_config = OmegaConf.load(inference_config_path)
|
| 68 |
-
|
| 69 |
config.pretrained_base_model_path,
|
| 70 |
config.motion_module_path,
|
| 71 |
subfolder="unet",
|
| 72 |
unet_additional_kwargs=infer_config.unet_additional_kwargs,
|
| 73 |
).to(dtype=weight_dtype, device="cuda")
|
| 74 |
|
| 75 |
-
|
| 76 |
|
| 77 |
-
|
| 78 |
config.image_encoder_path
|
| 79 |
).to(dtype=weight_dtype, device="cuda")
|
| 80 |
|
| 81 |
sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs)
|
| 82 |
-
|
| 83 |
|
| 84 |
# load pretrained weights
|
| 85 |
-
|
| 86 |
torch.load(config.denoising_unet_path, map_location="cpu"),
|
| 87 |
strict=False,
|
| 88 |
)
|
| 89 |
-
|
| 90 |
torch.load(config.reference_unet_path, map_location="cpu"),
|
| 91 |
)
|
| 92 |
-
|
| 93 |
torch.load(config.pose_guider_path, map_location="cpu"),
|
| 94 |
)
|
| 95 |
|
| 96 |
-
|
| 97 |
-
vae=
|
| 98 |
-
image_encoder=
|
| 99 |
-
reference_unet=
|
| 100 |
-
denoising_unet=
|
| 101 |
-
pose_guider=
|
| 102 |
-
scheduler=
|
| 103 |
)
|
| 104 |
-
|
|
|
|
|
|
|
| 105 |
|
| 106 |
|
| 107 |
@spaces.GPU
|
| 108 |
def audio2video(self, input_audio, ref_img, headpose_video=None, size=512, steps=25, length=150, seed=42):
|
| 109 |
fps = 30
|
| 110 |
cfg = 3.5
|
|
|
|
|
|
|
|
|
|
| 111 |
|
| 112 |
config = OmegaConf.load('./configs/prompts/animation_audio.yaml')
|
| 113 |
audio_infer_config = OmegaConf.load(config.audio_inference_config)
|
|
@@ -123,19 +125,19 @@ class Processer():
|
|
| 123 |
save_dir.mkdir(exist_ok=True, parents=True)
|
| 124 |
|
| 125 |
ref_image_np = cv2.cvtColor(ref_img, cv2.COLOR_RGB2BGR)
|
| 126 |
-
ref_image_np = crop_face(ref_image_np,
|
| 127 |
if ref_image_np is None:
|
| 128 |
return None, Image.fromarray(ref_img)
|
| 129 |
|
| 130 |
ref_image_np = cv2.resize(ref_image_np, (size, size))
|
| 131 |
ref_image_pil = Image.fromarray(cv2.cvtColor(ref_image_np, cv2.COLOR_BGR2RGB))
|
| 132 |
|
| 133 |
-
face_result =
|
| 134 |
if face_result is None:
|
| 135 |
return None, ref_image_pil
|
| 136 |
|
| 137 |
lmks = face_result['lmks'].astype(np.float32)
|
| 138 |
-
ref_pose =
|
| 139 |
|
| 140 |
sample = prepare_audio_feature(input_audio, wav2vec_model_path=audio_infer_config['a2m_model']['model_path'])
|
| 141 |
sample['audio_feature'] = torch.from_numpy(sample['audio_feature']).float().cuda()
|
|
@@ -148,7 +150,7 @@ class Processer():
|
|
| 148 |
pred = pred + face_result['lmks3d']
|
| 149 |
|
| 150 |
if headpose_video is not None:
|
| 151 |
-
pose_seq = get_headpose_temp(headpose_video,
|
| 152 |
else:
|
| 153 |
pose_seq = np.load(config['pose_temp'])
|
| 154 |
mirrored_pose_seq = np.concatenate((pose_seq, pose_seq[-2:0:-1]), axis=0)
|
|
@@ -159,7 +161,7 @@ class Processer():
|
|
| 159 |
|
| 160 |
pose_images = []
|
| 161 |
for i, verts in enumerate(projected_vertices):
|
| 162 |
-
lmk_img =
|
| 163 |
pose_images.append(lmk_img)
|
| 164 |
|
| 165 |
pose_list = []
|
|
@@ -210,6 +212,9 @@ class Processer():
|
|
| 210 |
@spaces.GPU
|
| 211 |
def video2video(self, ref_img, source_video, size=512, steps=25, length=150, seed=42):
|
| 212 |
cfg = 3.5
|
|
|
|
|
|
|
|
|
|
| 213 |
|
| 214 |
generator = torch.manual_seed(seed)
|
| 215 |
width, height = size, size
|
|
@@ -222,19 +227,19 @@ class Processer():
|
|
| 222 |
save_dir.mkdir(exist_ok=True, parents=True)
|
| 223 |
|
| 224 |
ref_image_np = cv2.cvtColor(ref_img, cv2.COLOR_RGB2BGR)
|
| 225 |
-
ref_image_np = crop_face(ref_image_np,
|
| 226 |
if ref_image_np is None:
|
| 227 |
return None, Image.fromarray(ref_img)
|
| 228 |
|
| 229 |
ref_image_np = cv2.resize(ref_image_np, (size, size))
|
| 230 |
ref_image_pil = Image.fromarray(cv2.cvtColor(ref_image_np, cv2.COLOR_BGR2RGB))
|
| 231 |
|
| 232 |
-
face_result =
|
| 233 |
if face_result is None:
|
| 234 |
return None, ref_image_pil
|
| 235 |
|
| 236 |
lmks = face_result['lmks'].astype(np.float32)
|
| 237 |
-
ref_pose =
|
| 238 |
|
| 239 |
source_images = read_frames(source_video)
|
| 240 |
src_fps = get_fps(source_video)
|
|
@@ -257,7 +262,7 @@ class Processer():
|
|
| 257 |
src_tensor_list.append(pose_transform(src_image_pil))
|
| 258 |
src_img_np = cv2.cvtColor(np.array(src_image_pil), cv2.COLOR_RGB2BGR)
|
| 259 |
frame_height, frame_width, _ = src_img_np.shape
|
| 260 |
-
src_img_result =
|
| 261 |
if src_img_result is None:
|
| 262 |
break
|
| 263 |
pose_trans_list.append(src_img_result['trans_mat'])
|
|
@@ -291,7 +296,7 @@ class Processer():
|
|
| 291 |
|
| 292 |
pose_list = []
|
| 293 |
for i, verts in enumerate(projected_vertices):
|
| 294 |
-
lmk_img =
|
| 295 |
pose_image_np = cv2.resize(lmk_img, (width, height))
|
| 296 |
pose_list.append(pose_image_np)
|
| 297 |
|
|
|
|
| 33 |
|
| 34 |
class Processer():
|
| 35 |
def __init__(self):
|
| 36 |
+
self.a2m_model, self.pipe = self.create_models()
|
| 37 |
+
|
| 38 |
@spaces.GPU
|
| 39 |
def create_models(self):
|
| 40 |
|
|
|
|
|
|
|
|
|
|
| 41 |
config = OmegaConf.load('./configs/prompts/animation_audio.yaml')
|
| 42 |
|
| 43 |
if config.weight_dtype == "fp16":
|
|
|
|
| 47 |
|
| 48 |
audio_infer_config = OmegaConf.load(config.audio_inference_config)
|
| 49 |
# prepare model
|
| 50 |
+
a2m_model = Audio2MeshModel(audio_infer_config['a2m_model'])
|
| 51 |
+
a2m_model.load_state_dict(torch.load(audio_infer_config['pretrained_model']['a2m_ckpt'], map_location="cpu"), strict=False)
|
| 52 |
+
a2m_model.to("cuda").eval()
|
| 53 |
|
| 54 |
+
vae = AutoencoderKL.from_pretrained(
|
| 55 |
config.pretrained_vae_path,
|
| 56 |
).to("cuda", dtype=weight_dtype)
|
| 57 |
|
| 58 |
+
reference_unet = UNet2DConditionModel.from_pretrained(
|
| 59 |
config.pretrained_base_model_path,
|
| 60 |
subfolder="unet",
|
| 61 |
).to(dtype=weight_dtype, device="cuda")
|
| 62 |
|
| 63 |
inference_config_path = config.inference_config
|
| 64 |
infer_config = OmegaConf.load(inference_config_path)
|
| 65 |
+
denoising_unet = UNet3DConditionModel.from_pretrained_2d(
|
| 66 |
config.pretrained_base_model_path,
|
| 67 |
config.motion_module_path,
|
| 68 |
subfolder="unet",
|
| 69 |
unet_additional_kwargs=infer_config.unet_additional_kwargs,
|
| 70 |
).to(dtype=weight_dtype, device="cuda")
|
| 71 |
|
| 72 |
+
pose_guider = PoseGuider(noise_latent_channels=320, use_ca=True).to(device="cuda", dtype=weight_dtype) # not use cross attention
|
| 73 |
|
| 74 |
+
image_enc = CLIPVisionModelWithProjection.from_pretrained(
|
| 75 |
config.image_encoder_path
|
| 76 |
).to(dtype=weight_dtype, device="cuda")
|
| 77 |
|
| 78 |
sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs)
|
| 79 |
+
scheduler = DDIMScheduler(**sched_kwargs)
|
| 80 |
|
| 81 |
# load pretrained weights
|
| 82 |
+
denoising_unet.load_state_dict(
|
| 83 |
torch.load(config.denoising_unet_path, map_location="cpu"),
|
| 84 |
strict=False,
|
| 85 |
)
|
| 86 |
+
reference_unet.load_state_dict(
|
| 87 |
torch.load(config.reference_unet_path, map_location="cpu"),
|
| 88 |
)
|
| 89 |
+
pose_guider.load_state_dict(
|
| 90 |
torch.load(config.pose_guider_path, map_location="cpu"),
|
| 91 |
)
|
| 92 |
|
| 93 |
+
pipe = Pose2VideoPipeline(
|
| 94 |
+
vae=vae,
|
| 95 |
+
image_encoder=image_enc,
|
| 96 |
+
reference_unet=reference_unet,
|
| 97 |
+
denoising_unet=denoising_unet,
|
| 98 |
+
pose_guider=pose_guider,
|
| 99 |
+
scheduler=scheduler,
|
| 100 |
)
|
| 101 |
+
pipe = pipe.to("cuda", dtype=weight_dtype)
|
| 102 |
+
|
| 103 |
+
return a2m_model, pipe
|
| 104 |
|
| 105 |
|
| 106 |
@spaces.GPU
|
| 107 |
def audio2video(self, input_audio, ref_img, headpose_video=None, size=512, steps=25, length=150, seed=42):
|
| 108 |
fps = 30
|
| 109 |
cfg = 3.5
|
| 110 |
+
|
| 111 |
+
lmk_extractor = LMKExtractor()
|
| 112 |
+
vis = FaceMeshVisualizer()
|
| 113 |
|
| 114 |
config = OmegaConf.load('./configs/prompts/animation_audio.yaml')
|
| 115 |
audio_infer_config = OmegaConf.load(config.audio_inference_config)
|
|
|
|
| 125 |
save_dir.mkdir(exist_ok=True, parents=True)
|
| 126 |
|
| 127 |
ref_image_np = cv2.cvtColor(ref_img, cv2.COLOR_RGB2BGR)
|
| 128 |
+
ref_image_np = crop_face(ref_image_np, lmk_extractor)
|
| 129 |
if ref_image_np is None:
|
| 130 |
return None, Image.fromarray(ref_img)
|
| 131 |
|
| 132 |
ref_image_np = cv2.resize(ref_image_np, (size, size))
|
| 133 |
ref_image_pil = Image.fromarray(cv2.cvtColor(ref_image_np, cv2.COLOR_BGR2RGB))
|
| 134 |
|
| 135 |
+
face_result = lmk_extractor(ref_image_np)
|
| 136 |
if face_result is None:
|
| 137 |
return None, ref_image_pil
|
| 138 |
|
| 139 |
lmks = face_result['lmks'].astype(np.float32)
|
| 140 |
+
ref_pose = vis.draw_landmarks((ref_image_np.shape[1], ref_image_np.shape[0]), lmks, normed=True)
|
| 141 |
|
| 142 |
sample = prepare_audio_feature(input_audio, wav2vec_model_path=audio_infer_config['a2m_model']['model_path'])
|
| 143 |
sample['audio_feature'] = torch.from_numpy(sample['audio_feature']).float().cuda()
|
|
|
|
| 150 |
pred = pred + face_result['lmks3d']
|
| 151 |
|
| 152 |
if headpose_video is not None:
|
| 153 |
+
pose_seq = get_headpose_temp(headpose_video, lmk_extractor)
|
| 154 |
else:
|
| 155 |
pose_seq = np.load(config['pose_temp'])
|
| 156 |
mirrored_pose_seq = np.concatenate((pose_seq, pose_seq[-2:0:-1]), axis=0)
|
|
|
|
| 161 |
|
| 162 |
pose_images = []
|
| 163 |
for i, verts in enumerate(projected_vertices):
|
| 164 |
+
lmk_img = vis.draw_landmarks((width, height), verts, normed=False)
|
| 165 |
pose_images.append(lmk_img)
|
| 166 |
|
| 167 |
pose_list = []
|
|
|
|
| 212 |
@spaces.GPU
|
| 213 |
def video2video(self, ref_img, source_video, size=512, steps=25, length=150, seed=42):
|
| 214 |
cfg = 3.5
|
| 215 |
+
|
| 216 |
+
lmk_extractor = LMKExtractor()
|
| 217 |
+
vis = FaceMeshVisualizer()
|
| 218 |
|
| 219 |
generator = torch.manual_seed(seed)
|
| 220 |
width, height = size, size
|
|
|
|
| 227 |
save_dir.mkdir(exist_ok=True, parents=True)
|
| 228 |
|
| 229 |
ref_image_np = cv2.cvtColor(ref_img, cv2.COLOR_RGB2BGR)
|
| 230 |
+
ref_image_np = crop_face(ref_image_np, lmk_extractor)
|
| 231 |
if ref_image_np is None:
|
| 232 |
return None, Image.fromarray(ref_img)
|
| 233 |
|
| 234 |
ref_image_np = cv2.resize(ref_image_np, (size, size))
|
| 235 |
ref_image_pil = Image.fromarray(cv2.cvtColor(ref_image_np, cv2.COLOR_BGR2RGB))
|
| 236 |
|
| 237 |
+
face_result = lmk_extractor(ref_image_np)
|
| 238 |
if face_result is None:
|
| 239 |
return None, ref_image_pil
|
| 240 |
|
| 241 |
lmks = face_result['lmks'].astype(np.float32)
|
| 242 |
+
ref_pose = vis.draw_landmarks((ref_image_np.shape[1], ref_image_np.shape[0]), lmks, normed=True)
|
| 243 |
|
| 244 |
source_images = read_frames(source_video)
|
| 245 |
src_fps = get_fps(source_video)
|
|
|
|
| 262 |
src_tensor_list.append(pose_transform(src_image_pil))
|
| 263 |
src_img_np = cv2.cvtColor(np.array(src_image_pil), cv2.COLOR_RGB2BGR)
|
| 264 |
frame_height, frame_width, _ = src_img_np.shape
|
| 265 |
+
src_img_result = lmk_extractor(src_img_np)
|
| 266 |
if src_img_result is None:
|
| 267 |
break
|
| 268 |
pose_trans_list.append(src_img_result['trans_mat'])
|
|
|
|
| 296 |
|
| 297 |
pose_list = []
|
| 298 |
for i, verts in enumerate(projected_vertices):
|
| 299 |
+
lmk_img = vis.draw_landmarks((frame_width, frame_height), verts, normed=False)
|
| 300 |
pose_image_np = cv2.resize(lmk_img, (width, height))
|
| 301 |
pose_list.append(pose_image_np)
|
| 302 |
|