JiantaoLin
commited on
Commit
·
235efa3
1
Parent(s):
0067ad0
new
Browse files- shader.py +5 -2
- video_render.py +6 -26
shader.py
CHANGED
@@ -3,6 +3,8 @@ from pytorch3d.renderer.mesh.shader import ShaderBase
|
|
3 |
from pytorch3d.renderer import (
|
4 |
SoftPhongShader,
|
5 |
)
|
|
|
|
|
6 |
|
7 |
class MultiOutputShader(ShaderBase):
|
8 |
def __init__(self, device, cameras, lights, materials, ccm_scale=1.0, choices=None):
|
@@ -17,12 +19,13 @@ class MultiOutputShader(ShaderBase):
|
|
17 |
self.choices = ["rgb", "mask", "depth", "normal", "albedo", "ccm"]
|
18 |
else:
|
19 |
self.choices = choices
|
20 |
-
|
21 |
self.phong_shader = SoftPhongShader(
|
22 |
device=self.device,
|
23 |
cameras=self.cameras,
|
24 |
lights=self.lights,
|
25 |
-
materials=self.materials
|
|
|
26 |
)
|
27 |
|
28 |
def forward(self, fragments, meshes, **kwargs):
|
|
|
3 |
from pytorch3d.renderer import (
|
4 |
SoftPhongShader,
|
5 |
)
|
6 |
+
from pytorch3d.renderer import BlendParams
|
7 |
+
|
8 |
|
9 |
class MultiOutputShader(ShaderBase):
|
10 |
def __init__(self, device, cameras, lights, materials, ccm_scale=1.0, choices=None):
|
|
|
19 |
self.choices = ["rgb", "mask", "depth", "normal", "albedo", "ccm"]
|
20 |
else:
|
21 |
self.choices = choices
|
22 |
+
blend_params = BlendParams(sigma=1e-4, gamma=1e-4)
|
23 |
self.phong_shader = SoftPhongShader(
|
24 |
device=self.device,
|
25 |
cameras=self.cameras,
|
26 |
lights=self.lights,
|
27 |
+
materials=self.materials,
|
28 |
+
blend_params=blend_params
|
29 |
)
|
30 |
|
31 |
def forward(self, fragments, meshes, **kwargs):
|
video_render.py
CHANGED
@@ -26,57 +26,44 @@ def render_video_from_obj(input_obj_path, output_video_path, num_frames=60, imag
|
|
26 |
if not os.path.exists(input_obj_path):
|
27 |
raise FileNotFoundError(f"Input OBJ file not found: {input_obj_path}")
|
28 |
|
29 |
-
# 加载3D模型
|
30 |
scene_data = trimesh.load(input_obj_path)
|
31 |
|
32 |
-
# 提取或合并网格
|
33 |
if isinstance(scene_data, trimesh.Scene):
|
34 |
mesh_data = trimesh.util.concatenate([geom for geom in scene_data.geometry.values()])
|
35 |
else:
|
36 |
mesh_data = scene_data
|
37 |
|
38 |
-
# 确保顶点法线存在
|
39 |
if not hasattr(mesh_data, 'vertex_normals') or mesh_data.vertex_normals is None:
|
40 |
mesh_data.compute_vertex_normals()
|
41 |
|
42 |
-
# 获取顶点坐标、法线和面
|
43 |
vertices = torch.tensor(mesh_data.vertices, dtype=torch.float32, device=device)
|
44 |
faces = torch.tensor(mesh_data.faces, dtype=torch.int64, device=device)
|
45 |
-
vertex_normals = torch.tensor(mesh_data.vertex_normals, dtype=torch.float32)
|
46 |
|
47 |
-
# 获取顶点颜色
|
48 |
if mesh_data.visual.vertex_colors is None:
|
49 |
-
# 如果没有顶点颜色,可以给定一个默认值(例如,白色)
|
50 |
vertex_colors = torch.ones_like(vertices)[None]
|
51 |
else:
|
52 |
vertex_colors = torch.tensor(mesh_data.visual.vertex_colors[:, :3], dtype=torch.float32)[None]
|
53 |
-
# 创建纹理并分配顶点颜色
|
54 |
textures = TexturesVertex(verts_features=vertex_colors)
|
55 |
textures.to(device)
|
56 |
-
# 创建Mesh对象
|
57 |
mesh = pytorch3d.structures.Meshes(verts=[vertices], faces=[faces], textures=textures)
|
58 |
|
59 |
-
# 设置渲染器
|
60 |
lights = AmbientLights(ambient_color=((2.0,)*3,), device=device)
|
61 |
# lights = PointLights(device=device, location=[[0.0, 0.0, 3.0]], ambient_color=[[0.5, 0.5, 0.5]], diffuse_color=[[1.0, 1.0, 1.0]])
|
62 |
raster_settings = RasterizationSettings(
|
63 |
-
image_size=image_size,
|
64 |
-
blur_radius=0.0,
|
65 |
-
faces_per_pixel=1,
|
66 |
-
# background_color=(1.0, 1.0, 1.0)
|
67 |
)
|
68 |
|
69 |
-
# 设置旋转和渲染参数
|
70 |
frames = []
|
71 |
camera_distance = 6.5
|
72 |
elevs = 0.0
|
73 |
center = (0.0, 0.0, 0.0)
|
74 |
-
# 渲染每一帧
|
75 |
materials = Materials(
|
76 |
device=device,
|
77 |
-
diffuse_color=((
|
78 |
ambient_color=((1.0, 1.0, 1.0),),
|
79 |
-
specular_color=((
|
80 |
shininess=0.0,
|
81 |
)
|
82 |
|
@@ -91,7 +78,6 @@ def render_video_from_obj(input_obj_path, output_video_path, num_frames=60, imag
|
|
91 |
degrees=True
|
92 |
)
|
93 |
|
94 |
-
|
95 |
# 手动设置相机的旋转矩阵
|
96 |
cameras = PerspectiveCameras(device=device, R=R, T=T, focal_length=5.0)
|
97 |
cameras.znear = 0.0001
|
@@ -105,32 +91,26 @@ def render_video_from_obj(input_obj_path, output_video_path, num_frames=60, imag
|
|
105 |
)
|
106 |
|
107 |
renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
|
108 |
-
# 渲染RGB图像和Normal图像
|
109 |
render_result = renderer(mesh, cameras=cameras)
|
110 |
rgb_image = render_result["rgb"] * render_result["mask"] + (1 - render_result["mask"]) * torch.ones_like(render_result["rgb"]) * 255.0
|
111 |
normal_map = render_result["normal"]
|
112 |
|
113 |
-
|
114 |
-
rgb = rgb_image[0, ..., :3].cpu().numpy() # RGB图像
|
115 |
normal_map = torch.nn.functional.normalize(normal_map, dim=-1) # Normal map
|
116 |
normal_map = (normal_map + 1) / 2
|
117 |
normal_map = normal_map * render_result["mask"] + (1 - render_result["mask"]) * torch.ones_like(render_result["normal"])
|
118 |
normal = normal_map[0, ..., :3].cpu().numpy() # Normal map
|
119 |
rgb = np.clip(rgb, 0, 255).astype(np.uint8)
|
120 |
normal = np.clip(normal*255, 0, 255).astype(np.uint8)
|
121 |
-
# 将RGB和Normal map合并为一张图,左边RGB,右边Normal map
|
122 |
combined_image = np.concatenate((rgb, normal), axis=1)
|
123 |
|
124 |
-
# 将合并后的图像加入到帧列表
|
125 |
frames.append(combined_image)
|
126 |
|
127 |
-
# 使用imageio保存视频
|
128 |
imageio.mimsave(output_video_path, frames, fps=fps)
|
129 |
|
130 |
print(f"Video saved to {output_video_path}")
|
131 |
|
132 |
if __name__ == '__main__':
|
133 |
-
# 示例调用
|
134 |
input_obj_path = "/hpc2hdd/home/jlin695/code/github/Kiss3DGen/outputs/a_owl_wearing_a_hat/ISOMER/rgb_projected.obj"
|
135 |
output_video_path = "output.mp4"
|
136 |
render_video_from_obj(input_obj_path, output_video_path)
|
|
|
26 |
if not os.path.exists(input_obj_path):
|
27 |
raise FileNotFoundError(f"Input OBJ file not found: {input_obj_path}")
|
28 |
|
|
|
29 |
scene_data = trimesh.load(input_obj_path)
|
30 |
|
|
|
31 |
if isinstance(scene_data, trimesh.Scene):
|
32 |
mesh_data = trimesh.util.concatenate([geom for geom in scene_data.geometry.values()])
|
33 |
else:
|
34 |
mesh_data = scene_data
|
35 |
|
|
|
36 |
if not hasattr(mesh_data, 'vertex_normals') or mesh_data.vertex_normals is None:
|
37 |
mesh_data.compute_vertex_normals()
|
38 |
|
|
|
39 |
vertices = torch.tensor(mesh_data.vertices, dtype=torch.float32, device=device)
|
40 |
faces = torch.tensor(mesh_data.faces, dtype=torch.int64, device=device)
|
|
|
41 |
|
|
|
42 |
if mesh_data.visual.vertex_colors is None:
|
|
|
43 |
vertex_colors = torch.ones_like(vertices)[None]
|
44 |
else:
|
45 |
vertex_colors = torch.tensor(mesh_data.visual.vertex_colors[:, :3], dtype=torch.float32)[None]
|
|
|
46 |
textures = TexturesVertex(verts_features=vertex_colors)
|
47 |
textures.to(device)
|
|
|
48 |
mesh = pytorch3d.structures.Meshes(verts=[vertices], faces=[faces], textures=textures)
|
49 |
|
|
|
50 |
lights = AmbientLights(ambient_color=((2.0,)*3,), device=device)
|
51 |
# lights = PointLights(device=device, location=[[0.0, 0.0, 3.0]], ambient_color=[[0.5, 0.5, 0.5]], diffuse_color=[[1.0, 1.0, 1.0]])
|
52 |
raster_settings = RasterizationSettings(
|
53 |
+
image_size=image_size,
|
54 |
+
blur_radius=0.0,
|
55 |
+
faces_per_pixel=1,
|
|
|
56 |
)
|
57 |
|
|
|
58 |
frames = []
|
59 |
camera_distance = 6.5
|
60 |
elevs = 0.0
|
61 |
center = (0.0, 0.0, 0.0)
|
|
|
62 |
materials = Materials(
|
63 |
device=device,
|
64 |
+
diffuse_color=((1.0, 1.0, 1.0),),
|
65 |
ambient_color=((1.0, 1.0, 1.0),),
|
66 |
+
specular_color=((1.0, 1.0, 1.0),),
|
67 |
shininess=0.0,
|
68 |
)
|
69 |
|
|
|
78 |
degrees=True
|
79 |
)
|
80 |
|
|
|
81 |
# 手动设置相机的旋转矩阵
|
82 |
cameras = PerspectiveCameras(device=device, R=R, T=T, focal_length=5.0)
|
83 |
cameras.znear = 0.0001
|
|
|
91 |
)
|
92 |
|
93 |
renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
|
|
|
94 |
render_result = renderer(mesh, cameras=cameras)
|
95 |
rgb_image = render_result["rgb"] * render_result["mask"] + (1 - render_result["mask"]) * torch.ones_like(render_result["rgb"]) * 255.0
|
96 |
normal_map = render_result["normal"]
|
97 |
|
98 |
+
rgb = rgb_image[0, ..., :3].cpu().numpy()
|
|
|
99 |
normal_map = torch.nn.functional.normalize(normal_map, dim=-1) # Normal map
|
100 |
normal_map = (normal_map + 1) / 2
|
101 |
normal_map = normal_map * render_result["mask"] + (1 - render_result["mask"]) * torch.ones_like(render_result["normal"])
|
102 |
normal = normal_map[0, ..., :3].cpu().numpy() # Normal map
|
103 |
rgb = np.clip(rgb, 0, 255).astype(np.uint8)
|
104 |
normal = np.clip(normal*255, 0, 255).astype(np.uint8)
|
|
|
105 |
combined_image = np.concatenate((rgb, normal), axis=1)
|
106 |
|
|
|
107 |
frames.append(combined_image)
|
108 |
|
|
|
109 |
imageio.mimsave(output_video_path, frames, fps=fps)
|
110 |
|
111 |
print(f"Video saved to {output_video_path}")
|
112 |
|
113 |
if __name__ == '__main__':
|
|
|
114 |
input_obj_path = "/hpc2hdd/home/jlin695/code/github/Kiss3DGen/outputs/a_owl_wearing_a_hat/ISOMER/rgb_projected.obj"
|
115 |
output_video_path = "output.mp4"
|
116 |
render_video_from_obj(input_obj_path, output_video_path)
|