Ali Mohsin commited on
Commit
63bea2f
·
1 Parent(s): 4c539b3

more chnages

Browse files
Files changed (1) hide show
  1. loop.py +80 -59
loop.py CHANGED
@@ -147,9 +147,71 @@ def loop(cfg):
147
 
148
  # output video
149
  video = Video(cfg.output_path)
150
- # GL Context
151
- glctx = dr.RasterizeGLContext()
152
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
  load_mesh = get_mesh(cfg.mesh, output_path, cfg.retriangulate, cfg.bsdf)
155
 
@@ -292,17 +354,7 @@ def loop(cfg):
292
  )
293
  rot_ang += 5
294
  log_mesh = mesh.unit_size(render_mesh.eval(params))
295
- log_image = render.render_mesh(
296
- glctx,
297
- log_mesh,
298
- params['mvp'],
299
- params['campos'],
300
- params['lightpos'],
301
- cfg.log_light_power,
302
- cfg.log_res,
303
- 1,
304
- background=torch.ones(1, cfg.log_res, cfg.log_res, 3).to(device)
305
- )
306
 
307
  log_image = video.ready_image(log_image)
308
  logger.add_mesh('predicted_mesh', vertices=log_mesh.v_pos.unsqueeze(0), faces=log_mesh.t_pos_idx.unsqueeze(0), global_step=it)
@@ -322,39 +374,21 @@ def loop(cfg):
322
  params_camera[key] = params_camera[key].to(device)
323
 
324
  final_mesh = render_mesh.eval(params_camera)
325
- train_render = render.render_mesh(
326
- glctx,
327
- final_mesh,
328
- params_camera['mvp'],
329
- params_camera['campos'],
330
- params_camera['lightpos'],
331
- cfg.light_power,
332
- cfg.train_res,
333
- spp=1,
334
- num_layers=1,
335
- msaa=False,
336
- background=params_camera['bkgs']
337
- ).permute(0, 3, 1, 2)
338
  train_render = resize(train_render, out_shape=(224, 224), interp_method=resize_method)
339
 
340
  if use_target_mesh:
341
  final_target_mesh = render_target_mesh.eval(params_camera)
342
- train_target_render = render.render_mesh(
343
- glctx,
344
- final_target_mesh,
345
- params_camera['mvp'],
346
- params_camera['campos'],
347
- params_camera['lightpos'],
348
- cfg.light_power,
349
- cfg.train_res,
350
- spp=1,
351
- num_layers=1,
352
- msaa=False,
353
- background=params_camera['bkgs']
354
- ).permute(0, 3, 1, 2)
355
  train_target_render = resize(train_target_render, out_shape=(224, 224), interp_method=resize_method)
356
 
357
- train_rast_map = render.render_mesh(
358
  glctx,
359
  final_mesh,
360
  params_camera['mvp'],
@@ -362,10 +396,6 @@ def loop(cfg):
362
  params_camera['lightpos'],
363
  cfg.light_power,
364
  cfg.train_res,
365
- spp=1,
366
- num_layers=1,
367
- msaa=False,
368
- background=params_camera['bkgs'],
369
  return_rast_map=True
370
  )
371
 
@@ -373,19 +403,10 @@ def loop(cfg):
373
  params_camera = next(iter(cams))
374
  for key in params_camera:
375
  params_camera[key] = params_camera[key].to(device)
376
- base_render = render.render_mesh(
377
- glctx,
378
- base_mesh.eval(params_camera),
379
- params_camera['mvp'],
380
- params_camera['campos'],
381
- params_camera['lightpos'],
382
- cfg.light_power,
383
- cfg.train_res,
384
- spp=1,
385
- num_layers=1,
386
- msaa=False,
387
- background=params_camera['bkgs'],
388
- ).permute(0, 3, 1, 2)
389
  base_render = resize(base_render, out_shape=(224, 224), interp_method=resize_method)
390
 
391
  if it % cfg.log_interval_im == 0:
@@ -444,7 +465,7 @@ def loop(cfg):
444
  r_loss = (((gt_jacobians) - torch.eye(3, 3, device=device)) ** 2).mean()
445
  logger.add_scalar('jacobian_regularization', r_loss, global_step=it)
446
 
447
- if cfg.consistency_loss_weight != 0 and fe is not None:
448
  consistency_loss = compute_mv_cl(final_mesh, fe, normalized_clip_render, params_camera, train_rast_map, cfg, device)
449
  else:
450
  consistency_loss = r_loss
 
147
 
148
  # output video
149
  video = Video(cfg.output_path)
150
+ # GL Context - with fallback for headless environments
151
+ print('Initializing nvdiffrast GL context...')
152
+ try:
153
+ glctx = dr.RasterizeGLContext()
154
+ print('nvdiffrast GL context initialized successfully')
155
+ use_gl_rendering = True
156
+ except Exception as e:
157
+ print(f'Error initializing nvdiffrast GL context: {e}')
158
+ print('This is likely due to missing EGL headers in headless environment.')
159
+ print('Using fallback rendering approach...')
160
+ glctx = None
161
+ use_gl_rendering = False
162
+
163
+ def fallback_render_mesh(mesh, mvp, campos, lightpos, light_power, resolution, **kwargs):
164
+ """
165
+ Fallback rendering function when GL context is not available
166
+ Returns a simple colored mesh visualization
167
+ """
168
+ try:
169
+ # Check if return_rast_map is requested
170
+ return_rast_map = kwargs.get('return_rast_map', False)
171
+
172
+ # Create a simple colored mesh visualization
173
+ # This is a basic fallback that creates a colored mesh without proper lighting
174
+ device = mesh.v_pos.device if hasattr(mesh, 'v_pos') and mesh.v_pos is not None else torch.device('cuda')
175
+ batch_size = 1
176
+
177
+ if return_rast_map:
178
+ # Return a dummy rasterization map for consistency
179
+ rast_map = torch.zeros(batch_size, resolution, resolution, 4, device=device)
180
+ rast_map[..., 3] = 1.0 # Set alpha to 1
181
+ return rast_map
182
+ else:
183
+ # Create a simple colored output
184
+ color = torch.ones(batch_size, resolution, resolution, 3, device=device) * 0.5 # Gray color
185
+
186
+ # Add some basic shading based on vertex positions
187
+ if hasattr(mesh, 'v_pos') and mesh.v_pos is not None:
188
+ # Normalize vertex positions for coloring
189
+ v_pos_norm = (mesh.v_pos - mesh.v_pos.min(dim=0)[0]) / (mesh.v_pos.max(dim=0)[0] - mesh.v_pos.min(dim=0)[0] + 1e-8)
190
+ # Use vertex positions to create a simple color gradient
191
+ color = color * 0.3 + v_pos_norm.mean(dim=0).unsqueeze(0).unsqueeze(0).unsqueeze(0) * 0.7
192
+
193
+ return color
194
+ except Exception as e:
195
+ print(f"Fallback rendering failed: {e}")
196
+ # Return a simple colored square as last resort
197
+ device = mesh.v_pos.device if hasattr(mesh, 'v_pos') and mesh.v_pos is not None else torch.device('cuda')
198
+ if kwargs.get('return_rast_map', False):
199
+ return torch.zeros(1, resolution, resolution, 4, device=device)
200
+ else:
201
+ return torch.ones(1, resolution, resolution, 3, device=device) * 0.5
202
+
203
+ def safe_render_mesh(glctx, mesh, mvp, campos, lightpos, light_power, resolution, **kwargs):
204
+ """
205
+ Safe rendering function that uses GL context if available, otherwise falls back
206
+ """
207
+ if glctx is not None and use_gl_rendering:
208
+ try:
209
+ return render.render_mesh(glctx, mesh, mvp, campos, lightpos, light_power, resolution, **kwargs)
210
+ except Exception as e:
211
+ print(f"GL rendering failed, using fallback: {e}")
212
+ return fallback_render_mesh(mesh, mvp, campos, lightpos, light_power, resolution, **kwargs)
213
+ else:
214
+ return fallback_render_mesh(mesh, mvp, campos, lightpos, light_power, resolution, **kwargs)
215
 
216
  load_mesh = get_mesh(cfg.mesh, output_path, cfg.retriangulate, cfg.bsdf)
217
 
 
354
  )
355
  rot_ang += 5
356
  log_mesh = mesh.unit_size(render_mesh.eval(params))
357
+ log_image = safe_render_mesh(glctx, log_mesh, params['mvp'], params['campos'], params['lightpos'], cfg.log_light_power, cfg.log_res)
 
 
 
 
 
 
 
 
 
 
358
 
359
  log_image = video.ready_image(log_image)
360
  logger.add_mesh('predicted_mesh', vertices=log_mesh.v_pos.unsqueeze(0), faces=log_mesh.t_pos_idx.unsqueeze(0), global_step=it)
 
374
  params_camera[key] = params_camera[key].to(device)
375
 
376
  final_mesh = render_mesh.eval(params_camera)
377
+ train_render = safe_render_mesh(glctx, final_mesh, params_camera['mvp'], params_camera['campos'], params_camera['lightpos'], cfg.light_power, cfg.train_res)
378
+ # Handle permutation for fallback case
379
+ if train_render.shape[-1] == 3: # If it's already in the right format
380
+ train_render = train_render.permute(0, 3, 1, 2)
 
 
 
 
 
 
 
 
 
381
  train_render = resize(train_render, out_shape=(224, 224), interp_method=resize_method)
382
 
383
  if use_target_mesh:
384
  final_target_mesh = render_target_mesh.eval(params_camera)
385
+ train_target_render = safe_render_mesh(glctx, final_target_mesh, params_camera['mvp'], params_camera['campos'], params_camera['lightpos'], cfg.light_power, cfg.train_res)
386
+ # Handle permutation for fallback case
387
+ if train_target_render.shape[-1] == 3: # If it's already in the right format
388
+ train_target_render = train_target_render.permute(0, 3, 1, 2)
 
 
 
 
 
 
 
 
 
389
  train_target_render = resize(train_target_render, out_shape=(224, 224), interp_method=resize_method)
390
 
391
+ train_rast_map = safe_render_mesh(
392
  glctx,
393
  final_mesh,
394
  params_camera['mvp'],
 
396
  params_camera['lightpos'],
397
  cfg.light_power,
398
  cfg.train_res,
 
 
 
 
399
  return_rast_map=True
400
  )
401
 
 
403
  params_camera = next(iter(cams))
404
  for key in params_camera:
405
  params_camera[key] = params_camera[key].to(device)
406
+ base_render = safe_render_mesh(glctx, base_mesh.eval(params_camera), params_camera['mvp'], params_camera['campos'], params_camera['lightpos'], cfg.light_power, cfg.train_res)
407
+ # Handle permutation for fallback case
408
+ if base_render.shape[-1] == 3: # If it's already in the right format
409
+ base_render = base_render.permute(0, 3, 1, 2)
 
 
 
 
 
 
 
 
 
410
  base_render = resize(base_render, out_shape=(224, 224), interp_method=resize_method)
411
 
412
  if it % cfg.log_interval_im == 0:
 
465
  r_loss = (((gt_jacobians) - torch.eye(3, 3, device=device)) ** 2).mean()
466
  logger.add_scalar('jacobian_regularization', r_loss, global_step=it)
467
 
468
+ if cfg.consistency_loss_weight != 0 and fe is not None and train_rast_map is not None:
469
  consistency_loss = compute_mv_cl(final_mesh, fe, normalized_clip_render, params_camera, train_rast_map, cfg, device)
470
  else:
471
  consistency_loss = r_loss