TRME / app.py
rsax's picture
Update app.py
02c7eaf verified
raw
history blame
2.92 kB
import gradio as gr
import clip
import torch
import numpy as np
import models.vqvae as vqvae
import models.t2m_trans as trans
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from mpl_toolkits.mplot3d import Axes3D
import sys
import os
from matplotlib.animation import FuncAnimation
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
import warnings
warnings.filterwarnings('ignore')
device = "cuda" if torch.cuda.is_available() else "cpu"
vqvae_model = vqvae.VQVAE().to(device)
transformer_model = trans.Text2Motion_Transformer().to(device)
vqvae_model.load_state_dict(torch.load("output/VQVAE_imp_resnet_100k_hml3d/net_last.pth", map_location=device))
transformer_model.load_state_dict(torch.load("output/net_best_fid.pth", map_location=device))
vqvae_model.eval()
transformer_model.eval()
mean = torch.from_numpy(np.load('output/Mean.npy', allow_pickle=True)).to(device)
std = torch.from_numpy(np.load('output/Std.npy', allow_pickle=True)).to(device)
def generate_motion(text, vqvae_model, transformer_model):
clip_text = [text]
text_encoded = clip.tokenize(clip_text, truncate=True).to(device)
with torch.no_grad():
motion_indices = transformer_model.sample(text_encoded, False)
pred_pose = vqvae_model.forward_decoder(motion_indices)
pred_xyz = recover_from_ric((pred_pose * std + mean).float(), 22)
return pred_xyz.cpu().numpy().reshape(-1, 22, 3)
def create_animation(joints, title="3D Motion", save_path="static/animation.gif"):
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection='3d')
data = np.array(joints).T
line, = ax.plot(data[0, 0:1], data[1, 0:1], data[2, 0:1])
def update(num, data, line):
line.set_data(data[:2, :num])
line.set_3d_properties(data[2, :num])
return line,
ani = FuncAnimation(fig, update, frames=len(joints), fargs=(data, line), interval=50, blit=True)
ani.save(save_path, writer=PillowWriter(fps=20))
plt.close(fig)
return save_path
examples = [
"Person doing yoga",
"A person is dancing ballet",
]
def infer(text):
motion_data = generate_motion(text, vqvae_model, transformer_model)
gif_path = create_animation(motion_data)
return gif_path
with gr.Blocks(css=".container { max-width: 800px; margin: auto; }") as demo:
with gr.Column():
gr.Markdown("## 3D Motion Generation on " + ("GPU" if device == "cuda" else "CPU"))
text_input = gr.Textbox(label="Describe the action", placeholder="Enter text description for the action here...")
output_image = gr.Image(label="Generated Motion Animation")
submit_button = gr.Button("Generate Motion")
submit_button.click(
fn=infer,
inputs=[text_input],
outputs=[output_image]
)
if __name__ == "__main__":
demo.launch()