Spaces:
Sleeping
Sleeping
# -*- coding: utf-8 -*- | |
# | |
# @File: inference.py | |
# @Author: Haozhe Xie | |
# @Date: 2024-03-02 16:30:00 | |
# @Last Modified by: Haozhe Xie | |
# @Last Modified at: 2024-10-13 15:17:20 | |
# @Email: root@haozhexie.com | |
import cv2 | |
import math | |
import numpy as np | |
import scipy.spatial.transform | |
import torch | |
from tqdm import tqdm | |
CLASSES = { | |
"NULL": 0, | |
"ROAD": 1, | |
"BLDG_FACADE": 2, | |
"GREEN_LANDS": 3, | |
"CONSTRUCTION": 4, | |
"COAST_ZONES": 5, | |
"ZONE": 6, | |
"BLDG_ROOF": 7, | |
} | |
SCALES = { | |
"ROAD": 2, | |
"BLDG_FACADE": 1, | |
"BLDG_ROOF": 1, | |
"GREEN_LANDS": 2, | |
"CONSTRUCTION": 1, | |
"COAST_ZONES": 4, | |
"ZONE": 2, | |
} | |
CONSTANTS = { | |
"CAM_K": [1528.1469407006614, 0, 480, 0, 1528.1469407006614, 270, 0, 0, 1], | |
"SENSOR_SIZE": [960, 540], | |
"BLDG_INST_RANGE": [100, 16384], | |
"PROJECTION_SIZE": 2048, | |
"POINT_SCALE_FACTOR": 0.5, | |
"SPECIAL_Z_SCALE_CLASSES": [ | |
CLASSES["ROAD"], | |
CLASSES["COAST_ZONES"], | |
CLASSES["ZONE"], | |
], | |
} | |
def get_instance_seg_map(seg_map): | |
# Mapping constructions to buildings | |
seg_map[seg_map == CLASSES["CONSTRUCTION"]] = CLASSES["BLDG_FACADE"] | |
# Use connected components to get building instances | |
_, labels, _, _ = cv2.connectedComponentsWithStats( | |
(seg_map == CLASSES["BLDG_FACADE"]).astype(np.uint8), connectivity=4 | |
) | |
# Remove non-building instance masks | |
labels[seg_map != CLASSES["BLDG_FACADE"]] = 0 | |
# Building instance mask | |
building_mask = labels != 0 | |
# Make building instance IDs are even numbers and start from 10 | |
# Assume the ID of a facade instance is 2k, the corresponding roof instance is 2k - 1. | |
labels = (labels + CONSTANTS["BLDG_INST_RANGE"][0]) * 2 | |
seg_map[seg_map == CLASSES["BLDG_FACADE"]] = 0 | |
seg_map = seg_map * (1 - building_mask) + labels * building_mask | |
assert np.max(labels) < 2147483648 | |
return seg_map.astype(np.int32) | |
def get_point_map(seg_map): | |
inverted_index = {v: k for k, v in CLASSES.items()} | |
pts_map = np.zeros(seg_map.shape, dtype=bool) | |
for c in np.unique(seg_map): | |
cls_name = inverted_index[c] | |
if cls_name == "NULL": | |
continue | |
mask = seg_map == c | |
pt_map = _get_point_map(seg_map.shape, SCALES[cls_name]) | |
pt_map[~mask] = False | |
pts_map += pt_map | |
return pts_map | |
def _get_point_map(map_size, stride): | |
pts_map = np.zeros(map_size, dtype=bool) | |
ys = np.arange(0, map_size[0], stride) | |
xs = np.arange(0, map_size[1], stride) | |
coords = np.stack(np.meshgrid(ys, xs), axis=-1).reshape(-1, 2) | |
pts_map[coords[:, 0], coords[:, 1]] = True | |
return pts_map | |
def get_centers(ins_map, td_hf): | |
centers = {} | |
instances = np.unique(ins_map) | |
for i in tqdm(instances, desc="Calculating centers ..."): | |
if i >= CONSTANTS["BLDG_INST_RANGE"][0]: | |
ds_mask = ins_map == i | |
contours, _ = cv2.findContours( | |
ds_mask.astype(np.uint8), | |
cv2.RETR_EXTERNAL, | |
cv2.CHAIN_APPROX_SIMPLE, | |
) | |
contours = np.vstack(contours).reshape(-1, 2) | |
min_x, max_x = np.min(contours[:, 0]), np.max(contours[:, 0]) | |
min_y, max_y = np.min(contours[:, 1]), np.max(contours[:, 1]) | |
max_z = np.max(td_hf[ds_mask]) + 1 | |
else: | |
min_x, max_x = 0, CONSTANTS["PROJECTION_SIZE"] | |
min_y, max_y = 0, CONSTANTS["PROJECTION_SIZE"] | |
max_z = np.max(td_hf) | |
centers[i] = np.array( | |
[ | |
(min_x + max_x) / 2, | |
(min_y + max_y) / 2, | |
(max_x - min_x), | |
(max_y - min_y), | |
max_z, | |
], | |
dtype=np.float32, | |
) | |
return centers | |
def generate_city( | |
fgm, bgm, city_layout, cx, cy, radius, altitude, azimuth, style_lut=None | |
): | |
import gaussiancity.extensions.diff_gaussian_rasterization as dgr | |
device = torch.device("cuda") | |
gr = dgr.GaussianRasterizerWrapper( | |
np.array(CONSTANTS["CAM_K"], dtype=np.float32).reshape((3, 3)), | |
CONSTANTS["SENSOR_SIZE"], | |
flip_lr=True, | |
flip_ud=False, | |
device=device, | |
) | |
layout = _get_local_layout( | |
city_layout, | |
cx, | |
cy, | |
CONSTANTS["PROJECTION_SIZE"] // 2, | |
CONSTANTS["BLDG_INST_RANGE"], | |
device, | |
) | |
bev_pts = _get_bev_points(layout, SCALES, CLASSES) | |
bev_pt_classes = _instances_to_classes( | |
bev_pts[:, [3]], CONSTANTS["BLDG_INST_RANGE"], CLASSES | |
) | |
bev_pt_classes_onehot = _get_onehot_seg(bev_pt_classes, len(CLASSES)) | |
bev_pt_scales = _get_point_scales( | |
bev_pt_classes, | |
SCALES, | |
CLASSES, | |
CONSTANTS["SPECIAL_Z_SCALE_CLASSES"], | |
) | |
bev_pts = torch.cat([bev_pts, bev_pt_scales, bev_pt_classes_onehot], dim=1) | |
# print(bev_pts.shape) # [N, XYZ + Inst + Scale3D + N_CLASSES] | |
if style_lut is None: | |
style_lut = _get_style_lut( | |
layout["CTR"], | |
{"BLDG": fgm, "REST": bgm}, | |
{ | |
"BLDG": CONSTANTS["BLDG_INST_RANGE"], | |
"REST": [0, CONSTANTS["BLDG_INST_RANGE"][0]], | |
}, | |
device, | |
) | |
cam_look_at, cam_pose = _get_orbit_camera_pose( | |
radius, altitude, azimuth, CONSTANTS["PROJECTION_SIZE"] // 2, device | |
) | |
vp_idx = _get_visible_points( | |
bev_pts[:, :3], | |
bev_pt_scales, | |
CONSTANTS["CAM_K"], | |
CONSTANTS["SENSOR_SIZE"], | |
cam_pose[:3], | |
cam_look_at, | |
) | |
gs_attrs = _get_gs_attrs( | |
bev_pts[vp_idx], | |
layout["TD_HF"].float(), | |
layout["SEG"].float(), | |
style_lut, | |
layout["CTR"], | |
{"BLDG": fgm, "REST": bgm}, | |
CONSTANTS["POINT_SCALE_FACTOR"], | |
CONSTANTS["BLDG_INST_RANGE"], | |
) | |
return _render(gs_attrs, gr, cam_pose) | |
def _get_local_layout(city_layout, cx, cy, half_proj_size, bldg_inst_range, device): | |
x_min, x_max = cx - half_proj_size, cx + half_proj_size | |
y_min, y_max = cy - half_proj_size, cy + half_proj_size | |
_layout = { | |
k: torch.from_numpy(v[None, None, y_min:y_max, x_min:x_max]).cuda(device) | |
for k, v in city_layout.items() | |
if k in ["TD_HF", "BU_HF", "SEG", "INS", "PTS"] | |
} | |
_layout["SEG"] = _get_onehot_seg(_layout["SEG"], len(CLASSES)) | |
_instances = torch.unique(_layout["INS"]) | |
_centers = {} | |
for inst in _instances: | |
inst = inst.item() | |
if inst >= bldg_inst_range[0]: | |
_centers[inst] = torch.from_numpy(city_layout["CTR"][inst]).cuda(device) | |
_centers[inst][0] -= x_min | |
_centers[inst][1] -= y_min | |
_centers[inst + 1] = _centers[inst] # Fix the centers for BLDG_ROOF | |
else: | |
_centers[inst] = torch.from_numpy(city_layout["CTR"][inst]).cuda(device) | |
_centers[inst][0] = x_min | |
_centers[inst][1] = y_min | |
_layout["CTR"] = _centers | |
return _layout | |
def _get_onehot_seg(seg_map, n_classes): | |
shape = seg_map.shape | |
# shape -> NxCxHxW or NxC | |
# assert shape[1] == 1 | |
output_shape = (shape[0], n_classes, *shape[2:]) | |
one_hot_masks = torch.zeros(output_shape, device=seg_map.device, dtype=torch.bool) | |
for i in range(n_classes): | |
one_hot_masks[:, [i]] = seg_map == i | |
return one_hot_masks | |
def _get_style_lut(centers, models, inst_ranges, device, z_dim=256): | |
lut = {ins: torch.rand(1, z_dim, device=device) for ins in centers.keys()} | |
for k, v in models.items(): | |
if v is None: | |
continue | |
if v.module.cfg.Z_DIM is None: | |
for i in range(*inst_ranges[k]): | |
if i in lut: | |
del lut[i] | |
continue | |
if hasattr(v.module, "z"): | |
zs = v.module.z | |
lut.update( | |
{ | |
ins: zs[np.random.choice(list(zs.keys()))].unsqueeze(0) | |
for ins in centers.keys() | |
} | |
) | |
return lut | |
def _get_orbit_camera_pose(radius, altitude, azimuth, half_proj_size, device): | |
cx, cy = half_proj_size, half_proj_size | |
theta = np.deg2rad(azimuth) | |
cam_x = cx + radius * math.cos(theta) | |
cam_y = cy + radius * math.sin(theta) | |
cam_pos = np.array([cam_x, cam_y, altitude], dtype=np.float32) | |
cam_look_at = np.array([cx, cy, 1], dtype=np.float32) | |
quat = _get_quat_from_look_at(cam_pos, cam_look_at) | |
return torch.tensor([*cam_look_at], device=device), torch.tensor( | |
[*cam_pos, *quat], device=device | |
) | |
def _get_quat_from_look_at(cam_pos, cam_look_at): | |
fwd_vec = cam_look_at - cam_pos | |
fwd_vec /= np.linalg.norm(fwd_vec) | |
up_vec = np.array([0, 0, 1]) | |
right_vec = np.cross(up_vec, fwd_vec) | |
right_vec /= np.linalg.norm(right_vec) | |
up_vec = np.cross(fwd_vec, right_vec) | |
R = np.stack([fwd_vec, right_vec, up_vec], axis=1) | |
return scipy.spatial.transform.Rotation.from_matrix(R).as_quat() | |
def _get_bev_points(layout, scales, classes): | |
import gaussiancity.extensions.voxlib | |
assert torch.max(layout["INS"]) < 16384 | |
# torch.nonzero(torch.zeros(2048, 2048, 512).cuda()) | |
# -> nonzero is not supported for tensors with more than INT_MAX elements | |
# torch.nonzero(torch.zeros(2048, 2048, 508).cuda()) | |
# -> an illegal memory access was encountered | |
assert torch.max(layout["TD_HF"]) <= 500 | |
volume = gaussiancity.extensions.voxlib.maps_to_volume( | |
layout["INS"].squeeze().short(), | |
layout["TD_HF"].squeeze().short(), | |
layout["BU_HF"].squeeze().short(), | |
layout["PTS"].squeeze().bool(), | |
torch.tensor( | |
[scales[k] if k in scales else 0 for k in classes.keys()], | |
dtype=torch.int8, | |
device=layout["INS"].device, | |
), | |
) | |
non_zero_indices = torch.nonzero(volume, as_tuple=False) | |
non_zero_values = volume[ | |
non_zero_indices[:, 0], non_zero_indices[:, 1], non_zero_indices[:, 2] | |
] | |
return torch.cat( | |
[non_zero_indices.short(), non_zero_values.unsqueeze(dim=1)], dim=1 | |
) | |
def _instances_to_classes(instances, bldg_inst_range, bldg_classes): | |
bldg_facade_idx = (instances >= bldg_inst_range[0]) & (instances % 2 == 0) | |
bldg_roof_idx = (instances >= bldg_inst_range[0]) & (instances % 2 == 1) | |
classes = instances.clone() | |
classes[bldg_facade_idx] = bldg_classes["BLDG_FACADE"] | |
classes[bldg_roof_idx] = bldg_classes["BLDG_ROOF"] | |
return classes | |
def _get_point_scales(pt_classes, scales, classes, special_z_scale_classes=[]): | |
pt_scales = pt_classes.clone() | |
for k, v in scales.items(): | |
pt_scales[pt_classes == classes[k]] = v | |
pt_scales_3d = torch.ones_like(pt_scales).repeat(1, 3) * pt_scales | |
# Set the z-scale = 1 for roads, zones, and waters | |
pt_scales_3d[..., 2][ | |
torch.isin( | |
pt_classes.squeeze(dim=-1), | |
torch.tensor( | |
list(special_z_scale_classes), | |
device=pt_classes.device, | |
), | |
) | |
] = 1 | |
return pt_scales_3d | |
def _get_visible_points(points, scales, K, sensor_size, cam_pos, cam_look_at): | |
## NOTE: Each point is assigned with a unique ID. The values in the rendered map | |
## denotes the visibility of the points. The values are the same as the point IDs. | |
# Generate 3D volume | |
volume, offsets = _get_volume(points, scales) | |
# Ray-voxel intersection | |
vp_map = _get_ray_voxel_intersection( | |
K, sensor_size, cam_pos - offsets, cam_look_at - cam_pos, volume | |
) | |
## Generate the instance segmentation map as a side product | |
# ins_map = instances[vp_map] | |
# null_mask = vp_map == -1 | |
# ins_map[null_mask] = null_class_id | |
# Manually release the memory to avoid OOM | |
del volume | |
torch.cuda.empty_cache() | |
vp_idx = torch.unique(vp_map) | |
return vp_idx[vp_idx >= 0] | |
def _get_volume(points, scales): | |
import gaussiancity.extensions.voxlib | |
x_min, x_max = torch.min(points[:, 0]).item(), torch.max(points[:, 0]).item() | |
y_min, y_max = torch.min(points[:, 1]).item(), torch.max(points[:, 1]).item() | |
z_min, z_max = torch.min(points[:, 2]).item(), torch.max(points[:, 2]).item() | |
offsets = torch.tensor( | |
[x_min, y_min, z_min], dtype=torch.int16, device=points.device | |
) | |
# Normalize points coordinates to local coordinate system | |
points = _get_localized_pt_cords(points, offsets) | |
# Generate an empty 3D volume | |
w, h, d = x_max - x_min + 1, y_max - y_min + 1, z_max - z_min + 2 | |
# Generate point IDs | |
# NOTE: The point IDs start from 1 to avoid the conflict with the NULL class. | |
assert points.shape[0] < 2147483648 | |
pt_ids = torch.arange( | |
start=1, end=points.shape[0] + 1, dtype=torch.int32, device=points.device | |
).unsqueeze(dim=1) | |
volume = gaussiancity.extensions.voxlib.points_to_volume( | |
points.contiguous(), pt_ids, scales, h, w, d | |
) | |
return volume, offsets | |
def _get_localized_pt_cords(points, offsets): | |
points[:, 0] -= offsets[0] | |
points[:, 1] -= offsets[1] | |
points[:, 2] -= offsets[2] - 1 | |
return points | |
def _get_ray_voxel_intersection(K, sensor_size, cam_origin, viewdir, volume): | |
import gaussiancity.extensions.voxlib | |
N_MAX_SAMPLES = 1 | |
voxel_id, _, _ = gaussiancity.extensions.voxlib.ray_voxel_intersection_perspective( | |
volume, | |
cam_origin[[1, 0, 2]].float(), | |
viewdir[[1, 0, 2]].float(), | |
torch.tensor([0, 0, 1], dtype=torch.float32), | |
K[0], | |
[K[5], K[2]], | |
[sensor_size[1], sensor_size[0]], | |
N_MAX_SAMPLES, | |
) | |
# NOTE: The point ID for NULL class is -1, the rest point IDs are from 0 to N - 1. | |
# The ray_voxel_intersection_perspective seems not accepting the negative values. | |
return voxel_id.squeeze() - 1 | |
def get_hf_seg_tensor(part_hf, part_seg, layout_cfg, output_device): | |
part_hf = torch.from_numpy(part_hf[None, None, ...]).to(output_device) | |
part_seg = torch.from_numpy(part_seg[None, None, ...]).to(output_device) | |
part_hf = part_hf / CONSTANTS["LAYOUT_MAX_HEIGHT"] | |
part_seg = _masks_to_onehots(part_seg[:, 0, :, :], CONSTANTS["LAYOUT_N_CLASSES"]) | |
return torch.cat([part_hf, part_seg], dim=1) | |
def _masks_to_onehots(masks, n_class, ignored_classes=[]): | |
b, h, w = masks.shape | |
n_class_actual = n_class - len(ignored_classes) | |
one_hot_masks = torch.zeros( | |
(b, n_class_actual, h, w), dtype=torch.float32, device=masks.device | |
) | |
n_class_cnt = 0 | |
for i in range(n_class): | |
if i not in ignored_classes: | |
one_hot_masks[:, n_class_cnt] = masks == i | |
n_class_cnt += 1 | |
return one_hot_masks | |
def _get_gs_attrs( | |
pts, | |
proj_hf, | |
proj_seg, | |
style_lut, | |
centers, | |
models, | |
scale_factor, | |
bldg_inst_range, | |
): | |
n_pts, _ = pts.shape | |
# NOTE: 4: XYZ, Instance ID; 3: Scale; N_CLASSES: One-hot | |
# print(pts.shape) # [N, 4 + 3 + N_CLASSES] | |
bldg_selector = pts[:, 3] >= bldg_inst_range[0] | |
bldg_pts = pts[bldg_selector] | |
rest_pts = pts[~bldg_selector] | |
bldg_attrs = _get_pt_input_attrs( | |
bldg_pts[:, :4], | |
centers, | |
style_lut, | |
models["BLDG"].module.cfg.Z_DIM, | |
bldg_inst_range, | |
) | |
rest_attrs = _get_pt_input_attrs( | |
rest_pts[:, :4], | |
centers, | |
style_lut, | |
models["REST"].module.cfg.Z_DIM, | |
bldg_inst_range, | |
) | |
bldg_colors = _get_gs_colors( | |
bldg_pts, bldg_attrs, proj_hf, proj_seg, models["BLDG"] | |
) | |
rest_colors = _get_gs_colors( | |
rest_pts, rest_attrs, proj_hf, proj_seg, models["REST"] | |
) | |
abs_xyz = torch.cat([bldg_pts[:, :3], rest_pts[:, :3]], dim=0) | |
scales = torch.cat([bldg_pts[:, 4:7], rest_pts[:, 4:7]], dim=0) * scale_factor | |
rgb = torch.cat([bldg_colors, rest_colors], dim=0) | |
# Attributes with default values | |
opacity = torch.ones((n_pts, 1), device=pts.device) | |
rotations = torch.cat( | |
[ | |
torch.ones(n_pts, 1, device=pts.device), | |
torch.zeros(n_pts, 3, device=pts.device), | |
], | |
dim=-1, | |
) | |
return torch.cat((abs_xyz, opacity, scales, rotations, rgb), dim=-1) | |
def _get_pt_input_attrs(pts, centers, style_lut, z_dim, bldg_inst_range): | |
n_pts = pts.shape[0] | |
instances = torch.unique(pts[:, -1]) | |
rel_xyz = torch.zeros(1, n_pts, 3, dtype=torch.float32, device=pts.device) | |
batch_idx = torch.zeros(1, n_pts, dtype=torch.int32, device=pts.device) | |
zs = {} if z_dim is not None else None | |
for idx, ins in enumerate(instances): | |
ins = ins.item() | |
is_pts = pts[:, -1] == ins | |
cx, cy, w, h, d = centers[ins] | |
if ins >= bldg_inst_range[0]: | |
rel_xyz[:, is_pts, 0] = (pts[is_pts, 0] - cx) / w * 2 if w > 0 else 0 | |
rel_xyz[:, is_pts, 1] = (pts[is_pts, 1] - cy) / h * 2 if h > 0 else 0 | |
else: | |
# Make the BG contiguous | |
period_x = torch.ceil((pts[is_pts, 0] / w / 2) - 0.5) | |
period_y = torch.ceil((pts[is_pts, 1] / h / 2) - 0.5) | |
rel_xyz[:, is_pts, 0] = ( | |
(pts[is_pts, 0] - 2 * period_x * w) * (-1) ** period_x | |
) / w | |
rel_xyz[:, is_pts, 1] = ( | |
(pts[is_pts, 1] - 2 * period_y * h) * (-1) ** period_y | |
) / h | |
rel_xyz[:, is_pts, 2] = ( | |
torch.clip(pts[is_pts, 2] / d * 2 - 1, -1, 1) if d > 0 else 0 | |
) | |
batch_idx[:, is_pts] = idx | |
if zs is not None: | |
zs[ins] = {"z": style_lut[ins], "idx": is_pts.unsqueeze(dim=0)} | |
return rel_xyz, batch_idx, zs | |
def _get_gs_colors(pts, pt_attrs, proj_hf, proj_seg, model): | |
if pts.shape[0] == 0: | |
return torch.empty(0, 3, dtype=torch.float32, device=pts.device) | |
abs_xyz, onehots = pts[None, :, :3], pts[None, :, 7:] | |
rel_xyz, batch_idx, zs = pt_attrs | |
proj_uv = None | |
if model.module.cfg.ENCODER is not None: | |
proj_uv = get_projection_uv(abs_xyz) | |
with torch.no_grad(): | |
# TODO: Optimize the _instance_forward in Generator | |
gs_attrs = model( | |
proj_uv, rel_xyz, batch_idx, onehots.float(), zs, proj_hf, proj_seg | |
) | |
return gs_attrs["rgb"].squeeze(dim=0) | |
def get_projection_uv(xyz, proj_tlp=None, proj_size=2048): | |
n_pts = xyz.size(1) | |
if proj_tlp is None: | |
proj_uv = xyz[..., :2].clone().float() | |
else: | |
proj_uv = xyz[..., :2] - proj_tlp.unsqueeze(dim=1) | |
assert proj_uv.size() == (xyz.size(0), n_pts, 2) | |
proj_uv[..., 0] /= proj_size | |
proj_uv[..., 1] /= proj_size | |
# Normalize to [-1, 1] | |
return proj_uv * 2 - 1 | |
def _render(gs_attrs, rasterizator, cam_pose): | |
import torchvision.transforms.functional as F | |
with torch.no_grad(): | |
img = rasterizator( | |
gs_attrs, | |
cam_pose[:3], # Position | |
cam_pose[3:], # Quaternion | |
) | |
img = img.squeeze() / 2 + 0.5 | |
img = F.adjust_brightness(img, 1.2) | |
img = F.adjust_contrast(img, 1.2) | |
return (img * 255).permute(1, 2, 0).cpu().numpy().astype(np.uint8) | |