|
import torch |
|
|
|
|
|
def qinv(q: torch.Tensor) -> torch.Tensor: |
|
assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)' |
|
mask = torch.ones_like(q) |
|
mask[..., 1:] = -mask[..., 1:] |
|
return q * mask |
|
|
|
|
|
def qrot(q: torch.Tensor, v: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Rotate vector(s) v about the rotation described by quaternion(s) q. |
|
Expects a tensor of shape (*, 4) for q and a tensor of shape (*, 3) for v, |
|
where * denotes any number of dimensions. |
|
Returns a tensor of shape (*, 3). |
|
""" |
|
assert q.shape[-1] == 4 |
|
assert v.shape[-1] == 3 |
|
assert q.shape[:-1] == v.shape[:-1] |
|
|
|
original_shape = list(v.shape) |
|
q = q.contiguous().view(-1, 4) |
|
v = v.contiguous().view(-1, 3) |
|
|
|
qvec = q[:, 1:] |
|
uv = torch.cross(qvec, v, dim=1) |
|
uuv = torch.cross(qvec, uv, dim=1) |
|
return (v + 2 * (q[:, :1] * uv + uuv)).view(original_shape) |
|
|