rsax commited on
Commit
02c7eaf
·
verified ·
1 Parent(s): 2425c62

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -14
app.py CHANGED
@@ -1,22 +1,43 @@
1
  import gradio as gr
 
2
  import torch
3
  import numpy as np
 
 
4
  import matplotlib.pyplot as plt
5
  from matplotlib.animation import FuncAnimation
6
  from mpl_toolkits.mplot3d import Axes3D
 
 
 
 
7
  from mpl_toolkits.mplot3d.art3d import Poly3DCollection
8
 
 
 
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
 
11
- vqvae_model = VQVAE().to(device)
12
- transformer_model = Transformer().to(device)
13
  vqvae_model.load_state_dict(torch.load("output/VQVAE_imp_resnet_100k_hml3d/net_last.pth", map_location=device))
14
  transformer_model.load_state_dict(torch.load("output/net_best_fid.pth", map_location=device))
15
  vqvae_model.eval()
16
  transformer_model.eval()
17
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- def create_animation(joints, title="3D Motion", save_path="animation.gif"):
20
  fig = plt.figure(figsize=(10, 10))
21
  ax = fig.add_subplot(111, projection='3d')
22
  data = np.array(joints).T
@@ -28,31 +49,31 @@ def create_animation(joints, title="3D Motion", save_path="animation.gif"):
28
  return line,
29
 
30
  ani = FuncAnimation(fig, update, frames=len(joints), fargs=(data, line), interval=50, blit=True)
31
- ani.save(save_path, writer='pillow')
32
  plt.close(fig)
33
  return save_path
34
-
35
- def infer(text):
36
- motion_data = generate_motion(text, vqvae_model, transformer_model)
37
- gif_path = create_animation(motion_data, kinematic_tree)
38
- return gif_path
39
-
40
  examples = [
41
  "Person doing yoga",
42
  "A person is dancing ballet",
43
  ]
44
 
 
 
 
 
 
45
  with gr.Blocks(css=".container { max-width: 800px; margin: auto; }") as demo:
46
  with gr.Column():
47
- gr.Markdown("## 3D Motion Generation")
48
- text_input = gr.Textbox(label="Describe the action", placeholder="Enter text description of the action here...")
49
  output_image = gr.Image(label="Generated Motion Animation")
50
  submit_button = gr.Button("Generate Motion")
51
 
52
  submit_button.click(
53
  fn=infer,
54
- inputs=text_input,
55
- outputs=output_image
56
  )
57
 
58
  if __name__ == "__main__":
 
1
  import gradio as gr
2
+ import clip
3
  import torch
4
  import numpy as np
5
+ import models.vqvae as vqvae
6
+ import models.t2m_trans as trans
7
  import matplotlib.pyplot as plt
8
  from matplotlib.animation import FuncAnimation
9
  from mpl_toolkits.mplot3d import Axes3D
10
+ import sys
11
+ import os
12
+ from matplotlib.animation import FuncAnimation
13
+ from mpl_toolkits.mplot3d import Axes3D
14
  from mpl_toolkits.mplot3d.art3d import Poly3DCollection
15
 
16
+ import warnings
17
+ warnings.filterwarnings('ignore')
18
+
19
  device = "cuda" if torch.cuda.is_available() else "cpu"
20
 
21
+ vqvae_model = vqvae.VQVAE().to(device)
22
+ transformer_model = trans.Text2Motion_Transformer().to(device)
23
  vqvae_model.load_state_dict(torch.load("output/VQVAE_imp_resnet_100k_hml3d/net_last.pth", map_location=device))
24
  transformer_model.load_state_dict(torch.load("output/net_best_fid.pth", map_location=device))
25
  vqvae_model.eval()
26
  transformer_model.eval()
27
 
28
+ mean = torch.from_numpy(np.load('output/Mean.npy', allow_pickle=True)).to(device)
29
+ std = torch.from_numpy(np.load('output/Std.npy', allow_pickle=True)).to(device)
30
+
31
+ def generate_motion(text, vqvae_model, transformer_model):
32
+ clip_text = [text]
33
+ text_encoded = clip.tokenize(clip_text, truncate=True).to(device)
34
+ with torch.no_grad():
35
+ motion_indices = transformer_model.sample(text_encoded, False)
36
+ pred_pose = vqvae_model.forward_decoder(motion_indices)
37
+ pred_xyz = recover_from_ric((pred_pose * std + mean).float(), 22)
38
+ return pred_xyz.cpu().numpy().reshape(-1, 22, 3)
39
 
40
+ def create_animation(joints, title="3D Motion", save_path="static/animation.gif"):
41
  fig = plt.figure(figsize=(10, 10))
42
  ax = fig.add_subplot(111, projection='3d')
43
  data = np.array(joints).T
 
49
  return line,
50
 
51
  ani = FuncAnimation(fig, update, frames=len(joints), fargs=(data, line), interval=50, blit=True)
52
+ ani.save(save_path, writer=PillowWriter(fps=20))
53
  plt.close(fig)
54
  return save_path
55
+
 
 
 
 
 
56
  examples = [
57
  "Person doing yoga",
58
  "A person is dancing ballet",
59
  ]
60
 
61
+ def infer(text):
62
+ motion_data = generate_motion(text, vqvae_model, transformer_model)
63
+ gif_path = create_animation(motion_data)
64
+ return gif_path
65
+
66
  with gr.Blocks(css=".container { max-width: 800px; margin: auto; }") as demo:
67
  with gr.Column():
68
+ gr.Markdown("## 3D Motion Generation on " + ("GPU" if device == "cuda" else "CPU"))
69
+ text_input = gr.Textbox(label="Describe the action", placeholder="Enter text description for the action here...")
70
  output_image = gr.Image(label="Generated Motion Animation")
71
  submit_button = gr.Button("Generate Motion")
72
 
73
  submit_button.click(
74
  fn=infer,
75
+ inputs=[text_input],
76
+ outputs=[output_image]
77
  )
78
 
79
  if __name__ == "__main__":