MotionLCM / mld /data /humanml /utils /plot_script.py
wxDai's picture
[Init]
eb339cb
from textwrap import wrap
from typing import Optional
import numpy as np
import matplotlib.pyplot as plt
import mpl_toolkits.mplot3d.axes3d as p3
from matplotlib.animation import FuncAnimation
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
import mld.data.humanml.utils.paramUtil as paramUtil
skeleton = paramUtil.t2m_kinematic_chain
def plot_3d_motion(save_path: str, joints: np.ndarray, title: str,
figsize: tuple[int, int] = (3, 3),
fps: int = 120, radius: int = 3, kinematic_tree: list = skeleton,
hint: Optional[np.ndarray] = None) -> None:
title = '\n'.join(wrap(title, 20))
def init():
ax.set_xlim3d([-radius / 2, radius / 2])
ax.set_ylim3d([0, radius])
ax.set_zlim3d([-radius / 3., radius * 2 / 3.])
fig.suptitle(title, fontsize=10)
ax.grid(b=False)
def plot_xzPlane(minx, maxx, miny, minz, maxz):
# Plot a plane XZ
verts = [
[minx, miny, minz],
[minx, miny, maxz],
[maxx, miny, maxz],
[maxx, miny, minz]
]
xz_plane = Poly3DCollection([verts])
xz_plane.set_facecolor((0.5, 0.5, 0.5, 0.5))
ax.add_collection3d(xz_plane)
# (seq_len, joints_num, 3)
data = joints.copy().reshape(len(joints), -1, 3)
data *= 1.3 # scale for visualization
if hint is not None:
mask = hint.sum(-1) != 0
hint = hint[mask]
hint *= 1.3
fig = plt.figure(figsize=figsize)
plt.tight_layout()
ax = p3.Axes3D(fig)
init()
MINS = data.min(axis=0).min(axis=0)
MAXS = data.max(axis=0).max(axis=0)
colors = ["#DD5A37", "#D69E00", "#B75A39", "#DD5A37", "#D69E00",
"#FF6D00", "#FF6D00", "#FF6D00", "#FF6D00", "#FF6D00",
"#DDB50E", "#DDB50E", "#DDB50E", "#DDB50E", "#DDB50E", ]
frame_number = data.shape[0]
height_offset = MINS[1]
data[:, :, 1] -= height_offset
if hint is not None:
hint[..., 1] -= height_offset
trajec = data[:, 0, [0, 2]]
data[..., 0] -= data[:, 0:1, 0]
data[..., 2] -= data[:, 0:1, 2]
def update(index):
ax.lines = []
ax.collections = []
ax.view_init(elev=120, azim=-90)
ax.dist = 7.5
plot_xzPlane(MINS[0] - trajec[index, 0], MAXS[0] - trajec[index, 0], 0, MINS[2] - trajec[index, 1],
MAXS[2] - trajec[index, 1])
if hint is not None:
ax.scatter(hint[..., 0] - trajec[index, 0], hint[..., 1], hint[..., 2] - trajec[index, 1], color="#80B79A")
for i, (chain, color) in enumerate(zip(kinematic_tree, colors)):
if i < 5:
linewidth = 4.0
else:
linewidth = 2.0
ax.plot3D(data[index, chain, 0], data[index, chain, 1], data[index, chain, 2], linewidth=linewidth,
color=color)
plt.axis('off')
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_zticklabels([])
ani = FuncAnimation(fig, update, frames=frame_number, interval=1000 / fps, repeat=False)
ani.save(save_path, fps=fps)
plt.close()