File size: 2,923 Bytes
ab3bbe5 02c7eaf ab3bbe5 af907cd 02c7eaf af907cd 02c7eaf af907cd ab3bbe5 02c7eaf ab3bbe5 02c7eaf af907cd 02c7eaf af907cd 02c7eaf af907cd 02c7eaf af907cd 02c7eaf ab3bbe5 af907cd ab3bbe5 02c7eaf af907cd 02c7eaf af907cd ab3bbe5 af907cd 02c7eaf ab3bbe5 af907cd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
import gradio as gr
import clip
import torch
import numpy as np
import models.vqvae as vqvae
import models.t2m_trans as trans
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from mpl_toolkits.mplot3d import Axes3D
import sys
import os
from matplotlib.animation import FuncAnimation
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
import warnings
warnings.filterwarnings('ignore')
device = "cuda" if torch.cuda.is_available() else "cpu"
vqvae_model = vqvae.VQVAE().to(device)
transformer_model = trans.Text2Motion_Transformer().to(device)
vqvae_model.load_state_dict(torch.load("output/VQVAE_imp_resnet_100k_hml3d/net_last.pth", map_location=device))
transformer_model.load_state_dict(torch.load("output/net_best_fid.pth", map_location=device))
vqvae_model.eval()
transformer_model.eval()
mean = torch.from_numpy(np.load('output/Mean.npy', allow_pickle=True)).to(device)
std = torch.from_numpy(np.load('output/Std.npy', allow_pickle=True)).to(device)
def generate_motion(text, vqvae_model, transformer_model):
clip_text = [text]
text_encoded = clip.tokenize(clip_text, truncate=True).to(device)
with torch.no_grad():
motion_indices = transformer_model.sample(text_encoded, False)
pred_pose = vqvae_model.forward_decoder(motion_indices)
pred_xyz = recover_from_ric((pred_pose * std + mean).float(), 22)
return pred_xyz.cpu().numpy().reshape(-1, 22, 3)
def create_animation(joints, title="3D Motion", save_path="static/animation.gif"):
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection='3d')
data = np.array(joints).T
line, = ax.plot(data[0, 0:1], data[1, 0:1], data[2, 0:1])
def update(num, data, line):
line.set_data(data[:2, :num])
line.set_3d_properties(data[2, :num])
return line,
ani = FuncAnimation(fig, update, frames=len(joints), fargs=(data, line), interval=50, blit=True)
ani.save(save_path, writer=PillowWriter(fps=20))
plt.close(fig)
return save_path
examples = [
"Person doing yoga",
"A person is dancing ballet",
]
def infer(text):
motion_data = generate_motion(text, vqvae_model, transformer_model)
gif_path = create_animation(motion_data)
return gif_path
with gr.Blocks(css=".container { max-width: 800px; margin: auto; }") as demo:
with gr.Column():
gr.Markdown("## 3D Motion Generation on " + ("GPU" if device == "cuda" else "CPU"))
text_input = gr.Textbox(label="Describe the action", placeholder="Enter text description for the action here...")
output_image = gr.Image(label="Generated Motion Animation")
submit_button = gr.Button("Generate Motion")
submit_button.click(
fn=infer,
inputs=[text_input],
outputs=[output_image]
)
if __name__ == "__main__":
demo.launch() |