File size: 22,856 Bytes
7b74407
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
import os, shutil, uuid, cv2, numpy as np, torch, torch.nn as nn, torch.nn.functional as F, yaml, safetensors, librosa, imageio
from PIL import Image
from skimage import img_as_ubyte, transform
from scipy.io import loadmat, wavfile

class SadTalker():
    def __init__(self, checkpoint_path='checkpoints', config_path='src/config', size=256, preprocess='crop', old_version=False):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.cfg = self.get_cfg_defaults()
        self.merge_from_file(os.path.join(config_path, 'sadtalker_config.yaml'))
        self.cfg['MODEL']['CHECKPOINTS_DIR'] = checkpoint_path
        self.cfg['MODEL']['CONFIG_DIR'] = config_path
        self.cfg['MODEL']['DEVICE'] = self.device
        self.cfg['INPUT_IMAGE'] = {}
        self.cfg['INPUT_IMAGE']['SOURCE_IMAGE'] = 'None'
        self.cfg['INPUT_IMAGE']['DRIVEN_AUDIO'] = 'None'
        self.cfg['INPUT_IMAGE']['PREPROCESS'] = preprocess
        self.cfg['INPUT_IMAGE']['SIZE'] = size
        self.cfg['INPUT_IMAGE']['OLD_VERSION'] = old_version
        for filename, url in [(kp_file, kp_url), (aud_file, aud_url), (wav_file, wav_url), (gen_file, gen_url), (mapx_file, mapx_url), (den_file, den_url), ('GFPGANv1.4.pth', GFPGAN_URL), ('RealESRGAN_x2plus.pth', REALESRGAN_URL)]: download_model(url, filename, checkpoint_dir)
        self.sadtalker_model = SadTalkerModel(self.cfg, device_id=[0])

    def get_cfg_defaults(self):
        return {'MODEL': {'CHECKPOINTS_DIR': '', 'CONFIG_DIR': '', 'DEVICE': self.device, 'SCALE': 64, 'NUM_VOXEL_FRAMES': 8, 'NUM_MOTION_FRAMES': 10, 'MAX_FEATURES': 256, 'DRIVEN_AUDIO_SAMPLE_RATE': 16000, 'VIDEO_FPS': 25, 'OUTPUT_VIDEO_FPS': None, 'OUTPUT_AUDIO_SAMPLE_RATE': None, 'USE_ENHANCER': False, 'ENHANCER_NAME': '', 'BG_UPSAMPLER': None, 'IS_HALF': False}, 'INPUT_IMAGE': {}}

    def merge_from_file(self, filepath):
        if os.path.exists(filepath):
            with open(filepath, 'r') as f: cfg_from_file = yaml.safe_load(f); self.cfg.update(cfg_from_file)

    def test(self, source_image, driven_audio, preprocess='crop', still_mode=False, use_enhancer=False, batch_size=1, size=256, pose_style=0, exp_scale=1.0, use_ref_video=False, ref_video=None, ref_info=None, use_idle_mode=False, length_of_audio=0, use_blink=True, result_dir='./results/', tts_text=None, tts_lang='en'):
        self.sadtalker_model.test(source_image, driven_audio, preprocess, still_mode, use_enhancer, batch_size, size, pose_style, exp_scale, use_ref_video, ref_video, ref_info, use_idle_mode, length_of_audio, use_blink, result_dir, tts_text, tts_lang); return self.sadtalker_model.save_result()

class SadTalkerModel():
    def __init__(self, sadtalker_cfg, device_id=[0]):
        self.cfg = sadtalker_cfg; self.device = sadtalker_cfg['MODEL'].get('DEVICE', 'cpu')
        self.sadtalker = SadTalkerInnerModel(sadtalker_cfg, device_id)
        self.preprocesser = self.sadtalker.preprocesser
        self.kp_extractor = self.sadtalker.kp_extractor; self.generator = self.sadtalker.generator
        self.mapping = self.sadtalker.mapping; self.he_estimator = self.sadtalker.he_estimator
        self.audio_to_coeff = self.sadtalker.audio_to_coeff; self.animate_from_coeff = self.sadtalker.animate_from_coeff; self.face_enhancer = self.sadtalker.face_enhancer

    def test(self, source_image, driven_audio, preprocess='crop', still_mode=False, use_enhancer=False, batch_size=1, size=256, pose_style=0, exp_scale=1.0, use_ref_video=False, ref_video=None, ref_info=None, use_idle_mode=False, length_of_audio=0, use_blink=True, result_dir='./results/', tts_text=None, tts_lang='en', jitter_amount=10, jitter_source_image=False):
        self.inner_test = SadTalkerInner(self, source_image, driven_audio, preprocess, still_mode, use_enhancer, batch_size, size, pose_style, exp_scale, use_ref_video, ref_video, ref_info, use_idle_mode, length_of_audio, use_blink, result_dir, tts_text, tts_lang, jitter_amount, jitter_source_image); return self.inner_test.test()

    def save_result(self):
        return self.inner_test.save_result()

class SadTalkerInner():
    def __init__(self, sadtalker_model, source_image, driven_audio, preprocess, still_mode, use_enhancer, batch_size, size, pose_style, exp_scale, use_ref_video, ref_video, ref_info, use_idle_mode, length_of_audio, use_blink, result_dir, tts_text, tts_lang, jitter_amount, jitter_source_image):
        self.sadtalker_model = sadtalker_model; self.source_image = source_image; self.driven_audio = driven_audio
        self.preprocess = preprocess; self.still_mode = still_mode; self.use_enhancer = use_enhancer
        self.batch_size = batch_size; self.size = size; self.pose_style = pose_style; self.exp_scale = exp_scale
        self.use_ref_video = use_ref_video; self.ref_video = ref_video; self.ref_info = ref_info
        self.use_idle_mode = use_idle_mode; self.length_of_audio = length_of_audio; self.use_blink = use_blink
        self.result_dir = result_dir; self.tts_text = tts_text; self.tts_lang = tts_lang
        self.jitter_amount = jitter_amount; self.jitter_source_image = jitter_source_image; self.device = self.sadtalker_model.device; self.output_path = None

    def get_test_data(self):
        proc = self.sadtalker_model.preprocesser
        if self.tts_text is not None: temp_dir = tempfile.mkdtemp(); audio_path = os.path.join(temp_dir, 'audio.wav'); tts = TTSTalker(); tts.test(self.tts_text, self.tts_lang); self.driven_audio = audio_path
        source_image_pil = Image.open(self.source_image).convert('RGB')
        if self.jitter_source_image: jitter_dx = np.random.randint(-self.jitter_amount, self.jitter_amount + 1); jitter_dy = np.random.randint(-self.jitter_amount, self.jitter_amount + 1); source_image_pil = Image.fromarray(np.roll(np.roll(np.array(source_image_pil), jitter_dx, axis=1), jitter_dy, axis=0))
        source_image_tensor, crop_info, cropped_image = proc.crop(source_image_pil, self.preprocess, self.size)
        if self.still_mode or self.use_idle_mode: ref_pose_coeff = proc.generate_still_pose(self.pose_style); ref_expression_coeff = proc.generate_still_expression(self.exp_scale)
        else: ref_pose_coeff = None; ref_expression_coeff = None
        audio_tensor, audio_sample_rate = proc.process_audio(self.driven_audio, self.sadtalker_model.cfg['MODEL']['DRIVEN_AUDIO_SAMPLE_RATE'])
        batch = {'source_image': source_image_tensor.unsqueeze(0).to(self.device), 'audio': audio_tensor.unsqueeze(0).to(self.device), 'ref_pose_coeff': ref_pose_coeff, 'ref_expression_coeff': ref_expression_coeff, 'source_image_crop': cropped_image, 'crop_info': crop_info, 'use_blink': self.use_blink, 'pose_style': self.pose_style, 'exp_scale': self.exp_scale, 'ref_video': self.ref_video, 'use_ref_video': self.use_ref_video, 'ref_info': self.ref_info}
        return batch, audio_sample_rate

    def run_inference(self, batch):
        kp_extractor, generator, mapping, he_estimator, audio_to_coeff, animate_from_coeff, face_enhancer = self.sadtalker_model.kp_extractor, self.sadtalker_model.generator, self.sadtalker_model.mapping, self.sadtalker_model.he_estimator, self.sadtalker_model.audio_to_coeff, self.sadtalker_model.animate_from_coeff, self.sadtalker_model.face_enhancer
        with torch.no_grad():
            kp_source = kp_extractor(batch['source_image'])
            if self.still_mode or self.use_idle_mode: pose_coeff = audio_to_coeff.get_pose_coeff(batch['audio'], batch['ref_pose_coeff']); expression_coeff = audio_to_coeff.get_exp_coeff(batch['audio'], batch['ref_expression_coeff'])
            elif self.use_idle_mode: pose_coeff = audio_to_coeff.get_pose_coeff(batch['audio'], batch['ref_pose_coeff']); expression_coeff = audio_to_coeff.get_exp_coeff(batch['audio'], batch['ref_expression_coeff'])
            else:
                if self.use_ref_video: kp_ref = kp_extractor(batch['source_image']); pose_coeff = audio_to_coeff.get_pose_coeff(batch['audio'], kp_ref=kp_ref, use_ref_info=batch['ref_info'])
                else: pose_coeff = audio_to_coeff.get_pose_coeff(batch['audio'])
                expression_coeff = audio_to_coeff.get_exp_coeff(batch['audio'])
            coeff = {'pose_coeff': pose_coeff, 'expression_coeff': expression_coeff}
            if self.use_blink: coeff['blink_coeff'] = audio_to_coeff.get_blink_coeff(batch['audio'])
            else: coeff['blink_coeff'] = None
            kp_driving = audio_to_coeff(batch['audio'])[0]; kp_norm = animate_from_coeff.normalize_kp(kp_driving); coeff['kp_driving'] = kp_norm; coeff['jacobian'] = [torch.eye(2).unsqueeze(0).unsqueeze(0).to(self.device)] * 4
            output_video = animate_from_coeff.generate(batch['source_image'], kp_source, coeff, generator, mapping, he_estimator, batch['audio'], batch['source_image_crop'], face_enhancer=face_enhancer)
        return output_video

    def post_processing(self, output_video, audio_sample_rate, batch):
        proc = self.sadtalker_model.preprocesser; base_name = os.path.splitext(os.path.basename(batch['source_image_crop']))[0]; audio_name = os.path.splitext(os.path.basename(self.driven_audio))[0]
        output_video_path = os.path.join(self.result_dir, base_name + '_' + audio_name + '.mp4'); self.output_path = output_video_path
        video_fps = self.sadtalker_model.cfg['MODEL']['VIDEO_FPS'] if self.sadtalker_model.cfg['MODEL']['OUTPUT_VIDEO_FPS'] is None else self.sadtalker_model.cfg['MODEL']['OUTPUT_VIDEO_FPS']
        audio_output_sample_rate = self.sadtalker_model.cfg['MODEL']['DRIVEN_AUDIO_SAMPLE_RATE'] if self.sadtalker_model.cfg['MODEL']['OUTPUT_AUDIO_SAMPLE_RATE'] is None else self.sadtalker_model.cfg['MODEL']['OUTPUT_AUDIO_SAMPLE_RATE']
        if self.use_enhancer: enhanced_path = os.path.join(self.result_dir, base_name + '_' + audio_name + '_enhanced.mp4'); save_video_with_watermark(output_video, self.driven_audio, enhanced_path); paste_pic(enhanced_path, batch['source_image_crop'], batch['crop_info'], self.driven_audio, output_video_path); os.remove(enhanced_path)
        else: save_video_with_watermark(output_video, self.driven_audio, output_video_path)
        if self.tts_text is not None: shutil.rmtree(os.path.dirname(self.driven_audio))

    def save_result(self):
        return self.output_path

    def __call__(self):
        return self.output_path

    def test(self):
        batch, audio_sample_rate = self.get_test_data(); output_video = self.run_inference(batch); self.post_processing(output_video, audio_sample_rate, batch); return self.save_result()

class SadTalkerInnerModel():
    def __init__(self, sadtalker_cfg, device_id=[0]):
        self.cfg = sadtalker_cfg; self.device = sadtalker_cfg['MODEL'].get('DEVICE', 'cpu')
        self.sadtalker = SadTalkerInnerModel(sadtalker_cfg, device_id)
        self.preprocesser = Preprocesser(sadtalker_cfg, self.device); self.kp_extractor = KeyPointExtractor(sadtalker_cfg, self.device)
        self.audio_to_coeff = Audio2Coeff(sadtalker_cfg, self.device); self.animate_from_coeff = AnimateFromCoeff(sadtalker_cfg, self.device)
        self.face_enhancer = FaceEnhancer(sadtalker_cfg, self.device) if sadtalker_cfg['MODEL']['USE_ENHANCER'] else None
        self.generator = Generator(sadtalker_cfg, self.device); self.mapping = Mapping(sadtalker_cfg, self.device); self.he_estimator = OcclusionAwareDenseMotion(sadtalker_cfg, self.device)

class Preprocesser():
    def __init__(self, sadtalker_cfg, device):
        self.cfg = sadtalker_cfg; self.device = device
        self.face3d_helper = Face3DHelper(self.cfg['INPUT_IMAGE'].get('LOCAL_PCA_PATH', ''), device); self.mouth_detector = MouthDetector()

    def crop(self, source_image_pil, preprocess_type, size=256):
        source_image = np.array(source_image_pil); face_info = self.face3d_helper.run(source_image)
        if face_info is None: raise Exception("No face detected")
        x_min, y_min, x_max, y_max = face_info[:4]; old_size = (x_max - x_min, y_max - y_min); x_center = (x_max + x_min) / 2; y_center = (y_max + y_min) / 2
        if preprocess_type == 'crop': face_size = max(x_max - x_min, y_max - y_min); x_min = int(x_center - face_size / 2); y_min = int(y_center - face_size / 2); x_max = int(x_center + face_size / 2); y_max = int(y_center + face_size / 2)
        else: x_min -= int((x_max - x_min) * 0.1); y_min -= int((y_max - y_min) * 0.1); x_max += int((x_max - x_min) * 0.1); y_max += int((y_max - y_min) * 0.1)
        h, w = source_image.shape[:2]; x_min = max(0, x_min); y_min = max(0, y_min); x_max = min(w, x_max); y_max = min(h, y_max)
        cropped_image = source_image[y_min:y_max, x_min:x_max]; cropped_image_pil = Image.fromarray(cropped_image)
        if size is not None and size != 0: cropped_image_pil = cropped_image_pil.resize((size, size), Image.Resampling.LANCZOS)
        source_image_tensor = self.img2tensor(cropped_image_pil); return source_image_tensor, [[y_min, y_max], [x_min, x_max], old_size, cropped_image_pil.size], os.path.basename(self.cfg['INPUT_IMAGE'].get('SOURCE_IMAGE', ''))

    def img2tensor(self, img):
        img = np.array(img).astype(np.float32) / 255.0; img = np.transpose(img, (2, 0, 1)); return torch.FloatTensor(img)
    def video_to_tensor(self, video, device): return 0
    def process_audio(self, audio_path, sample_rate): wav = load_wav_util(audio_path, sample_rate); wav_tensor = torch.FloatTensor(wav).unsqueeze(0); return wav_tensor, sample_rate
    def generate_still_pose(self, pose_style): ref_pose_coeff = torch.zeros((1, 64), dtype=torch.float32).to(self.device); ref_pose_coeff[:, :3] = torch.tensor([0, 0, pose_style * 0.3], dtype=torch.float32); return ref_pose_coeff
    def generate_still_expression(self, exp_scale): ref_expression_coeff = torch.zeros((1, 64), dtype=torch.float32).to(self.device); ref_expression_coeff[:, :3] = torch.tensor([0, 0, exp_scale * 0.3], dtype=torch.float32); return ref_expression_coeff
    def generate_idles_pose(self, length_of_audio, pose_style): return 0
    def generate_idles_expression(self, length_of_audio): return 0

class KeyPointExtractor(nn.Module):
    def __init__(self, sadtalker_cfg, device):
        super(KeyPointExtractor, self).__init__(); self.kp_extractor = OcclusionAwareKPDetector(kp_channels=sadtalker_cfg['MODEL']['NUM_MOTION_FRAMES'], num_kp=10, num_dilation_blocks=2, dropout_rate=0.1).to(device)
        checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'kp_detector.safetensors'); load_state_dict_robust(self.kp_extractor, checkpoint_path, device, model_name='kp_detector')
    def forward(self, x): kp = self.kp_extractor(x); return kp

class Audio2Coeff(nn.Module):
    def __init__(self, sadtalker_cfg, device):
        super(Audio2Coeff, self).__init__(); self.audio_model = Wav2Vec2Model().to(device)
        checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'wav2vec2.pth'); load_state_dict_robust(self.audio_model, checkpoint_path, device, model_name='wav2vec2')
        self.pose_mapper = AudioCoeffsPredictor(2048, 64).to(device); self.exp_mapper = AudioCoeffsPredictor(2048, 64).to(device); self.blink_mapper = AudioCoeffsPredictor(2048, 1).to(device)
        mapping_checkpoint = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'auido2pose_00140-model.pth'); load_state_dict_robust(self, mapping_checkpoint, device)
    def get_pose_coeff(self, audio_tensor, ref_pose_coeff=None, kp_ref=None, use_ref_info=''): audio_embedding = self.audio_model(audio_tensor); pose_coeff = self.pose_mapper(audio_embedding)
        if ref_pose_coeff is not None: pose_coeff = ref_pose_coeff
        if kp_ref is not None and use_ref_info == 'pose': ref_pose_6d = kp_ref['value'][:, :6]; pose_coeff[:, :6] = self.mean_std_normalize(ref_pose_6d).mean(dim=1)
        return pose_coeff
    def get_exp_coeff(self, audio_tensor, ref_expression_coeff=None): audio_embedding = self.audio_model(audio_tensor); expression_coeff = self.exp_mapper(audio_embedding)
        if ref_expression_coeff is not None: expression_coeff = ref_expression_coeff; return expression_coeff
    def get_blink_coeff(self, audio_tensor): audio_embedding = self.audio_model(audio_tensor); blink_coeff = self.blink_mapper(audio_embedding); return blink_coeff
    def forward(self, audio): audio_embedding = self.audio_model(audio); pose_coeff, expression_coeff, blink_coeff = self.pose_mapper(audio_embedding), self.exp_mapper(audio_embedding), self.blink_mapper(audio_embedding); return pose_coeff, expression_coeff, blink_coeff
    def mean_std_normalize(self, coeff): mean = coeff.mean(dim=1, keepdim=True); std = coeff.std(dim=1, keepdim=True); return (coeff - mean) / std

class AnimateFromCoeff(nn.Module):
    def __init__(self, sadtalker_cfg, device):
        super(AnimateFromCoeff, self).__init__(); self.generator = Generator(sadtalker_cfg, device); self.mapping = Mapping(sadtalker_cfg, device); self.kp_norm = KeypointNorm(device=device); self.he_estimator = OcclusionAwareDenseMotion(sadtalker_cfg, device)
    def normalize_kp(self, kp_driving): return self.kp_norm(kp_driving)
    def generate(self, source_image, kp_source, coeff, generator, mapping, he_estimator, audio, source_image_crop, face_enhancer=None):
        kp_driving, jacobian, pose_coeff, expression_coeff, blink_coeff = coeff['kp_driving'], coeff['jacobian'], coeff['pose_coeff'], coeff['expression_coeff'], coeff['blink_coeff']
        face_3d = mapping(expression_coeff, pose_coeff, blink_coeff) if blink_coeff is not None else mapping(expression_coeff, pose_coeff); sparse_motion = he_estimator(kp_source, kp_driving, jacobian)
        dense_motion = sparse_motion['dense_motion']; video_deocclusion = generator(source_image, dense_motion, bg_param={'mask': None, 'color': None})
        video_3d = generator(source_image, dense_motion, bg_param={'mask': None, 'color': None}, face_3d_param=face_3d); video_output = video_deocclusion['video_no_reocclusion'] + video_3d['video_3d']
        if face_enhancer is not None: video_output_enhanced = []; for frame in tqdm(video_output, 'Face enhancer running'): pil_image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)); enhanced_image = face_enhancer.forward(np.array(pil_image)); video_output_enhanced.append(cv2.cvtColor(enhanced_image, cv2.COLOR_BGR2RGB)); video_output = video_output_enhanced
        return video_output
    def make_animation(self, video_array): H, W, _ = video_array[0].shape; out = cv2.VideoWriter('./tmp.mp4', cv2.VideoWriter_fourcc(*'mp4v'), 25, (W, H)); for img in video_array: out.write(cv2.cvtColor(img, cv2.COLOR_RGB2BGR)); out.release(); video = imageio.mimread('./tmp.mp4'); os.remove('./tmp.mp4'); return video

class Generator(nn.Module):
    def __init__(self, sadtalker_cfg, device):
        super(Generator, self).__init__(); self.generator = Hourglass(block_expansion=sadtalker_cfg['MODEL']['SCALE'], num_blocks=sadtalker_cfg['MODEL']['NUM_VOXEL_FRAMES'], max_features=sadtalker_cfg['MODEL']['MAX_FEATURES'], num_channels=3, kp_size=10, num_deform_blocks=sadtalker_cfg['MODEL']['NUM_MOTION_FRAMES']).to(device)
        checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'generator.pth'); load_state_dict_robust(self.generator, checkpoint_path, device, model_name='generator')
    def forward(self, source_image, dense_motion, bg_param, face_3d_param=None): video_3d = self.generator(source_image, kp_driving=dense_motion, bg_param=bg_param, face_3d_param=face_3d_param); return {'video_3d': video_3d, 'video_no_reocclusion': video_3d}

class Mapping(nn.Module):
    def __init__(self, sadtalker_cfg, device):
        super(Mapping, self).__init__(); self.mapping_net = MappingNet(num_coeffs=64, num_layers=3, hidden_dim=128).to(device)
        checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'mapping.pth'); load_state_dict_robust(self.mapping_net, checkpoint_path, device, model_name='mapping')
        self.f_3d_mean = torch.zeros(1, 64, device=device)
    def forward(self, expression_coeff, pose_coeff, blink_coeff=None): coeff = torch.cat([expression_coeff, pose_coeff], dim=1); face_3d = self.mapping_net(coeff) + self.f_3d_mean; if blink_coeff is not None: face_3d[:, -1:] = blink_coeff; return face_3d

class OcclusionAwareDenseMotion(nn.Module):
    def __init__(self, sadtalker_cfg, device):
        super(OcclusionAwareDenseMotion, self).__init__(); self.dense_motion_network = DenseMotionNetwork(num_kp=10, num_channels=3, block_expansion=sadtalker_cfg['MODEL']['SCALE'], num_blocks=sadtalker_cfg['MODEL']['NUM_MOTION_FRAMES'] - 1, max_features=sadtalker_cfg['MODEL']['MAX_FEATURES']).to(device)
        checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'dense_motion.pth'); load_state_dict_robust(self.dense_motion_network, checkpoint_path, device, model_name='dense_motion')
    def forward(self, kp_source, kp_driving, jacobian): sparse_motion = self.dense_motion_network(kp_source, kp_driving, jacobian); return sparse_motion

class FaceEnhancer(nn.Module):
    def __init__(self, sadtalker_cfg, device):
        super(FaceEnhancer, self).__init__(); enhancer_name = sadtalker_cfg['MODEL']['ENHANCER_NAME']; bg_upsampler = sadtalker_cfg['MODEL']['BG_UPSAMPLER']
        if enhancer_name == 'gfpgan': from gfpgan import GFPGANer; self.face_enhancer = GFPGANer(model_path=os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'GFPGANv1.4.pth'), upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=bg_upsampler)
        elif enhancer_name == 'realesrgan': from realesrgan import RealESRGANer; half = False if device == 'cpu' else sadtalker_cfg['MODEL']['IS_HALF']; self.face_enhancer = RealESRGANer(scale=2, model_path=os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'RealESRGAN_x2plus.pth'), tile=0, tile_pad=10, pre_pad=0, half=half, device=device)
        else: self.face_enhancer = None
    def forward(self, x): return self.face_enhancer.enhance(x, outscale=1)[0] if self.face_enhancer else x

def download_model(url, filename, checkpoint_dir):
    if not os.path.exists(os.path.join(checkpoint_dir, filename)): print(f"Downloading {filename}..."); os.makedirs(checkpoint_dir, exist_ok=True); urllib.request.urlretrieve(url, os.path.join(checkpoint_dir, filename)); print(f"{filename} downloaded.")
    else: print(f"{filename} already exists.")

def load_models():
    checkpoint_path = './checkpoints'; config_path = './src/config'; size = 256; preprocess = 'crop'; old_version = False
    sadtalker_instance = SadTalker(checkpoint_path, config_path, size, preprocess, old_version); print("SadTalker models loaded successfully!"); return sadtalker_instance

if __name__ == '__main__': sadtalker_instance = load_models()