File size: 2,923 Bytes
ab3bbe5
02c7eaf
ab3bbe5
af907cd
02c7eaf
 
af907cd
 
 
02c7eaf
 
 
 
af907cd
ab3bbe5
02c7eaf
 
 
ab3bbe5
 
02c7eaf
 
af907cd
 
 
 
 
02c7eaf
 
 
 
 
 
 
 
 
 
 
af907cd
02c7eaf
af907cd
 
 
 
 
 
 
 
 
 
 
02c7eaf
af907cd
 
02c7eaf
ab3bbe5
af907cd
 
ab3bbe5
 
02c7eaf
 
 
 
 
af907cd
 
02c7eaf
 
af907cd
 
ab3bbe5
af907cd
 
02c7eaf
 
ab3bbe5
 
af907cd
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import gradio as gr
import clip
import torch
import numpy as np
import models.vqvae as vqvae
import models.t2m_trans as trans
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from mpl_toolkits.mplot3d import Axes3D
import sys
import os
from matplotlib.animation import FuncAnimation
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d.art3d import Poly3DCollection

import warnings
warnings.filterwarnings('ignore')

device = "cuda" if torch.cuda.is_available() else "cpu"

vqvae_model = vqvae.VQVAE().to(device)
transformer_model = trans.Text2Motion_Transformer().to(device)
vqvae_model.load_state_dict(torch.load("output/VQVAE_imp_resnet_100k_hml3d/net_last.pth", map_location=device))
transformer_model.load_state_dict(torch.load("output/net_best_fid.pth", map_location=device))
vqvae_model.eval()
transformer_model.eval()

mean = torch.from_numpy(np.load('output/Mean.npy', allow_pickle=True)).to(device)
std = torch.from_numpy(np.load('output/Std.npy', allow_pickle=True)).to(device)

def generate_motion(text, vqvae_model, transformer_model):
    clip_text = [text]
    text_encoded = clip.tokenize(clip_text, truncate=True).to(device)
    with torch.no_grad():
        motion_indices = transformer_model.sample(text_encoded, False)
        pred_pose = vqvae_model.forward_decoder(motion_indices)
        pred_xyz = recover_from_ric((pred_pose * std + mean).float(), 22)
    return pred_xyz.cpu().numpy().reshape(-1, 22, 3)

def create_animation(joints, title="3D Motion", save_path="static/animation.gif"):
    fig = plt.figure(figsize=(10, 10))
    ax = fig.add_subplot(111, projection='3d')
    data = np.array(joints).T
    line, = ax.plot(data[0, 0:1], data[1, 0:1], data[2, 0:1])

    def update(num, data, line):
        line.set_data(data[:2, :num])
        line.set_3d_properties(data[2, :num])
        return line,

    ani = FuncAnimation(fig, update, frames=len(joints), fargs=(data, line), interval=50, blit=True)
    ani.save(save_path, writer=PillowWriter(fps=20))
    plt.close(fig)
    return save_path
    
examples = [
    "Person doing yoga",
    "A person is dancing ballet",
]

def infer(text):
    motion_data = generate_motion(text, vqvae_model, transformer_model)
    gif_path = create_animation(motion_data)
    return gif_path

with gr.Blocks(css=".container { max-width: 800px; margin: auto; }") as demo:
    with gr.Column():
        gr.Markdown("## 3D Motion Generation on " + ("GPU" if device == "cuda" else "CPU"))
        text_input = gr.Textbox(label="Describe the action", placeholder="Enter text description for the action here...")
        output_image = gr.Image(label="Generated Motion Animation")
        submit_button = gr.Button("Generate Motion")
        
        submit_button.click(
            fn=infer,
            inputs=[text_input],
            outputs=[output_image]
        )

if __name__ == "__main__":
    demo.launch()