# Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import numpy as np import torch import math from modules.real3d.segformer import SegFormerImg2PlaneBackbone, SegFormerSECC2PlaneBackbone from modules.real3d.img2plane_baseline import OSAvatar_Img2plane from modules.img2plane.img2plane_model import Img2PlaneModel from utils.commons.hparams import hparams # 换成attention吧?value用plane。 class OSAvatarSECC_Img2plane(OSAvatar_Img2plane): def __init__(self, hp=None): super().__init__(hp=hp) hparams = self.hparams # extract canonical triplane from src img self.cano_img2plane_backbone = self.img2plane_backbone # rename del self.img2plane_backbone self.secc_img2plane_backbone = SegFormerSECC2PlaneBackbone(mode=hparams['secc_segformer_scale'], out_channels=3*self.triplane_hid_dim*self.triplane_depth, pncc_cond_mode=hparams['pncc_cond_mode']) self.lambda_pertube_blink_secc = torch.nn.Parameter(torch.tensor([0.001]), requires_grad=False) self.lambda_pertube_secc = torch.nn.Parameter(torch.tensor([0.001]), requires_grad=False) def on_train_full_model(self): self.requires_grad_(True) def on_train_nerf(self): self.cano_img2plane_backbone.requires_grad_(True) self.secc_img2plane_backbone.requires_grad_(True) self.decoder.requires_grad_(True) self.superresolution.requires_grad_(False) def on_train_superresolution(self): self.cano_img2plane_backbone.requires_grad_(False) self.secc_img2plane_backbone.requires_grad_(False) self.decoder.requires_grad_(False) self.superresolution.requires_grad_(True) def cal_cano_plane(self, img, cond=None, **kwargs): hparams = self.hparams planes = cano_planes = self.cano_img2plane_backbone(img, cond, **kwargs) # [B, 3, C*D, H, W] if hparams.get("triplane_feature_type", "triplane") in ['triplane', 'trigrid']: planes = planes.view(len(planes), 3, self.triplane_hid_dim*self.triplane_depth, planes.shape[-2], planes.shape[-1]) elif hparams.get("triplane_feature_type", "triplane") in ['trigrid_v2']: b, k, cd, h, w = planes.shape # k = 3 planes = planes.reshape([b, k*cd, h, w]) planes = self.plane2grid_module(planes) planes = planes.reshape([b, k, cd, h, w]) else: raise NotImplementedError() return planes def cal_secc_plane(self, cond): cano_pncc, src_pncc, tgt_pncc = cond['cond_cano'], cond['cond_src'], cond['cond_tgt'] if self.hparams.get("pncc_cond_mode", "cano_tgt") == 'cano_src_tgt': inp_pncc = torch.cat([cano_pncc, src_pncc, tgt_pncc], dim=1) else: inp_pncc = torch.cat([cano_pncc, tgt_pncc], dim=1) secc_planes = self.secc_img2plane_backbone(inp_pncc) return secc_planes def cal_plane_given_cano(self, cano_planes, cond=None): # cano_planes: # [B, 3, C*D, H, W] secc_planes = self.cal_secc_plane(cond) # [B, 3, C*D, H, W] if self.hparams.get("phase1_plane_fusion_mode", "add") == 'add': planes = cano_planes + secc_planes elif self.hparams.get("phase1_plane_fusion_mode", "add") == 'mul': planes = cano_planes * secc_planes else: raise NotImplementedError() return planes def cal_plane(self, img, cond, ret=None, **kwargs): cano_planes = self.cal_cano_plane(img, cond, **kwargs) # [B, 3, C*D, H, W] planes = self.cal_plane_given_cano(cano_planes, cond) return planes, cano_planes def sample(self, coordinates, directions, img, cond=None, truncation_psi=1, truncation_cutoff=None, update_emas=False, ref_camera=None, **synthesis_kwargs): # Compute RGB features, density for arbitrary 3D coordinates. Mostly used for extracting shapes. planes, _ = self.cal_plane(img, cond, ret={}, ref_camera=ref_camera) return self.renderer.run_model(planes, self.decoder, coordinates, directions, self.rendering_kwargs) def synthesis(self, img, camera, cond=None, ret=None, update_emas=False, cache_backbone=True, use_cached_backbone=False, **synthesis_kwargs): if ret is None: ret = {} cam2world_matrix = camera[:, :16].view(-1, 4, 4) intrinsics = camera[:, 16:25].view(-1, 3, 3) neural_rendering_resolution = self.neural_rendering_resolution # Create a batch of rays for volume rendering ray_origins, ray_directions = self.ray_sampler(cam2world_matrix, intrinsics, neural_rendering_resolution) # Create triplanes by running StyleGAN backbone N, M, _ = ray_origins.shape if use_cached_backbone: # use the cached cano_planes obtained from a previous forward with flag cache_backbone=True cano_planes = self._last_cano_planes planes = self.cal_plane_given_cano(cano_planes, cond) else: planes, cano_planes = self.cal_plane(img, cond, ret, **synthesis_kwargs) if cache_backbone: self._last_cano_planes = cano_planes # Perform volume rendering feature_samples, depth_samples, weights_samples, is_ray_valid = self.renderer(planes, self.decoder, ray_origins, ray_directions, self.rendering_kwargs) # channels last # Reshape into 'raw' neural-rendered image H = W = self.neural_rendering_resolution feature_image = feature_samples.permute(0, 2, 1).reshape(N, feature_samples.shape[-1], H, W).contiguous() weights_image = weights_samples.permute(0, 2, 1).reshape(N,1,H,W).contiguous() # [N,1,H,W] depth_image = depth_samples.permute(0, 2, 1).reshape(N, 1, H, W) if self.hparams.get("mask_invalid_rays", False): is_ray_valid_mask = is_ray_valid.reshape([feature_samples.shape[0], 1,self.neural_rendering_resolution,self.neural_rendering_resolution]) # [B, 1, H, W] feature_image[~is_ray_valid_mask.repeat([1,feature_image.shape[1],1,1])] = -1 # feature_image[~is_ray_valid_mask.repeat([1,feature_image.shape[1],1,1])] *= 0 # feature_image[~is_ray_valid_mask.repeat([1,feature_image.shape[1],1,1])] -= 1 depth_image[~is_ray_valid_mask] = depth_image[is_ray_valid_mask].min().item() # Run superresolution to get final image rgb_image = feature_image[:, :3] ret['weights_img'] = weights_image sr_image = self._forward_sr(rgb_image, feature_image, cond, ret, **synthesis_kwargs) rgb_image = rgb_image.clamp(-1,1) sr_image = sr_image.clamp(-1,1) ret.update({'image_raw': rgb_image, 'image_depth': depth_image, 'image': sr_image, 'image_feature': feature_image[:, 3:], 'plane': planes}) return ret