Spaces:
Configuration error
Configuration error
| import math | |
| import os | |
| from dataclasses import dataclass, field | |
| from typing import List, Union | |
| import numpy as np | |
| import PIL.Image | |
| import torch | |
| import torch.nn.functional as F | |
| import trimesh | |
| from einops import rearrange | |
| from huggingface_hub import hf_hub_download | |
| from omegaconf import OmegaConf | |
| from PIL import Image | |
| from .models.isosurface import MarchingCubeHelper | |
| from .utils import ( | |
| BaseModule, | |
| ImagePreprocessor, | |
| find_class, | |
| get_spherical_cameras, | |
| scale_tensor, | |
| ) | |
| class TSR(BaseModule): | |
| class Config(BaseModule.Config): | |
| cond_image_size: int | |
| image_tokenizer_cls: str | |
| image_tokenizer: dict | |
| tokenizer_cls: str | |
| tokenizer: dict | |
| backbone_cls: str | |
| backbone: dict | |
| post_processor_cls: str | |
| post_processor: dict | |
| decoder_cls: str | |
| decoder: dict | |
| renderer_cls: str | |
| renderer: dict | |
| cfg: Config | |
| def from_pretrained( | |
| cls, pretrained_model_name_or_path: str, config_name: str, weight_name: str | |
| ): | |
| if os.path.isdir(pretrained_model_name_or_path): | |
| config_path = os.path.join(pretrained_model_name_or_path, config_name) | |
| weight_path = os.path.join(pretrained_model_name_or_path, weight_name) | |
| else: | |
| config_path = hf_hub_download( | |
| repo_id=pretrained_model_name_or_path, filename=config_name | |
| ) | |
| weight_path = hf_hub_download( | |
| repo_id=pretrained_model_name_or_path, filename=weight_name | |
| ) | |
| cfg = OmegaConf.load(config_path) | |
| OmegaConf.resolve(cfg) | |
| model = cls(cfg) | |
| ckpt = torch.load(weight_path, map_location="cpu") | |
| model.load_state_dict(ckpt) | |
| return model | |
| def configure(self): | |
| self.image_tokenizer = find_class(self.cfg.image_tokenizer_cls)( | |
| self.cfg.image_tokenizer | |
| ) | |
| self.tokenizer = find_class(self.cfg.tokenizer_cls)(self.cfg.tokenizer) | |
| self.backbone = find_class(self.cfg.backbone_cls)(self.cfg.backbone) | |
| self.post_processor = find_class(self.cfg.post_processor_cls)( | |
| self.cfg.post_processor | |
| ) | |
| self.decoder = find_class(self.cfg.decoder_cls)(self.cfg.decoder) | |
| self.renderer = find_class(self.cfg.renderer_cls)(self.cfg.renderer) | |
| self.image_processor = ImagePreprocessor() | |
| self.isosurface_helper = None | |
| def forward( | |
| self, | |
| image: Union[ | |
| PIL.Image.Image, | |
| np.ndarray, | |
| torch.FloatTensor, | |
| List[PIL.Image.Image], | |
| List[np.ndarray], | |
| List[torch.FloatTensor], | |
| ], | |
| device: str, | |
| ) -> torch.FloatTensor: | |
| rgb_cond = self.image_processor(image, self.cfg.cond_image_size)[:, None].to( | |
| device | |
| ) | |
| batch_size = rgb_cond.shape[0] | |
| input_image_tokens: torch.Tensor = self.image_tokenizer( | |
| rearrange(rgb_cond, "B Nv H W C -> B Nv C H W", Nv=1), | |
| ) | |
| input_image_tokens = rearrange( | |
| input_image_tokens, "B Nv C Nt -> B (Nv Nt) C", Nv=1 | |
| ) | |
| tokens: torch.Tensor = self.tokenizer(batch_size) | |
| tokens = self.backbone( | |
| tokens, | |
| encoder_hidden_states=input_image_tokens, | |
| ) | |
| scene_codes = self.post_processor(self.tokenizer.detokenize(tokens)) | |
| return scene_codes | |
| def render( | |
| self, | |
| scene_codes, | |
| n_views: int, | |
| elevation_deg: float = 0.0, | |
| camera_distance: float = 1.9, | |
| fovy_deg: float = 40.0, | |
| height: int = 256, | |
| width: int = 256, | |
| return_type: str = "pil", | |
| ): | |
| rays_o, rays_d = get_spherical_cameras( | |
| n_views, elevation_deg, camera_distance, fovy_deg, height, width | |
| ) | |
| rays_o, rays_d = rays_o.to(scene_codes.device), rays_d.to(scene_codes.device) | |
| def process_output(image: torch.FloatTensor): | |
| if return_type == "pt": | |
| return image | |
| elif return_type == "np": | |
| return image.detach().cpu().numpy() | |
| elif return_type == "pil": | |
| return Image.fromarray( | |
| (image.detach().cpu().numpy() * 255.0).astype(np.uint8) | |
| ) | |
| else: | |
| raise NotImplementedError | |
| images = [] | |
| for scene_code in scene_codes: | |
| images_ = [] | |
| for i in range(n_views): | |
| with torch.no_grad(): | |
| image = self.renderer( | |
| self.decoder, scene_code, rays_o[i], rays_d[i] | |
| ) | |
| images_.append(process_output(image)) | |
| images.append(images_) | |
| return images | |
| def set_marching_cubes_resolution(self, resolution: int): | |
| if ( | |
| self.isosurface_helper is not None | |
| and self.isosurface_helper.resolution == resolution | |
| ): | |
| return | |
| self.isosurface_helper = MarchingCubeHelper(resolution) | |
| def extract_mesh(self, scene_codes, resolution: int = 256, threshold: float = 25.0): | |
| self.set_marching_cubes_resolution(resolution) | |
| meshes = [] | |
| for scene_code in scene_codes: | |
| with torch.no_grad(): | |
| density = self.renderer.query_triplane( | |
| self.decoder, | |
| scale_tensor( | |
| self.isosurface_helper.grid_vertices.to(scene_codes.device), | |
| self.isosurface_helper.points_range, | |
| (-self.renderer.cfg.radius, self.renderer.cfg.radius), | |
| ), | |
| scene_code, | |
| )["density_act"] | |
| v_pos, t_pos_idx = self.isosurface_helper(-(density - threshold)) | |
| v_pos = scale_tensor( | |
| v_pos, | |
| self.isosurface_helper.points_range, | |
| (-self.renderer.cfg.radius, self.renderer.cfg.radius), | |
| ) | |
| with torch.no_grad(): | |
| color = self.renderer.query_triplane( | |
| self.decoder, | |
| v_pos, | |
| scene_code, | |
| )["color"] | |
| mesh = trimesh.Trimesh( | |
| vertices=v_pos.cpu().numpy(), | |
| faces=t_pos_idx.cpu().numpy(), | |
| vertex_colors=color.cpu().numpy(), | |
| ) | |
| meshes.append(mesh) | |
| return meshes | |