""" Copyright (c) Meta Platforms, Inc. and affiliates. All rights reserved. This source code is licensed under the license found in the LICENSE file in the root directory of this source tree. """ import copy import glob import os import re import subprocess from collections import OrderedDict from typing import Dict, List import mediapy import numpy as np import torch import torch as th import torchaudio from attrdict import AttrDict from omegaconf import OmegaConf from tqdm import tqdm from utils.model_util import get_person_num from visualize.ca_body.utils.image import linear2displayBatch from visualize.ca_body.utils.train import load_checkpoint, load_from_config ffmpeg_header = "ffmpeg -y " # -hide_banner -loglevel error " def filter_params(params, ignore_names): return OrderedDict( [ (k, v) for k, v in params.items() if not any([re.match(n, k) is not None for n in ignore_names]) ] ) def call_ffmpeg(command: str) -> None: print(command, "-" * 100) e = subprocess.call(command, shell=True) if e != 0: assert False, e class BodyRenderer(th.nn.Module): def __init__( self, config_base: str, render_rgb: bool, ): super().__init__() self.config_base = config_base ckpt_path = f"{config_base}/body_dec.ckpt" config_path = f"{config_base}/config.yml" assets_path = f"{config_base}/static_assets.pt" # config config = OmegaConf.load(config_path) gpu = config.get("gpu", 0) self.device = th.device(f"cuda:{gpu}") # assets static_assets = AttrDict(torch.load(assets_path)) # build model self.model = load_from_config(config.model, assets=static_assets).to( self.device ) self.model.cal_enabled = False self.model.pixel_cal_enabled = False self.model.learn_blur_enabled = False self.render_rgb = render_rgb if not self.render_rgb: self.model.rendering_enabled = None # load model checkpoints print("loading...", ckpt_path) load_checkpoint( ckpt_path, modules={"model": self.model}, ignore_names={"model": ["lbs_fn.*"]}, ) self.model.eval() self.model.to(self.device) # load default parameters for renderer person = get_person_num(config_path) self.default_inputs = th.load(f"assets/render_defaults_{person}.pth") def _write_video_stream( self, motion: np.ndarray, face: np.ndarray, save_name: str ) -> None: out = self._render_loop(motion, face) mediapy.write_video(save_name, out, fps=30) def _render_loop(self, body_pose: np.ndarray, face: np.ndarray) -> List[np.ndarray]: all_rgb = [] default_inputs_copy = copy.deepcopy(self.default_inputs) for b in tqdm(range(len(body_pose))): B = default_inputs_copy["K"].shape[0] default_inputs_copy["lbs_motion"] = ( th.tensor(body_pose[b : b + 1, :], device=self.device, dtype=th.float) .tile(B, 1) .to(self.device) ) geom = ( self.model.lbs_fn.lbs_fn( default_inputs_copy["lbs_motion"], self.model.lbs_fn.lbs_scale.unsqueeze(0).tile(B, 1), self.model.lbs_fn.lbs_template_verts.unsqueeze(0).tile(B, 1, 1), ) * self.model.lbs_fn.global_scaling ) default_inputs_copy["geom"] = geom face_codes = ( th.from_numpy(face).float().cuda() if not th.is_tensor(face) else face ) curr_face = th.tile(face_codes[b : b + 1, ...], (2, 1)) default_inputs_copy["face_embs"] = curr_face preds = self.model(**default_inputs_copy) rgb0 = linear2displayBatch(preds["rgb"])[0] rgb1 = linear2displayBatch(preds["rgb"])[1] rgb = th.cat((rgb0, rgb1), axis=-1).permute(1, 2, 0) rgb = rgb.clip(0, 255).to(th.uint8) all_rgb.append(rgb.contiguous().detach().byte().cpu().numpy()) return all_rgb def render_full_video( self, data_block: Dict[str, np.ndarray], animation_save_path: str, audio_sr: int = None, render_gt: bool = False, ) -> None: tag = os.path.basename(os.path.dirname(animation_save_path)) save_name = os.path.splitext(os.path.basename(animation_save_path))[0] save_name = f"{tag}_{save_name}" torchaudio.save( f"/tmp/audio_{save_name}.wav", torch.tensor(data_block["audio"]), audio_sr, ) if render_gt: tag = "gt" self._write_video_stream( data_block["gt_body"], data_block["gt_face"], f"/tmp/{tag}_{save_name}.mp4", ) else: tag = "pred" self._write_video_stream( data_block["body_motion"], data_block["face_motion"], f"/tmp/{tag}_{save_name}.mp4", ) command = f"{ffmpeg_header} -i /tmp/{tag}_{save_name}.mp4 -i /tmp/audio_{save_name}.wav -c:v copy -map 0:v:0 -map 1:a:0 -c:a aac -b:a 192k -pix_fmt yuva420p {animation_save_path}_{tag}.mp4" call_ffmpeg(command) subprocess.call( f"rm /tmp/audio_{save_name}.wav && rm /tmp/{tag}_{save_name}.mp4", shell=True, )