kleinhe commited on
Commit
73ba0a5
·
1 Parent(s): a995483
SMPLX/visualize_joint2smpl/joints2smpl/src/smplify.py CHANGED
@@ -8,7 +8,7 @@ from customloss import (camera_fitting_loss_3d,
8
  )
9
  from prior import MaxMixturePrior
10
  from SMPLX.visualize_joint2smpl.joints2smpl.src import config
11
-
12
 
13
  @torch.no_grad()
14
  def guess_init_3d(model_joints,
@@ -41,32 +41,21 @@ class SMPLify3D():
41
 
42
  def __init__(self,
43
  smplxmodel,
44
- step_size=1e-2,
45
- batch_size=1,
46
  num_iters=100,
47
- use_collision=False,
48
- use_lbfgs=True,
49
  joints_category="orig",
50
  device=torch.device('cuda:0'),
51
  ):
52
 
53
  # Store options
54
- self.batch_size = batch_size
55
  self.device = device
56
  self.step_size = step_size
57
 
58
  self.num_iters = num_iters
59
- # --- choose optimizer
60
- self.use_lbfgs = use_lbfgs
61
  # GMM pose prior
62
  self.pose_prior = MaxMixturePrior(prior_folder=config.GMM_MODEL_DIR,
63
  num_gaussians=8,
64
  dtype=torch.float32).to(device)
65
- # collision part
66
- self.use_collision = use_collision
67
- if self.use_collision:
68
- self.part_segm_fn = config.Part_Seg_DIR
69
-
70
  # reLoad SMPL-X model
71
  self.smpl = smplxmodel
72
 
@@ -103,35 +92,6 @@ class SMPLify3D():
103
  betas: SMPL beta parameters of optimized shape
104
  camera_translation: Camera translation
105
  """
106
-
107
- # # # add the mesh inter-section to avoid
108
- search_tree = None
109
- pen_distance = None
110
- filter_faces = None
111
-
112
- if self.use_collision:
113
- from mesh_intersection.bvh_search_tree import BVH
114
- import mesh_intersection.loss as collisions_loss
115
- from mesh_intersection.filter_faces import FilterFaces
116
-
117
- search_tree = BVH(max_collisions=8)
118
-
119
- pen_distance = collisions_loss.DistanceFieldPenetrationLoss(
120
- sigma=0.5, point2plane=False, vectorized=True, penalize_outside=True)
121
-
122
- if self.part_segm_fn:
123
- # Read the part segmentation
124
- part_segm_fn = os.path.expandvars(self.part_segm_fn)
125
- with open(part_segm_fn, 'rb') as faces_parents_file:
126
- face_segm_data = pickle.load(faces_parents_file, encoding='latin1')
127
- faces_segm = face_segm_data['segm']
128
- faces_parents = face_segm_data['parents']
129
- # Create the module used to filter invalid collision pairs
130
- filter_faces = FilterFaces(
131
- faces_segm=faces_segm, faces_parents=faces_parents,
132
- ign_part_pairs=None).to(device=self.device)
133
-
134
-
135
  # Split SMPL pose to body pose and global orientation
136
  body_pose = init_pose[:, 3:].detach().clone()
137
  global_orient = init_pose[:, :3].detach().clone()
@@ -150,42 +110,29 @@ class SMPLify3D():
150
  # -------------Step 1: Optimize camera translation and body orientation--------
151
  # Optimize only camera translation and body orientation
152
  body_pose.requires_grad = False
153
- betas.requires_grad = False
154
  global_orient.requires_grad = True
155
  camera_translation.requires_grad = True
156
 
157
- camera_opt_params = [global_orient, camera_translation]
158
-
159
- if self.use_lbfgs:
160
- camera_optimizer = torch.optim.LBFGS(camera_opt_params, max_iter=self.num_iters,
161
- lr=self.step_size, line_search_fn='strong_wolfe')
162
- for i in range(10):
163
- def closure():
164
- camera_optimizer.zero_grad()
165
- smpl_output = self.smpl(global_orient=global_orient,
166
- body_pose=body_pose,
167
- betas=betas)
168
- model_joints = smpl_output.joints
169
- loss = camera_fitting_loss_3d(model_joints, camera_translation,
170
- init_cam_t, j3d, self.joints_category)
171
- loss.backward()
172
- return loss
173
 
174
- camera_optimizer.step(closure)
175
- else:
176
- camera_optimizer = torch.optim.Adam(camera_opt_params, lr=self.step_size, betas=(0.9, 0.999))
177
-
178
- for i in range(20):
 
179
  smpl_output = self.smpl(global_orient=global_orient,
180
  body_pose=body_pose,
181
  betas=betas)
182
  model_joints = smpl_output.joints
 
 
183
 
184
- loss = camera_fitting_loss_3d(model_joints[:, self.smpl_index], camera_translation,
185
- init_cam_t, j3d[:, self.corr_index], self.joints_category)
186
- camera_optimizer.zero_grad()
187
  loss.backward()
188
- camera_optimizer.step()
 
 
189
 
190
  # Fix camera translation after optimizing camera
191
  # --------Step 2: Optimize body joints --------------------------
@@ -193,43 +140,15 @@ class SMPLify3D():
193
  body_pose.requires_grad = True
194
  global_orient.requires_grad = True
195
  camera_translation.requires_grad = True
196
-
197
- # --- if we use the sequence, fix the shape
198
- if seq_ind == 0:
199
- betas.requires_grad = True
200
- body_opt_params = [body_pose, betas, global_orient, camera_translation]
201
- else:
202
- betas.requires_grad = False
203
- body_opt_params = [body_pose, global_orient, camera_translation]
204
-
205
- if self.use_lbfgs:
206
- body_optimizer = torch.optim.LBFGS(body_opt_params, max_iter=self.num_iters,
207
- lr=self.step_size, line_search_fn='strong_wolfe')
208
- for i in range(self.num_iters):
209
- def closure():
210
- body_optimizer.zero_grad()
211
- smpl_output = self.smpl(global_orient=global_orient,
212
- body_pose=body_pose,
213
- betas=betas)
214
- model_joints = smpl_output.joints
215
- model_vertices = smpl_output.vertices
216
-
217
- loss = body_fitting_loss_3d(body_pose, preserve_pose, betas, model_joints[:, self.smpl_index], camera_translation,
218
- j3d[:, self.corr_index], self.pose_prior,
219
- joints3d_conf=conf_3d,
220
- joint_loss_weight=600.0,
221
- pose_preserve_weight=5.0,
222
- use_collision=self.use_collision,
223
- model_vertices=model_vertices, model_faces=self.model_faces,
224
- search_tree=search_tree, pen_distance=pen_distance, filter_faces=filter_faces)
225
- loss.backward()
226
- return loss
227
-
228
- body_optimizer.step(closure)
229
- else:
230
- body_optimizer = torch.optim.Adam(body_opt_params, lr=self.step_size, betas=(0.9, 0.999))
231
-
232
- for i in range(self.num_iters):
233
  smpl_output = self.smpl(global_orient=global_orient,
234
  body_pose=body_pose,
235
  betas=betas)
@@ -240,31 +159,15 @@ class SMPLify3D():
240
  j3d[:, self.corr_index], self.pose_prior,
241
  joints3d_conf=conf_3d,
242
  joint_loss_weight=600.0,
243
- use_collision=self.use_collision,
 
244
  model_vertices=model_vertices, model_faces=self.model_faces,
245
- search_tree=search_tree, pen_distance=pen_distance, filter_faces=filter_faces)
246
- body_optimizer.zero_grad()
247
  loss.backward()
248
- body_optimizer.step()
249
 
250
- # Get final loss value
251
- with torch.no_grad():
252
- smpl_output = self.smpl(global_orient=global_orient,
253
- body_pose=body_pose,
254
- betas=betas, return_full_pose=True)
255
- model_joints = smpl_output.joints
256
- model_vertices = smpl_output.vertices
257
 
258
- final_loss = body_fitting_loss_3d(body_pose, preserve_pose, betas, model_joints[:, self.smpl_index], camera_translation,
259
- j3d[:, self.corr_index], self.pose_prior,
260
- joints3d_conf=conf_3d,
261
- joint_loss_weight=600.0,
262
- use_collision=self.use_collision, model_vertices=model_vertices, model_faces=self.model_faces,
263
- search_tree=search_tree, pen_distance=pen_distance, filter_faces=filter_faces)
264
 
265
- vertices = smpl_output.vertices.detach()
266
- joints = smpl_output.joints.detach()
267
  pose = torch.cat([global_orient, body_pose], dim=-1).detach()
268
- betas = betas.detach()
269
-
270
- return vertices, joints, pose, betas, camera_translation, final_loss
 
8
  )
9
  from prior import MaxMixturePrior
10
  from SMPLX.visualize_joint2smpl.joints2smpl.src import config
11
+ from tqdm import tqdm
12
 
13
  @torch.no_grad()
14
  def guess_init_3d(model_joints,
 
41
 
42
  def __init__(self,
43
  smplxmodel,
44
+ step_size=1.0,
 
45
  num_iters=100,
 
 
46
  joints_category="orig",
47
  device=torch.device('cuda:0'),
48
  ):
49
 
50
  # Store options
 
51
  self.device = device
52
  self.step_size = step_size
53
 
54
  self.num_iters = num_iters
 
 
55
  # GMM pose prior
56
  self.pose_prior = MaxMixturePrior(prior_folder=config.GMM_MODEL_DIR,
57
  num_gaussians=8,
58
  dtype=torch.float32).to(device)
 
 
 
 
 
59
  # reLoad SMPL-X model
60
  self.smpl = smplxmodel
61
 
 
92
  betas: SMPL beta parameters of optimized shape
93
  camera_translation: Camera translation
94
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  # Split SMPL pose to body pose and global orientation
96
  body_pose = init_pose[:, 3:].detach().clone()
97
  global_orient = init_pose[:, :3].detach().clone()
 
110
  # -------------Step 1: Optimize camera translation and body orientation--------
111
  # Optimize only camera translation and body orientation
112
  body_pose.requires_grad = False
113
+ betas.requires_grad = True
114
  global_orient.requires_grad = True
115
  camera_translation.requires_grad = True
116
 
117
+ camera_opt_params = [betas, global_orient, camera_translation]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
+ camera_optimizer = torch.optim.LBFGS(camera_opt_params, max_iter=10,
120
+ lr=self.step_size, line_search_fn='strong_wolfe')
121
+ cycle = tqdm(range(10))
122
+ for i in cycle:
123
+ def closure():
124
+ camera_optimizer.zero_grad()
125
  smpl_output = self.smpl(global_orient=global_orient,
126
  body_pose=body_pose,
127
  betas=betas)
128
  model_joints = smpl_output.joints
129
+ loss = camera_fitting_loss_3d(model_joints, camera_translation,
130
+ init_cam_t, j3d, self.joints_category)
131
 
 
 
 
132
  loss.backward()
133
+ return loss
134
+
135
+ camera_optimizer.step(closure)
136
 
137
  # Fix camera translation after optimizing camera
138
  # --------Step 2: Optimize body joints --------------------------
 
140
  body_pose.requires_grad = True
141
  global_orient.requires_grad = True
142
  camera_translation.requires_grad = True
143
+ betas.requires_grad = True
144
+ body_opt_params = [body_pose, betas, global_orient, camera_translation]
145
+
146
+ cycle = tqdm(range(self.num_iters))
147
+ body_optimizer = torch.optim.LBFGS(body_opt_params, max_iter=self.num_iters,
148
+ lr=self.step_size, line_search_fn='strong_wolfe')
149
+ for i in cycle:
150
+ def closure():
151
+ body_optimizer.zero_grad()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  smpl_output = self.smpl(global_orient=global_orient,
153
  body_pose=body_pose,
154
  betas=betas)
 
159
  j3d[:, self.corr_index], self.pose_prior,
160
  joints3d_conf=conf_3d,
161
  joint_loss_weight=600.0,
162
+ pose_preserve_weight=5.0,
163
+ use_collision=False,
164
  model_vertices=model_vertices, model_faces=self.model_faces,
165
+ search_tree=None, pen_distance=None, filter_faces=None)
 
166
  loss.backward()
167
+ return loss
168
 
169
+ body_optimizer.step(closure)
 
 
 
 
 
 
170
 
 
 
 
 
 
 
171
 
 
 
172
  pose = torch.cat([global_orient, body_pose], dim=-1).detach()
173
+ return pose
 
 
SMPLX/visualize_joint2smpl/simplify_loc2rot.py CHANGED
@@ -7,7 +7,6 @@ from SMPLX.visualize_joint2smpl.joints2smpl.src.smplify import SMPLify3D
7
  from tqdm import tqdm
8
  import argparse
9
 
10
-
11
  class joints2smpl:
12
 
13
  def __init__(self, num_frames, device, model_path=None, json_dict=None):
@@ -17,8 +16,9 @@ class joints2smpl:
17
  self.batch_size = num_frames
18
  self.num_joints = 22 # for HumanML3D
19
  self.joint_category = "AMASS"
20
- self.num_smplify_iters = 100
21
  self.fix_foot = False
 
22
  smplmodel = smplx.create(self.smpl_dir, model_type="smpl", gender="neutral", ext="pkl",
23
  batch_size=self.batch_size).to(self.device)
24
 
@@ -33,7 +33,6 @@ class joints2smpl:
33
 
34
  # # #-------------initialize SMPLify
35
  self.smplify = SMPLify3D(smplxmodel=smplmodel,
36
- batch_size=self.batch_size,
37
  joints_category=self.joint_category,
38
  num_iters=self.num_smplify_iters,
39
  device=self.device)
@@ -92,18 +91,17 @@ class joints2smpl:
92
  else:
93
  print("Such category not settle down!")
94
 
95
- new_opt_vertices, new_opt_joints, new_opt_pose, new_opt_betas, \
96
- new_opt_cam_t, new_opt_joint_loss = self.smplify(
97
  pred_pose.detach(),
98
  pred_betas.detach(),
99
  pred_cam_t.detach(),
100
  keypoints_3d,
101
  conf_3d=confidence_input.to(self.device),
102
- # seq_ind=idx
103
  )
104
 
105
  thetas = new_opt_pose.reshape(self.batch_size, 24 * 3)
106
  vecs = thetas.detach().cpu().numpy()
 
107
  return vecs, root_loc
108
 
109
 
 
7
  from tqdm import tqdm
8
  import argparse
9
 
 
10
  class joints2smpl:
11
 
12
  def __init__(self, num_frames, device, model_path=None, json_dict=None):
 
16
  self.batch_size = num_frames
17
  self.num_joints = 22 # for HumanML3D
18
  self.joint_category = "AMASS"
19
+ self.num_smplify_iters = 15
20
  self.fix_foot = False
21
+
22
  smplmodel = smplx.create(self.smpl_dir, model_type="smpl", gender="neutral", ext="pkl",
23
  batch_size=self.batch_size).to(self.device)
24
 
 
33
 
34
  # # #-------------initialize SMPLify
35
  self.smplify = SMPLify3D(smplxmodel=smplmodel,
 
36
  joints_category=self.joint_category,
37
  num_iters=self.num_smplify_iters,
38
  device=self.device)
 
91
  else:
92
  print("Such category not settle down!")
93
 
94
+ new_opt_pose = self.smplify(
 
95
  pred_pose.detach(),
96
  pred_betas.detach(),
97
  pred_cam_t.detach(),
98
  keypoints_3d,
99
  conf_3d=confidence_input.to(self.device),
 
100
  )
101
 
102
  thetas = new_opt_pose.reshape(self.batch_size, 24 * 3)
103
  vecs = thetas.detach().cpu().numpy()
104
+
105
  return vecs, root_loc
106
 
107
 
app.py CHANGED
@@ -121,7 +121,7 @@ def t2m_demo():
121
 
122
  with gr.Row():
123
  condition = gr.Radio(['text', 'uncond'], value='text', label='Condition', info="If sythesize motion with prompt?")
124
- out_size = gr.Number(value=1024, label="Resolution", info="The resolution of output videos", minimum=224, maximum=2048, precision=0)
125
 
126
  with gr.Row():
127
  render_mode = gr.Radio(['joints','pyrender_fast', 'pyrender_slow'], value='joints', label='Render', info="If render results to 3D meshes? Pyrender need more time.")
 
121
 
122
  with gr.Row():
123
  condition = gr.Radio(['text', 'uncond'], value='text', label='Condition', info="If sythesize motion with prompt?")
124
+ out_size = gr.Number(value=256, label="Resolution", info="The resolution of output videos", minimum=128, maximum=2048, precision=0)
125
 
126
  with gr.Row():
127
  render_mode = gr.Radio(['joints','pyrender_fast', 'pyrender_slow'], value='joints', label='Render', info="If render results to 3D meshes? Pyrender need more time.")