SemanticBoost / motion /sample.py
kleinhe
init
c3d0293
from argparse import Namespace
import torch
from motion.dataset.recover_joints import recover_from_ric
from motion.model.cfg_sampler import ClassifierFreeSampleModel
from motion.model_util import create_model_and_diffusion, load_model_wo_clip
import os
import numpy as np
from motion.dataset.recover_smr import *
import json
from motion.double_take import double_take
class Predictor(object):
def __init__(self, **kargs):
self.path = kargs["path"]
self.handshake_size = 20
self.blend_size = 10
args = Namespace()
with open(self.path["config"], 'r') as f:
params1 = json.load(f)
for key, value in params1.items():
setattr(args, key, value)
mode = kargs.get("mode", "cadm")
if mode == "cadm":
args.arch = "refined_decoder"
args.encode_full = 2
args.txt_tokens = 1
args.model_path = self.path["cadm"]
args.rep = "smr"
elif mode == "cadm-augment":
args.arch = "refined_decoder"
args.encode_full = 2
args.txt_tokens = 1
args.model_path = self.path["cadm-augment"]
args.rep = "smr"
elif mode == "mdm":
args.arch = "trans_enc"
args.encode_full = 0
args.txt_tokens = 0
args.model_path = self.path["mdm"]
args.rep = "t2m"
self.skip_steps = kargs.get("skip_steps", 0)
self.device = kargs.get("device", "cpu")
self.args = args
self.rep = args.rep
self.num_frames = args.num_frames
self.condition = kargs.get("condition", "text")
if self.condition == "uncond":
self.args.guidance_param = 0
if self.rep == "t2m":
extension = ""
elif self.rep == "smr":
extension = "_smr"
self.mean = torch.from_numpy(np.load(os.path.join(self.path["dataset_dir"], 'Mean{}.npy'.format(extension)))).to(self.device)
self.std = torch.from_numpy(np.load(os.path.join(self.path["dataset_dir"], 'Std{}.npy'.format(extension)))).to(self.device)
print(f"Loading checkpoints from...")
self.model, self.diffusion = create_model_and_diffusion(args, args.control_signal, self.path)
state_dict = torch.load(self.args.model_path, map_location='cpu')
try:
if self.args.ema:
print("EMA Checkpoints Loading.")
load_model_wo_clip(self.model, state_dict["ema"])
else:
print("Normal Checkpoints Loading.")
load_model_wo_clip(self.model, state_dict["model"])
except:
load_model_wo_clip(self.model, state_dict)
if self.args.guidance_param != 1 and not self.args.unconstrained:
self.model = ClassifierFreeSampleModel(self.model) # wrapping model with the classifier-free sampler
self.model.to(self.device)
self.model.eval() # disable random masking
def predict(self,prompt, num_repetitions=1, path=None):
double_split = prompt.split("|")
if len(double_split) > 1:
print("sample mode - double_take long motion")
sample, step_sizes = double_take(prompt, path, num_repetitions, self.model, self.diffusion, self.handshake_size,
self.blend_size, self.num_frames, self.args.guidance_param, self.device)
sample = sample.permute(0, 2, 3, 1).float()
sample = sample * self.std + self.mean
if self.rep == "t2m":
sample = recover_from_ric(sample, 22)
sample = sample.view(-1, *sample.shape[2:]).permute(0, 2, 3, 1)
elif self.rep == "smr":
sample = sample.permute(0, 2, 3, 1)
else:
nframes = prompt.split(",")[0]
try:
nframes = int(nframes)
prompt = prompt.split(",")[1::]
prompt = ",".join(prompt)
except:
nframes = self.num_frames
model_kwargs = {'y':{'text': str(prompt), 'lengths':nframes}}
if self.args.guidance_param != 1:
model_kwargs['y']['scale'] = torch.ones(num_repetitions, device=self.device) * self.args.guidance_param
sample_fn = self.diffusion.p_sample_loop
sample = sample_fn(
self.model,
(num_repetitions, self.model.njoints, self.model.nfeats, nframes),
clip_denoised=False,
model_kwargs=model_kwargs,
skip_timesteps=self.skip_steps, # 0 is the default value - i.e. don't skip any step
init_image=None,
progress=True,
dump_steps=None,
noise=None,
const_noise=False
)
sample = sample["output"]
sample = sample.permute(0, 2, 3, 1).float()
sample = sample * self.std + self.mean
if self.rep == "t2m":
sample = recover_from_ric(sample, 22)
sample = sample.view(-1, *sample.shape[2:]).permute(0, 2, 3, 1)
elif self.rep == "smr":
sample = sample.permute(0, 2, 3, 1)
all_motions = sample.permute(0, 3, 1, 2)
return all_motions