|
import gradio as gr |
|
import clip |
|
import torch |
|
import numpy as np |
|
import models.vqvae as vqvae |
|
import models.t2m_trans as trans |
|
import matplotlib.pyplot as plt |
|
from matplotlib.animation import FuncAnimation |
|
from mpl_toolkits.mplot3d import Axes3D |
|
import sys |
|
import os |
|
from matplotlib.animation import FuncAnimation |
|
from mpl_toolkits.mplot3d import Axes3D |
|
from mpl_toolkits.mplot3d.art3d import Poly3DCollection |
|
|
|
import warnings |
|
warnings.filterwarnings('ignore') |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
vqvae_model = vqvae.VQVAE().to(device) |
|
transformer_model = trans.Text2Motion_Transformer().to(device) |
|
vqvae_model.load_state_dict(torch.load("output/VQVAE_imp_resnet_100k_hml3d/net_last.pth", map_location=device)) |
|
transformer_model.load_state_dict(torch.load("output/net_best_fid.pth", map_location=device)) |
|
vqvae_model.eval() |
|
transformer_model.eval() |
|
|
|
mean = torch.from_numpy(np.load('output/Mean.npy', allow_pickle=True)).to(device) |
|
std = torch.from_numpy(np.load('output/Std.npy', allow_pickle=True)).to(device) |
|
|
|
def generate_motion(text, vqvae_model, transformer_model): |
|
clip_text = [text] |
|
text_encoded = clip.tokenize(clip_text, truncate=True).to(device) |
|
with torch.no_grad(): |
|
motion_indices = transformer_model.sample(text_encoded, False) |
|
pred_pose = vqvae_model.forward_decoder(motion_indices) |
|
pred_xyz = recover_from_ric((pred_pose * std + mean).float(), 22) |
|
return pred_xyz.cpu().numpy().reshape(-1, 22, 3) |
|
|
|
def create_animation(joints, title="3D Motion", save_path="static/animation.gif"): |
|
fig = plt.figure(figsize=(10, 10)) |
|
ax = fig.add_subplot(111, projection='3d') |
|
data = np.array(joints).T |
|
line, = ax.plot(data[0, 0:1], data[1, 0:1], data[2, 0:1]) |
|
|
|
def update(num, data, line): |
|
line.set_data(data[:2, :num]) |
|
line.set_3d_properties(data[2, :num]) |
|
return line, |
|
|
|
ani = FuncAnimation(fig, update, frames=len(joints), fargs=(data, line), interval=50, blit=True) |
|
ani.save(save_path, writer=PillowWriter(fps=20)) |
|
plt.close(fig) |
|
return save_path |
|
|
|
examples = [ |
|
"Person doing yoga", |
|
"A person is dancing ballet", |
|
] |
|
|
|
def infer(text): |
|
motion_data = generate_motion(text, vqvae_model, transformer_model) |
|
gif_path = create_animation(motion_data) |
|
return gif_path |
|
|
|
with gr.Blocks(css=".container { max-width: 800px; margin: auto; }") as demo: |
|
with gr.Column(): |
|
gr.Markdown("## 3D Motion Generation on " + ("GPU" if device == "cuda" else "CPU")) |
|
text_input = gr.Textbox(label="Describe the action", placeholder="Enter text description for the action here...") |
|
output_image = gr.Image(label="Generated Motion Animation") |
|
submit_button = gr.Button("Generate Motion") |
|
|
|
submit_button.click( |
|
fn=infer, |
|
inputs=[text_input], |
|
outputs=[output_image] |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |