File size: 5,734 Bytes
34ea900 386545a 34ea900 af4b8d1 34ea900 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
from models.mllm import MotionLLM
import json
import torch
from utils.motion_utils import recover_from_ric, plot_3d_motion
from utils.paramUtil import t2m_kinematic_chain
import time
import os
import numpy as np
class MotionAgent:
def __init__(self, args, client):
self.args = args
self.device = args.device
self.model = MotionLLM(self.args)
self.model.load_model('ckpt/motionllm.pth')
self.model.eval()
self.model.to(self.device)
self.client = client
self.save_dir = args.save_dir
if not os.path.exists(self.save_dir):
os.makedirs(self.save_dir)
# self.if_replace_each_turn = args.if_replace_each_turn
self.context = []
self.motion_history = {}
print("Loading example prompt from example_prompt.txt, feel free to use your own prompt")
prompt = open("example_prompt.txt", "r").read() # loading the example prompt, feel free to use your own prompt
self.context.append({"role": "system", "content": prompt})
def process_motion_dialogue(self, message):
# if the message contains 'npy', it means the user wants to reason on a motion
if 'npy' in message:
motion_file = message.split(' ')[-1]
# print(motion_file)
assert motion_file.endswith('.npy'), "The file must be a npy file and should be the last word of your message"
message = message.replace(motion_file, '<motion_file>') # replace the motion file with a placeholder
motion_input = np.load(motion_file)
# print(message)
# Update context with the new message
self.context.append({"role": "user", "content": message})
# Create chat completion request
response = self.client.chat.completions.create(
model="gpt-4o-mini",
messages=self.context
)
# Extract and store the assistant's response
assistant_response = response.choices[0].message.content
# print(assistant_response)
self.context.append({"role": "assistant", "content": assistant_response})
# parse the assistant's response to get the plan
try:
plan = json.loads(assistant_response)["plan"]
except:
plan = None
try:
reasoning = json.loads(assistant_response)["reasoning"]
except:
reasoning = None
# if the plan is not None, it means the user wants to generate a motion or reason on a motion
if plan is not None:
if "generate" in plan:
motion_tokens_to_generate = [] # list of motion tokens to generate
descriptions = plan.split(";")
for description in descriptions:
description = description.strip()
if description:
description = description.split("MotionLLM.generate('")[1].rstrip("');")
# print(description)
if description not in self.motion_history:
motion_tokens = self.model.generate(description)
self.motion_history[description] = motion_tokens
motion_tokens_to_generate.append(motion_tokens)
else:
motion_tokens_to_generate.append(self.motion_history[description])
# print(self.motion_history)
motion_tokens = torch.cat(motion_tokens_to_generate)
motion = self.model.net.forward_decoder(motion_tokens)
motion = self.model.denormalize(motion.detach().cpu().numpy())
motion = recover_from_ric(torch.from_numpy(motion).float().to(self.device), 22)
filename = f"{self.save_dir}/motion_{int(time.time())}.gif"
print('Plotting motion...')
plot_3d_motion(filename, t2m_kinematic_chain, motion.squeeze().detach().cpu().numpy(), title=message, fps=20, radius=4)
np.save(f"{self.save_dir}/motion_{int(time.time())}.npy", motion.squeeze().detach().cpu().numpy())
print(f"Motion saved to {filename}")
elif 'caption' in plan:
caption = self.model.caption(motion_input)
# caption = 'A person is walking.' # TODO: remove this
new_message = f"MotionLLM: '{caption}'"
self.process_motion_dialogue(new_message)
else:
raise ValueError(f"Invalid format of the assistant's response: {assistant_response}")
# if the reasoning is not None, it means the model is reasoning on a motion
elif reasoning is not None:
print(reasoning)
else:
raise ValueError(f"Invalid format of the assistant's response: {assistant_response}")
def clean(self):
self.context = []
self.motion_history = {}
print("Cleaned up the context and motion history")
def chat(self):
print("Welcome to Motion-Agent! Type 'exit' to quit.")
print("Generate a motion: directly type your prompt.")
print("Reason on a motion: type the file name of the npy motion file you want to reason on after your question.")
print("Clean the context and motion history: type 'clean'.")
while True:
message = input("User: ")
if message == "exit":
break
if message == "clean":
self.clean()
continue
try:
self.process_motion_dialogue(message)
except Exception as e:
print(f"Error: {e}")
|