|
from openai import AzureOpenAI |
|
from models.motion_agent import MotionAgent |
|
from models.mllm import MotionLLM |
|
from options.option_llm import get_args_parser |
|
from utils.motion_utils import recover_from_ric, plot_3d_motion |
|
from utils.paramUtil import t2m_kinematic_chain |
|
import torch |
|
import os |
|
from openai import OpenAI |
|
|
|
|
|
def motion_agent_demo(): |
|
|
|
|
|
|
|
|
|
|
|
|
|
client = OpenAI( |
|
api_key="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") |
|
|
|
args = get_args_parser() |
|
args.save_dir = "./demo" |
|
args.device = 'cuda:1' |
|
|
|
motion_agent = MotionAgent(args, client) |
|
motion_agent.chat() |
|
|
|
def motionllm_demo(): |
|
model = MotionLLM() |
|
model.load_model('ckpt/motionllm.pth') |
|
model.llm.eval() |
|
model.llm.cuda() |
|
|
|
caption = 'A man is doing cartwheels.' |
|
motion = model.generate(caption) |
|
|
|
motion = MotionLLM.denormalize(motion.detach().cpu().numpy()) |
|
motion = recover_from_ric(torch.from_numpy(motion).float().cuda(), 22) |
|
print(motion.shape) |
|
plot_3d_motion(f"motionllm_demo.mp4", t2m_kinematic_chain, motion.squeeze().detach().cpu().numpy(), title=caption, fps=20, radius=4) |
|
|
|
if __name__ == "__main__": |
|
motion_agent_demo() |