Spaces:
Runtime error
Runtime error
""" | |
Copyright (c) Meta Platforms, Inc. and affiliates. | |
All rights reserved. | |
This source code is licensed under the license found in the | |
LICENSE file in the root directory of this source tree. | |
""" | |
import copy | |
import json | |
from typing import Dict, Union | |
import gradio as gr | |
import numpy as np | |
import torch | |
import torchaudio | |
from attrdict import AttrDict | |
from diffusion.respace import SpacedDiffusion | |
from model.cfg_sampler import ClassifierFreeSampleModel | |
from model.diffusion import FiLMTransformer | |
from utils.misc import fixseed | |
from utils.model_util import create_model_and_diffusion, load_model | |
from visualize.render_codes import BodyRenderer | |
class GradioModel: | |
def __init__(self, face_args, pose_args) -> None: | |
self.face_model, self.face_diffusion, self.device = self._setup_model( | |
face_args, "checkpoints/diffusion/c1_face/model000155000.pt" | |
) | |
self.pose_model, self.pose_diffusion, _ = self._setup_model( | |
pose_args, "checkpoints/diffusion/c1_pose/model000340000.pt" | |
) | |
# load standardization stuff | |
stats = torch.load("dataset/PXB184/data_stats.pth") | |
stats["pose_mean"] = stats["pose_mean"].reshape(-1) | |
stats["pose_std"] = stats["pose_std"].reshape(-1) | |
self.stats = stats | |
# set up renderer | |
config_base = f"./checkpoints/ca_body/data/PXB184" | |
self.body_renderer = BodyRenderer( | |
config_base=config_base, | |
render_rgb=True, | |
) | |
def _setup_model( | |
self, | |
args_path: str, | |
model_path: str, | |
) -> (Union[FiLMTransformer, ClassifierFreeSampleModel], SpacedDiffusion): | |
with open(args_path) as f: | |
args = json.load(f) | |
args = AttrDict(args) | |
args.device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
print("running on...", args.device) | |
args.model_path = model_path | |
args.output_dir = "/tmp/gradio/" | |
args.timestep_respacing = "ddim100" | |
if args.data_format == "pose": | |
args.resume_trans = "checkpoints/guide/c1_pose/checkpoints/iter-0100000.pt" | |
## create model | |
model, diffusion = create_model_and_diffusion(args, split_type="test") | |
print(f"Loading checkpoints from [{args.model_path}]...") | |
state_dict = torch.load(args.model_path, map_location=args.device) | |
load_model(model, state_dict) | |
model = ClassifierFreeSampleModel(model) | |
model.eval() | |
model.to(args.device) | |
return model, diffusion, args.device | |
def _replace_keyframes( | |
self, | |
model_kwargs: Dict[str, Dict[str, torch.Tensor]], | |
B: int, | |
T: int, | |
top_p: float = 0.97, | |
) -> torch.Tensor: | |
with torch.no_grad(): | |
tokens = self.pose_model.transformer.generate( | |
model_kwargs["y"]["audio"], | |
T, | |
layers=self.pose_model.tokenizer.residual_depth, | |
n_sequences=B, | |
top_p=top_p, | |
) | |
tokens = tokens.reshape((B, -1, self.pose_model.tokenizer.residual_depth)) | |
pred = self.pose_model.tokenizer.decode(tokens).detach() | |
return pred | |
def _run_single_diffusion( | |
self, | |
model_kwargs: Dict[str, Dict[str, torch.Tensor]], | |
diffusion: SpacedDiffusion, | |
model: Union[FiLMTransformer, ClassifierFreeSampleModel], | |
curr_seq_length: int, | |
num_repetitions: int = 1, | |
) -> (torch.Tensor,): | |
sample_fn = diffusion.ddim_sample_loop | |
with torch.no_grad(): | |
sample = sample_fn( | |
model, | |
(num_repetitions, model.nfeats, 1, curr_seq_length), | |
clip_denoised=False, | |
model_kwargs=model_kwargs, | |
init_image=None, | |
progress=True, | |
dump_steps=None, | |
noise=None, | |
const_noise=False, | |
) | |
return sample | |
def generate_sequences( | |
self, | |
model_kwargs: Dict[str, Dict[str, torch.Tensor]], | |
data_format: str, | |
curr_seq_length: int, | |
num_repetitions: int = 5, | |
guidance_param: float = 10.0, | |
top_p: float = 0.97, | |
# batch_size: int = 1, | |
) -> Dict[str, np.ndarray]: | |
if data_format == "pose": | |
model = self.pose_model | |
diffusion = self.pose_diffusion | |
else: | |
model = self.face_model | |
diffusion = self.face_diffusion | |
all_motions = [] | |
model_kwargs["y"]["scale"] = torch.ones(num_repetitions) * guidance_param | |
model_kwargs["y"] = { | |
key: val.to(self.device) if torch.is_tensor(val) else val | |
for key, val in model_kwargs["y"].items() | |
} | |
if data_format == "pose": | |
model_kwargs["y"]["mask"] = ( | |
torch.ones((num_repetitions, 1, 1, curr_seq_length)) | |
.to(self.device) | |
.bool() | |
) | |
model_kwargs["y"]["keyframes"] = self._replace_keyframes( | |
model_kwargs, | |
num_repetitions, | |
int(curr_seq_length / 30), | |
top_p=top_p, | |
) | |
sample = self._run_single_diffusion( | |
model_kwargs, diffusion, model, curr_seq_length, num_repetitions | |
) | |
all_motions.append(sample.cpu().numpy()) | |
print(f"created {len(all_motions) * num_repetitions} samples") | |
return np.concatenate(all_motions, axis=0) | |
def generate_results(audio: np.ndarray, num_repetitions: int, top_p: float): | |
if audio is None: | |
raise gr.Error("Please record audio to start") | |
sr, y = audio | |
# set to mono and perform resampling | |
y = torch.Tensor(y) | |
if y.dim() == 2: | |
dim = 0 if y.shape[0] == 2 else 1 | |
y = torch.mean(y, dim=dim) | |
y = torchaudio.functional.resample(torch.Tensor(y), orig_freq=sr, new_freq=48_000) | |
sr = 48_000 | |
# make it so that it is 4 seconds long | |
if len(y) < (sr * 4): | |
raise gr.Error("Please record at least 4 second of audio") | |
if num_repetitions is None or num_repetitions <= 0 or num_repetitions > 10: | |
raise gr.Error( | |
f"Invalid number of samples: {num_repetitions}. Please specify a number between 1-10" | |
) | |
cutoff = int(len(y) / (sr * 4)) | |
y = y[: cutoff * sr * 4] | |
curr_seq_length = int(len(y) / sr) * 30 | |
# create model_kwargs | |
model_kwargs = {"y": {}} | |
dual_audio = np.random.normal(0, 0.001, (1, len(y), 2)) | |
dual_audio[:, :, 0] = y / max(y) | |
dual_audio = (dual_audio - gradio_model.stats["audio_mean"]) / gradio_model.stats[ | |
"audio_std_flat" | |
] | |
model_kwargs["y"]["audio"] = ( | |
torch.Tensor(dual_audio).float().tile(num_repetitions, 1, 1) | |
) | |
face_results = ( | |
gradio_model.generate_sequences( | |
model_kwargs, "face", curr_seq_length, num_repetitions=int(num_repetitions) | |
) | |
.squeeze(2) | |
.transpose(0, 2, 1) | |
) | |
face_results = ( | |
face_results * gradio_model.stats["code_std"] + gradio_model.stats["code_mean"] | |
) | |
pose_results = ( | |
gradio_model.generate_sequences( | |
model_kwargs, | |
"pose", | |
curr_seq_length, | |
num_repetitions=int(num_repetitions), | |
guidance_param=2.0, | |
top_p=top_p, | |
) | |
.squeeze(2) | |
.transpose(0, 2, 1) | |
) | |
pose_results = ( | |
pose_results * gradio_model.stats["pose_std"] + gradio_model.stats["pose_mean"] | |
) | |
dual_audio = ( | |
dual_audio * gradio_model.stats["audio_std_flat"] | |
+ gradio_model.stats["audio_mean"] | |
) | |
return face_results, pose_results, dual_audio[0].transpose(1, 0).astype(np.float32) | |
def audio_to_avatar(audio: np.ndarray, num_repetitions: int, top_p: float): | |
face_results, pose_results, audio = generate_results(audio, num_repetitions, top_p) | |
# returns: num_rep x T x 104 | |
B = len(face_results) | |
results = [] | |
for i in range(B): | |
render_data_block = { | |
"audio": audio, # 2 x T | |
"body_motion": pose_results[i, ...], # T x 104 | |
"face_motion": face_results[i, ...], # T x 256 | |
} | |
gradio_model.body_renderer.render_full_video( | |
render_data_block, f"/tmp/sample{i}", audio_sr=48_000 | |
) | |
results += [gr.Video(value=f"/tmp/sample{i}_pred.mp4", visible=True)] | |
results += [gr.Video(visible=False) for _ in range(B, 10)] | |
return results | |
gradio_model = GradioModel( | |
face_args="./checkpoints/diffusion/c1_face/args.json", | |
pose_args="./checkpoints/diffusion/c1_pose/args.json", | |
) | |
demo = gr.Interface( | |
audio_to_avatar, # function | |
[ | |
gr.Audio(sources=["microphone"]), | |
gr.Number( | |
value=3, | |
label="Number of Samples (default = 3)", | |
precision=0, | |
minimum=1, | |
maximum=10, | |
), | |
gr.Number( | |
value=0.97, | |
label="Sample Diversity (default = 0.97)", | |
precision=None, | |
minimum=0.01, | |
step=0.01, | |
maximum=1.00, | |
), | |
], # input type | |
[gr.Video(format="mp4", visible=True)] | |
+ [gr.Video(format="mp4", visible=False) for _ in range(9)], # output type | |
title='"From Audio to Photoreal Embodiment: Synthesizing Humans in Conversations" Demo', | |
description="You can generate a photorealistic avatar from your voice! <br/>\ | |
1) Start by recording your audio. <br/>\ | |
2) Specify the number of samples to generate. <br/>\ | |
3) Specify how diverse you want the samples to be. This tunes the cumulative probability in nucleus sampling: 0.01 = low diversity, 1.0 = high diversity. <br/>\ | |
4) Then, sit back and wait for the rendering to happen! This may take a while (e.g. 30 minutes) <br/>\ | |
5) After, you can view the videos and download the ones you like. <br/>", | |
article="Relevant links: [Project Page](https://people.eecs.berkeley.edu/~evonne_ng/projects/audio2photoreal)", # TODO: code and arxiv | |
) | |
if __name__ == "__main__": | |
fixseed(10) | |
demo.launch(share=True) | |