DarthReca commited on
Commit
7ccc452
·
verified ·
1 Parent(s): 527b83f

Create positional_encoding.py

Browse files
Files changed (1) hide show
  1. positional_encoding.py +110 -0
positional_encoding.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+
3
+ import math
4
+
5
+ import torch
6
+ from torch import nn
7
+
8
+ from .spherical_armonics import SH as SH_analytic
9
+
10
+
11
+ class SphericalHarmonics(nn.Module):
12
+ """
13
+ Spherical Harmonics locaiton encoder
14
+ """
15
+
16
+ def __init__(self, legendre_polys: int = 10, harmonics_calculation="analytic"):
17
+ """
18
+ legendre_polys: determines the number of legendre polynomials.
19
+ more polynomials lead more fine-grained resolutions
20
+ calculation of spherical harmonics:
21
+ analytic uses pre-computed equations. This is exact, but works only up to degree 50,
22
+ closed-form uses one equation but is computationally slower (especially for high degrees)
23
+ """
24
+ super(SphericalHarmonics, self).__init__()
25
+ self.L, self.M = int(legendre_polys), int(legendre_polys)
26
+ self.embedding_dim = self.L * self.M
27
+
28
+ if harmonics_calculation == "closed-form":
29
+ self.SH = SH_closed_form
30
+ elif harmonics_calculation == "analytic":
31
+ self.SH = SH_analytic
32
+
33
+ def forward(self, lonlat):
34
+ lon, lat = lonlat[:, 0], lonlat[:, 1]
35
+
36
+ # convert degree to rad
37
+ phi = torch.deg2rad(lon + 180)
38
+ theta = torch.deg2rad(lat + 90)
39
+ """
40
+ greater_than_50 = (lon > 50).any() or (lat > 50).any()
41
+ if greater_than_50:
42
+ SH = SH_closed_form
43
+ else:
44
+ SH = SH_analytic
45
+ """
46
+ SH = self.SH
47
+
48
+ Y = []
49
+ for l in range(self.L):
50
+ for m in range(-l, l + 1):
51
+ y = SH(m, l, phi, theta)
52
+ if isinstance(y, float):
53
+ y = y * torch.ones_like(phi)
54
+ if y.isnan().any():
55
+ print(m, l, y)
56
+ Y.append(y)
57
+
58
+ return torch.stack(Y, dim=-1)
59
+
60
+
61
+ ####################### Spherical Harmonics utilities ########################
62
+ # Code copied from https://github.com/BachiLi/redner/blob/master/pyredner/utils.py
63
+ # Code adapted from "Spherical Harmonic Lighting: The Gritty Details", Robin Green
64
+ # http://silviojemma.com/public/papers/lighting/spherical-harmonic-lighting.pdf
65
+ def associated_legendre_polynomial(l, m, x):
66
+ pmm = torch.ones_like(x)
67
+ if m > 0:
68
+ somx2 = torch.sqrt((1 - x) * (1 + x))
69
+ fact = 1.0
70
+ for i in range(1, m + 1):
71
+ pmm = pmm * (-fact) * somx2
72
+ fact += 2.0
73
+ if l == m:
74
+ return pmm
75
+ pmmp1 = x * (2.0 * m + 1.0) * pmm
76
+ if l == m + 1:
77
+ return pmmp1
78
+ pll = torch.zeros_like(x)
79
+ for ll in range(m + 2, l + 1):
80
+ pll = ((2.0 * ll - 1.0) * x * pmmp1 - (ll + m - 1.0) * pmm) / (ll - m)
81
+ pmm = pmmp1
82
+ pmmp1 = pll
83
+ return pll
84
+
85
+
86
+ def SH_renormalization(l, m):
87
+ return math.sqrt(
88
+ (2.0 * l + 1.0) * math.factorial(l - m) / (4 * math.pi * math.factorial(l + m))
89
+ )
90
+
91
+
92
+ def SH_closed_form(m, l, phi, theta):
93
+ if m == 0:
94
+ return SH_renormalization(l, m) * associated_legendre_polynomial(
95
+ l, m, torch.cos(theta)
96
+ )
97
+ elif m > 0:
98
+ return (
99
+ math.sqrt(2.0)
100
+ * SH_renormalization(l, m)
101
+ * torch.cos(m * phi)
102
+ * associated_legendre_polynomial(l, m, torch.cos(theta))
103
+ )
104
+ else:
105
+ return (
106
+ math.sqrt(2.0)
107
+ * SH_renormalization(l, -m)
108
+ * torch.sin(-m * phi)
109
+ * associated_legendre_polynomial(l, -m, torch.cos(theta))
110
+ )