test_virtual / utils /vq_parser_utils.py
whyun13's picture
Upload folder using huggingface_hub
882f6e2 verified
"""
Copyright (c) Meta Platforms, Inc. and affiliates.
All rights reserved.
This source code is licensed under the license found in the
LICENSE file in the root directory of this source tree.
"""
import argparse
def _add_dataset_args(parser):
parser.add_argument("--dataname", type=str, default="kit", help="dataset directory")
parser.add_argument("--data_root", type=str, default=None, help="dataset directory")
parser.add_argument("--max_seq_length", default=600, type=int)
parser.add_argument("--add_frame_cond", type=float, choices=[1], default=None)
parser.add_argument(
"--data_format", type=str, default="pose", choices=["pose", "face"]
)
parser.add_argument("--dataset", default="social", type=str)
parser.add_argument("--batch_size", default=64, type=int, help="batch size")
def _add_optim_args(parser):
parser.add_argument(
"--total_iter",
default=300_000,
type=int,
help="number of total iterations to run",
)
parser.add_argument(
"--warm_up_iter",
default=1000,
type=int,
help="number of total iterations for warmup",
)
parser.add_argument("--lr", default=2e-4, type=float, help="max learning rate")
parser.add_argument(
"--lr_scheduler",
default=[300_000],
nargs="+",
type=int,
help="learning rate schedule (iterations)",
)
parser.add_argument("--gamma", default=0.05, type=float, help="learning rate decay")
parser.add_argument("--weight_decay", default=0.0, type=float, help="weight decay")
parser.add_argument(
"--commit",
type=float,
default=0.02,
help="hyper-parameter for the commitment loss",
)
parser.add_argument(
"--loss_vel",
type=float,
default=0.1,
help="hyper-parameter for the velocity loss",
)
def _add_model_args(parser):
parser.add_argument("--code_dim", type=int, default=512, help="embedding dimension")
parser.add_argument("--depth", type=int, default=3, help="depth of the network")
parser.add_argument(
"--output_emb_width", type=int, default=512, help="output embedding width"
)
parser.add_argument(
"--resume_pth", type=str, default=None, help="resume pth for VQ"
)
def train_args():
parser = argparse.ArgumentParser(
description="Optimal Transport AutoEncoder training for AIST",
add_help=True,
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
_add_dataset_args(parser)
_add_optim_args(parser)
_add_model_args(parser)
## output directory
parser.add_argument("--out_dir", type=str, required=True, help="output directory")
## other
parser.add_argument("--print_iter", default=200, type=int, help="print frequency")
parser.add_argument(
"--eval_iter", default=1000, type=int, help="evaluation frequency"
)
parser.add_argument(
"--seed", default=123, type=int, help="seed for initializing training."
)
args = parser.parse_args()
return args