JiantaoLin commited on
Commit
235efa3
·
1 Parent(s): 0067ad0
Files changed (2) hide show
  1. shader.py +5 -2
  2. video_render.py +6 -26
shader.py CHANGED
@@ -3,6 +3,8 @@ from pytorch3d.renderer.mesh.shader import ShaderBase
3
  from pytorch3d.renderer import (
4
  SoftPhongShader,
5
  )
 
 
6
 
7
  class MultiOutputShader(ShaderBase):
8
  def __init__(self, device, cameras, lights, materials, ccm_scale=1.0, choices=None):
@@ -17,12 +19,13 @@ class MultiOutputShader(ShaderBase):
17
  self.choices = ["rgb", "mask", "depth", "normal", "albedo", "ccm"]
18
  else:
19
  self.choices = choices
20
-
21
  self.phong_shader = SoftPhongShader(
22
  device=self.device,
23
  cameras=self.cameras,
24
  lights=self.lights,
25
- materials=self.materials
 
26
  )
27
 
28
  def forward(self, fragments, meshes, **kwargs):
 
3
  from pytorch3d.renderer import (
4
  SoftPhongShader,
5
  )
6
+ from pytorch3d.renderer import BlendParams
7
+
8
 
9
  class MultiOutputShader(ShaderBase):
10
  def __init__(self, device, cameras, lights, materials, ccm_scale=1.0, choices=None):
 
19
  self.choices = ["rgb", "mask", "depth", "normal", "albedo", "ccm"]
20
  else:
21
  self.choices = choices
22
+ blend_params = BlendParams(sigma=1e-4, gamma=1e-4)
23
  self.phong_shader = SoftPhongShader(
24
  device=self.device,
25
  cameras=self.cameras,
26
  lights=self.lights,
27
+ materials=self.materials,
28
+ blend_params=blend_params
29
  )
30
 
31
  def forward(self, fragments, meshes, **kwargs):
video_render.py CHANGED
@@ -26,57 +26,44 @@ def render_video_from_obj(input_obj_path, output_video_path, num_frames=60, imag
26
  if not os.path.exists(input_obj_path):
27
  raise FileNotFoundError(f"Input OBJ file not found: {input_obj_path}")
28
 
29
- # 加载3D模型
30
  scene_data = trimesh.load(input_obj_path)
31
 
32
- # 提取或合并网格
33
  if isinstance(scene_data, trimesh.Scene):
34
  mesh_data = trimesh.util.concatenate([geom for geom in scene_data.geometry.values()])
35
  else:
36
  mesh_data = scene_data
37
 
38
- # 确保顶点法线存在
39
  if not hasattr(mesh_data, 'vertex_normals') or mesh_data.vertex_normals is None:
40
  mesh_data.compute_vertex_normals()
41
 
42
- # 获取顶点坐标、法线和面
43
  vertices = torch.tensor(mesh_data.vertices, dtype=torch.float32, device=device)
44
  faces = torch.tensor(mesh_data.faces, dtype=torch.int64, device=device)
45
- vertex_normals = torch.tensor(mesh_data.vertex_normals, dtype=torch.float32)
46
 
47
- # 获取顶点颜色
48
  if mesh_data.visual.vertex_colors is None:
49
- # 如果没有顶点颜色,可以给定一个默认值(例如,白色)
50
  vertex_colors = torch.ones_like(vertices)[None]
51
  else:
52
  vertex_colors = torch.tensor(mesh_data.visual.vertex_colors[:, :3], dtype=torch.float32)[None]
53
- # 创建纹理并分配顶点颜色
54
  textures = TexturesVertex(verts_features=vertex_colors)
55
  textures.to(device)
56
- # 创建Mesh对象
57
  mesh = pytorch3d.structures.Meshes(verts=[vertices], faces=[faces], textures=textures)
58
 
59
- # 设置渲染器
60
  lights = AmbientLights(ambient_color=((2.0,)*3,), device=device)
61
  # lights = PointLights(device=device, location=[[0.0, 0.0, 3.0]], ambient_color=[[0.5, 0.5, 0.5]], diffuse_color=[[1.0, 1.0, 1.0]])
62
  raster_settings = RasterizationSettings(
63
- image_size=image_size, # 渲染图像的尺寸
64
- blur_radius=0.0, # 默认无模糊
65
- faces_per_pixel=1, # 每像素渲染一个面
66
- # background_color=(1.0, 1.0, 1.0)
67
  )
68
 
69
- # 设置旋转和渲染参数
70
  frames = []
71
  camera_distance = 6.5
72
  elevs = 0.0
73
  center = (0.0, 0.0, 0.0)
74
- # 渲染每一帧
75
  materials = Materials(
76
  device=device,
77
- diffuse_color=((0.0, 0.0, 0.0),),
78
  ambient_color=((1.0, 1.0, 1.0),),
79
- specular_color=((0.0, 0.0, 0.0),),
80
  shininess=0.0,
81
  )
82
 
@@ -91,7 +78,6 @@ def render_video_from_obj(input_obj_path, output_video_path, num_frames=60, imag
91
  degrees=True
92
  )
93
 
94
-
95
  # 手动设置相机的旋转矩阵
96
  cameras = PerspectiveCameras(device=device, R=R, T=T, focal_length=5.0)
97
  cameras.znear = 0.0001
@@ -105,32 +91,26 @@ def render_video_from_obj(input_obj_path, output_video_path, num_frames=60, imag
105
  )
106
 
107
  renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
108
- # 渲染RGB图像和Normal图像
109
  render_result = renderer(mesh, cameras=cameras)
110
  rgb_image = render_result["rgb"] * render_result["mask"] + (1 - render_result["mask"]) * torch.ones_like(render_result["rgb"]) * 255.0
111
  normal_map = render_result["normal"]
112
 
113
- # 提取RGB和Normal map
114
- rgb = rgb_image[0, ..., :3].cpu().numpy() # RGB图像
115
  normal_map = torch.nn.functional.normalize(normal_map, dim=-1) # Normal map
116
  normal_map = (normal_map + 1) / 2
117
  normal_map = normal_map * render_result["mask"] + (1 - render_result["mask"]) * torch.ones_like(render_result["normal"])
118
  normal = normal_map[0, ..., :3].cpu().numpy() # Normal map
119
  rgb = np.clip(rgb, 0, 255).astype(np.uint8)
120
  normal = np.clip(normal*255, 0, 255).astype(np.uint8)
121
- # 将RGB和Normal map合并为一张图,左边RGB,右边Normal map
122
  combined_image = np.concatenate((rgb, normal), axis=1)
123
 
124
- # 将合并后的图像加入到帧列表
125
  frames.append(combined_image)
126
 
127
- # 使用imageio保存视频
128
  imageio.mimsave(output_video_path, frames, fps=fps)
129
 
130
  print(f"Video saved to {output_video_path}")
131
 
132
  if __name__ == '__main__':
133
- # 示例调用
134
  input_obj_path = "/hpc2hdd/home/jlin695/code/github/Kiss3DGen/outputs/a_owl_wearing_a_hat/ISOMER/rgb_projected.obj"
135
  output_video_path = "output.mp4"
136
  render_video_from_obj(input_obj_path, output_video_path)
 
26
  if not os.path.exists(input_obj_path):
27
  raise FileNotFoundError(f"Input OBJ file not found: {input_obj_path}")
28
 
 
29
  scene_data = trimesh.load(input_obj_path)
30
 
 
31
  if isinstance(scene_data, trimesh.Scene):
32
  mesh_data = trimesh.util.concatenate([geom for geom in scene_data.geometry.values()])
33
  else:
34
  mesh_data = scene_data
35
 
 
36
  if not hasattr(mesh_data, 'vertex_normals') or mesh_data.vertex_normals is None:
37
  mesh_data.compute_vertex_normals()
38
 
 
39
  vertices = torch.tensor(mesh_data.vertices, dtype=torch.float32, device=device)
40
  faces = torch.tensor(mesh_data.faces, dtype=torch.int64, device=device)
 
41
 
 
42
  if mesh_data.visual.vertex_colors is None:
 
43
  vertex_colors = torch.ones_like(vertices)[None]
44
  else:
45
  vertex_colors = torch.tensor(mesh_data.visual.vertex_colors[:, :3], dtype=torch.float32)[None]
 
46
  textures = TexturesVertex(verts_features=vertex_colors)
47
  textures.to(device)
 
48
  mesh = pytorch3d.structures.Meshes(verts=[vertices], faces=[faces], textures=textures)
49
 
 
50
  lights = AmbientLights(ambient_color=((2.0,)*3,), device=device)
51
  # lights = PointLights(device=device, location=[[0.0, 0.0, 3.0]], ambient_color=[[0.5, 0.5, 0.5]], diffuse_color=[[1.0, 1.0, 1.0]])
52
  raster_settings = RasterizationSettings(
53
+ image_size=image_size,
54
+ blur_radius=0.0,
55
+ faces_per_pixel=1,
 
56
  )
57
 
 
58
  frames = []
59
  camera_distance = 6.5
60
  elevs = 0.0
61
  center = (0.0, 0.0, 0.0)
 
62
  materials = Materials(
63
  device=device,
64
+ diffuse_color=((1.0, 1.0, 1.0),),
65
  ambient_color=((1.0, 1.0, 1.0),),
66
+ specular_color=((1.0, 1.0, 1.0),),
67
  shininess=0.0,
68
  )
69
 
 
78
  degrees=True
79
  )
80
 
 
81
  # 手动设置相机的旋转矩阵
82
  cameras = PerspectiveCameras(device=device, R=R, T=T, focal_length=5.0)
83
  cameras.znear = 0.0001
 
91
  )
92
 
93
  renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
 
94
  render_result = renderer(mesh, cameras=cameras)
95
  rgb_image = render_result["rgb"] * render_result["mask"] + (1 - render_result["mask"]) * torch.ones_like(render_result["rgb"]) * 255.0
96
  normal_map = render_result["normal"]
97
 
98
+ rgb = rgb_image[0, ..., :3].cpu().numpy()
 
99
  normal_map = torch.nn.functional.normalize(normal_map, dim=-1) # Normal map
100
  normal_map = (normal_map + 1) / 2
101
  normal_map = normal_map * render_result["mask"] + (1 - render_result["mask"]) * torch.ones_like(render_result["normal"])
102
  normal = normal_map[0, ..., :3].cpu().numpy() # Normal map
103
  rgb = np.clip(rgb, 0, 255).astype(np.uint8)
104
  normal = np.clip(normal*255, 0, 255).astype(np.uint8)
 
105
  combined_image = np.concatenate((rgb, normal), axis=1)
106
 
 
107
  frames.append(combined_image)
108
 
 
109
  imageio.mimsave(output_video_path, frames, fps=fps)
110
 
111
  print(f"Video saved to {output_video_path}")
112
 
113
  if __name__ == '__main__':
 
114
  input_obj_path = "/hpc2hdd/home/jlin695/code/github/Kiss3DGen/outputs/a_owl_wearing_a_hat/ISOMER/rgb_projected.obj"
115
  output_video_path = "output.mp4"
116
  render_video_from_obj(input_obj_path, output_video_path)