Upload option_llm.py with huggingface_hub
Browse files- option_llm.py +73 -0
option_llm.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
|
3 |
+
def get_args_parser():
|
4 |
+
parser = argparse.ArgumentParser(description='Optimal Transport AutoEncoder training for AIST',
|
5 |
+
add_help=True,
|
6 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
7 |
+
|
8 |
+
## device
|
9 |
+
parser.add_argument('--device', type=str, default='cuda:0', help='device')
|
10 |
+
|
11 |
+
## LLM
|
12 |
+
parser.add_argument('--llm-backbone', type=str, default='/home/zbz5349/WorkSpace/aigeeks/Motion-Agent/gemma-2-2b', help='name of huggingface model backbone')
|
13 |
+
parser.add_argument('--lora-r-t2m', type=int, default=64, help='lora_r for t2m')
|
14 |
+
parser.add_argument('--lora-alpha-t2m', type=int, default=64, help='lora_alpha for t2m')
|
15 |
+
parser.add_argument('--lora-r-m2t', type=int, default=32, help='lora_r for m2t')
|
16 |
+
parser.add_argument('--lora-alpha-m2t', type=int, default=32, help='lora_alpha for m2t')
|
17 |
+
parser.add_argument('--lora-dropout', type=float, default=0.1, help='lora_dropout')
|
18 |
+
|
19 |
+
## dataloader
|
20 |
+
parser.add_argument('--dataname', type=str, default='kit', help='dataset directory')
|
21 |
+
parser.add_argument('--batch-size', default=128, type=int, help='batch size')
|
22 |
+
parser.add_argument('--window-size', type=int, default=64, help='training motion length')
|
23 |
+
|
24 |
+
## optimization
|
25 |
+
parser.add_argument('--total-iter', default=200000, type=int, help='number of total iterations to run')
|
26 |
+
parser.add_argument('--warm-up-iter', default=1000, type=int, help='number of total iterations for warmup')
|
27 |
+
parser.add_argument('--lr', default=2e-4, type=float, help='max learning rate')
|
28 |
+
parser.add_argument('--lr-scheduler', default=[50000, 400000], nargs="+", type=int, help="learning rate schedule (iterations)")
|
29 |
+
parser.add_argument('--gamma', default=0.05, type=float, help="learning rate decay")
|
30 |
+
|
31 |
+
parser.add_argument('--weight-decay', default=0.01, type=float, help='weight decay')
|
32 |
+
parser.add_argument("--commit", type=float, default=0.02, help="hyper-parameter for the commitment loss")
|
33 |
+
parser.add_argument('--loss-vel', type=float, default=0.1, help='hyper-parameter for the velocity loss')
|
34 |
+
parser.add_argument('--recons-loss', type=str, default='l2', help='reconstruction loss')
|
35 |
+
|
36 |
+
## vqvae arch
|
37 |
+
parser.add_argument("--code-dim", type=int, default=512, help="embedding dimension")
|
38 |
+
parser.add_argument("--nb-code", type=int, default=512, help="nb of embedding")
|
39 |
+
parser.add_argument("--mu", type=float, default=0.99, help="exponential moving average to update the codebook")
|
40 |
+
parser.add_argument("--down-t", type=int, default=2, help="downsampling rate")
|
41 |
+
parser.add_argument("--stride-t", type=int, default=2, help="stride size")
|
42 |
+
parser.add_argument("--width", type=int, default=512, help="width of the network")
|
43 |
+
parser.add_argument("--depth", type=int, default=3, help="depth of the network")
|
44 |
+
parser.add_argument("--dilation-growth-rate", type=int, default=3, help="dilation growth rate")
|
45 |
+
parser.add_argument("--output-emb-width", type=int, default=512, help="output embedding width")
|
46 |
+
parser.add_argument('--vq-act', type=str, default='relu', choices = ['relu', 'silu', 'gelu'], help='dataset directory')
|
47 |
+
parser.add_argument('--vq-norm', type=str, default=None, help='dataset directory')
|
48 |
+
|
49 |
+
## quantizer
|
50 |
+
parser.add_argument("--quantizer", type=str, default='ema_reset', choices = ['ema', 'orig', 'ema_reset', 'reset'], help="eps for optimal transport")
|
51 |
+
parser.add_argument('--beta', type=float, default=1.0, help='commitment loss in standard VQ')
|
52 |
+
|
53 |
+
## resume
|
54 |
+
parser.add_argument("--resume-pth", type=str, default=None, help='resume pth for VQ')
|
55 |
+
parser.add_argument("--resume-gpt", type=str, default=None, help='resume pth for GPT')
|
56 |
+
|
57 |
+
|
58 |
+
## output directory
|
59 |
+
parser.add_argument('--out-dir', type=str, default='experiments', help='output directory')
|
60 |
+
parser.add_argument('--results-dir', type=str, default='visual_results/', help='output directory')
|
61 |
+
parser.add_argument('--visual-name', type=str, default='baseline', help='output directory')
|
62 |
+
parser.add_argument('--exp-name', type=str, default='exp_debug', help='name of the experiment, will create a file inside out-dir')
|
63 |
+
|
64 |
+
## other
|
65 |
+
parser.add_argument('--print-iter', default=200, type=int, help='print frequency')
|
66 |
+
parser.add_argument('--eval-iter', default=1000, type=int, help='evaluation frequency')
|
67 |
+
parser.add_argument('--seed', default=123, type=int, help='seed for initializing training.')
|
68 |
+
|
69 |
+
parser.add_argument('--vis-gt', action='store_true', help='whether visualize GT motions')
|
70 |
+
parser.add_argument('--nb-vis', default=20, type=int, help='nb of visualizations')
|
71 |
+
|
72 |
+
|
73 |
+
return parser.parse_args()
|