Spaces:
Runtime error
Runtime error
""" | |
Copyright (c) Meta Platforms, Inc. and affiliates. | |
All rights reserved. | |
This source code is licensed under the license found in the | |
LICENSE file in the root directory of this source tree. | |
""" | |
import torch | |
from diffusion import gaussian_diffusion as gd | |
from diffusion.respace import space_timesteps, SpacedDiffusion | |
from model.diffusion import FiLMTransformer | |
from torch.nn import functional as F | |
def get_person_num(config_path): | |
if "PXB184" in config_path: | |
person = "PXB184" | |
elif "RLW104" in config_path: | |
person = "RLW104" | |
elif "TXB805" in config_path: | |
person = "TXB805" | |
elif "GQS883" in config_path: | |
person = "GQS883" | |
else: | |
assert False, f"something wrong with config: {config_path}" | |
return person | |
def load_model(model, state_dict): | |
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) | |
assert len(unexpected_keys) == 0, unexpected_keys | |
assert all( | |
[ | |
k.startswith("transformer.") or k.startswith("tokenizer.") | |
for k in missing_keys | |
] | |
), missing_keys | |
def create_model_and_diffusion(args, split_type): | |
model = FiLMTransformer(**get_model_args(args, split_type=split_type)).to( | |
torch.float32 | |
) | |
diffusion = create_gaussian_diffusion(args) | |
return model, diffusion | |
def get_model_args(args, split_type): | |
if args.data_format == "face": | |
nfeat = 256 | |
lfeat = 512 | |
elif args.data_format == "pose": | |
nfeat = 104 | |
lfeat = 256 | |
if not hasattr(args, "num_audio_layers"): | |
args.num_audio_layers = 3 # backwards compat | |
model_args = { | |
"args": args, | |
"nfeats": nfeat, | |
"latent_dim": lfeat, | |
"ff_size": 1024, | |
"num_layers": args.layers, | |
"num_heads": args.heads, | |
"dropout": 0.1, | |
"cond_feature_dim": 512 * 2, | |
"activation": F.gelu, | |
"use_rotary": not args.not_rotary, | |
"cond_mode": "uncond" if args.unconstrained else "audio", | |
"split_type": split_type, | |
"num_audio_layers": args.num_audio_layers, | |
"device": args.device, | |
} | |
return model_args | |
def create_gaussian_diffusion(args): | |
predict_xstart = True | |
steps = 1000 | |
scale_beta = 1.0 | |
timestep_respacing = args.timestep_respacing | |
learn_sigma = False | |
rescale_timesteps = False | |
betas = gd.get_named_beta_schedule(args.noise_schedule, steps, scale_beta) | |
loss_type = gd.LossType.MSE | |
if not timestep_respacing: | |
timestep_respacing = [steps] | |
name = args.save_dir if hasattr(args, "save_dir") else args.model_path | |
return SpacedDiffusion( | |
use_timesteps=space_timesteps(steps, timestep_respacing), | |
betas=betas, | |
model_mean_type=( | |
gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X | |
), | |
model_var_type=( | |
( | |
gd.ModelVarType.FIXED_LARGE | |
if not args.sigma_small | |
else gd.ModelVarType.FIXED_SMALL | |
) | |
if not learn_sigma | |
else gd.ModelVarType.LEARNED_RANGE | |
), | |
data_format=args.data_format, | |
loss_type=loss_type, | |
rescale_timesteps=rescale_timesteps, | |
lambda_vel=args.lambda_vel, | |
model_path=name, | |
) | |