File size: 14,249 Bytes
ec0c8fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
from typing import *

import torch
import nvdiffrast.torch as dr

from . import utils, transforms, mesh
from ._helpers import batched


__all__ = [
    'RastContext',
    'rasterize_triangle_faces', 
    'warp_image_by_depth',
    'warp_image_by_forward_flow',
]


class RastContext:
    """
    Create a rasterization context. Nothing but a wrapper of nvdiffrast.torch.RasterizeCudaContext or nvdiffrast.torch.RasterizeGLContext.
    """
    def __init__(self, nvd_ctx: Union[dr.RasterizeCudaContext, dr.RasterizeGLContext] = None, *, backend: Literal['cuda', 'gl'] = 'gl',  device: Union[str, torch.device] = None):
        import nvdiffrast.torch as dr
        if nvd_ctx is not None:
            self.nvd_ctx = nvd_ctx
            return 
        
        if backend == 'gl':
            self.nvd_ctx = dr.RasterizeGLContext(device=device)
        elif backend == 'cuda':
            self.nvd_ctx = dr.RasterizeCudaContext(device=device)
        else:
            raise ValueError(f'Unknown backend: {backend}')


def rasterize_triangle_faces(
    ctx: RastContext,
    vertices: torch.Tensor,
    faces: torch.Tensor,
    attr: torch.Tensor,
    width: int,
    height: int,
    model: torch.Tensor = None,
    view: torch.Tensor = None,
    projection: torch.Tensor = None,
    antialiasing: Union[bool, List[int]] = True,
    diff_attrs: Union[None, List[int]] = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
    """
    Rasterize a mesh with vertex attributes.

    Args:
        ctx (GLContext): rasterizer context
        vertices (np.ndarray): (B, N, 2 or 3 or 4)
        faces (torch.Tensor): (T, 3)
        attr (torch.Tensor): (B, N, C)
        width (int): width of the output image
        height (int): height of the output image
        model (torch.Tensor, optional): ([B,] 4, 4) model matrix. Defaults to None (identity).
        view (torch.Tensor, optional): ([B,] 4, 4) view matrix. Defaults to None (identity).
        projection (torch.Tensor, optional): ([B,] 4, 4) projection matrix. Defaults to None (identity).
        antialiasing (Union[bool, List[int]], optional): whether to perform antialiasing. Defaults to True. If a list of indices is provided, only those channels will be antialiased.
        diff_attrs (Union[None, List[int]], optional): indices of attributes to compute screen-space derivatives. Defaults to None.

    Returns:
        image: (torch.Tensor): (B, C, H, W)
        depth: (torch.Tensor): (B, H, W) screen space depth, ranging from 0 (near) to 1. (far)
            NOTE: Empty pixels will have depth 1., i.e. far plane.
    """
    assert vertices.ndim == 3
    assert faces.ndim == 2

    if vertices.shape[-1] == 2:
        vertices = torch.cat([vertices, torch.zeros_like(vertices[..., :1]), torch.ones_like(vertices[..., :1])], dim=-1)
    elif vertices.shape[-1] == 3:
        vertices = torch.cat([vertices, torch.ones_like(vertices[..., :1])], dim=-1)
    elif vertices.shape[-1] == 4:
        pass
    else:
        raise ValueError(f'Wrong shape of vertices: {vertices.shape}')
    
    mvp = projection if projection is not None else torch.eye(4).to(vertices)
    if view is not None:
        mvp = mvp @ view
    if model is not None:
        mvp = mvp @ model
    
    pos_clip = vertices @ mvp.transpose(-1, -2)
    faces = faces.contiguous()
    attr = attr.contiguous()
    
    rast_out, rast_db = dr.rasterize(ctx.nvd_ctx, pos_clip, faces, resolution=[height, width], grad_db=True)
    image, image_dr = dr.interpolate(attr, rast_out, faces, rast_db, diff_attrs=diff_attrs)
    if antialiasing == True:
        image = dr.antialias(image, rast_out, pos_clip, faces)
    elif isinstance(antialiasing, list):
        aa_image = dr.antialias(image[..., antialiasing], rast_out, pos_clip, faces)
        image[..., antialiasing] = aa_image

    image = image.flip(1).permute(0, 3, 1, 2)
    
    depth = rast_out[..., 2].flip(1) 
    depth = (depth * 0.5 + 0.5) * (depth > 0).float() + (depth == 0).float()
    if diff_attrs is not None:
        image_dr = image_dr.flip(1).permute(0, 3, 1, 2)
        return image, depth, image_dr
    return image, depth


def texture(
    ctx: RastContext,
    uv: torch.Tensor,
    uv_da: torch.Tensor,
    texture: torch.Tensor,
) -> torch.Tensor:
    dr.texture(ctx.nvd_ctx, uv, texture)


def warp_image_by_depth(
    ctx: RastContext,
    depth: torch.FloatTensor,
    image: torch.FloatTensor = None,
    mask: torch.BoolTensor = None,
    width: int = None,
    height: int = None,
    *,
    extrinsics_src: torch.FloatTensor = None,
    extrinsics_tgt: torch.FloatTensor = None,
    intrinsics_src: torch.FloatTensor = None,
    intrinsics_tgt: torch.FloatTensor = None,
    near: float = 0.1,
    far: float = 100.0,
    antialiasing: bool = True,
    backslash: bool = False,
    padding: int = 0,
    return_uv: bool = False,
    return_dr: bool = False,
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.BoolTensor, Optional[torch.FloatTensor], Optional[torch.FloatTensor]]:
    """
    Warp image by depth. 
    NOTE: if batch size is 1, image mesh will be triangulated aware of the depth, yielding less distorted results.
    Otherwise, image mesh will be triangulated simply for batch rendering.

    Args:
        ctx (Union[dr.RasterizeCudaContext, dr.RasterizeGLContext]): rasterization context
        depth (torch.Tensor): (B, H, W) linear depth
        image (torch.Tensor): (B, C, H, W). None to use image space uv. Defaults to None.
        width (int, optional): width of the output image. None to use the same as depth. Defaults to None.
        height (int, optional): height of the output image. Defaults the same as depth..
        extrinsics_src (torch.Tensor, optional): (B, 4, 4) extrinsics matrix for source. None to use identity. Defaults to None.
        extrinsics_tgt (torch.Tensor, optional): (B, 4, 4) extrinsics matrix for target. None to use identity. Defaults to None.
        intrinsics_src (torch.Tensor, optional): (B, 3, 3) intrinsics matrix for source. None to use the same as target. Defaults to None.
        intrinsics_tgt (torch.Tensor, optional): (B, 3, 3) intrinsics matrix for target. None to use the same as source. Defaults to None.
        near (float, optional): near plane. Defaults to 0.1. 
        far (float, optional): far plane. Defaults to 100.0.
        antialiasing (bool, optional): whether to perform antialiasing. Defaults to True.
        backslash (bool, optional): whether to use backslash triangulation. Defaults to False.
        padding (int, optional): padding of the image. Defaults to 0.
        return_uv (bool, optional): whether to return the uv. Defaults to False.
        return_dr (bool, optional): whether to return the image-space derivatives of uv. Defaults to False.
    
    Returns:
        image: (torch.FloatTensor): (B, C, H, W) rendered image
        depth: (torch.FloatTensor): (B, H, W) linear depth, ranging from 0 to inf
        mask: (torch.BoolTensor): (B, H, W) mask of valid pixels
        uv: (torch.FloatTensor): (B, 2, H, W) image-space uv
        dr: (torch.FloatTensor): (B, 4, H, W) image-space derivatives of uv
    """
    assert depth.ndim == 3
    batch_size = depth.shape[0]

    if width is None:
        width = depth.shape[-1]
    if height is None:
        height = depth.shape[-2]
    if image is not None:
        assert image.shape[-2:] == depth.shape[-2:], f'Shape of image {image.shape} does not match shape of depth {depth.shape}'

    if extrinsics_src is None:
        extrinsics_src = torch.eye(4).to(depth)
    if extrinsics_tgt is None:
        extrinsics_tgt = torch.eye(4).to(depth)
    if intrinsics_src is None:
        intrinsics_src = intrinsics_tgt
    if intrinsics_tgt is None:
        intrinsics_tgt = intrinsics_src
    
    assert all(x is not None for x in [extrinsics_src, extrinsics_tgt, intrinsics_src, intrinsics_tgt]), "Make sure you have provided all the necessary camera parameters."

    view_tgt = transforms.extrinsics_to_view(extrinsics_tgt)
    perspective_tgt = transforms.intrinsics_to_perspective(intrinsics_tgt, near=near, far=far)

    if padding > 0:
        uv, faces = utils.image_mesh(width=width+2, height=height+2)
        uv = (uv - 1 / (width + 2)) * ((width + 2) / width)
        uv_ = uv.clone().reshape(height+2, width+2, 2)
        uv_[0, :, 1] -= padding / height
        uv_[-1, :, 1] += padding / height
        uv_[:, 0, 0] -= padding / width
        uv_[:, -1, 0] += padding / width
        uv_ = uv_.reshape(-1, 2)
        depth = torch.nn.functional.pad(depth, [1, 1, 1, 1], mode='replicate')
        if image is not None:
            image = torch.nn.functional.pad(image, [1, 1, 1, 1], mode='replicate')
        uv, uv_, faces = uv.to(depth.device), uv_.to(depth.device), faces.to(depth.device)
        pts = transforms.unproject_cv(
            uv_,
            depth.flatten(-2, -1),
            extrinsics_src,
            intrinsics_src,
        )
    else:    
        uv, faces = utils.image_mesh(width=depth.shape[-1], height=depth.shape[-2])
        if mask is not None:
            depth = torch.where(mask, depth, torch.tensor(far, dtype=depth.dtype, device=depth.device))
        uv, faces = uv.to(depth.device), faces.to(depth.device)
        pts = transforms.unproject_cv(
            uv,
            depth.flatten(-2, -1),
            extrinsics_src,
            intrinsics_src,
        )

    # triangulate
    if batch_size == 1:
        faces = mesh.triangulate(faces, vertices=pts[0])
    else:
        faces = mesh.triangulate(faces, backslash=backslash)

    # rasterize attributes
    diff_attrs = None
    if image is not None:
        attr = image.permute(0, 2, 3, 1).flatten(1, 2)
        if return_dr or return_uv:
            if return_dr:
                diff_attrs = [image.shape[1], image.shape[1]+1]
            if return_uv and antialiasing:
                antialiasing = list(range(image.shape[1]))
            attr = torch.cat([attr, uv.expand(batch_size, -1, -1)], dim=-1)
    else:
        attr = uv.expand(batch_size, -1, -1)
        if antialiasing:
            print("\033[93mWarning: you are performing antialiasing on uv. This may cause artifacts.\033[0m")
        if return_uv:
            return_uv = False
            print("\033[93mWarning: image is None, return_uv is ignored.\033[0m")
        if return_dr:
            diff_attrs = [0, 1]

    if mask is not None:
        attr = torch.cat([attr, mask.float().flatten(1, 2).unsqueeze(-1)], dim=-1)

    rast = rasterize_triangle_faces(
        ctx,
        pts,
        faces,
        attr,
        width,
        height,
        view=view_tgt,
        perspective=perspective_tgt,
        antialiasing=antialiasing,
        diff_attrs=diff_attrs,
    )
    if return_dr:
        output_image, screen_depth, output_dr = rast
    else:
        output_image, screen_depth = rast
    output_mask = screen_depth < 1.0

    if mask is not None:
        output_image, rast_mask = output_image[..., :-1, :, :], output_image[..., -1, :, :]
        output_mask &= (rast_mask > 0.9999).reshape(-1, height, width)

    if (return_dr or return_uv) and image is not None:
        output_image, output_uv = output_image[..., :-2, :, :], output_image[..., -2:, :, :]

    output_depth = transforms.depth_buffer_to_linear(screen_depth, near=near, far=far) * output_mask
    output_image = output_image * output_mask.unsqueeze(1)

    outs = [output_image, output_depth, output_mask]
    if return_uv:
        outs.append(output_uv)
    if return_dr:
        outs.append(output_dr)
    return tuple(outs)


def warp_image_by_forward_flow(
    ctx: RastContext,
    image: torch.FloatTensor,
    flow: torch.FloatTensor,
    depth: torch.FloatTensor = None,
    *,
    antialiasing: bool = True,
    backslash: bool = False,
) -> Tuple[torch.FloatTensor, torch.BoolTensor]:
    """
    Warp image by forward flow.
    NOTE: if batch size is 1, image mesh will be triangulated aware of the depth, yielding less distorted results.
    Otherwise, image mesh will be triangulated simply for batch rendering.

    Args:
        ctx (Union[dr.RasterizeCudaContext, dr.RasterizeGLContext]): rasterization context
        image (torch.Tensor): (B, C, H, W) image
        flow (torch.Tensor): (B, 2, H, W) forward flow
        depth (torch.Tensor, optional): (B, H, W) linear depth. If None, will use the same for all pixels. Defaults to None.
        antialiasing (bool, optional): whether to perform antialiasing. Defaults to True.
        backslash (bool, optional): whether to use backslash triangulation. Defaults to False.
    
    Returns:
        image: (torch.FloatTensor): (B, C, H, W) rendered image
        mask: (torch.BoolTensor): (B, H, W) mask of valid pixels
    """
    assert image.ndim == 4, f'Wrong shape of image: {image.shape}'
    batch_size, _, height, width = image.shape

    if depth is None:
        depth = torch.ones_like(flow[:, 0])

    extrinsics = torch.eye(4).to(image)
    fov = torch.deg2rad(torch.tensor([45.0], device=image.device))
    intrinsics = transforms.intrinsics_from_fov(fov, width, height, normalize=True)[0] 
   
    view = transforms.extrinsics_to_view(extrinsics)
    perspective = transforms.intrinsics_to_perspective(intrinsics, near=0.1, far=100)

    uv, faces = utils.image_mesh(width=width, height=height)
    uv, faces = uv.to(image.device), faces.to(image.device)
    uv = uv + flow.permute(0, 2, 3, 1).flatten(1, 2)
    pts = transforms.unproject_cv(
        uv,
        depth.flatten(-2, -1),
        extrinsics,
        intrinsics,
    )

    # triangulate
    if batch_size == 1:
        faces = mesh.triangulate(faces, vertices=pts[0])
    else:
        faces = mesh.triangulate(faces, backslash=backslash)

    # rasterize attributes
    attr = image.permute(0, 2, 3, 1).flatten(1, 2)
    rast = rasterize_triangle_faces(
        ctx,
        pts,
        faces,
        attr,
        width,
        height,
        view=view,
        perspective=perspective,
        antialiasing=antialiasing,
    )
    output_image, screen_depth = rast
    output_mask = screen_depth < 1.0
    output_image = output_image * output_mask.unsqueeze(1)

    outs = [output_image, output_mask]
    return tuple(outs)