File size: 3,326 Bytes
882f6e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
"""
Copyright (c) Meta Platforms, Inc. and affiliates.
All rights reserved.
This source code is licensed under the license found in the
LICENSE file in the root directory of this source tree.
"""

import torch

from diffusion import gaussian_diffusion as gd
from diffusion.respace import space_timesteps, SpacedDiffusion
from model.diffusion import FiLMTransformer
from torch.nn import functional as F


def get_person_num(config_path):
    if "PXB184" in config_path:
        person = "PXB184"
    elif "RLW104" in config_path:
        person = "RLW104"
    elif "TXB805" in config_path:
        person = "TXB805"
    elif "GQS883" in config_path:
        person = "GQS883"
    else:
        assert False, f"something wrong with config: {config_path}"
    return person


def load_model(model, state_dict):
    missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
    assert len(unexpected_keys) == 0, unexpected_keys
    assert all(
        [
            k.startswith("transformer.") or k.startswith("tokenizer.")
            for k in missing_keys
        ]
    ), missing_keys


def create_model_and_diffusion(args, split_type):
    model = FiLMTransformer(**get_model_args(args, split_type=split_type)).to(
        torch.float32
    )
    diffusion = create_gaussian_diffusion(args)
    return model, diffusion


def get_model_args(args, split_type):
    if args.data_format == "face":
        nfeat = 256
        lfeat = 512
    elif args.data_format == "pose":
        nfeat = 104
        lfeat = 256

    if not hasattr(args, "num_audio_layers"):
        args.num_audio_layers = 3  # backwards compat

    model_args = {
        "args": args,
        "nfeats": nfeat,
        "latent_dim": lfeat,
        "ff_size": 1024,
        "num_layers": args.layers,
        "num_heads": args.heads,
        "dropout": 0.1,
        "cond_feature_dim": 512 * 2,
        "activation": F.gelu,
        "use_rotary": not args.not_rotary,
        "cond_mode": "uncond" if args.unconstrained else "audio",
        "split_type": split_type,
        "num_audio_layers": args.num_audio_layers,
        "device": args.device,
    }
    return model_args


def create_gaussian_diffusion(args):
    predict_xstart = True
    steps = 1000
    scale_beta = 1.0
    timestep_respacing = args.timestep_respacing
    learn_sigma = False
    rescale_timesteps = False

    betas = gd.get_named_beta_schedule(args.noise_schedule, steps, scale_beta)
    loss_type = gd.LossType.MSE

    if not timestep_respacing:
        timestep_respacing = [steps]

    name = args.save_dir if hasattr(args, "save_dir") else args.model_path
    return SpacedDiffusion(
        use_timesteps=space_timesteps(steps, timestep_respacing),
        betas=betas,
        model_mean_type=(
            gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
        ),
        model_var_type=(
            (
                gd.ModelVarType.FIXED_LARGE
                if not args.sigma_small
                else gd.ModelVarType.FIXED_SMALL
            )
            if not learn_sigma
            else gd.ModelVarType.LEARNED_RANGE
        ),
        data_format=args.data_format,
        loss_type=loss_type,
        rescale_timesteps=rescale_timesteps,
        lambda_vel=args.lambda_vel,
        model_path=name,
    )