|
from dataclasses import dataclass |
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch import distributed as tdist |
|
from torch.nn import functional as F |
|
import math |
|
import mcubes |
|
import numpy as np |
|
from einops import repeat, rearrange |
|
from skimage import measure |
|
|
|
from craftsman.utils.base import BaseModule |
|
from craftsman.utils.typing import * |
|
from craftsman.utils.misc import get_world_size |
|
from craftsman.utils.ops import generate_dense_grid_points |
|
|
|
VALID_EMBED_TYPES = ["identity", "fourier", "hashgrid", "sphere_harmonic", "triplane_fourier"] |
|
|
|
class FourierEmbedder(nn.Module): |
|
def __init__(self, |
|
num_freqs: int = 6, |
|
logspace: bool = True, |
|
input_dim: int = 3, |
|
include_input: bool = True, |
|
include_pi: bool = True) -> None: |
|
super().__init__() |
|
|
|
if logspace: |
|
frequencies = 2.0 ** torch.arange( |
|
num_freqs, |
|
dtype=torch.float32 |
|
) |
|
else: |
|
frequencies = torch.linspace( |
|
1.0, |
|
2.0 ** (num_freqs - 1), |
|
num_freqs, |
|
dtype=torch.float32 |
|
) |
|
|
|
if include_pi: |
|
frequencies *= torch.pi |
|
|
|
self.register_buffer("frequencies", frequencies, persistent=False) |
|
self.include_input = include_input |
|
self.num_freqs = num_freqs |
|
|
|
self.out_dim = self.get_dims(input_dim) |
|
|
|
def get_dims(self, input_dim): |
|
temp = 1 if self.include_input or self.num_freqs == 0 else 0 |
|
out_dim = input_dim * (self.num_freqs * 2 + temp) |
|
|
|
return out_dim |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
if self.num_freqs > 0: |
|
embed = (x[..., None].contiguous() * self.frequencies).view(*x.shape[:-1], -1) |
|
if self.include_input: |
|
return torch.cat((x, embed.sin(), embed.cos()), dim=-1) |
|
else: |
|
return torch.cat((embed.sin(), embed.cos()), dim=-1) |
|
else: |
|
return x |
|
|
|
|
|
class LearnedFourierEmbedder(nn.Module): |
|
def __init__(self, input_dim, dim): |
|
super().__init__() |
|
assert (dim % 2) == 0 |
|
half_dim = dim // 2 |
|
per_channel_dim = half_dim // input_dim |
|
self.weights = nn.Parameter(torch.randn(per_channel_dim)) |
|
|
|
self.out_dim = self.get_dims(input_dim) |
|
|
|
def forward(self, x): |
|
|
|
freqs = (x[..., None] * self.weights[None] * 2 * np.pi).view(*x.shape[:-1], -1) |
|
fouriered = torch.cat((x, freqs.sin(), freqs.cos()), dim=-1) |
|
return fouriered |
|
|
|
def get_dims(self, input_dim): |
|
return input_dim * (self.weights.shape[0] * 2 + 1) |
|
|
|
class Sine(nn.Module): |
|
def __init__(self, w0 = 1.): |
|
super().__init__() |
|
self.w0 = w0 |
|
def forward(self, x): |
|
return torch.sin(self.w0 * x) |
|
|
|
class Siren(nn.Module): |
|
def __init__( |
|
self, |
|
in_dim, |
|
out_dim, |
|
w0 = 1., |
|
c = 6., |
|
is_first = False, |
|
use_bias = True, |
|
activation = None, |
|
dropout = 0. |
|
): |
|
super().__init__() |
|
self.in_dim = in_dim |
|
self.out_dim = out_dim |
|
self.is_first = is_first |
|
|
|
weight = torch.zeros(out_dim, in_dim) |
|
bias = torch.zeros(out_dim) if use_bias else None |
|
self.init_(weight, bias, c = c, w0 = w0) |
|
|
|
self.weight = nn.Parameter(weight) |
|
self.bias = nn.Parameter(bias) if use_bias else None |
|
self.activation = Sine(w0) if activation is None else activation |
|
self.dropout = nn.Dropout(dropout) |
|
|
|
def init_(self, weight, bias, c, w0): |
|
dim = self.in_dim |
|
|
|
w_std = (1 / dim) if self.is_first else (math.sqrt(c / dim) / w0) |
|
weight.uniform_(-w_std, w_std) |
|
|
|
if bias is not None: |
|
bias.uniform_(-w_std, w_std) |
|
|
|
def forward(self, x): |
|
out = F.linear(x, self.weight, self.bias) |
|
out = self.activation(out) |
|
out = self.dropout(out) |
|
return out |
|
|
|
def get_embedder(embed_type="fourier", num_freqs=-1, input_dim=3, include_pi=True): |
|
if embed_type == "identity" or (embed_type == "fourier" and num_freqs == -1): |
|
return nn.Identity(), input_dim |
|
|
|
elif embed_type == "fourier": |
|
embedder_obj = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi) |
|
|
|
elif embed_type == "learned_fourier": |
|
embedder_obj = LearnedFourierEmbedder(in_channels=input_dim, dim=num_freqs) |
|
|
|
elif embed_type == "siren": |
|
embedder_obj = Siren(in_dim=input_dim, out_dim=num_freqs * input_dim * 2 + input_dim) |
|
|
|
elif embed_type == "hashgrid": |
|
raise NotImplementedError |
|
|
|
elif embed_type == "sphere_harmonic": |
|
raise NotImplementedError |
|
|
|
else: |
|
raise ValueError(f"{embed_type} is not valid. Currently only supprts {VALID_EMBED_TYPES}") |
|
return embedder_obj |
|
|
|
|
|
|
|
class AutoEncoder(BaseModule): |
|
@dataclass |
|
class Config(BaseModule.Config): |
|
pretrained_model_name_or_path: str = "" |
|
num_latents: int = 256 |
|
embed_dim: int = 64 |
|
width: int = 768 |
|
|
|
cfg: Config |
|
|
|
def configure(self) -> None: |
|
super().configure() |
|
|
|
def encode(self, x: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor]: |
|
raise NotImplementedError |
|
|
|
def decode(self, z: torch.FloatTensor) -> torch.FloatTensor: |
|
raise NotImplementedError |
|
|
|
def encode_kl_embed(self, latents: torch.FloatTensor, sample_posterior: bool = True): |
|
posterior = None |
|
if self.cfg.embed_dim > 0: |
|
moments = self.pre_kl(latents) |
|
posterior = DiagonalGaussianDistribution(moments, feat_dim=-1) |
|
if sample_posterior: |
|
kl_embed = posterior.sample() |
|
else: |
|
kl_embed = posterior.mode() |
|
else: |
|
kl_embed = latents |
|
return kl_embed, posterior |
|
|
|
def forward(self, |
|
surface: torch.FloatTensor, |
|
queries: torch.FloatTensor, |
|
sample_posterior: bool = True): |
|
shape_latents, kl_embed, posterior = self.encode(surface, sample_posterior=sample_posterior) |
|
|
|
latents = self.decode(kl_embed) |
|
|
|
logits = self.query(queries, latents) |
|
|
|
return shape_latents, latents, posterior, logits |
|
|
|
def query(self, queries: torch.FloatTensor, latents: torch.FloatTensor) -> torch.FloatTensor: |
|
raise NotImplementedError |
|
|
|
@torch.no_grad() |
|
def extract_geometry(self, |
|
latents: torch.FloatTensor, |
|
bounds: Union[Tuple[float], List[float], float] = (-1.05, -1.05, -1.05, 1.05, 1.05, 1.05), |
|
octree_depth: int = 8, |
|
num_chunks: int = 10000, |
|
): |
|
|
|
if isinstance(bounds, float): |
|
bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds] |
|
|
|
bbox_min = np.array(bounds[0:3]) |
|
bbox_max = np.array(bounds[3:6]) |
|
bbox_size = bbox_max - bbox_min |
|
|
|
xyz_samples, grid_size, length = generate_dense_grid_points( |
|
bbox_min=bbox_min, |
|
bbox_max=bbox_max, |
|
octree_depth=octree_depth, |
|
indexing="ij" |
|
) |
|
xyz_samples = torch.FloatTensor(xyz_samples) |
|
batch_size = latents.shape[0] |
|
|
|
batch_logits = [] |
|
for start in range(0, xyz_samples.shape[0], num_chunks): |
|
queries = xyz_samples[start: start + num_chunks, :].to(latents) |
|
batch_queries = repeat(queries, "p c -> b p c", b=batch_size) |
|
|
|
logits = self.query(batch_queries, latents) |
|
batch_logits.append(logits.cpu()) |
|
|
|
grid_logits = torch.cat(batch_logits, dim=1).view((batch_size, grid_size[0], grid_size[1], grid_size[2])).float().numpy() |
|
|
|
mesh_v_f = [] |
|
has_surface = np.zeros((batch_size,), dtype=np.bool_) |
|
for i in range(batch_size): |
|
try: |
|
vertices, faces, normals, _ = measure.marching_cubes(grid_logits[i], 0, method="lewiner") |
|
|
|
vertices = vertices / grid_size * bbox_size + bbox_min |
|
faces = faces[:, [2, 1, 0]] |
|
mesh_v_f.append((vertices.astype(np.float32), np.ascontiguousarray(faces))) |
|
has_surface[i] = True |
|
except: |
|
mesh_v_f.append((None, None)) |
|
has_surface[i] = False |
|
|
|
return mesh_v_f, has_surface |
|
|
|
class DiagonalGaussianDistribution(object): |
|
def __init__(self, parameters: Union[torch.Tensor, List[torch.Tensor]], deterministic=False, feat_dim=1): |
|
self.feat_dim = feat_dim |
|
self.parameters = parameters |
|
|
|
if isinstance(parameters, list): |
|
self.mean = parameters[0] |
|
self.logvar = parameters[1] |
|
else: |
|
self.mean, self.logvar = torch.chunk(parameters, 2, dim=feat_dim) |
|
|
|
self.logvar = torch.clamp(self.logvar, -30.0, 20.0) |
|
self.deterministic = deterministic |
|
self.std = torch.exp(0.5 * self.logvar) |
|
self.var = torch.exp(self.logvar) |
|
if self.deterministic: |
|
self.var = self.std = torch.zeros_like(self.mean) |
|
|
|
def sample(self): |
|
x = self.mean + self.std * torch.randn_like(self.mean) |
|
return x |
|
|
|
def kl(self, other=None, dims=(1, 2)): |
|
if self.deterministic: |
|
return torch.Tensor([0.]) |
|
else: |
|
if other is None: |
|
return 0.5 * torch.mean(torch.pow(self.mean, 2) |
|
+ self.var - 1.0 - self.logvar, |
|
dim=dims) |
|
else: |
|
return 0.5 * torch.mean( |
|
torch.pow(self.mean - other.mean, 2) / other.var |
|
+ self.var / other.var - 1.0 - self.logvar + other.logvar, |
|
dim=dims) |
|
|
|
def nll(self, sample, dims=(1, 2)): |
|
if self.deterministic: |
|
return torch.Tensor([0.]) |
|
logtwopi = np.log(2.0 * np.pi) |
|
return 0.5 * torch.sum( |
|
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, |
|
dim=dims) |
|
|
|
def mode(self): |
|
return self.mean |
|
|