JiantaoLin commited on
Commit
8ab9738
·
1 Parent(s): b5de118
models/lrm/models/geometry/render/neural_render.py CHANGED
@@ -7,6 +7,7 @@
7
  # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
8
 
9
  import torch
 
10
  import torch.nn.functional as F
11
  import nvdiffrast.torch as dr
12
  from . import Renderer
@@ -66,7 +67,7 @@ def compute_vertex_normal(v_pos, t_pos_idx):
66
 
67
  return v_nrm
68
 
69
-
70
  class NeuralRender(Renderer):
71
  def __init__(self, device='cuda', camera_model=None):
72
  super(NeuralRender, self).__init__()
@@ -146,6 +147,7 @@ class NeuralRender(Renderer):
146
  # - Single light
147
  # - Single material
148
  # ==============================================================================================
 
149
  def render_layer(
150
  self,
151
  rast,
@@ -189,6 +191,7 @@ class NeuralRender(Renderer):
189
 
190
  return gb_pos, gb_normal
191
 
 
192
  def render_mesh(
193
  self,
194
  mesh_v_pos_bxnx3,
@@ -245,6 +248,7 @@ class NeuralRender(Renderer):
245
  # normal = torch.lerp(torch.zeros_like(normal), (normal + 1.0) / 2.0, hard_mask.float()) # black background
246
  return ori_mesh_feature, antialias_mask, hard_mask, rast, v_pos_clip, mask_pyramid, depth, normal, gb_normal
247
 
 
248
  def render_mesh_light(
249
  self,
250
  mesh_v_pos_bxnx3,
 
7
  # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
8
 
9
  import torch
10
+ import spaces
11
  import torch.nn.functional as F
12
  import nvdiffrast.torch as dr
13
  from . import Renderer
 
67
 
68
  return v_nrm
69
 
70
+ @spaces.GPU
71
  class NeuralRender(Renderer):
72
  def __init__(self, device='cuda', camera_model=None):
73
  super(NeuralRender, self).__init__()
 
147
  # - Single light
148
  # - Single material
149
  # ==============================================================================================
150
+ @spaces.GPU
151
  def render_layer(
152
  self,
153
  rast,
 
191
 
192
  return gb_pos, gb_normal
193
 
194
+ @spaces.GPU
195
  def render_mesh(
196
  self,
197
  mesh_v_pos_bxnx3,
 
248
  # normal = torch.lerp(torch.zeros_like(normal), (normal + 1.0) / 2.0, hard_mask.float()) # black background
249
  return ori_mesh_feature, antialias_mask, hard_mask, rast, v_pos_clip, mask_pyramid, depth, normal, gb_normal
250
 
251
+ @spaces.GPU
252
  def render_mesh_light(
253
  self,
254
  mesh_v_pos_bxnx3,