rsax commited on
Commit
0826835
·
verified ·
1 Parent(s): 2472db7

Update visualization/plot_3d_global.py

Browse files
Files changed (1) hide show
  1. visualization/plot_3d_global.py +2 -2
visualization/plot_3d_global.py CHANGED
@@ -98,7 +98,7 @@ def plot_3d_motion(args, figsize=(10, 10), fps=120, radius=4):
98
  plt.savefig(io_buf, format='png', dpi=96)
99
  io_buf.seek(0)
100
  img = Image.open(io_buf)
101
- frame = np.array(img.convert('RGB'), dtype=np.uint8) # Ensure correct data type and remove alpha channel if present
102
  io_buf.close()
103
 
104
  plt.close(fig)
@@ -107,7 +107,7 @@ def plot_3d_motion(args, figsize=(10, 10), fps=120, radius=4):
107
  out = []
108
  for i in range(frame_number):
109
  frame = update(i)
110
- if frame.ndim == 3 and frame.shape[2] == 3: # Check that frame is H x W x 3
111
  out.append(frame)
112
  else:
113
  print(f"Frame {i} has incorrect shape or channels: {frame.shape}")
 
98
  plt.savefig(io_buf, format='png', dpi=96)
99
  io_buf.seek(0)
100
  img = Image.open(io_buf)
101
+ frame = np.array(img.convert('RGB'), dtype=np.uint8)
102
  io_buf.close()
103
 
104
  plt.close(fig)
 
107
  out = []
108
  for i in range(frame_number):
109
  frame = update(i)
110
+ if frame.ndim == 3 and frame.shape[2] == 3:
111
  out.append(frame)
112
  else:
113
  print(f"Frame {i} has incorrect shape or channels: {frame.shape}")