hzxie's picture
perf: faster initialize env.
1b18401 verified
# -*- coding: utf-8 -*-
#
# @File: inference.py
# @Author: Haozhe Xie
# @Date: 2024-03-02 16:30:00
# @Last Modified by: Haozhe Xie
# @Last Modified at: 2024-10-13 15:17:20
# @Email: root@haozhexie.com
import cv2
import math
import numpy as np
import scipy.spatial.transform
import torch
from tqdm import tqdm
CLASSES = {
"NULL": 0,
"ROAD": 1,
"BLDG_FACADE": 2,
"GREEN_LANDS": 3,
"CONSTRUCTION": 4,
"COAST_ZONES": 5,
"ZONE": 6,
"BLDG_ROOF": 7,
}
SCALES = {
"ROAD": 2,
"BLDG_FACADE": 1,
"BLDG_ROOF": 1,
"GREEN_LANDS": 2,
"CONSTRUCTION": 1,
"COAST_ZONES": 4,
"ZONE": 2,
}
CONSTANTS = {
"CAM_K": [1528.1469407006614, 0, 480, 0, 1528.1469407006614, 270, 0, 0, 1],
"SENSOR_SIZE": [960, 540],
"BLDG_INST_RANGE": [100, 16384],
"PROJECTION_SIZE": 2048,
"POINT_SCALE_FACTOR": 0.5,
"SPECIAL_Z_SCALE_CLASSES": [
CLASSES["ROAD"],
CLASSES["COAST_ZONES"],
CLASSES["ZONE"],
],
}
def get_instance_seg_map(seg_map):
# Mapping constructions to buildings
seg_map[seg_map == CLASSES["CONSTRUCTION"]] = CLASSES["BLDG_FACADE"]
# Use connected components to get building instances
_, labels, _, _ = cv2.connectedComponentsWithStats(
(seg_map == CLASSES["BLDG_FACADE"]).astype(np.uint8), connectivity=4
)
# Remove non-building instance masks
labels[seg_map != CLASSES["BLDG_FACADE"]] = 0
# Building instance mask
building_mask = labels != 0
# Make building instance IDs are even numbers and start from 10
# Assume the ID of a facade instance is 2k, the corresponding roof instance is 2k - 1.
labels = (labels + CONSTANTS["BLDG_INST_RANGE"][0]) * 2
seg_map[seg_map == CLASSES["BLDG_FACADE"]] = 0
seg_map = seg_map * (1 - building_mask) + labels * building_mask
assert np.max(labels) < 2147483648
return seg_map.astype(np.int32)
def get_point_map(seg_map):
inverted_index = {v: k for k, v in CLASSES.items()}
pts_map = np.zeros(seg_map.shape, dtype=bool)
for c in np.unique(seg_map):
cls_name = inverted_index[c]
if cls_name == "NULL":
continue
mask = seg_map == c
pt_map = _get_point_map(seg_map.shape, SCALES[cls_name])
pt_map[~mask] = False
pts_map += pt_map
return pts_map
def _get_point_map(map_size, stride):
pts_map = np.zeros(map_size, dtype=bool)
ys = np.arange(0, map_size[0], stride)
xs = np.arange(0, map_size[1], stride)
coords = np.stack(np.meshgrid(ys, xs), axis=-1).reshape(-1, 2)
pts_map[coords[:, 0], coords[:, 1]] = True
return pts_map
def get_centers(ins_map, td_hf):
centers = {}
instances = np.unique(ins_map)
for i in tqdm(instances, desc="Calculating centers ..."):
if i >= CONSTANTS["BLDG_INST_RANGE"][0]:
ds_mask = ins_map == i
contours, _ = cv2.findContours(
ds_mask.astype(np.uint8),
cv2.RETR_EXTERNAL,
cv2.CHAIN_APPROX_SIMPLE,
)
contours = np.vstack(contours).reshape(-1, 2)
min_x, max_x = np.min(contours[:, 0]), np.max(contours[:, 0])
min_y, max_y = np.min(contours[:, 1]), np.max(contours[:, 1])
max_z = np.max(td_hf[ds_mask]) + 1
else:
min_x, max_x = 0, CONSTANTS["PROJECTION_SIZE"]
min_y, max_y = 0, CONSTANTS["PROJECTION_SIZE"]
max_z = np.max(td_hf)
centers[i] = np.array(
[
(min_x + max_x) / 2,
(min_y + max_y) / 2,
(max_x - min_x),
(max_y - min_y),
max_z,
],
dtype=np.float32,
)
return centers
def generate_city(
fgm, bgm, city_layout, cx, cy, radius, altitude, azimuth, style_lut=None
):
import gaussiancity.extensions.diff_gaussian_rasterization as dgr
device = torch.device("cuda")
gr = dgr.GaussianRasterizerWrapper(
np.array(CONSTANTS["CAM_K"], dtype=np.float32).reshape((3, 3)),
CONSTANTS["SENSOR_SIZE"],
flip_lr=True,
flip_ud=False,
device=device,
)
layout = _get_local_layout(
city_layout,
cx,
cy,
CONSTANTS["PROJECTION_SIZE"] // 2,
CONSTANTS["BLDG_INST_RANGE"],
device,
)
bev_pts = _get_bev_points(layout, SCALES, CLASSES)
bev_pt_classes = _instances_to_classes(
bev_pts[:, [3]], CONSTANTS["BLDG_INST_RANGE"], CLASSES
)
bev_pt_classes_onehot = _get_onehot_seg(bev_pt_classes, len(CLASSES))
bev_pt_scales = _get_point_scales(
bev_pt_classes,
SCALES,
CLASSES,
CONSTANTS["SPECIAL_Z_SCALE_CLASSES"],
)
bev_pts = torch.cat([bev_pts, bev_pt_scales, bev_pt_classes_onehot], dim=1)
# print(bev_pts.shape) # [N, XYZ + Inst + Scale3D + N_CLASSES]
if style_lut is None:
style_lut = _get_style_lut(
layout["CTR"],
{"BLDG": fgm, "REST": bgm},
{
"BLDG": CONSTANTS["BLDG_INST_RANGE"],
"REST": [0, CONSTANTS["BLDG_INST_RANGE"][0]],
},
device,
)
cam_look_at, cam_pose = _get_orbit_camera_pose(
radius, altitude, azimuth, CONSTANTS["PROJECTION_SIZE"] // 2, device
)
vp_idx = _get_visible_points(
bev_pts[:, :3],
bev_pt_scales,
CONSTANTS["CAM_K"],
CONSTANTS["SENSOR_SIZE"],
cam_pose[:3],
cam_look_at,
)
gs_attrs = _get_gs_attrs(
bev_pts[vp_idx],
layout["TD_HF"].float(),
layout["SEG"].float(),
style_lut,
layout["CTR"],
{"BLDG": fgm, "REST": bgm},
CONSTANTS["POINT_SCALE_FACTOR"],
CONSTANTS["BLDG_INST_RANGE"],
)
return _render(gs_attrs, gr, cam_pose)
def _get_local_layout(city_layout, cx, cy, half_proj_size, bldg_inst_range, device):
x_min, x_max = cx - half_proj_size, cx + half_proj_size
y_min, y_max = cy - half_proj_size, cy + half_proj_size
_layout = {
k: torch.from_numpy(v[None, None, y_min:y_max, x_min:x_max]).cuda(device)
for k, v in city_layout.items()
if k in ["TD_HF", "BU_HF", "SEG", "INS", "PTS"]
}
_layout["SEG"] = _get_onehot_seg(_layout["SEG"], len(CLASSES))
_instances = torch.unique(_layout["INS"])
_centers = {}
for inst in _instances:
inst = inst.item()
if inst >= bldg_inst_range[0]:
_centers[inst] = torch.from_numpy(city_layout["CTR"][inst]).cuda(device)
_centers[inst][0] -= x_min
_centers[inst][1] -= y_min
_centers[inst + 1] = _centers[inst] # Fix the centers for BLDG_ROOF
else:
_centers[inst] = torch.from_numpy(city_layout["CTR"][inst]).cuda(device)
_centers[inst][0] = x_min
_centers[inst][1] = y_min
_layout["CTR"] = _centers
return _layout
def _get_onehot_seg(seg_map, n_classes):
shape = seg_map.shape
# shape -> NxCxHxW or NxC
# assert shape[1] == 1
output_shape = (shape[0], n_classes, *shape[2:])
one_hot_masks = torch.zeros(output_shape, device=seg_map.device, dtype=torch.bool)
for i in range(n_classes):
one_hot_masks[:, [i]] = seg_map == i
return one_hot_masks
def _get_style_lut(centers, models, inst_ranges, device, z_dim=256):
lut = {ins: torch.rand(1, z_dim, device=device) for ins in centers.keys()}
for k, v in models.items():
if v is None:
continue
if v.module.cfg.Z_DIM is None:
for i in range(*inst_ranges[k]):
if i in lut:
del lut[i]
continue
if hasattr(v.module, "z"):
zs = v.module.z
lut.update(
{
ins: zs[np.random.choice(list(zs.keys()))].unsqueeze(0)
for ins in centers.keys()
}
)
return lut
def _get_orbit_camera_pose(radius, altitude, azimuth, half_proj_size, device):
cx, cy = half_proj_size, half_proj_size
theta = np.deg2rad(azimuth)
cam_x = cx + radius * math.cos(theta)
cam_y = cy + radius * math.sin(theta)
cam_pos = np.array([cam_x, cam_y, altitude], dtype=np.float32)
cam_look_at = np.array([cx, cy, 1], dtype=np.float32)
quat = _get_quat_from_look_at(cam_pos, cam_look_at)
return torch.tensor([*cam_look_at], device=device), torch.tensor(
[*cam_pos, *quat], device=device
)
def _get_quat_from_look_at(cam_pos, cam_look_at):
fwd_vec = cam_look_at - cam_pos
fwd_vec /= np.linalg.norm(fwd_vec)
up_vec = np.array([0, 0, 1])
right_vec = np.cross(up_vec, fwd_vec)
right_vec /= np.linalg.norm(right_vec)
up_vec = np.cross(fwd_vec, right_vec)
R = np.stack([fwd_vec, right_vec, up_vec], axis=1)
return scipy.spatial.transform.Rotation.from_matrix(R).as_quat()
def _get_bev_points(layout, scales, classes):
import gaussiancity.extensions.voxlib
assert torch.max(layout["INS"]) < 16384
# torch.nonzero(torch.zeros(2048, 2048, 512).cuda())
# -> nonzero is not supported for tensors with more than INT_MAX elements
# torch.nonzero(torch.zeros(2048, 2048, 508).cuda())
# -> an illegal memory access was encountered
assert torch.max(layout["TD_HF"]) <= 500
volume = gaussiancity.extensions.voxlib.maps_to_volume(
layout["INS"].squeeze().short(),
layout["TD_HF"].squeeze().short(),
layout["BU_HF"].squeeze().short(),
layout["PTS"].squeeze().bool(),
torch.tensor(
[scales[k] if k in scales else 0 for k in classes.keys()],
dtype=torch.int8,
device=layout["INS"].device,
),
)
non_zero_indices = torch.nonzero(volume, as_tuple=False)
non_zero_values = volume[
non_zero_indices[:, 0], non_zero_indices[:, 1], non_zero_indices[:, 2]
]
return torch.cat(
[non_zero_indices.short(), non_zero_values.unsqueeze(dim=1)], dim=1
)
def _instances_to_classes(instances, bldg_inst_range, bldg_classes):
bldg_facade_idx = (instances >= bldg_inst_range[0]) & (instances % 2 == 0)
bldg_roof_idx = (instances >= bldg_inst_range[0]) & (instances % 2 == 1)
classes = instances.clone()
classes[bldg_facade_idx] = bldg_classes["BLDG_FACADE"]
classes[bldg_roof_idx] = bldg_classes["BLDG_ROOF"]
return classes
def _get_point_scales(pt_classes, scales, classes, special_z_scale_classes=[]):
pt_scales = pt_classes.clone()
for k, v in scales.items():
pt_scales[pt_classes == classes[k]] = v
pt_scales_3d = torch.ones_like(pt_scales).repeat(1, 3) * pt_scales
# Set the z-scale = 1 for roads, zones, and waters
pt_scales_3d[..., 2][
torch.isin(
pt_classes.squeeze(dim=-1),
torch.tensor(
list(special_z_scale_classes),
device=pt_classes.device,
),
)
] = 1
return pt_scales_3d
def _get_visible_points(points, scales, K, sensor_size, cam_pos, cam_look_at):
## NOTE: Each point is assigned with a unique ID. The values in the rendered map
## denotes the visibility of the points. The values are the same as the point IDs.
# Generate 3D volume
volume, offsets = _get_volume(points, scales)
# Ray-voxel intersection
vp_map = _get_ray_voxel_intersection(
K, sensor_size, cam_pos - offsets, cam_look_at - cam_pos, volume
)
## Generate the instance segmentation map as a side product
# ins_map = instances[vp_map]
# null_mask = vp_map == -1
# ins_map[null_mask] = null_class_id
# Manually release the memory to avoid OOM
del volume
torch.cuda.empty_cache()
vp_idx = torch.unique(vp_map)
return vp_idx[vp_idx >= 0]
def _get_volume(points, scales):
import gaussiancity.extensions.voxlib
x_min, x_max = torch.min(points[:, 0]).item(), torch.max(points[:, 0]).item()
y_min, y_max = torch.min(points[:, 1]).item(), torch.max(points[:, 1]).item()
z_min, z_max = torch.min(points[:, 2]).item(), torch.max(points[:, 2]).item()
offsets = torch.tensor(
[x_min, y_min, z_min], dtype=torch.int16, device=points.device
)
# Normalize points coordinates to local coordinate system
points = _get_localized_pt_cords(points, offsets)
# Generate an empty 3D volume
w, h, d = x_max - x_min + 1, y_max - y_min + 1, z_max - z_min + 2
# Generate point IDs
# NOTE: The point IDs start from 1 to avoid the conflict with the NULL class.
assert points.shape[0] < 2147483648
pt_ids = torch.arange(
start=1, end=points.shape[0] + 1, dtype=torch.int32, device=points.device
).unsqueeze(dim=1)
volume = gaussiancity.extensions.voxlib.points_to_volume(
points.contiguous(), pt_ids, scales, h, w, d
)
return volume, offsets
def _get_localized_pt_cords(points, offsets):
points[:, 0] -= offsets[0]
points[:, 1] -= offsets[1]
points[:, 2] -= offsets[2] - 1
return points
def _get_ray_voxel_intersection(K, sensor_size, cam_origin, viewdir, volume):
import gaussiancity.extensions.voxlib
N_MAX_SAMPLES = 1
voxel_id, _, _ = gaussiancity.extensions.voxlib.ray_voxel_intersection_perspective(
volume,
cam_origin[[1, 0, 2]].float(),
viewdir[[1, 0, 2]].float(),
torch.tensor([0, 0, 1], dtype=torch.float32),
K[0],
[K[5], K[2]],
[sensor_size[1], sensor_size[0]],
N_MAX_SAMPLES,
)
# NOTE: The point ID for NULL class is -1, the rest point IDs are from 0 to N - 1.
# The ray_voxel_intersection_perspective seems not accepting the negative values.
return voxel_id.squeeze() - 1
def get_hf_seg_tensor(part_hf, part_seg, layout_cfg, output_device):
part_hf = torch.from_numpy(part_hf[None, None, ...]).to(output_device)
part_seg = torch.from_numpy(part_seg[None, None, ...]).to(output_device)
part_hf = part_hf / CONSTANTS["LAYOUT_MAX_HEIGHT"]
part_seg = _masks_to_onehots(part_seg[:, 0, :, :], CONSTANTS["LAYOUT_N_CLASSES"])
return torch.cat([part_hf, part_seg], dim=1)
def _masks_to_onehots(masks, n_class, ignored_classes=[]):
b, h, w = masks.shape
n_class_actual = n_class - len(ignored_classes)
one_hot_masks = torch.zeros(
(b, n_class_actual, h, w), dtype=torch.float32, device=masks.device
)
n_class_cnt = 0
for i in range(n_class):
if i not in ignored_classes:
one_hot_masks[:, n_class_cnt] = masks == i
n_class_cnt += 1
return one_hot_masks
def _get_gs_attrs(
pts,
proj_hf,
proj_seg,
style_lut,
centers,
models,
scale_factor,
bldg_inst_range,
):
n_pts, _ = pts.shape
# NOTE: 4: XYZ, Instance ID; 3: Scale; N_CLASSES: One-hot
# print(pts.shape) # [N, 4 + 3 + N_CLASSES]
bldg_selector = pts[:, 3] >= bldg_inst_range[0]
bldg_pts = pts[bldg_selector]
rest_pts = pts[~bldg_selector]
bldg_attrs = _get_pt_input_attrs(
bldg_pts[:, :4],
centers,
style_lut,
models["BLDG"].module.cfg.Z_DIM,
bldg_inst_range,
)
rest_attrs = _get_pt_input_attrs(
rest_pts[:, :4],
centers,
style_lut,
models["REST"].module.cfg.Z_DIM,
bldg_inst_range,
)
bldg_colors = _get_gs_colors(
bldg_pts, bldg_attrs, proj_hf, proj_seg, models["BLDG"]
)
rest_colors = _get_gs_colors(
rest_pts, rest_attrs, proj_hf, proj_seg, models["REST"]
)
abs_xyz = torch.cat([bldg_pts[:, :3], rest_pts[:, :3]], dim=0)
scales = torch.cat([bldg_pts[:, 4:7], rest_pts[:, 4:7]], dim=0) * scale_factor
rgb = torch.cat([bldg_colors, rest_colors], dim=0)
# Attributes with default values
opacity = torch.ones((n_pts, 1), device=pts.device)
rotations = torch.cat(
[
torch.ones(n_pts, 1, device=pts.device),
torch.zeros(n_pts, 3, device=pts.device),
],
dim=-1,
)
return torch.cat((abs_xyz, opacity, scales, rotations, rgb), dim=-1)
def _get_pt_input_attrs(pts, centers, style_lut, z_dim, bldg_inst_range):
n_pts = pts.shape[0]
instances = torch.unique(pts[:, -1])
rel_xyz = torch.zeros(1, n_pts, 3, dtype=torch.float32, device=pts.device)
batch_idx = torch.zeros(1, n_pts, dtype=torch.int32, device=pts.device)
zs = {} if z_dim is not None else None
for idx, ins in enumerate(instances):
ins = ins.item()
is_pts = pts[:, -1] == ins
cx, cy, w, h, d = centers[ins]
if ins >= bldg_inst_range[0]:
rel_xyz[:, is_pts, 0] = (pts[is_pts, 0] - cx) / w * 2 if w > 0 else 0
rel_xyz[:, is_pts, 1] = (pts[is_pts, 1] - cy) / h * 2 if h > 0 else 0
else:
# Make the BG contiguous
period_x = torch.ceil((pts[is_pts, 0] / w / 2) - 0.5)
period_y = torch.ceil((pts[is_pts, 1] / h / 2) - 0.5)
rel_xyz[:, is_pts, 0] = (
(pts[is_pts, 0] - 2 * period_x * w) * (-1) ** period_x
) / w
rel_xyz[:, is_pts, 1] = (
(pts[is_pts, 1] - 2 * period_y * h) * (-1) ** period_y
) / h
rel_xyz[:, is_pts, 2] = (
torch.clip(pts[is_pts, 2] / d * 2 - 1, -1, 1) if d > 0 else 0
)
batch_idx[:, is_pts] = idx
if zs is not None:
zs[ins] = {"z": style_lut[ins], "idx": is_pts.unsqueeze(dim=0)}
return rel_xyz, batch_idx, zs
def _get_gs_colors(pts, pt_attrs, proj_hf, proj_seg, model):
if pts.shape[0] == 0:
return torch.empty(0, 3, dtype=torch.float32, device=pts.device)
abs_xyz, onehots = pts[None, :, :3], pts[None, :, 7:]
rel_xyz, batch_idx, zs = pt_attrs
proj_uv = None
if model.module.cfg.ENCODER is not None:
proj_uv = get_projection_uv(abs_xyz)
with torch.no_grad():
# TODO: Optimize the _instance_forward in Generator
gs_attrs = model(
proj_uv, rel_xyz, batch_idx, onehots.float(), zs, proj_hf, proj_seg
)
return gs_attrs["rgb"].squeeze(dim=0)
def get_projection_uv(xyz, proj_tlp=None, proj_size=2048):
n_pts = xyz.size(1)
if proj_tlp is None:
proj_uv = xyz[..., :2].clone().float()
else:
proj_uv = xyz[..., :2] - proj_tlp.unsqueeze(dim=1)
assert proj_uv.size() == (xyz.size(0), n_pts, 2)
proj_uv[..., 0] /= proj_size
proj_uv[..., 1] /= proj_size
# Normalize to [-1, 1]
return proj_uv * 2 - 1
def _render(gs_attrs, rasterizator, cam_pose):
import torchvision.transforms.functional as F
with torch.no_grad():
img = rasterizator(
gs_attrs,
cam_pose[:3], # Position
cam_pose[3:], # Quaternion
)
img = img.squeeze() / 2 + 0.5
img = F.adjust_brightness(img, 1.2)
img = F.adjust_contrast(img, 1.2)
return (img * 255).permute(1, 2, 0).cpu().numpy().astype(np.uint8)