File size: 5,384 Bytes
c3d0293
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
from argparse import Namespace
import torch
from motion.dataset.recover_joints import recover_from_ric
from motion.model.cfg_sampler import ClassifierFreeSampleModel
from motion.model_util import create_model_and_diffusion, load_model_wo_clip
import os
import numpy as np
from motion.dataset.recover_smr import *
import json
from motion.double_take import double_take

class Predictor(object):
    def __init__(self, **kargs):
        self.path = kargs["path"]
        self.handshake_size = 20
        self.blend_size = 10

        args = Namespace()
        with open(self.path["config"], 'r') as f:
            params1 = json.load(f)
        for key, value in params1.items():
            setattr(args, key, value)


        mode = kargs.get("mode", "cadm")
        if mode == "cadm":
            args.arch = "refined_decoder"
            args.encode_full = 2
            args.txt_tokens = 1
            args.model_path = self.path["cadm"]
            args.rep = "smr"
        elif mode == "cadm-augment":
            args.arch = "refined_decoder"
            args.encode_full = 2
            args.txt_tokens = 1
            args.model_path = self.path["cadm-augment"]
            args.rep = "smr"
        elif mode == "mdm":
            args.arch = "trans_enc"
            args.encode_full = 0
            args.txt_tokens = 0
            args.model_path = self.path["mdm"]
            args.rep = "t2m"

        self.skip_steps = kargs.get("skip_steps", 0)
        self.device = kargs.get("device", "cpu")
        self.args = args
        self.rep = args.rep
        self.num_frames = args.num_frames
        self.condition = kargs.get("condition", "text")
        if self.condition == "uncond":
            self.args.guidance_param = 0

        if self.rep == "t2m":
            extension = ""
        elif self.rep == "smr":
            extension = "_smr"

        self.mean = torch.from_numpy(np.load(os.path.join(self.path["dataset_dir"], 'Mean{}.npy'.format(extension)))).to(self.device)
        self.std = torch.from_numpy(np.load(os.path.join(self.path["dataset_dir"], 'Std{}.npy'.format(extension)))).to(self.device)

        print(f"Loading checkpoints from...")
        self.model, self.diffusion = create_model_and_diffusion(args, args.control_signal, self.path)
        state_dict = torch.load(self.args.model_path, map_location='cpu')
        try:
            if self.args.ema:
                print("EMA Checkpoints Loading.")
                load_model_wo_clip(self.model, state_dict["ema"])
            else:
                print("Normal Checkpoints Loading.")
                load_model_wo_clip(self.model, state_dict["model"])
        except:
            load_model_wo_clip(self.model, state_dict)

        if self.args.guidance_param != 1 and not self.args.unconstrained:
            self.model = ClassifierFreeSampleModel(self.model)   # wrapping model with the classifier-free sampler
        self.model.to(self.device)
        self.model.eval()  # disable random masking

    def predict(self,prompt, num_repetitions=1, path=None):
        double_split = prompt.split("|")
        if len(double_split) > 1:
            print("sample mode - double_take long motion")
            sample, step_sizes = double_take(prompt, path, num_repetitions, self.model, self.diffusion, self.handshake_size, 
                                    self.blend_size, self.num_frames, self.args.guidance_param, self.device)
            
            sample = sample.permute(0, 2, 3, 1).float() 
            sample = sample * self.std + self.mean   
            if self.rep == "t2m":
                sample = recover_from_ric(sample, 22)    
                sample = sample.view(-1, *sample.shape[2:]).permute(0, 2, 3, 1) 
            elif self.rep == "smr":
                sample = sample.permute(0, 2, 3, 1)       
        else:
            nframes = prompt.split(",")[0]
            try:
                nframes = int(nframes)
                prompt = prompt.split(",")[1::]
                prompt = ",".join(prompt)
            except:
                nframes = self.num_frames
            
            model_kwargs = {'y':{'text': str(prompt), 'lengths':nframes}}
            if self.args.guidance_param != 1:
                model_kwargs['y']['scale'] = torch.ones(num_repetitions, device=self.device) * self.args.guidance_param

            sample_fn = self.diffusion.p_sample_loop
            sample = sample_fn(
                self.model,
                (num_repetitions, self.model.njoints, self.model.nfeats, nframes),
                clip_denoised=False,
                model_kwargs=model_kwargs,
                skip_timesteps=self.skip_steps,  # 0 is the default value - i.e. don't skip any step
                init_image=None,
                progress=True,
                dump_steps=None,
                noise=None,
                const_noise=False
            )
            sample = sample["output"]
            sample = sample.permute(0, 2, 3, 1).float() 
            sample = sample * self.std + self.mean            

            if self.rep == "t2m":
                sample = recover_from_ric(sample, 22)    
                sample = sample.view(-1, *sample.shape[2:]).permute(0, 2, 3, 1) 
            elif self.rep == "smr":
                sample = sample.permute(0, 2, 3, 1) 

        all_motions = sample.permute(0, 3, 1, 2)
        return all_motions