|
from typing import * |
|
from contextlib import contextmanager |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import numpy as np |
|
from tqdm import tqdm |
|
from easydict import EasyDict as edict |
|
from torchvision import transforms |
|
from PIL import Image |
|
import rembg |
|
from .base import Pipeline |
|
from . import samplers |
|
from ..modules import sparse as sp |
|
from ..representations import Gaussian, Strivec, MeshExtractResult |
|
|
|
|
|
class TrellisImageTo3DPipeline(Pipeline): |
|
""" |
|
Pipeline for inferring Trellis image-to-3D models. |
|
|
|
Args: |
|
models (dict[str, nn.Module]): The models to use in the pipeline. |
|
sparse_structure_sampler (samplers.Sampler): The sampler for the sparse structure. |
|
slat_sampler (samplers.Sampler): The sampler for the structured latent. |
|
slat_normalization (dict): The normalization parameters for the structured latent. |
|
image_cond_model (str): The name of the image conditioning model. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
models: dict[str, nn.Module] = None, |
|
sparse_structure_sampler: samplers.Sampler = None, |
|
slat_sampler: samplers.Sampler = None, |
|
slat_normalization: dict = None, |
|
image_cond_model: str = None, |
|
): |
|
if models is None: |
|
return |
|
super().__init__(models) |
|
self.sparse_structure_sampler = sparse_structure_sampler |
|
self.slat_sampler = slat_sampler |
|
self.sparse_structure_sampler_params = {} |
|
self.slat_sampler_params = {} |
|
self.slat_normalization = slat_normalization |
|
self.rembg_session = None |
|
self._init_image_cond_model(image_cond_model) |
|
|
|
@staticmethod |
|
def from_pretrained(path: str) -> "TrellisImageTo3DPipeline": |
|
""" |
|
Load a pretrained model. |
|
|
|
Args: |
|
path (str): The path to the model. Can be either local path or a Hugging Face repository. |
|
""" |
|
pipeline = super( |
|
TrellisImageTo3DPipeline, TrellisImageTo3DPipeline |
|
).from_pretrained(path) |
|
new_pipeline = TrellisImageTo3DPipeline() |
|
new_pipeline.__dict__ = pipeline.__dict__ |
|
args = pipeline._pretrained_args |
|
|
|
new_pipeline.sparse_structure_sampler = getattr( |
|
samplers, args["sparse_structure_sampler"]["name"] |
|
)(**args["sparse_structure_sampler"]["args"]) |
|
new_pipeline.sparse_structure_sampler_params = args["sparse_structure_sampler"][ |
|
"params" |
|
] |
|
|
|
new_pipeline.slat_sampler = getattr(samplers, args["slat_sampler"]["name"])( |
|
**args["slat_sampler"]["args"] |
|
) |
|
new_pipeline.slat_sampler_params = args["slat_sampler"]["params"] |
|
|
|
new_pipeline.slat_normalization = args["slat_normalization"] |
|
|
|
new_pipeline._init_image_cond_model(args["image_cond_model"]) |
|
|
|
return new_pipeline |
|
|
|
def _init_image_cond_model(self, name: str): |
|
""" |
|
Initialize the image conditioning model. |
|
""" |
|
dinov2_model = torch.hub.load("facebookresearch/dinov2", name, pretrained=True) |
|
dinov2_model.eval() |
|
self.models["image_cond_model"] = dinov2_model |
|
transform = transforms.Compose( |
|
[ |
|
transforms.Normalize( |
|
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] |
|
), |
|
] |
|
) |
|
self.image_cond_model_transform = transform |
|
|
|
def preprocess_image(self, input: Image.Image) -> Image.Image: |
|
""" |
|
Preprocess the input image. |
|
""" |
|
|
|
has_alpha = False |
|
if input.mode == "RGBA": |
|
alpha = np.array(input)[:, :, 3] |
|
if not np.all(alpha == 255): |
|
has_alpha = True |
|
if has_alpha: |
|
output = input |
|
else: |
|
input = input.convert("RGB") |
|
max_size = max(input.size) |
|
scale = min(1, 1024 / max_size) |
|
if scale < 1: |
|
input = input.resize( |
|
(int(input.width * scale), int(input.height * scale)), |
|
Image.Resampling.LANCZOS, |
|
) |
|
if getattr(self, "rembg_session", None) is None: |
|
self.rembg_session = rembg.new_session("u2net") |
|
output = rembg.remove(input, session=self.rembg_session) |
|
output_np = np.array(output) |
|
alpha = output_np[:, :, 3] |
|
bbox = np.argwhere(alpha > 0.8 * 255) |
|
bbox = ( |
|
np.min(bbox[:, 1]), |
|
np.min(bbox[:, 0]), |
|
np.max(bbox[:, 1]), |
|
np.max(bbox[:, 0]), |
|
) |
|
center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2 |
|
size = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) |
|
size = int(size * 1.2) |
|
bbox = ( |
|
center[0] - size // 2, |
|
center[1] - size // 2, |
|
center[0] + size // 2, |
|
center[1] + size // 2, |
|
) |
|
output = output.crop(bbox) |
|
output = output.resize((518, 518), Image.Resampling.LANCZOS) |
|
output = np.array(output).astype(np.float32) / 255 |
|
output = output[:, :, :3] * output[:, :, 3:4] |
|
output = Image.fromarray((output * 255).astype(np.uint8)) |
|
return output |
|
|
|
@torch.no_grad() |
|
def encode_image( |
|
self, image: Union[torch.Tensor, list[Image.Image]] |
|
) -> torch.Tensor: |
|
""" |
|
Encode the image. |
|
|
|
Args: |
|
image (Union[torch.Tensor, list[Image.Image]]): The image to encode |
|
|
|
Returns: |
|
torch.Tensor: The encoded features. |
|
""" |
|
if isinstance(image, torch.Tensor): |
|
assert image.ndim == 4, "Image tensor should be batched (B, C, H, W)" |
|
elif isinstance(image, list): |
|
assert all( |
|
isinstance(i, Image.Image) for i in image |
|
), "Image list should be list of PIL images" |
|
image = [i.resize((518, 518), Image.LANCZOS) for i in image] |
|
image = [np.array(i.convert("RGB")).astype(np.float32) / 255 for i in image] |
|
image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image] |
|
image = torch.stack(image).to(self.device) |
|
else: |
|
raise ValueError(f"Unsupported type of image: {type(image)}") |
|
|
|
image = self.image_cond_model_transform(image).to(self.device) |
|
features = self.models["image_cond_model"](image, is_training=True)["x_prenorm"] |
|
patchtokens = F.layer_norm(features, features.shape[-1:]) |
|
return patchtokens |
|
|
|
def get_cond(self, image: Union[torch.Tensor, list[Image.Image]]) -> dict: |
|
""" |
|
Get the conditioning information for the model. |
|
|
|
Args: |
|
image (Union[torch.Tensor, list[Image.Image]]): The image prompts. |
|
|
|
Returns: |
|
dict: The conditioning information |
|
""" |
|
cond = self.encode_image(image) |
|
neg_cond = torch.zeros_like(cond) |
|
return { |
|
"cond": cond, |
|
"neg_cond": neg_cond, |
|
} |
|
|
|
def sample_sparse_structure( |
|
self, |
|
cond: dict, |
|
num_samples: int = 1, |
|
sampler_params: dict = {}, |
|
) -> torch.Tensor: |
|
""" |
|
Sample sparse structures with the given conditioning. |
|
|
|
Args: |
|
cond (dict): The conditioning information. |
|
num_samples (int): The number of samples to generate. |
|
sampler_params (dict): Additional parameters for the sampler. |
|
""" |
|
|
|
flow_model = self.models["sparse_structure_flow_model"] |
|
reso = flow_model.resolution |
|
noise = torch.randn(num_samples, flow_model.in_channels, reso, reso, reso).to( |
|
self.device |
|
) |
|
sampler_params = {**self.sparse_structure_sampler_params, **sampler_params} |
|
z_s = self.sparse_structure_sampler.sample( |
|
flow_model, noise, **cond, **sampler_params, verbose=True |
|
).samples |
|
|
|
|
|
decoder = self.models["sparse_structure_decoder"] |
|
coords = torch.argwhere(decoder(z_s) > 0)[:, [0, 2, 3, 4]].int() |
|
|
|
return coords |
|
|
|
def decode_slat( |
|
self, |
|
slat: sp.SparseTensor, |
|
formats: List[str] = ["mesh", "gaussian", "radiance_field"], |
|
) -> dict: |
|
""" |
|
Decode the structured latent. |
|
|
|
Args: |
|
slat (sp.SparseTensor): The structured latent. |
|
formats (List[str]): The formats to decode the structured latent to. |
|
|
|
Returns: |
|
dict: The decoded structured latent. |
|
""" |
|
ret = {} |
|
if "mesh" in formats: |
|
ret["mesh"] = self.models["slat_decoder_mesh"](slat) |
|
if "gaussian" in formats: |
|
ret["gaussian"] = self.models["slat_decoder_gs"](slat) |
|
if "radiance_field" in formats: |
|
ret["radiance_field"] = self.models["slat_decoder_rf"](slat) |
|
return ret |
|
|
|
def sample_slat( |
|
self, |
|
cond: dict, |
|
coords: torch.Tensor, |
|
sampler_params: dict = {}, |
|
) -> sp.SparseTensor: |
|
""" |
|
Sample structured latent with the given conditioning. |
|
|
|
Args: |
|
cond (dict): The conditioning information. |
|
coords (torch.Tensor): The coordinates of the sparse structure. |
|
sampler_params (dict): Additional parameters for the sampler. |
|
""" |
|
|
|
flow_model = self.models["slat_flow_model"] |
|
noise = sp.SparseTensor( |
|
feats=torch.randn(coords.shape[0], flow_model.in_channels).to(self.device), |
|
coords=coords, |
|
) |
|
sampler_params = {**self.slat_sampler_params, **sampler_params} |
|
slat = self.slat_sampler.sample( |
|
flow_model, noise, **cond, **sampler_params, verbose=True |
|
).samples |
|
|
|
std = torch.tensor(self.slat_normalization["std"])[None].to(slat.device) |
|
mean = torch.tensor(self.slat_normalization["mean"])[None].to(slat.device) |
|
slat = slat * std + mean |
|
|
|
return slat |
|
|
|
@torch.no_grad() |
|
def run( |
|
self, |
|
image: Image.Image, |
|
num_samples: int = 1, |
|
seed: int = 42, |
|
sparse_structure_sampler_params: dict = {}, |
|
slat_sampler_params: dict = {}, |
|
formats: List[str] = ["mesh", "gaussian", "radiance_field"], |
|
preprocess_image: bool = True, |
|
) -> dict: |
|
""" |
|
Run the pipeline. |
|
|
|
Args: |
|
image (Image.Image): The image prompt. |
|
num_samples (int): The number of samples to generate. |
|
sparse_structure_sampler_params (dict): Additional parameters for the sparse structure sampler. |
|
slat_sampler_params (dict): Additional parameters for the structured latent sampler. |
|
preprocess_image (bool): Whether to preprocess the image. |
|
""" |
|
if preprocess_image: |
|
image = self.preprocess_image(image) |
|
cond = self.get_cond([image]) |
|
torch.manual_seed(seed) |
|
coords = self.sample_sparse_structure( |
|
cond, num_samples, sparse_structure_sampler_params |
|
) |
|
slat = self.sample_slat(cond, coords, slat_sampler_params) |
|
return self.decode_slat(slat, formats) |
|
|
|
@contextmanager |
|
def inject_sampler_multi_image( |
|
self, |
|
sampler_name: str, |
|
num_images: int, |
|
num_steps: int, |
|
mode: Literal["stochastic", "multidiffusion"] = "stochastic", |
|
): |
|
""" |
|
Inject a sampler with multiple images as condition. |
|
|
|
Args: |
|
sampler_name (str): The name of the sampler to inject. |
|
num_images (int): The number of images to condition on. |
|
num_steps (int): The number of steps to run the sampler for. |
|
""" |
|
sampler = getattr(self, sampler_name) |
|
setattr(sampler, f"_old_inference_model", sampler._inference_model) |
|
|
|
if mode == "stochastic": |
|
if num_images > num_steps: |
|
print( |
|
f"\033[93mWarning: number of conditioning images is greater than number of steps for {sampler_name}. " |
|
"This may lead to performance degradation.\033[0m" |
|
) |
|
|
|
cond_indices = (np.arange(num_steps) % num_images).tolist() |
|
|
|
def _new_inference_model(self, model, x_t, t, cond, **kwargs): |
|
cond_idx = cond_indices.pop(0) |
|
cond_i = cond[cond_idx : cond_idx + 1] |
|
return self._old_inference_model(model, x_t, t, cond=cond_i, **kwargs) |
|
|
|
elif mode == "multidiffusion": |
|
from .samplers import FlowEulerSampler |
|
|
|
def _new_inference_model( |
|
self, |
|
model, |
|
x_t, |
|
t, |
|
cond, |
|
neg_cond, |
|
cfg_strength, |
|
cfg_interval, |
|
**kwargs, |
|
): |
|
if cfg_interval[0] <= t <= cfg_interval[1]: |
|
preds = [] |
|
for i in range(len(cond)): |
|
preds.append( |
|
FlowEulerSampler._inference_model( |
|
self, model, x_t, t, cond[i : i + 1], **kwargs |
|
) |
|
) |
|
pred = sum(preds) / len(preds) |
|
neg_pred = FlowEulerSampler._inference_model( |
|
self, model, x_t, t, neg_cond, **kwargs |
|
) |
|
return (1 + cfg_strength) * pred - cfg_strength * neg_pred |
|
else: |
|
preds = [] |
|
for i in range(len(cond)): |
|
preds.append( |
|
FlowEulerSampler._inference_model( |
|
self, model, x_t, t, cond[i : i + 1], **kwargs |
|
) |
|
) |
|
pred = sum(preds) / len(preds) |
|
return pred |
|
|
|
else: |
|
raise ValueError(f"Unsupported mode: {mode}") |
|
|
|
sampler._inference_model = _new_inference_model.__get__(sampler, type(sampler)) |
|
|
|
yield |
|
|
|
sampler._inference_model = sampler._old_inference_model |
|
delattr(sampler, f"_old_inference_model") |
|
|
|
@torch.no_grad() |
|
def run_multi_image( |
|
self, |
|
images: List[Image.Image], |
|
num_samples: int = 1, |
|
seed: int = 42, |
|
sparse_structure_sampler_params: dict = {}, |
|
slat_sampler_params: dict = {}, |
|
formats: List[str] = ["mesh", "gaussian", "radiance_field"], |
|
preprocess_image: bool = True, |
|
mode: Literal["stochastic", "multidiffusion"] = "stochastic", |
|
) -> dict: |
|
""" |
|
Run the pipeline with multiple images as condition |
|
|
|
Args: |
|
images (List[Image.Image]): The multi-view images of the assets |
|
num_samples (int): The number of samples to generate. |
|
sparse_structure_sampler_params (dict): Additional parameters for the sparse structure sampler. |
|
slat_sampler_params (dict): Additional parameters for the structured latent sampler. |
|
preprocess_image (bool): Whether to preprocess the image. |
|
""" |
|
if preprocess_image: |
|
images = [self.preprocess_image(image) for image in images] |
|
cond = self.get_cond(images) |
|
cond["neg_cond"] = cond["neg_cond"][:1] |
|
torch.manual_seed(seed) |
|
ss_steps = { |
|
**self.sparse_structure_sampler_params, |
|
**sparse_structure_sampler_params, |
|
}.get("steps") |
|
with self.inject_sampler_multi_image( |
|
"sparse_structure_sampler", len(images), ss_steps, mode=mode |
|
): |
|
coords = self.sample_sparse_structure( |
|
cond, num_samples, sparse_structure_sampler_params |
|
) |
|
slat_steps = {**self.slat_sampler_params, **slat_sampler_params}.get("steps") |
|
with self.inject_sampler_multi_image( |
|
"slat_sampler", len(images), slat_steps, mode=mode |
|
): |
|
slat = self.sample_slat(cond, coords, slat_sampler_params) |
|
return self.decode_slat(slat, formats) |
|
|