# Open Source Model Licensed under the Apache License Version 2.0 # and Other Licenses of the Third-Party Components therein: # The below Model in this distribution may have been modified by THL A29 Limited # ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited. # Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. # The below software and/or models in this distribution may have been # modified by THL A29 Limited ("Tencent Modifications"). # All Tencent Modifications are Copyright (C) THL A29 Limited. # Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT # except for the third-party components listed below. # Hunyuan 3D does not impose any additional limitations beyond what is outlined # in the repsective licenses of these third-party components. # Users must comply with all terms and conditions of original licenses of these third-party # components and must ensure that the usage of the third party components adheres to # all relevant laws and regulations. # For avoidance of doubts, Hunyuan 3D means the large language models and # their software and algorithms, including trained model weights, parameters (including # optimizer states), machine-learning model code, inference-enabling code, training-enabling code, # fine-tuning enabling code and other elements of the foregoing made publicly available # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT. import cv2 import numpy as np import torch import torch.nn.functional as F import trimesh from PIL import Image from .camera_utils import ( transform_pos, get_mv_matrix, get_orthographic_projection_matrix, get_perspective_projection_matrix, ) from .mesh_processor import meshVerticeInpaint from .mesh_utils import load_mesh, save_mesh def stride_from_shape(shape): stride = [1] for x in reversed(shape[1:]): stride.append(stride[-1] * x) return list(reversed(stride)) def scatter_add_nd_with_count(input, count, indices, values, weights=None): # input: [..., C], D dimension + C channel # count: [..., 1], D dimension # indices: [N, D], long # values: [N, C] D = indices.shape[-1] C = input.shape[-1] size = input.shape[:-1] stride = stride_from_shape(size) assert len(size) == D input = input.view(-1, C) # [HW, C] count = count.view(-1, 1) flatten_indices = (indices * torch.tensor(stride, dtype=torch.long, device=indices.device)).sum(-1) # [N] if weights is None: weights = torch.ones_like(values[..., :1]) input.scatter_add_(0, flatten_indices.unsqueeze(1).repeat(1, C), values) count.scatter_add_(0, flatten_indices.unsqueeze(1), weights) return input.view(*size, C), count.view(*size, 1) def linear_grid_put_2d(H, W, coords, values, return_count=False): # coords: [N, 2], float in [0, 1] # values: [N, C] C = values.shape[-1] indices = coords * torch.tensor( [H - 1, W - 1], dtype=torch.float32, device=coords.device ) indices_00 = indices.floor().long() # [N, 2] indices_00[:, 0].clamp_(0, H - 2) indices_00[:, 1].clamp_(0, W - 2) indices_01 = indices_00 + torch.tensor( [0, 1], dtype=torch.long, device=indices.device ) indices_10 = indices_00 + torch.tensor( [1, 0], dtype=torch.long, device=indices.device ) indices_11 = indices_00 + torch.tensor( [1, 1], dtype=torch.long, device=indices.device ) h = indices[..., 0] - indices_00[..., 0].float() w = indices[..., 1] - indices_00[..., 1].float() w_00 = (1 - h) * (1 - w) w_01 = (1 - h) * w w_10 = h * (1 - w) w_11 = h * w result = torch.zeros(H, W, C, device=values.device, dtype=values.dtype) # [H, W, C] count = torch.zeros(H, W, 1, device=values.device, dtype=values.dtype) # [H, W, 1] weights = torch.ones_like(values[..., :1]) # [N, 1] result, count = scatter_add_nd_with_count( result, count, indices_00, values * w_00.unsqueeze(1), weights * w_00.unsqueeze(1)) result, count = scatter_add_nd_with_count( result, count, indices_01, values * w_01.unsqueeze(1), weights * w_01.unsqueeze(1)) result, count = scatter_add_nd_with_count( result, count, indices_10, values * w_10.unsqueeze(1), weights * w_10.unsqueeze(1)) result, count = scatter_add_nd_with_count( result, count, indices_11, values * w_11.unsqueeze(1), weights * w_11.unsqueeze(1)) if return_count: return result, count mask = (count.squeeze(-1) > 0) result[mask] = result[mask] / count[mask].repeat(1, C) return result class MeshRender(): def __init__( self, camera_distance=1.45, camera_type='orth', default_resolution=1024, texture_size=1024, use_antialias=True, max_mip_level=None, filter_mode='linear', bake_mode='linear', raster_mode='cr', device='cuda'): self.device = device self.set_default_render_resolution(default_resolution) self.set_default_texture_resolution(texture_size) self.camera_distance = camera_distance self.use_antialias = use_antialias self.max_mip_level = max_mip_level self.filter_mode = filter_mode self.bake_angle_thres = 75 self.bake_unreliable_kernel_size = int( (2 / 512) * max(self.default_resolution[0], self.default_resolution[1])) self.bake_mode = bake_mode self.raster_mode = raster_mode if self.raster_mode == 'cr': import custom_rasterizer as cr self.raster = cr else: raise f'No raster named {self.raster_mode}' if camera_type == 'orth': self.ortho_scale = 1.2 self.camera_proj_mat = get_orthographic_projection_matrix( left=-self.ortho_scale * 0.5, right=self.ortho_scale * 0.5, bottom=-self.ortho_scale * 0.5, top=self.ortho_scale * 0.5, near=0.1, far=100 ) elif camera_type == 'perspective': self.camera_proj_mat = get_perspective_projection_matrix( 49.13, self.default_resolution[1] / self.default_resolution[0], 0.01, 100.0 ) else: raise f'No camera type {camera_type}' def raster_rasterize(self, pos, tri, resolution, ranges=None, grad_db=True): if self.raster_mode == 'cr': rast_out_db = None if pos.dim() == 2: pos = pos.unsqueeze(0) findices, barycentric = self.raster.rasterize(pos, tri, resolution) rast_out = torch.cat((barycentric, findices.unsqueeze(-1)), dim=-1) rast_out = rast_out.unsqueeze(0) else: raise f'No raster named {self.raster_mode}' return rast_out, rast_out_db def raster_interpolate(self, uv, rast_out, uv_idx, rast_db=None, diff_attrs=None): if self.raster_mode == 'cr': textd = None barycentric = rast_out[0, ..., :-1] findices = rast_out[0, ..., -1] if uv.dim() == 2: uv = uv.unsqueeze(0) textc = self.raster.interpolate(uv, findices, barycentric, uv_idx) else: raise f'No raster named {self.raster_mode}' return textc, textd def raster_texture(self, tex, uv, uv_da=None, mip_level_bias=None, mip=None, filter_mode='auto', boundary_mode='wrap', max_mip_level=None): if self.raster_mode == 'cr': raise f'Texture is not implemented in cr' else: raise f'No raster named {self.raster_mode}' return color def raster_antialias(self, color, rast, pos, tri, topology_hash=None, pos_gradient_boost=1.0): if self.raster_mode == 'cr': # Antialias has not been supported yet color = color else: raise f'No raster named {self.raster_mode}' return color def load_mesh( self, mesh, scale_factor=1.15, auto_center=True, ): vtx_pos, pos_idx, vtx_uv, uv_idx, texture_data = load_mesh(mesh) self.mesh_copy = mesh self.set_mesh(vtx_pos, pos_idx, vtx_uv=vtx_uv, uv_idx=uv_idx, scale_factor=scale_factor, auto_center=auto_center ) if texture_data is not None: self.set_texture(texture_data) def save_mesh(self): texture_data = self.get_texture() texture_data = Image.fromarray((texture_data * 255).astype(np.uint8)) return save_mesh(self.mesh_copy, texture_data) def set_mesh( self, vtx_pos, pos_idx, vtx_uv=None, uv_idx=None, scale_factor=1.15, auto_center=True ): self.vtx_pos = torch.from_numpy(vtx_pos).to(self.device).float() self.pos_idx = torch.from_numpy(pos_idx).to(self.device).to(torch.int) if (vtx_uv is not None) and (uv_idx is not None): self.vtx_uv = torch.from_numpy(vtx_uv).to(self.device).float() self.uv_idx = torch.from_numpy(uv_idx).to(self.device).to(torch.int) else: self.vtx_uv = None self.uv_idx = None self.vtx_pos[:, [0, 1]] = -self.vtx_pos[:, [0, 1]] self.vtx_pos[:, [1, 2]] = self.vtx_pos[:, [2, 1]] if (vtx_uv is not None) and (uv_idx is not None): self.vtx_uv[:, 1] = 1.0 - self.vtx_uv[:, 1] if auto_center: max_bb = (self.vtx_pos - 0).max(0)[0] min_bb = (self.vtx_pos - 0).min(0)[0] center = (max_bb + min_bb) / 2 scale = torch.norm(self.vtx_pos - center, dim=1).max() * 2.0 self.vtx_pos = (self.vtx_pos - center) * \ (scale_factor / float(scale)) self.scale_factor = scale_factor def set_texture(self, tex): if isinstance(tex, np.ndarray): tex = Image.fromarray((tex * 255).astype(np.uint8)) elif isinstance(tex, torch.Tensor): tex = tex.cpu().numpy() tex = Image.fromarray((tex * 255).astype(np.uint8)) tex = tex.resize(self.texture_size).convert('RGB') tex = np.array(tex) / 255.0 self.tex = torch.from_numpy(tex).to(self.device) self.tex = self.tex.float() def set_default_render_resolution(self, default_resolution): if isinstance(default_resolution, int): default_resolution = (default_resolution, default_resolution) self.default_resolution = default_resolution def set_default_texture_resolution(self, texture_size): if isinstance(texture_size, int): texture_size = (texture_size, texture_size) self.texture_size = texture_size def get_mesh(self): vtx_pos = self.vtx_pos.cpu().numpy() pos_idx = self.pos_idx.cpu().numpy() vtx_uv = self.vtx_uv.cpu().numpy() uv_idx = self.uv_idx.cpu().numpy() # 坐标变换的逆变换 vtx_pos[:, [1, 2]] = vtx_pos[:, [2, 1]] vtx_pos[:, [0, 1]] = -vtx_pos[:, [0, 1]] vtx_uv[:, 1] = 1.0 - vtx_uv[:, 1] return vtx_pos, pos_idx, vtx_uv, uv_idx def get_texture(self): return self.tex.cpu().numpy() def to(self, device): self.device = device for attr_name in dir(self): attr_value = getattr(self, attr_name) if isinstance(attr_value, torch.Tensor): setattr(self, attr_name, attr_value.to(self.device)) def color_rgb_to_srgb(self, image): if isinstance(image, Image.Image): image_rgb = torch.tesnor( np.array(image) / 255.0).float().to( self.device) elif isinstance(image, np.ndarray): image_rgb = torch.tensor(image).float() else: image_rgb = image.to(self.device) image_srgb = torch.where( image_rgb <= 0.0031308, 12.92 * image_rgb, 1.055 * torch.pow(image_rgb, 1 / 2.4) - 0.055 ) if isinstance(image, Image.Image): image_srgb = Image.fromarray( (image_srgb.cpu().numpy() * 255).astype( np.uint8)) elif isinstance(image, np.ndarray): image_srgb = image_srgb.cpu().numpy() else: image_srgb = image_srgb.to(image.device) return image_srgb def _render( self, glctx, mvp, pos, pos_idx, uv, uv_idx, tex, resolution, max_mip_level, keep_alpha, filter_mode ): pos_clip = transform_pos(mvp, pos) if isinstance(resolution, (int, float)): resolution = [resolution, resolution] rast_out, rast_out_db = self.raster_rasterize( glctx, pos_clip, pos_idx, resolution=resolution) tex = tex.contiguous() if filter_mode == 'linear-mipmap-linear': texc, texd = self.raster_interpolate( uv[None, ...], rast_out, uv_idx, rast_db=rast_out_db, diff_attrs='all') color = self.raster_texture( tex[None, ...], texc, texd, filter_mode='linear-mipmap-linear', max_mip_level=max_mip_level) else: texc, _ = self.raster_interpolate(uv[None, ...], rast_out, uv_idx) color = self.raster_texture(tex[None, ...], texc, filter_mode=filter_mode) visible_mask = torch.clamp(rast_out[..., -1:], 0, 1) color = color * visible_mask # Mask out background. if self.use_antialias: color = self.raster_antialias(color, rast_out, pos_clip, pos_idx) if keep_alpha: color = torch.cat([color, visible_mask], dim=-1) return color[0, ...] def render( self, elev, azim, camera_distance=None, center=None, resolution=None, tex=None, keep_alpha=True, bgcolor=None, filter_mode=None, return_type='th' ): proj = self.camera_proj_mat r_mv = get_mv_matrix( elev=elev, azim=azim, camera_distance=self.camera_distance if camera_distance is None else camera_distance, center=center) r_mvp = np.matmul(proj, r_mv).astype(np.float32) if tex is not None: if isinstance(tex, Image.Image): tex = torch.tensor(np.array(tex) / 255.0) elif isinstance(tex, np.ndarray): tex = torch.tensor(tex) if tex.dim() == 2: tex = tex.unsqueeze(-1) tex = tex.float().to(self.device) image = self._render(r_mvp, self.vtx_pos, self.pos_idx, self.vtx_uv, self.uv_idx, self.tex if tex is None else tex, self.default_resolution if resolution is None else resolution, self.max_mip_level, True, filter_mode if filter_mode else self.filter_mode) mask = (image[..., [-1]] == 1).float() if bgcolor is None: bgcolor = [0 for _ in range(image.shape[-1] - 1)] image = image * mask + (1 - mask) * \ torch.tensor(bgcolor + [0]).to(self.device) if keep_alpha == False: image = image[..., :-1] if return_type == 'np': image = image.cpu().numpy() elif return_type == 'pl': image = image.squeeze(-1).cpu().numpy() * 255 image = Image.fromarray(image.astype(np.uint8)) return image def render_normal( self, elev, azim, camera_distance=None, center=None, resolution=None, bg_color=[1, 1, 1], use_abs_coor=False, normalize_rgb=True, return_type='th' ): pos_camera, pos_clip = self.get_pos_from_mvp(elev, azim, camera_distance, center) if resolution is None: resolution = self.default_resolution if isinstance(resolution, (int, float)): resolution = [resolution, resolution] rast_out, rast_out_db = self.raster_rasterize( pos_clip, self.pos_idx, resolution=resolution) if use_abs_coor: mesh_triangles = self.vtx_pos[self.pos_idx[:, :3], :] else: pos_camera = pos_camera[:, :3] / pos_camera[:, 3:4] mesh_triangles = pos_camera[self.pos_idx[:, :3], :] face_normals = F.normalize( torch.cross(mesh_triangles[:, 1, :] - mesh_triangles[:, 0, :], mesh_triangles[:, 2, :] - mesh_triangles[:, 0, :], dim=-1), dim=-1) vertex_normals = trimesh.geometry.mean_vertex_normals(vertex_count=self.vtx_pos.shape[0], faces=self.pos_idx.cpu(), face_normals=face_normals.cpu(), ) vertex_normals = torch.from_numpy( vertex_normals).float().to(self.device).contiguous() # Interpolate normal values across the rasterized pixels normal, _ = self.raster_interpolate( vertex_normals[None, ...], rast_out, self.pos_idx) visible_mask = torch.clamp(rast_out[..., -1:], 0, 1) normal = normal * visible_mask + \ torch.tensor(bg_color, dtype=torch.float32, device=self.device) * (1 - visible_mask) # Mask out background. if normalize_rgb: normal = (normal + 1) * 0.5 if self.use_antialias: normal = self.raster_antialias(normal, rast_out, pos_clip, self.pos_idx) image = normal[0, ...] if return_type == 'np': image = image.cpu().numpy() elif return_type == 'pl': image = image.cpu().numpy() * 255 image = Image.fromarray(image.astype(np.uint8)) return image def convert_normal_map(self, image): # blue is front, red is left, green is top if isinstance(image, Image.Image): image = np.array(image) mask = (image == [255, 255, 255]).all(axis=-1) image = (image / 255.0) * 2.0 - 1.0 image[..., [1]] = -image[..., [1]] image[..., [1, 2]] = image[..., [2, 1]] image[..., [0]] = -image[..., [0]] image = (image + 1.0) * 0.5 image = (image * 255).astype(np.uint8) image[mask] = [127, 127, 255] return Image.fromarray(image) def get_pos_from_mvp(self, elev, azim, camera_distance, center): proj = self.camera_proj_mat r_mv = get_mv_matrix( elev=elev, azim=azim, camera_distance=self.camera_distance if camera_distance is None else camera_distance, center=center) pos_camera = transform_pos(r_mv, self.vtx_pos, keepdim=True) pos_clip = transform_pos(proj, pos_camera) return pos_camera, pos_clip def render_depth( self, elev, azim, camera_distance=None, center=None, resolution=None, return_type='th' ): pos_camera, pos_clip = self.get_pos_from_mvp(elev, azim, camera_distance, center) if resolution is None: resolution = self.default_resolution if isinstance(resolution, (int, float)): resolution = [resolution, resolution] rast_out, rast_out_db = self.raster_rasterize( pos_clip, self.pos_idx, resolution=resolution) pos_camera = pos_camera[:, :3] / pos_camera[:, 3:4] tex_depth = pos_camera[:, 2].reshape(1, -1, 1).contiguous() # Interpolate depth values across the rasterized pixels depth, _ = self.raster_interpolate(tex_depth, rast_out, self.pos_idx) visible_mask = torch.clamp(rast_out[..., -1:], 0, 1) depth_max, depth_min = depth[visible_mask > 0].max(), depth[visible_mask > 0].min() depth = (depth - depth_min) / (depth_max - depth_min) depth = depth * visible_mask # Mask out background. if self.use_antialias: depth = self.raster_antialias(depth, rast_out, pos_clip, self.pos_idx) image = depth[0, ...] if return_type == 'np': image = image.cpu().numpy() elif return_type == 'pl': image = image.squeeze(-1).cpu().numpy() * 255 image = Image.fromarray(image.astype(np.uint8)) return image def render_position(self, elev, azim, camera_distance=None, center=None, resolution=None, bg_color=[1, 1, 1], return_type='th'): pos_camera, pos_clip = self.get_pos_from_mvp(elev, azim, camera_distance, center) if resolution is None: resolution = self.default_resolution if isinstance(resolution, (int, float)): resolution = [resolution, resolution] rast_out, rast_out_db = self.raster_rasterize( pos_clip, self.pos_idx, resolution=resolution) tex_position = 0.5 - self.vtx_pos[:, :3] / self.scale_factor tex_position = tex_position.contiguous() # Interpolate depth values across the rasterized pixels position, _ = self.raster_interpolate( tex_position[None, ...], rast_out, self.pos_idx) visible_mask = torch.clamp(rast_out[..., -1:], 0, 1) position = position * visible_mask + \ torch.tensor(bg_color, dtype=torch.float32, device=self.device) * (1 - visible_mask) # Mask out background. if self.use_antialias: position = self.raster_antialias(position, rast_out, pos_clip, self.pos_idx) image = position[0, ...] if return_type == 'np': image = image.cpu().numpy() elif return_type == 'pl': image = image.squeeze(-1).cpu().numpy() * 255 image = Image.fromarray(image.astype(np.uint8)) return image def render_uvpos(self, return_type='th'): image = self.uv_feature_map(self.vtx_pos * 0.5 + 0.5) if return_type == 'np': image = image.cpu().numpy() elif return_type == 'pl': image = image.cpu().numpy() * 255 image = Image.fromarray(image.astype(np.uint8)) return image def uv_feature_map(self, vert_feat, bg=None): vtx_uv = self.vtx_uv * 2 - 1.0 vtx_uv = torch.cat( [vtx_uv, torch.zeros_like(self.vtx_uv)], dim=1).unsqueeze(0) vtx_uv[..., -1] = 1 uv_idx = self.uv_idx rast_out, rast_out_db = self.raster_rasterize( vtx_uv, uv_idx, resolution=self.texture_size) feat_map, _ = self.raster_interpolate(vert_feat[None, ...], rast_out, uv_idx) feat_map = feat_map[0, ...] if bg is not None: visible_mask = torch.clamp(rast_out[..., -1:], 0, 1)[0, ...] feat_map[visible_mask == 0] = bg return feat_map def render_sketch_from_geometry(self, normal_image, depth_image): normal_image_np = normal_image.cpu().numpy() depth_image_np = depth_image.cpu().numpy() normal_image_np = (normal_image_np * 255).astype(np.uint8) depth_image_np = (depth_image_np * 255).astype(np.uint8) normal_image_np = cv2.cvtColor(normal_image_np, cv2.COLOR_RGB2GRAY) normal_edges = cv2.Canny(normal_image_np, 80, 150) depth_edges = cv2.Canny(depth_image_np, 30, 80) combined_edges = np.maximum(normal_edges, depth_edges) sketch_image = torch.from_numpy(combined_edges).to( normal_image.device).float() / 255.0 sketch_image = sketch_image.unsqueeze(-1) return sketch_image def render_sketch_from_depth(self, depth_image): depth_image_np = depth_image.cpu().numpy() depth_image_np = (depth_image_np * 255).astype(np.uint8) depth_edges = cv2.Canny(depth_image_np, 30, 80) combined_edges = depth_edges sketch_image = torch.from_numpy(combined_edges).to( depth_image.device).float() / 255.0 sketch_image = sketch_image.unsqueeze(-1) return sketch_image def back_project(self, image, elev, azim, camera_distance=None, center=None, method=None): if isinstance(image, Image.Image): image = torch.tensor(np.array(image) / 255.0) elif isinstance(image, np.ndarray): image = torch.tensor(image) if image.dim() == 2: image = image.unsqueeze(-1) image = image.float().to(self.device) resolution = image.shape[:2] channel = image.shape[-1] texture = torch.zeros(self.texture_size + (channel,)).to(self.device) cos_map = torch.zeros(self.texture_size + (1,)).to(self.device) proj = self.camera_proj_mat r_mv = get_mv_matrix( elev=elev, azim=azim, camera_distance=self.camera_distance if camera_distance is None else camera_distance, center=center) pos_camera = transform_pos(r_mv, self.vtx_pos, keepdim=True) pos_clip = transform_pos(proj, pos_camera) pos_camera = pos_camera[:, :3] / pos_camera[:, 3:4] v0 = pos_camera[self.pos_idx[:, 0], :] v1 = pos_camera[self.pos_idx[:, 1], :] v2 = pos_camera[self.pos_idx[:, 2], :] face_normals = F.normalize( torch.cross( v1 - v0, v2 - v0, dim=-1), dim=-1) vertex_normals = trimesh.geometry.mean_vertex_normals(vertex_count=self.vtx_pos.shape[0], faces=self.pos_idx.cpu(), face_normals=face_normals.cpu(), ) vertex_normals = torch.from_numpy( vertex_normals).float().to(self.device).contiguous() tex_depth = pos_camera[:, 2].reshape(1, -1, 1).contiguous() rast_out, rast_out_db = self.raster_rasterize( pos_clip, self.pos_idx, resolution=resolution) visible_mask = torch.clamp(rast_out[..., -1:], 0, 1)[0, ...] normal, _ = self.raster_interpolate( vertex_normals[None, ...], rast_out, self.pos_idx) normal = normal[0, ...] uv, _ = self.raster_interpolate(self.vtx_uv[None, ...], rast_out, self.uv_idx) depth, _ = self.raster_interpolate(tex_depth, rast_out, self.pos_idx) depth = depth[0, ...] depth_max, depth_min = depth[visible_mask > 0].max(), depth[visible_mask > 0].min() depth_normalized = (depth - depth_min) / (depth_max - depth_min) depth_image = depth_normalized * visible_mask # Mask out background. sketch_image = self.render_sketch_from_depth(depth_image) lookat = torch.tensor([[0, 0, -1]], device=self.device) cos_image = torch.nn.functional.cosine_similarity( lookat, normal.view(-1, 3)) cos_image = cos_image.view(normal.shape[0], normal.shape[1], 1) cos_thres = np.cos(self.bake_angle_thres / 180 * np.pi) cos_image[cos_image < cos_thres] = 0 # shrink kernel_size = self.bake_unreliable_kernel_size * 2 + 1 kernel = torch.ones( (1, 1, kernel_size, kernel_size), dtype=torch.float32).to( sketch_image.device) visible_mask = visible_mask.permute(2, 0, 1).unsqueeze(0).float() visible_mask = F.conv2d( 1.0 - visible_mask, kernel, padding=kernel_size // 2) visible_mask = 1.0 - (visible_mask > 0).float() # 二值化 visible_mask = visible_mask.squeeze(0).permute(1, 2, 0) sketch_image = sketch_image.permute(2, 0, 1).unsqueeze(0) sketch_image = F.conv2d(sketch_image, kernel, padding=kernel_size // 2) sketch_image = (sketch_image > 0).float() # 二值化 sketch_image = sketch_image.squeeze(0).permute(1, 2, 0) visible_mask = visible_mask * (sketch_image < 0.5) cos_image[visible_mask == 0] = 0 method = self.bake_mode if method is None else method if method == 'linear': proj_mask = (visible_mask != 0).view(-1) uv = uv.squeeze(0).contiguous().view(-1, 2)[proj_mask] image = image.squeeze(0).contiguous().view(-1, channel)[proj_mask] cos_image = cos_image.contiguous().view(-1, 1)[proj_mask] sketch_image = sketch_image.contiguous().view(-1, 1)[proj_mask] texture = linear_grid_put_2d( self.texture_size[1], self.texture_size[0], uv[..., [1, 0]], image) cos_map = linear_grid_put_2d( self.texture_size[1], self.texture_size[0], uv[..., [1, 0]], cos_image) boundary_map = linear_grid_put_2d( self.texture_size[1], self.texture_size[0], uv[..., [1, 0]], sketch_image) else: raise f'No bake mode {method}' return texture, cos_map, boundary_map def bake_texture(self, colors, elevs, azims, camera_distance=None, center=None, exp=6, weights=None): for i in range(len(colors)): if isinstance(colors[i], Image.Image): colors[i] = torch.tensor( np.array( colors[i]) / 255.0, device=self.device).float() if weights is None: weights = [1.0 for _ in range(colors)] textures = [] cos_maps = [] for color, elev, azim, weight in zip(colors, elevs, azims, weights): texture, cos_map, _ = self.back_project( color, elev, azim, camera_distance, center) cos_map = weight * (cos_map ** exp) textures.append(texture) cos_maps.append(cos_map) texture_merge, trust_map_merge = self.fast_bake_texture( textures, cos_maps) return texture_merge, trust_map_merge @torch.no_grad() def fast_bake_texture(self, textures, cos_maps): channel = textures[0].shape[-1] texture_merge = torch.zeros( self.texture_size + (channel,)).to(self.device) trust_map_merge = torch.zeros(self.texture_size + (1,)).to(self.device) for texture, cos_map in zip(textures, cos_maps): view_sum = (cos_map > 0).sum() painted_sum = ((cos_map > 0) * (trust_map_merge > 0)).sum() if painted_sum / view_sum > 0.99: continue texture_merge += texture * cos_map trust_map_merge += cos_map texture_merge = texture_merge / torch.clamp(trust_map_merge, min=1E-8) return texture_merge, trust_map_merge > 1E-8 def uv_inpaint(self, texture, mask): if isinstance(texture, torch.Tensor): texture_np = texture.cpu().numpy() elif isinstance(texture, np.ndarray): texture_np = texture elif isinstance(texture, Image.Image): texture_np = np.array(texture) / 255.0 vtx_pos, pos_idx, vtx_uv, uv_idx = self.get_mesh() texture_np, mask = meshVerticeInpaint( texture_np, mask, vtx_pos, vtx_uv, pos_idx, uv_idx) texture_np = cv2.inpaint( (texture_np * 255).astype( np.uint8), 255 - mask, 3, cv2.INPAINT_NS) return texture_np