Spaces:
Runtime error
Runtime error
// Copyright (c) Meta Platforms, Inc. and affiliates. | |
// All rights reserved. | |
// | |
// This source code is licensed under the license found in the | |
// LICENSE file in the root directory of this source tree. | |
__forceinline__ __device__ void compute_aabb_srt( | |
float3 pt, float3 pr0, float3 pr1, float3 pr2, float3 ps, | |
float3 & pmin, float3 & pmax) { | |
float3 p; | |
p = make_float3(-1.f, -1.f, -1.f) / ps; | |
p = make_float3(dot(p, pr0), dot(p, pr1), dot(p, pr2)) + pt; | |
pmin = p; | |
pmax = p; | |
p = make_float3(1.f, -1.f, -1.f) / ps; | |
p = make_float3(dot(p, pr0), dot(p, pr1), dot(p, pr2)) + pt; | |
pmin = fminf(pmin, p); | |
pmax = fmaxf(pmax, p); | |
p = make_float3(-1.f, 1.f, -1.f) / ps; | |
p = make_float3(dot(p, pr0), dot(p, pr1), dot(p, pr2)) + pt; | |
pmin = fminf(pmin, p); | |
pmax = fmaxf(pmax, p); | |
p = make_float3(1.f, 1.f, -1.f) / ps; | |
p = make_float3(dot(p, pr0), dot(p, pr1), dot(p, pr2)) + pt; | |
pmin = fminf(pmin, p); | |
pmax = fmaxf(pmax, p); | |
p = make_float3(-1.f, -1.f, 1.f) / ps; | |
p = make_float3(dot(p, pr0), dot(p, pr1), dot(p, pr2)) + pt; | |
pmin = fminf(pmin, p); | |
pmax = fmaxf(pmax, p); | |
p = make_float3(1.f, -1.f, 1.f) / ps; | |
p = make_float3(dot(p, pr0), dot(p, pr1), dot(p, pr2)) + pt; | |
pmin = fminf(pmin, p); | |
pmax = fmaxf(pmax, p); | |
p = make_float3(-1.f, 1.f, 1.f) / ps; | |
p = make_float3(dot(p, pr0), dot(p, pr1), dot(p, pr2)) + pt; | |
pmin = fminf(pmin, p); | |
pmax = fmaxf(pmax, p); | |
p = make_float3(1.f, 1.f, 1.f) / ps; | |
p = make_float3(dot(p, pr0), dot(p, pr1), dot(p, pr2)) + pt; | |
pmin = fminf(pmin, p); | |
pmax = fmaxf(pmax, p); | |
} | |
struct PrimTransfDataBase { | |
typedef PrimTransfDataBase base; | |
}; | |
struct PrimTransfSRT { | |
struct Data : public PrimTransfDataBase { | |
int primpos_nstride; | |
float3 * primpos; | |
float3 * grad_primpos; | |
int primrot_nstride; | |
float3 * primrot; | |
float3 * grad_primrot; | |
int primscale_nstride; | |
float3 * primscale; | |
float3 * grad_primscale; | |
__forceinline__ __device__ void n_stride(int n) { | |
primpos += n * primpos_nstride; | |
grad_primpos += n * primpos_nstride; | |
primrot += n * primrot_nstride; | |
grad_primrot += n * primrot_nstride; | |
primscale += n * primscale_nstride; | |
grad_primscale += n * primscale_nstride; | |
} | |
__forceinline__ __device__ float3 get_center(int n, int k) { | |
return primpos[n * primpos_nstride + k]; | |
} | |
__forceinline__ __device__ void compute_aabb(int n, int k, float3 & pmin, float3 & pmax) { | |
float3 pt = primpos[n * primpos_nstride + k]; | |
float3 pr0 = primrot[n * primrot_nstride + k * 3 + 0]; | |
float3 pr1 = primrot[n * primrot_nstride + k * 3 + 1]; | |
float3 pr2 = primrot[n * primrot_nstride + k * 3 + 2]; | |
float3 ps = primscale[n * primscale_nstride + k]; | |
compute_aabb_srt(pt, pr0, pr1, pr2, ps, pmin, pmax); | |
} | |
}; | |
float3 xmt; | |
float3 pr0; | |
float3 pr1; | |
float3 pr2; | |
float3 rxmt; | |
float3 ps; | |
static __forceinline__ __device__ bool valid(float3 pos) { | |
return ( | |
pos.x > -1.f && pos.x < 1.f && | |
pos.y > -1.f && pos.y < 1.f && | |
pos.z > -1.f && pos.z < 1.f); | |
} | |
__forceinline__ __device__ float3 forward( | |
const Data & data, | |
int k, | |
float3 x) { | |
float3 pt = data.primpos[k]; | |
pr0 = data.primrot[(k) * 3 + 0]; | |
pr1 = data.primrot[(k) * 3 + 1]; | |
pr2 = data.primrot[(k) * 3 + 2]; | |
ps = data.primscale[k]; | |
xmt = x - pt; | |
rxmt = pr0 * xmt.x + pr1 * xmt.y + pr2 * xmt.z; | |
float3 y0 = rxmt * ps; | |
return y0; | |
} | |
static __forceinline__ __device__ void forward2( | |
const Data & data, | |
int k, | |
float3 r, float3 d, float3 & rout, float3 & dout) { | |
float3 pt = data.primpos[k]; | |
float3 pr0 = data.primrot[k * 3 + 0]; | |
float3 pr1 = data.primrot[k * 3 + 1]; | |
float3 pr2 = data.primrot[k * 3 + 2]; | |
float3 ps = data.primscale[k]; | |
float3 xmt = r - pt; | |
float3 dmt = d; | |
float3 rxmt = pr0 * xmt.x; | |
float3 rdmt = pr0 * dmt.x; | |
rxmt += pr1 * xmt.y; | |
rdmt += pr1 * dmt.y; | |
rxmt += pr2 * xmt.z; | |
rdmt += pr2 * dmt.z; | |
rout = rxmt * ps; | |
dout = rdmt * ps; | |
} | |
__forceinline__ __device__ void backward(const Data & data, int k, float3 x, float3 dL_y0, bool validthread) { | |
fastAtomicAdd((float*)data.grad_primscale + k * 3 + 0, validthread ? rxmt.x * dL_y0.x : 0.f); | |
fastAtomicAdd((float*)data.grad_primscale + k * 3 + 1, validthread ? rxmt.y * dL_y0.y : 0.f); | |
fastAtomicAdd((float*)data.grad_primscale + k * 3 + 2, validthread ? rxmt.z * dL_y0.z : 0.f); | |
dL_y0 *= ps; | |
float3 gpr0 = xmt.x * dL_y0; | |
fastAtomicAdd((float*)data.grad_primrot + (k * 3 + 0) * 3 + 0, validthread ? gpr0.x : 0.f); | |
fastAtomicAdd((float*)data.grad_primrot + (k * 3 + 0) * 3 + 1, validthread ? gpr0.y : 0.f); | |
fastAtomicAdd((float*)data.grad_primrot + (k * 3 + 0) * 3 + 2, validthread ? gpr0.z : 0.f); | |
float3 gpr1 = xmt.y * dL_y0; | |
fastAtomicAdd((float*)data.grad_primrot + (k * 3 + 1) * 3 + 0, validthread ? gpr1.x : 0.f); | |
fastAtomicAdd((float*)data.grad_primrot + (k * 3 + 1) * 3 + 1, validthread ? gpr1.y : 0.f); | |
fastAtomicAdd((float*)data.grad_primrot + (k * 3 + 1) * 3 + 2, validthread ? gpr1.z : 0.f); | |
float3 gpr2 = xmt.z * dL_y0; | |
fastAtomicAdd((float*)data.grad_primrot + (k * 3 + 2) * 3 + 0, validthread ? gpr2.x : 0.f); | |
fastAtomicAdd((float*)data.grad_primrot + (k * 3 + 2) * 3 + 1, validthread ? gpr2.y : 0.f); | |
fastAtomicAdd((float*)data.grad_primrot + (k * 3 + 2) * 3 + 2, validthread ? gpr2.z : 0.f); | |
fastAtomicAdd((float*)data.grad_primpos + k * 3 + 0, validthread ? -dot(pr0, dL_y0) : 0.f); | |
fastAtomicAdd((float*)data.grad_primpos + k * 3 + 1, validthread ? -dot(pr1, dL_y0) : 0.f); | |
fastAtomicAdd((float*)data.grad_primpos + k * 3 + 2, validthread ? -dot(pr2, dL_y0) : 0.f); | |
} | |
}; | |