import gradio as gr import clip import torch import numpy as np import tempfile import models.vqvae as vqvae import options.option_transformer as option_trans from utils.motion_process import recover_from_ric import models.t2m_trans as trans import matplotlib.pyplot as plt import matplotlib import base64 from PIL import Image import io import mpl_toolkits.mplot3d as p3 from matplotlib.animation import FuncAnimation, PillowWriter from mpl_toolkits.mplot3d import Axes3D from mpl_toolkits.mplot3d.art3d import Poly3DCollection from visualization.plot_3d_global import draw_to_batch import imageio import sys import os import warnings warnings.filterwarnings('ignore') device = "cuda" if torch.cuda.is_available() else "cpu" args = option_trans.get_args_parser() args.dataname = 't2m' args.down_t = 2 args.depth = 3 args.block_size = 51 vqvae_model = vqvae.HumanVQVAE(args).to(device) transformer_model = trans.Text2Motion_Transformer(num_vq=args.nb_code, embed_dim=1024, clip_dim=args.clip_dim, block_size=args.block_size, num_layers=9, n_head=16, drop_out_rate=args.drop_out_rate, fc_rate=args.ff_rate).to(device) vqvae_checkpoint = torch.load("output/VQVAE_imp_resnet_100k_hml3d/net_last.pth", map_location=device) transformed_vqvae_state_dict = {k.replace("vqvae.", ""): v for k, v in vqvae_checkpoint['net'].items()} vqvae_model.load_state_dict(transformed_vqvae_state_dict, strict=False) transformer_checkpoint = torch.load("output/VQVAE_imp_resnet_100k_hml3d/net_best_fid.pth", map_location=device) transformed_transformer_state_dict = {k.replace("trans.", ""): v for k, v in transformer_checkpoint['trans'].items()} transformer_model.load_state_dict(transformed_transformer_state_dict, strict=False) vqvae_model.eval() transformer_model.eval() mean = torch.from_numpy(np.load('output/Mean.npy', allow_pickle=True)).to(device) std = torch.from_numpy(np.load('output/Std.npy', allow_pickle=True)).to(device) def generate_motion(text, vqvae_model, transformer_model): clip_text = [text] text_encoded = clip.tokenize(clip_text, truncate=True).to(device) with torch.no_grad(): clip_model, _ = clip.load("ViT-B/32", device=device) clip_model.eval() clip_features = clip_model.encode_text(text_encoded).float() motion_indices = transformer_model.sample(clip_features, False) pred_pose = vqvae_model.forward_decoder(motion_indices) pred_xyz = recover_from_ric((pred_pose * std + mean).float(), 22) return pred_xyz.cpu().numpy().reshape(-1, 22, 3) def infer(text): print("Received text:", text) try: motion_data = generate_motion(text, vqvae_model, transformer_model) if motion_data.size == 0: raise ValueError("Generated motion data is empty") except Exception as e: print(f"Failed during motion generation: {str(e)}") return "Error in motion generation." try: gif_data = draw_to_batch([motion_data], [text], None) if gif_data: gif_filename = "output.gif" gif_path = os.path.join(tempfile.gettempdir(), gif_filename) with open(gif_path, "wb") as gif_file: gif_file.write(gif_data) print("GIF successfully saved to:", gif_path) return gif_path else: print("Failed to generate GIF data.") return "Error generating GIF. Please try again." except Exception as e: print(f"Error generating GIF: {str(e)}") return "Error generating GIF. Please try again." css = ".container { max-width: 800px; margin: auto; }" with gr.Blocks(css=css) as demo: with gr.Column(elem_id="col-container"): gr.Markdown("## 3D Human Motion Generation") with gr.Row(): text_input = gr.Textbox(label="Enter the human action to generate", placeholder="Enter text description for the action here...", show_label=True) submit_button = gr.Button("Generate Motion") output_image = gr.Image(label="Generated Human Motion", type="filepath", show_label=False) submit_button.click( fn=infer, inputs=[text_input], outputs=[output_image] ) if __name__ == "__main__": demo.launch()