import gradio as gr import clip import torch import numpy as np import models.vqvae as vqvae import models.t2m_trans as trans import matplotlib.pyplot as plt from matplotlib.animation import FuncAnimation from mpl_toolkits.mplot3d import Axes3D import sys import os from matplotlib.animation import FuncAnimation from mpl_toolkits.mplot3d import Axes3D from mpl_toolkits.mplot3d.art3d import Poly3DCollection import warnings warnings.filterwarnings('ignore') device = "cuda" if torch.cuda.is_available() else "cpu" vqvae_model = vqvae.VQVAE().to(device) transformer_model = trans.Text2Motion_Transformer().to(device) vqvae_model.load_state_dict(torch.load("output/VQVAE_imp_resnet_100k_hml3d/net_last.pth", map_location=device)) transformer_model.load_state_dict(torch.load("output/net_best_fid.pth", map_location=device)) vqvae_model.eval() transformer_model.eval() mean = torch.from_numpy(np.load('output/Mean.npy', allow_pickle=True)).to(device) std = torch.from_numpy(np.load('output/Std.npy', allow_pickle=True)).to(device) def generate_motion(text, vqvae_model, transformer_model): clip_text = [text] text_encoded = clip.tokenize(clip_text, truncate=True).to(device) with torch.no_grad(): motion_indices = transformer_model.sample(text_encoded, False) pred_pose = vqvae_model.forward_decoder(motion_indices) pred_xyz = recover_from_ric((pred_pose * std + mean).float(), 22) return pred_xyz.cpu().numpy().reshape(-1, 22, 3) def create_animation(joints, title="3D Motion", save_path="static/animation.gif"): fig = plt.figure(figsize=(10, 10)) ax = fig.add_subplot(111, projection='3d') data = np.array(joints).T line, = ax.plot(data[0, 0:1], data[1, 0:1], data[2, 0:1]) def update(num, data, line): line.set_data(data[:2, :num]) line.set_3d_properties(data[2, :num]) return line, ani = FuncAnimation(fig, update, frames=len(joints), fargs=(data, line), interval=50, blit=True) ani.save(save_path, writer=PillowWriter(fps=20)) plt.close(fig) return save_path examples = [ "Person doing yoga", "A person is dancing ballet", ] def infer(text): motion_data = generate_motion(text, vqvae_model, transformer_model) gif_path = create_animation(motion_data) return gif_path with gr.Blocks(css=".container { max-width: 800px; margin: auto; }") as demo: with gr.Column(): gr.Markdown("## 3D Motion Generation on " + ("GPU" if device == "cuda" else "CPU")) text_input = gr.Textbox(label="Describe the action", placeholder="Enter text description for the action here...") output_image = gr.Image(label="Generated Motion Animation") submit_button = gr.Button("Generate Motion") submit_button.click( fn=infer, inputs=[text_input], outputs=[output_image] ) if __name__ == "__main__": demo.launch()