Mikael Mieskolainen
commited on
Commit
·
9fa359b
1
Parent(s):
894ea5c
initial commit v0
Browse files- README.md +9 -0
- hypertrack/models/global_tune-5.py +308 -0
- hypertrack/models/models_tune-5.py +635 -0
- models/tag_f-0p01-hyper-5/model_net_epoch_834169.pt +3 -0
- models/tag_f-0p01-hyper-5/model_pdf_epoch_834169.pt +3 -0
- models/tag_f-0p01-hyper-5/optimizer_net_epoch_834169.pt +3 -0
- models/tag_f-0p01-hyper-5/optimizer_pdf_epoch_834169.pt +3 -0
- models/tag_f-0p01-hyper-5/scheduler_net_epoch_834169.pt +3 -0
- models/tag_f-0p01-hyper-5/scheduler_pdf_epoch_834169.pt +3 -0
- models/tag_f-0p1-hyper-5/model_net_epoch_752039.pt +3 -0
- models/tag_f-0p1-hyper-5/model_pdf_epoch_752039.pt +3 -0
- models/tag_f-0p1-hyper-5/optimizer_net_epoch_752039.pt +3 -0
- models/tag_f-0p1-hyper-5/optimizer_pdf_epoch_752039.pt +3 -0
- models/tag_f-0p1-hyper-5/scheduler_net_epoch_752039.pt +3 -0
- models/tag_f-0p1-hyper-5/scheduler_pdf_epoch_752039.pt +3 -0
- models/tag_f-0p3-hyper-5/model_net_epoch_484878.pt +3 -0
- models/tag_f-0p3-hyper-5/model_pdf_epoch_484878.pt +3 -0
- models/tag_f-0p3-hyper-5/optimizer_net_epoch_484878.pt +3 -0
- models/tag_f-0p3-hyper-5/optimizer_pdf_epoch_484878.pt +3 -0
- models/tag_f-0p3-hyper-5/scheduler_net_epoch_484878.pt +3 -0
- models/tag_f-0p3-hyper-5/scheduler_pdf_epoch_484878.pt +3 -0
- models/voxdyn_node2node_hyper_ncell_131072.pkl +3 -0
- models/voxdyn_node2node_hyper_ncell_262144.pkl +3 -0
- models/voxdyn_node2node_hyper_ncell_32768.pkl +3 -0
- models/voxdyn_node2node_hyper_ncell_524288.pkl +3 -0
- models/voxdyn_node2node_hyper_ncell_65536.pkl +3 -0
README.md
CHANGED
@@ -1,3 +1,12 @@
|
|
1 |
---
|
2 |
license: mit
|
3 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
license: mit
|
3 |
---
|
4 |
+
|
5 |
+
HyperTrack: Neural Combinatorics for High Energy Physics
|
6 |
+
|
7 |
+
https://arxiv.org/abs/2309.14113
|
8 |
+
|
9 |
+
Download the main repository from:
|
10 |
+
github.com/mieskolainen/hypertrack
|
11 |
+
|
12 |
+
Download this repository to the same path.
|
hypertrack/models/global_tune-5.py
ADDED
@@ -0,0 +1,308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# HyperTrack model and training loss parameters
|
2 |
+
#
|
3 |
+
# match with the corresponding 'models_<ID>.py' under 'hypertrack/models/'
|
4 |
+
#
|
5 |
+
# [email protected], 2023
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
# -------------------------------------------------------------------------
|
10 |
+
# Input normalization
|
11 |
+
# (e.g. can accelerate training, and mitigate float scale problems, but not necessarily needed)
|
12 |
+
normalize_input = False
|
13 |
+
|
14 |
+
"""
|
15 |
+
- coord[0] (min,max,mean,std): -1025.3399658203125 | 1025.3399658203125 | 1.0586246252059937 | 266.20428466796875
|
16 |
+
- coord[1] (min,max,mean,std): -1025.3399658203125 | 1025.3399658203125 | -0.022702794522047043 | 267.56085205078125
|
17 |
+
- coord[2] (min,max,mean,std): -2955.5 | 2955.5 | 1.6228374242782593 | 1064.4954833984375
|
18 |
+
"""
|
19 |
+
|
20 |
+
def feature_scaler(X):
|
21 |
+
mu = [1.06, -0.023, 1.62]
|
22 |
+
sigma = [266.2, 267.6, 1064.5]
|
23 |
+
|
24 |
+
for i in range(len(mu)):
|
25 |
+
X[:,i] = (X[:,i] - mu[i]) / sigma[i]
|
26 |
+
|
27 |
+
# -------------------------------------------------------------------------
|
28 |
+
|
29 |
+
# ** Training only parameters **
|
30 |
+
train_param = {
|
31 |
+
|
32 |
+
# Total loss weights per each individual loss
|
33 |
+
'beta': {
|
34 |
+
|
35 |
+
'net': {
|
36 |
+
'edge_BCE' : 0.2, # 0.2
|
37 |
+
'edge_contrastive' : 1.0, # 1.0
|
38 |
+
'cluster_BCE' : 0.2, # 0.2
|
39 |
+
'cluster_contrastive': 0.2, # 1.0
|
40 |
+
'cluster_neglogpdf': 0.0, # [EXPERIMENTAL] (keep it zero)
|
41 |
+
},
|
42 |
+
|
43 |
+
'pdf': {
|
44 |
+
'track_neglogpdf': 1.0, # [EXPERIMENTAL]
|
45 |
+
}
|
46 |
+
},
|
47 |
+
|
48 |
+
# Edge loss
|
49 |
+
'edge_BCE': {
|
50 |
+
'type': 'Focal', # 'Focal', 'BCE', 'BCE+Hinge'
|
51 |
+
'gamma': 1.0, # For 'Focal' (entropy exponent)
|
52 |
+
'delta': 0.05, # For 'BCE+Hinge' (proportion)
|
53 |
+
'remove_self_edges': False, # Remove self-edges
|
54 |
+
'edge_balance': True # true/false edge balance unity re-weight
|
55 |
+
},
|
56 |
+
|
57 |
+
# Contrastive loss per particle
|
58 |
+
'edge_contrastive': {
|
59 |
+
'weights': True, # TrackML hit weights ok with this
|
60 |
+
'type': 'softmax',
|
61 |
+
'tau': 0.3, # temperature (see: https://arxiv.org/abs/2012.09740, https://openreview.net/pdf?id=vnOHGQY4FP1)
|
62 |
+
'sub_sample': 300, # memory constraint (maximum number of target objects to compute the loss per event)
|
63 |
+
|
64 |
+
'min_prob': 1e-3, # minimum edge prob. score to be included in the loss [EXPERIMENTAL]
|
65 |
+
}, # (higher values push towards purity, but can weaken efficiency for e.g. high multiplicity clusters)
|
66 |
+
|
67 |
+
# Cluster hit binary cross entropy loss
|
68 |
+
'cluster_BCE': {
|
69 |
+
'weights': False, # TrackML hit weights (0 for noise) not exactly compatible
|
70 |
+
'type': 'Focal', # 'BCE', 'BCE+Hinge', 'Focal'
|
71 |
+
'gamma': 1.0, # For 'Focal' (entropy exponent)
|
72 |
+
'delta': 0.05, # For 'BCE+Hinge' (proportion)
|
73 |
+
},
|
74 |
+
|
75 |
+
# Cluster set hit loss
|
76 |
+
'cluster_contrastive': {
|
77 |
+
'weights': False, # TrackML hit weights (0 for noise) not exactly compatible
|
78 |
+
'type': 'intersect', # 'intersect', 'dice', 'jaccard'
|
79 |
+
'smooth': 1.0 # regularization for 'dice' and 'jaccard'
|
80 |
+
},
|
81 |
+
|
82 |
+
# Cluster meta-supervision target
|
83 |
+
'meta_target': 'pivotmajor' # 'major' (vote from all nodes ground truth) or 'pivotmajor' (vote from pivots ground truth)
|
84 |
+
}
|
85 |
+
|
86 |
+
# -------------------------------------------------------------------------
|
87 |
+
|
88 |
+
# These algorithm parameters can be changed after training, but
|
89 |
+
# note that the transformer network may adapt (learn) its weights according
|
90 |
+
# to the values set here during the training
|
91 |
+
cluster_param = {
|
92 |
+
|
93 |
+
# These are set from the command line interface
|
94 |
+
'algorithm': None,
|
95 |
+
'edge_threshold': None,
|
96 |
+
|
97 |
+
|
98 |
+
## Cut clustering & Transformer clustering input
|
99 |
+
'min_graph': 4, # Minimum subgraph size after the threshold and WCC search, the rest are treated as noise
|
100 |
+
|
101 |
+
## DBSCAN clustering
|
102 |
+
'dbscan': {
|
103 |
+
'eps': 0.2,
|
104 |
+
'min_samples': 3,
|
105 |
+
},
|
106 |
+
|
107 |
+
## HDBSCAN clustering
|
108 |
+
# https://hdbscan.readthedocs.io/en/latest/api.html
|
109 |
+
'hdbscan': {
|
110 |
+
'algorithm': 'generic',
|
111 |
+
'cluster_selection_epsilon': 0.0,
|
112 |
+
'cluster_selection_method': 'eom', # 'eom' or 'leaf'
|
113 |
+
'alpha': 1.0,
|
114 |
+
'min_samples': 2, # Keep it 2
|
115 |
+
'min_cluster_size': 4,
|
116 |
+
'max_dist': 1.0 # Keep it 1.0
|
117 |
+
},
|
118 |
+
|
119 |
+
## Transformer clustering
|
120 |
+
'worker_split': 4, # GPU Memory <-> GPU latency tradeoff (no accuracy impact)
|
121 |
+
|
122 |
+
'transformer': {
|
123 |
+
'seed_strategy': 'random', # 'random', 'max' (max norm), 'max_T (transverse max), 'min' (min norm), 'min_T' (transverse min)
|
124 |
+
'seed_ktop_max': 2, # Number of pivot walk (seed) candidates (higher -> better accuracy but slower)
|
125 |
+
|
126 |
+
'N_pivots': 3, # Number of pivotal hits to search per cluster (>> 1)
|
127 |
+
'random_paths': 1, # (Put >> 1 for MC sampled random walk, and 1 for greedy max-prob walk)
|
128 |
+
|
129 |
+
'max_failures': 2, # Maximum number of failures per pivot list nodes (put 1+ for greedy, >> 1 for MC walk)
|
130 |
+
'diffuse_threshold': 0.4, # Diffusion connectivity ~ Pivot quality threshold
|
131 |
+
|
132 |
+
# Micrograph extension type: 'pivot-spanned' (ok with 'hyper' adjacency), 'full' (for other than 'hyper' needed, more inclusive but possibly unstable)
|
133 |
+
'micrograph_mode': 'pivot-spanned',
|
134 |
+
|
135 |
+
'threshold_algo': 'fixed', # 'soft' (learnable), 'fisher' (batch-by-batch 1D-Fisher rule adaptive) or 'fixed'
|
136 |
+
|
137 |
+
'tau': 0.001, # 'soft':: Sigmoid 'temperature' (tau -> 0 ~ heaviside step)
|
138 |
+
|
139 |
+
'ktop_max': 30, # 'fisher':: Maximum cluster size (how many are considered from Transformer output), ranked by mask score
|
140 |
+
'fisher_threshold': np.linspace(0.4,0.6, 0), # 'fisher':: Threshold values tested
|
141 |
+
|
142 |
+
'fixed_threshold': 0.5, # 'fixed':: Note, if this is too high -> training may be unstable (first transformer iterations are bad)
|
143 |
+
|
144 |
+
'min_cluster_size': 4, # Require at least this many constituents per cluster
|
145 |
+
}
|
146 |
+
}
|
147 |
+
|
148 |
+
|
149 |
+
# -------------------------------------------------------------------------
|
150 |
+
|
151 |
+
### Geometric adjacency estimator
|
152 |
+
geom_param = {
|
153 |
+
|
154 |
+
# Use pre-trained 'voxdyn' or 'neurodyn' (experimental)
|
155 |
+
'algorithm': 'voxdyn',
|
156 |
+
|
157 |
+
# Print adjacency metrics (this will slow down significantly)
|
158 |
+
'verbose': False,
|
159 |
+
|
160 |
+
#'device': 'cuda', # CUDA not working with Faiss from conda atm (CUDA 11.4)
|
161 |
+
'device': 'cpu',
|
162 |
+
|
163 |
+
# 'neurodyn' parameters (PLACEHOLDER; not implemented)
|
164 |
+
'neural_param': {
|
165 |
+
'layers': [6, 128, 64, 1],
|
166 |
+
'act': 'silu',
|
167 |
+
'bn': True,
|
168 |
+
'dropout': 0.0,
|
169 |
+
'last_act': False
|
170 |
+
},
|
171 |
+
|
172 |
+
'neural_path': 'models/neurodyn'
|
173 |
+
}
|
174 |
+
|
175 |
+
|
176 |
+
# -------------------------------------------------------------------------
|
177 |
+
|
178 |
+
### GNN + Transformer model parameters
|
179 |
+
net_model_param = {
|
180 |
+
|
181 |
+
# GNN predictor block
|
182 |
+
'graph_block_param': {
|
183 |
+
|
184 |
+
'GNN_model' : 'SuperEdgeConv', # 'SuperEdgeConv', 'GaugeEdgeConv'
|
185 |
+
'nstack': 5, # Number of GNN message passing layers
|
186 |
+
|
187 |
+
'coord_dim': 3, # Input dimension
|
188 |
+
'h_dim': 64, # Intermediate latent embedding dimension
|
189 |
+
'z_dim': 61, # Final latent embedding dimension
|
190 |
+
|
191 |
+
# https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#aggregation-operators
|
192 |
+
'SuperEdgeConv': {
|
193 |
+
'm_dim': 64,
|
194 |
+
'aggr': ['mean']*5, # 'mean' (seems best memory/accuracy wise), 'sum', 'max', 'softmax', 'multi-aggregation', 'set-transformer'
|
195 |
+
'use_residual': True,
|
196 |
+
},
|
197 |
+
|
198 |
+
'GaugeEdgeConv': {
|
199 |
+
'm_dim': 64,
|
200 |
+
'aggr': ['mean']*5, # As many as 'nstack'
|
201 |
+
'norm_coord': False,
|
202 |
+
'norm_coord_scale_init': 1e-2,
|
203 |
+
},
|
204 |
+
|
205 |
+
# Edge prediction (correlation MLP) type: 'symmetric-dot', 'symmetrized', 'asymmetric'
|
206 |
+
# (clustering Transformer should prefer 'symmetric-dot')
|
207 |
+
'edge_type': 'symmetric-dot',
|
208 |
+
|
209 |
+
## Convolution (message passing) MLPs
|
210 |
+
|
211 |
+
'MLP_GNN_edge': {
|
212 |
+
'act': 'silu', # 'relu', 'tanh', 'silu', 'elu'
|
213 |
+
'bn': True,
|
214 |
+
'dropout': 0.0,
|
215 |
+
'last_act': True,
|
216 |
+
},
|
217 |
+
|
218 |
+
#'MLP_GNN_coord': { # Only for 'GaugeEdgeConv'
|
219 |
+
# 'act': 'silu',
|
220 |
+
# 'bn': True,
|
221 |
+
# 'dropout': 0.0,
|
222 |
+
# 'last_act': True,
|
223 |
+
#},
|
224 |
+
|
225 |
+
'MLP_GNN_latent': {
|
226 |
+
'act': 'silu',
|
227 |
+
'bn': True,
|
228 |
+
'dropout': 0.0,
|
229 |
+
'last_act': True,
|
230 |
+
},
|
231 |
+
|
232 |
+
## Latent Fusion MLP
|
233 |
+
'MLP_fusion': {
|
234 |
+
'act': 'silu',
|
235 |
+
'bn': True,
|
236 |
+
'dropout': 0.0,
|
237 |
+
'last_act': True,
|
238 |
+
},
|
239 |
+
|
240 |
+
## 2-pt edge correlation MLP
|
241 |
+
'MLP_correlate': {
|
242 |
+
'act': 'silu',
|
243 |
+
'bn': True,
|
244 |
+
'dropout': 0.0,
|
245 |
+
'last_act': False,
|
246 |
+
},
|
247 |
+
},
|
248 |
+
|
249 |
+
# Transformer clusterization block
|
250 |
+
'cluster_block_param': {
|
251 |
+
'in_dim': 64, # Same as GNN 'zdim' + 3 (for 3D coordinates)
|
252 |
+
'h_dim': 64, # Latent dim, needs to be divisible by num_heads
|
253 |
+
'output_dim': 1, # Always 1
|
254 |
+
'nstack_dec': 4, # Number of self-attention layers
|
255 |
+
|
256 |
+
'MLP_enc': { # First encoder MLP
|
257 |
+
'act': 'silu',
|
258 |
+
'bn': False,
|
259 |
+
'dropout': 0.0,
|
260 |
+
'last_act': False,
|
261 |
+
},
|
262 |
+
|
263 |
+
'MAB_dec': { # Transformer decoder MAB
|
264 |
+
'num_heads': 4,
|
265 |
+
'ln': True,
|
266 |
+
'dropout': 0.0,
|
267 |
+
'MLP_param':{
|
268 |
+
'act': 'silu',
|
269 |
+
'bn': False,
|
270 |
+
'dropout': 0.0,
|
271 |
+
'last_act': True,
|
272 |
+
}
|
273 |
+
},
|
274 |
+
|
275 |
+
'SAB_dec': { # Transformer decoder SAB
|
276 |
+
'num_heads': 4,
|
277 |
+
'ln': True,
|
278 |
+
'dropout': 0.0,
|
279 |
+
'MLP_param':{
|
280 |
+
'act': 'silu',
|
281 |
+
'bn': False,
|
282 |
+
'dropout': 0.0,
|
283 |
+
'last_act': True,
|
284 |
+
}
|
285 |
+
},
|
286 |
+
|
287 |
+
'MLP_mask': { # Mask decoder MLP
|
288 |
+
'act': 'silu',
|
289 |
+
'bn': False,
|
290 |
+
'dropout': 0.0,
|
291 |
+
'last_act': False,
|
292 |
+
}
|
293 |
+
}
|
294 |
+
}
|
295 |
+
|
296 |
+
# -------------------------------------------------------------------------
|
297 |
+
# [EXPERIMENTAL] -- normalizing flow
|
298 |
+
|
299 |
+
# Conditional data array indices (see /hypertrack/trackml.py)
|
300 |
+
cond_ind = [0,1,2,3,4,5,6]
|
301 |
+
|
302 |
+
pdf_model_param = {
|
303 |
+
'in_dim': 61,
|
304 |
+
'num_cond_inputs': len(cond_ind),
|
305 |
+
'h_dim': 196,
|
306 |
+
'nblocks': 4,
|
307 |
+
'act': 'tanh'
|
308 |
+
}
|
hypertrack/models/models_tune-5.py
ADDED
@@ -0,0 +1,635 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# HyperTrack neural model torch classes
|
2 |
+
#
|
3 |
+
# [email protected], 2023
|
4 |
+
|
5 |
+
from typing import Callable, Union, Optional
|
6 |
+
|
7 |
+
import math
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
|
12 |
+
import torch_geometric
|
13 |
+
import torch_geometric.transforms as T
|
14 |
+
|
15 |
+
from torch_geometric.nn import MessagePassing
|
16 |
+
from torch_geometric.typing import Size, Tensor, OptTensor, PairTensor, PairOptTensor, OptPairTensor, Adj
|
17 |
+
|
18 |
+
from hypertrack.dmlp import MLP
|
19 |
+
import hypertrack.flows as fnn
|
20 |
+
|
21 |
+
|
22 |
+
class SuperEdgeConv(MessagePassing):
|
23 |
+
r"""
|
24 |
+
Custom GNN convolution operator aka 'generalized EdgeConv' (original EdgeConv: arxiv.org/abs/1801.07829)
|
25 |
+
"""
|
26 |
+
|
27 |
+
def __init__(self, mlp_edge: Callable, mlp_latent: Callable, aggr: str='mean',
|
28 |
+
mp_attn_dim: int=0, use_residual=True, **kwargs):
|
29 |
+
|
30 |
+
if aggr == 'multi-aggregation':
|
31 |
+
aggr = torch_geometric.nn.aggr.MultiAggregation(aggrs=['sum', 'mean', 'std', 'max', 'min'], mode='attn',
|
32 |
+
mode_kwargs={'in_channels': mp_attn_dim, 'out_channels': mp_attn_dim, 'num_heads': 1})
|
33 |
+
|
34 |
+
if aggr == 'set-transformer':
|
35 |
+
aggr = torch_geometric.nn.aggr.SetTransformerAggregation(channels=mp_attn_dim, num_seed_points=1,
|
36 |
+
num_encoder_blocks=1, num_decoder_blocks=1, heads=1, concat=False,
|
37 |
+
layer_norm=False, dropout=0.0)
|
38 |
+
|
39 |
+
super().__init__(aggr=aggr, **kwargs)
|
40 |
+
self.nn = mlp_edge
|
41 |
+
self.nn_final = mlp_latent
|
42 |
+
self.use_residual = use_residual
|
43 |
+
|
44 |
+
self.reset_parameters()
|
45 |
+
|
46 |
+
self.apply(self.init_)
|
47 |
+
|
48 |
+
def init_(self, module):
|
49 |
+
if type(module) in {nn.Linear}:
|
50 |
+
#print(__name__ + f'.SuperEdgeConv: Initializing module: {module}')
|
51 |
+
nn.init.xavier_normal_(module.weight)
|
52 |
+
nn.init.zeros_(module.bias)
|
53 |
+
|
54 |
+
def reset_parameters(self):
|
55 |
+
torch_geometric.nn.inits.reset(self.nn)
|
56 |
+
torch_geometric.nn.inits.reset(self.nn_final)
|
57 |
+
|
58 |
+
def forward(self, x: Union[Tensor, PairTensor], edge_index: Adj,
|
59 |
+
edge_attr: OptTensor = None, edge_weight: OptTensor = None, size: Size = None) -> Tensor:
|
60 |
+
|
61 |
+
if edge_attr is not None and len(edge_attr.shape) == 1: # if 1-dim edge_attributes
|
62 |
+
edge_attr = edge_attr[:,None]
|
63 |
+
|
64 |
+
# Message passing
|
65 |
+
m = self.propagate(edge_index, x=x, edge_attr=edge_attr, edge_weight=edge_weight, size=None)
|
66 |
+
|
67 |
+
# Final MLP
|
68 |
+
y = self.nn_final(torch.concat([x, m], dim=-1))
|
69 |
+
|
70 |
+
# Residual connections
|
71 |
+
if self.use_residual and (y.shape[-1] == x.shape[-1]):
|
72 |
+
y = y + x
|
73 |
+
|
74 |
+
return y
|
75 |
+
|
76 |
+
def message(self, x_i: Tensor, x_j: Tensor, edge_attr: OptTensor, edge_weight: OptTensor) -> Tensor:
|
77 |
+
|
78 |
+
# Edge features
|
79 |
+
e1 = torch.norm(x_j - x_i, dim=-1) # Norm of the difference (invariant under rotations and translations)
|
80 |
+
e2 = torch.sum(x_j * x_i, dim=-1) # Dot-product (invariant under rotations but not translations)
|
81 |
+
|
82 |
+
if len(e1.shape) == 1:
|
83 |
+
e1 = e1[:,None]
|
84 |
+
e2 = e2[:,None]
|
85 |
+
|
86 |
+
if edge_attr is not None:
|
87 |
+
m = self.nn(torch.cat([x_i, x_j - x_i, x_j * x_i, e1, e2, edge_attr], dim=-1))
|
88 |
+
else:
|
89 |
+
m = self.nn(torch.cat([x_i, x_j - x_i, x_j * x_i, e1, e2], dim=-1))
|
90 |
+
|
91 |
+
return m if edge_weight is None else m * edge_weight.view(-1, 1)
|
92 |
+
|
93 |
+
def __repr__(self):
|
94 |
+
return f'{self.__class__.__name__} (nn={self.nn}, nn_final={self.nn_final})'
|
95 |
+
|
96 |
+
|
97 |
+
class CoordNorm(nn.Module):
|
98 |
+
"""
|
99 |
+
Coordinate normalization for stability with GaugeEdgeConv
|
100 |
+
"""
|
101 |
+
def __init__(self, eps = 1e-8, scale_init = 1.0):
|
102 |
+
super().__init__()
|
103 |
+
self.eps = eps
|
104 |
+
scale = torch.zeros(1).fill_(scale_init)
|
105 |
+
self.scale = nn.Parameter(scale)
|
106 |
+
|
107 |
+
def forward(self, coord):
|
108 |
+
norm = coord.norm(dim = -1, keepdim = True)
|
109 |
+
normed_coord = coord / norm.clamp(min = self.eps)
|
110 |
+
return normed_coord * self.scale
|
111 |
+
|
112 |
+
|
113 |
+
class GaugeEdgeConv(MessagePassing):
|
114 |
+
r"""
|
115 |
+
Custom GNN convolution operator aka 'E(N) equivariant GNN' (arxiv.org/abs/2102.09844)
|
116 |
+
"""
|
117 |
+
|
118 |
+
def __init__(self, mlp_edge: Callable, mlp_coord: Callable, mlp_latent: Callable, coord_dim: int=0,
|
119 |
+
update_coord: bool=True, update_latent: bool=True, aggr: str='mean', mp_attn_dim: int=0,
|
120 |
+
norm_coord=False, norm_coord_scale_init = 1e-2, **kwargs):
|
121 |
+
|
122 |
+
kwargs.setdefault('aggr', aggr)
|
123 |
+
super(GaugeEdgeConv, self).__init__(**kwargs)
|
124 |
+
|
125 |
+
self.mlp_edge = mlp_edge
|
126 |
+
self.mlp_coord = mlp_coord
|
127 |
+
self.mlp_latent = mlp_latent
|
128 |
+
|
129 |
+
self.update_coord = update_coord
|
130 |
+
self.update_latent = update_latent
|
131 |
+
|
132 |
+
self.coord_dim = coord_dim
|
133 |
+
|
134 |
+
# Coordinate normalization
|
135 |
+
self.coors_norm = CoordNorm(scale_init = norm_coord_scale_init) if norm_coord else nn.Identity()
|
136 |
+
|
137 |
+
self.reset_parameters()
|
138 |
+
|
139 |
+
self.apply(self.init_)
|
140 |
+
|
141 |
+
def init_(self, module):
|
142 |
+
if type(module) in {nn.Linear}:
|
143 |
+
#print(__name__ + f'.GaugeEdgeConv: Initializing module: {module}')
|
144 |
+
nn.init.xavier_normal_(module.weight)
|
145 |
+
nn.init.zeros_(module.bias)
|
146 |
+
|
147 |
+
def reset_parameters(self):
|
148 |
+
torch_geometric.nn.inits.reset(self.mlp_edge)
|
149 |
+
torch_geometric.nn.inits.reset(self.mlp_coord)
|
150 |
+
torch_geometric.nn.inits.reset(self.mlp_latent)
|
151 |
+
|
152 |
+
|
153 |
+
def forward(self, x: Union[Tensor, PairTensor], edge_index: Adj,
|
154 |
+
edge_attr: OptTensor = None, edge_weight: OptTensor = None, size: Size = None) -> Tensor:
|
155 |
+
"""
|
156 |
+
Forward function
|
157 |
+
"""
|
158 |
+
|
159 |
+
# Separate spatial (e.g. 3D-coordinates) and features
|
160 |
+
coord, feats = x[..., 0:self.coord_dim], x[..., self.coord_dim:]
|
161 |
+
|
162 |
+
# Coordinate difference: x_i - x_j
|
163 |
+
diff_coord = coord[edge_index[0]] - coord[edge_index[1]]
|
164 |
+
diff_norm2 = (diff_coord ** 2).sum(dim=-1, keepdim=True)
|
165 |
+
|
166 |
+
if edge_attr is not None:
|
167 |
+
if len(edge_attr.shape) == 1: # if 1-dim edge_attributes
|
168 |
+
edge_attr = edge_attr[:,None]
|
169 |
+
|
170 |
+
edge_attr_feats = torch.cat([edge_attr, diff_norm2], dim=-1)
|
171 |
+
else:
|
172 |
+
edge_attr_feats = diff_norm2
|
173 |
+
|
174 |
+
# Propagation
|
175 |
+
latent_out, coord_out = self.propagate(edge_index, x=feats, edge_attr=edge_attr_feats,
|
176 |
+
edge_weight=edge_weight, coord=coord, diff_coord=diff_coord, size=None)
|
177 |
+
|
178 |
+
return torch.cat([coord_out, latent_out], dim=-1)
|
179 |
+
|
180 |
+
def message(self, x_i: Tensor, x_j: Tensor, edge_attr: OptTensor, edge_weight: OptTensor) -> Tensor:
|
181 |
+
"""
|
182 |
+
Message passing core operation between nodes (i,j)
|
183 |
+
"""
|
184 |
+
|
185 |
+
m_ij = self.mlp_edge(torch.cat([x_i, x_j, edge_attr], dim=-1))
|
186 |
+
|
187 |
+
return m_ij if edge_weight is None else m_ij * edge_weight.view(-1, 1)
|
188 |
+
|
189 |
+
def propagate(self, edge_index: Adj, size: Size = None, **kwargs):
|
190 |
+
"""
|
191 |
+
The initial call to start propagating messages.
|
192 |
+
|
193 |
+
Args:
|
194 |
+
edge_index: holds the indices of a general (sparse)
|
195 |
+
assignment matrix of shape :obj:`[N, M]`.
|
196 |
+
size: (tuple, optional) if none, the size will be inferred
|
197 |
+
and assumed to be quadratic.
|
198 |
+
**kwargs: Any additional data which is needed to construct and
|
199 |
+
aggregate messages, and to update node embeddings.
|
200 |
+
"""
|
201 |
+
|
202 |
+
# Check:
|
203 |
+
# https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/nn/conv/message_passing.py
|
204 |
+
|
205 |
+
size = self._check_input(edge_index, size)
|
206 |
+
coll_dict = self._collect(self._user_args, edge_index, size, kwargs)
|
207 |
+
msg_kwargs = self.inspector.distribute('message', coll_dict)
|
208 |
+
aggr_kwargs = self.inspector.distribute('aggregate', coll_dict)
|
209 |
+
update_kwargs = self.inspector.distribute('update', coll_dict)
|
210 |
+
|
211 |
+
# Message passing of node latent embeddings
|
212 |
+
m_ij = self.message(**msg_kwargs)
|
213 |
+
|
214 |
+
if self.update_coord:
|
215 |
+
|
216 |
+
# Normalize
|
217 |
+
kwargs["diff_coord"] = self.coors_norm(kwargs["diff_coord"])
|
218 |
+
|
219 |
+
# Aggregate weighted coordinates
|
220 |
+
mhat_i = self.aggregate(kwargs["diff_coord"] * self.mlp_coord(m_ij), **aggr_kwargs)
|
221 |
+
coord_out = kwargs["coord"] + mhat_i # Residual connection
|
222 |
+
|
223 |
+
else:
|
224 |
+
coord_out = kwargs["coord"]
|
225 |
+
|
226 |
+
|
227 |
+
if self.update_latent:
|
228 |
+
|
229 |
+
# Aggregate message passing results
|
230 |
+
m_i = self.aggregate(m_ij, **aggr_kwargs)
|
231 |
+
|
232 |
+
# Update latent representation
|
233 |
+
latent_out = self.mlp_latent(torch.cat([kwargs["x"], m_i], dim = -1))
|
234 |
+
latent_out = kwargs["x"] + latent_out # Residual connection
|
235 |
+
else:
|
236 |
+
latent_out = kwargs["x"]
|
237 |
+
|
238 |
+
# Return tuple
|
239 |
+
return self.update((latent_out, coord_out), **update_kwargs)
|
240 |
+
|
241 |
+
def __repr__(self):
|
242 |
+
return f'{self.__class__.__name__}(GaugeEdgeConv = {self.mlp_edge} | {self.mlp_coord} | {self.mlp_latent})'
|
243 |
+
|
244 |
+
|
245 |
+
class InverterNet(torch.nn.Module):
|
246 |
+
"""
|
247 |
+
HyperTrack neural model "umbrella" class, encapsulating GNN and Transformer etc.
|
248 |
+
"""
|
249 |
+
def __init__(self, graph_block_param={}, cluster_block_param={}):
|
250 |
+
|
251 |
+
"""
|
252 |
+
conv_aggr: 'mean' seems very crucial.
|
253 |
+
"""
|
254 |
+
super().__init__()
|
255 |
+
|
256 |
+
self.training_on = True
|
257 |
+
|
258 |
+
self.coord_dim = graph_block_param['coord_dim'] # Input dimension
|
259 |
+
self.h_dim = graph_block_param['h_dim'] # Intermediate latent dimension
|
260 |
+
self.z_dim = graph_block_param['z_dim'] # Final latent dimension
|
261 |
+
|
262 |
+
self.GNN_model = graph_block_param['GNN_model']
|
263 |
+
self.nstack = graph_block_param['nstack']
|
264 |
+
|
265 |
+
self.edge_type = graph_block_param['edge_type']
|
266 |
+
MLP_fusion = graph_block_param['MLP_fusion']
|
267 |
+
|
268 |
+
self.num_edge_attr = 1 # cf. custom edge features constructed in self.encode()
|
269 |
+
self.conv_gnn_edx = nn.ModuleList()
|
270 |
+
|
271 |
+
# Transformer node mask learnable "soft" threshold
|
272 |
+
self.thr = nn.Parameter(torch.Tensor(1))
|
273 |
+
nn.init.constant_(self.thr, 0.5)
|
274 |
+
|
275 |
+
# 1. GNN encoder
|
276 |
+
|
277 |
+
## Model type
|
278 |
+
if self.GNN_model == 'GaugeEdgeConv':
|
279 |
+
|
280 |
+
self.m_dim = graph_block_param['GaugeEdgeConv']['m_dim']
|
281 |
+
|
282 |
+
MLP_GNN_edge = graph_block_param['MLP_GNN_edge']
|
283 |
+
MLP_GNN_coord = graph_block_param['MLP_GNN_coord']
|
284 |
+
MLP_GNN_latent = graph_block_param['MLP_GNN_latent']
|
285 |
+
|
286 |
+
num_intrinsic_attr = 1 # distance
|
287 |
+
|
288 |
+
for i in range(self.nstack):
|
289 |
+
self.conv_gnn_edx.append(GaugeEdgeConv(
|
290 |
+
mlp_edge = MLP([2*self.h_dim + self.num_edge_attr + num_intrinsic_attr, 2*self.m_dim, self.m_dim], **MLP_GNN_edge),
|
291 |
+
mlp_coord = MLP([self.m_dim, 2*self.m_dim, 1], **MLP_GNN_coord),
|
292 |
+
mlp_latent = MLP([self.h_dim + self.m_dim, 2*self.h_dim, self.h_dim], **MLP_GNN_latent),
|
293 |
+
aggr=graph_block_param['GaugeEdgeConv']['aggr'][i],
|
294 |
+
norm_coord=graph_block_param['GaugeEdgeConv']['norm_coord'],
|
295 |
+
norm_coord_scale_init=graph_block_param['GaugeEdgeConv']['norm_coord_scale_init'],
|
296 |
+
coord_dim=self.coord_dim))
|
297 |
+
|
298 |
+
self.mlp_fusion_edx = MLP([self.nstack * (self.coord_dim + self.h_dim), self.h_dim, self.z_dim], **MLP_fusion)
|
299 |
+
|
300 |
+
## Model type
|
301 |
+
elif self.GNN_model == 'SuperEdgeConv':
|
302 |
+
|
303 |
+
self.m_dim = graph_block_param['SuperEdgeConv']['m_dim']
|
304 |
+
|
305 |
+
MLP_GNN_edge = graph_block_param['MLP_GNN_edge']
|
306 |
+
MLP_GNN_latent = graph_block_param['MLP_GNN_latent']
|
307 |
+
|
308 |
+
num_intrinsic_attr = 2 # distance and dot-product
|
309 |
+
|
310 |
+
self.conv_gnn_edx.append(SuperEdgeConv(
|
311 |
+
mlp_edge = MLP([3*self.coord_dim + self.num_edge_attr + num_intrinsic_attr, self.m_dim, self.m_dim], **MLP_GNN_edge),
|
312 |
+
mlp_latent = MLP([self.coord_dim + self.m_dim, self.h_dim, self.h_dim], **MLP_GNN_latent),
|
313 |
+
mp_attn_dim=self.h_dim, aggr=graph_block_param['SuperEdgeConv']['aggr'][0], use_residual=False))
|
314 |
+
|
315 |
+
for i in range(1,self.nstack):
|
316 |
+
self.conv_gnn_edx.append(SuperEdgeConv(
|
317 |
+
mlp_edge = MLP([3*self.h_dim + self.num_edge_attr + num_intrinsic_attr, self.m_dim, self.m_dim], **MLP_GNN_edge),
|
318 |
+
mlp_latent = MLP([self.h_dim + self.m_dim, self.h_dim, self.h_dim], **MLP_GNN_latent),
|
319 |
+
mp_attn_dim=self.h_dim, aggr=graph_block_param['SuperEdgeConv']['aggr'][i], use_residual=graph_block_param['SuperEdgeConv']['use_residual']))
|
320 |
+
|
321 |
+
self.mlp_fusion_edx = MLP([self.nstack * self.h_dim, self.h_dim, self.z_dim], **MLP_fusion)
|
322 |
+
|
323 |
+
else:
|
324 |
+
raise Exception(__name__ + '.__init__: Unknown GNN_model chosen')
|
325 |
+
|
326 |
+
# 2. Graph edge predictor
|
327 |
+
MLP_correlate = graph_block_param['MLP_correlate']
|
328 |
+
|
329 |
+
if self.edge_type == 'symmetric-dot':
|
330 |
+
self.mlp_2pt_edx = MLP([self.z_dim, self.z_dim//2, self.z_dim//2, 1], **MLP_correlate)
|
331 |
+
elif self.edge_type == 'symmetrized' or self.edge_type == 'asymmetric':
|
332 |
+
self.mlp_2pt_edx = MLP([2*self.z_dim, self.z_dim//2, self.z_dim//2, 1], **MLP_correlate)
|
333 |
+
|
334 |
+
# 3. Clustering predictor
|
335 |
+
self.transformer_ccx = STransformer(**cluster_block_param)
|
336 |
+
|
337 |
+
def encode(self, x, edge_index):
|
338 |
+
"""
|
339 |
+
Encoder GNN
|
340 |
+
"""
|
341 |
+
# Compute node degree 'custom feature' between edges
|
342 |
+
d = torch_geometric.utils.degree(edge_index[0,:])
|
343 |
+
edge_attr = (d[edge_index[0,:]] - d[edge_index[1,:]]) / torch.mean(d)
|
344 |
+
edge_attr = edge_attr.to(x.dtype)
|
345 |
+
|
346 |
+
# We take each output for the parallel fusion
|
347 |
+
x_out = [None] * self.nstack
|
348 |
+
|
349 |
+
# First input [x; one-vector for the first embeddings (latents)]
|
350 |
+
if self.GNN_model == 'GaugeEdgeConv':
|
351 |
+
x_ = torch.cat([x, torch.ones((x.shape[0], self.h_dim), device=x.device)], dim=-1)
|
352 |
+
else:
|
353 |
+
x_ = x
|
354 |
+
|
355 |
+
# Apply GNN layers
|
356 |
+
x_out[0] = self.conv_gnn_edx[0](x_, edge_index, edge_attr)
|
357 |
+
for i in range(1,self.nstack):
|
358 |
+
x_out[i] = self.conv_gnn_edx[i](x_out[i-1], edge_index, edge_attr)
|
359 |
+
|
360 |
+
return self.mlp_fusion_edx(torch.cat(x_out, dim=-1))
|
361 |
+
|
362 |
+
def decode(self, z, edge_index):
|
363 |
+
"""
|
364 |
+
Decoder of two-point correlations (edges)
|
365 |
+
"""
|
366 |
+
if self.edge_type == 'symmetric-dot':
|
367 |
+
return self.mlp_2pt_edx(z[edge_index[0],:] * z[edge_index[1],:])
|
368 |
+
|
369 |
+
elif self.edge_type == 'symmetrized':
|
370 |
+
a = self.mlp_2pt_edx(torch.cat([z[edge_index[0],:], z[edge_index[1],:]], dim=-1))
|
371 |
+
b = self.mlp_2pt_edx(torch.cat([z[edge_index[1],:], z[edge_index[0],:]], dim=-1))
|
372 |
+
return (a + b) / 2.0
|
373 |
+
|
374 |
+
elif self.edge_type == 'asymmetric':
|
375 |
+
a = self.mlp_2pt_edx(torch.cat([z[edge_index[0],:], z[edge_index[1],:]], dim=-1))
|
376 |
+
return a
|
377 |
+
|
378 |
+
def decode_cc_ind(self, X, X_pivot, X_mask=None, X_pivot_mask=None):
|
379 |
+
"""
|
380 |
+
Decoder of N-point node mask
|
381 |
+
"""
|
382 |
+
|
383 |
+
return self.transformer_ccx(X=X, X_pivot=X_pivot, X_mask=X_mask, X_pivot_mask=X_pivot_mask)
|
384 |
+
|
385 |
+
def set_model_param_grad(self, string_id='edx', requires_grad=True):
|
386 |
+
"""
|
387 |
+
Freeze or unfreeze model parameters (for the gradient descent)
|
388 |
+
"""
|
389 |
+
|
390 |
+
for name, W in self.named_parameters():
|
391 |
+
if string_id in name:
|
392 |
+
W.requires_grad = requires_grad
|
393 |
+
#print(f'Setting requires_grad={W.requires_grad} of the parameter <{name}>')
|
394 |
+
return
|
395 |
+
|
396 |
+
def get_model_param_grad(self, string_id='edx'):
|
397 |
+
"""
|
398 |
+
Get model parameter state (for the gradient descent)
|
399 |
+
"""
|
400 |
+
|
401 |
+
for name, W in self.named_parameters():
|
402 |
+
if string_id in name: # Return the state of the first
|
403 |
+
return W.requires_grad
|
404 |
+
|
405 |
+
|
406 |
+
class MAB(nn.Module):
|
407 |
+
"""
|
408 |
+
Attention based set Transformer block (arxiv.org/abs/1810.00825)
|
409 |
+
"""
|
410 |
+
def __init__(self, dim_Q, dim_K, dim_V, num_heads=4, ln=True, dropout=0.0,
|
411 |
+
MLP_param={'act': 'relu', 'bn': False, 'dropout': 0.0, 'last_act': True}):
|
412 |
+
super(MAB, self).__init__()
|
413 |
+
assert dim_V % num_heads == 0, "MAB: dim_V must be divisible by num_heads"
|
414 |
+
self.dim_V = dim_V
|
415 |
+
self.num_heads = num_heads
|
416 |
+
self.W_q = nn.Linear(dim_Q, dim_V)
|
417 |
+
self.W_k = nn.Linear(dim_K, dim_V)
|
418 |
+
self.W_v = nn.Linear(dim_K, dim_V)
|
419 |
+
self.W_o = nn.Linear(dim_V, dim_V) # Projection layer
|
420 |
+
|
421 |
+
if ln:
|
422 |
+
self.ln0 = nn.LayerNorm(dim_V)
|
423 |
+
self.ln1 = nn.LayerNorm(dim_V)
|
424 |
+
|
425 |
+
if dropout > 0:
|
426 |
+
self.Dropout = nn.Dropout(dropout)
|
427 |
+
|
428 |
+
self.MLP = MLP([dim_V, dim_V, dim_V], **MLP_param)
|
429 |
+
|
430 |
+
# We use torch default initialization here
|
431 |
+
|
432 |
+
def reshape_attention_mask(self, Q, K, mask):
|
433 |
+
"""
|
434 |
+
Reshape attention masks
|
435 |
+
"""
|
436 |
+
total_mask = None
|
437 |
+
|
438 |
+
if mask[0] is not None:
|
439 |
+
qmask = mask[0].repeat(self.num_heads,1)[:,:,None] # New shape = [# heads x # batches, # queries, 1]
|
440 |
+
|
441 |
+
if mask[1] is not None:
|
442 |
+
kmask = mask[1].repeat(self.num_heads,1)[:,None,:] # New shape = [# heads x # batches, 1, # keys]
|
443 |
+
|
444 |
+
if mask[0] is None and mask[1] is not None:
|
445 |
+
total_mask = kmask.repeat(1,Q.shape[1],1)
|
446 |
+
elif mask[0] is not None and mask[1] is None:
|
447 |
+
total_mask = qmask.repeat(1,1,K.shape[1])
|
448 |
+
elif mask[0] is not None and mask[1] is not None:
|
449 |
+
total_mask = qmask & kmask # will auto broadcast dimensions to [# heads x # batches, # queries, # keys]
|
450 |
+
|
451 |
+
return total_mask
|
452 |
+
|
453 |
+
def forward(self, Q, K, mask = (None, None)):
|
454 |
+
"""
|
455 |
+
Q: queries
|
456 |
+
K: keys
|
457 |
+
mask: query mask [#batches x #queries], keys mask [#batches x #keys]
|
458 |
+
"""
|
459 |
+
dim_split = self.dim_V // self.num_heads
|
460 |
+
|
461 |
+
# Apply Matrix-vector multiplications
|
462 |
+
# and do multihead splittings (reshaping)
|
463 |
+
Q_ = torch.cat(self.W_q(Q).split(dim_split, -1), 0)
|
464 |
+
K_ = torch.cat(self.W_k(K).split(dim_split, -1), 0)
|
465 |
+
V_ = torch.cat(self.W_v(K).split(dim_split, -1), 0)
|
466 |
+
|
467 |
+
# Dot-product attention: softmax(QK^T / sqrt(dim_V))
|
468 |
+
QK = Q_.bmm(K_.transpose(-1,-2)) / math.sqrt(self.dim_V) # bmm does batched matrix multiplication
|
469 |
+
|
470 |
+
# Attention mask
|
471 |
+
total_mask = self.reshape_attention_mask(Q=Q, K=K, mask=mask)
|
472 |
+
if total_mask is not None:
|
473 |
+
QK.masked_fill_(~total_mask, float('-1E6'))
|
474 |
+
|
475 |
+
# Compute attention probabilities
|
476 |
+
A = torch.softmax(QK,-1)
|
477 |
+
|
478 |
+
# Residual connection of Q + multi-head attention A weighted V result
|
479 |
+
H = Q + self.W_o(torch.cat((A.bmm(V_)).split(Q.size(0), 0),-1))
|
480 |
+
|
481 |
+
# First layer normalization + Dropout
|
482 |
+
H = H if getattr(self, 'ln0', None) is None else self.ln0(H)
|
483 |
+
H = H if getattr(self, 'Dropout', None) is None else self.Dropout(H)
|
484 |
+
|
485 |
+
# Residual connection of H + feed-forward net
|
486 |
+
H = H + self.MLP(H)
|
487 |
+
|
488 |
+
# Second layer normalization + Dropout
|
489 |
+
H = H if getattr(self, 'ln1', None) is None else self.ln1(H)
|
490 |
+
H = H if getattr(self, 'Dropout', None) is None else self.Dropout(H)
|
491 |
+
|
492 |
+
return H
|
493 |
+
|
494 |
+
|
495 |
+
class SAB(nn.Module):
|
496 |
+
"""
|
497 |
+
Full self-attention MAB(X,X)
|
498 |
+
~O(N^2)
|
499 |
+
"""
|
500 |
+
def __init__(self, dim_in, dim_out, num_heads=4, ln=True, dropout=0.0,
|
501 |
+
MLP_param={'act': 'relu', 'bn': False, 'dropout': 0.0, 'last_act': True}):
|
502 |
+
super(SAB, self).__init__()
|
503 |
+
self.mab = MAB(dim_Q=dim_in, dim_K=dim_in, dim_V=dim_out,
|
504 |
+
num_heads=num_heads, ln=ln, dropout=dropout, MLP_param=MLP_param)
|
505 |
+
|
506 |
+
def forward(self, X, mask=None):
|
507 |
+
return self.mab(Q=X, K=X, mask=(mask, mask))
|
508 |
+
|
509 |
+
class ISAB(nn.Module):
|
510 |
+
"""
|
511 |
+
Faster version of SAB with inducing points
|
512 |
+
"""
|
513 |
+
def __init__(self, dim_in, dim_out, num_inds, num_heads=4, ln=True, dropout=0.0,
|
514 |
+
MLP_param={'act': 'relu', 'bn': False, 'dropout': 0.0, 'last_act': True}):
|
515 |
+
super(ISAB, self).__init__()
|
516 |
+
self.I = nn.Parameter(torch.Tensor(1, num_inds, dim_out))
|
517 |
+
nn.init.xavier_uniform_(self.I)
|
518 |
+
self.mab0 = MAB(dim_Q=dim_out, dim_K=dim_in, dim_V=dim_out,
|
519 |
+
num_heads=num_heads, ln=ln, dropout=dropout, MLP_param=MLP_param)
|
520 |
+
self.mab1 = MAB(dim_Q=dim_in, dim_K=dim_out, dim_V=dim_out,
|
521 |
+
num_heads=num_heads, ln=ln, dropout=dropout, MLP_param=MLP_param)
|
522 |
+
|
523 |
+
def forward(self, X, mask=None):
|
524 |
+
H = self.mab0(Q=self.I.expand(X.size(0),-1,-1), K=X, mask=(None, mask))
|
525 |
+
H = self.mab1(Q=X, K=H, mask=(mask, None))
|
526 |
+
return H
|
527 |
+
|
528 |
+
class PMA(nn.Module):
|
529 |
+
"""
|
530 |
+
Adaptive pooling with "k" > 1 option (several learnable reference vectors)
|
531 |
+
"""
|
532 |
+
def __init__(self, dim, k=1, num_heads=4, ln=True, dropout=0.0,
|
533 |
+
MLP_param={'act': 'relu', 'bn': False, 'dropout': 0.0, 'last_act': True}):
|
534 |
+
super(PMA, self).__init__()
|
535 |
+
self.S = nn.Parameter(torch.Tensor(1, k, dim))
|
536 |
+
nn.init.xavier_uniform_(self.S)
|
537 |
+
self.mab = MAB(dim_Q=dim, dim_K=dim, dim_V=dim,
|
538 |
+
num_heads=num_heads, ln=ln, dropout=dropout, MLP_param=MLP_param)
|
539 |
+
|
540 |
+
def forward(self, X, mask):
|
541 |
+
return self.mab(Q=self.S.expand(X.size(0),-1,-1), K=X, mask=(None, mask))
|
542 |
+
|
543 |
+
class STransformer(nn.Module):
|
544 |
+
"""
|
545 |
+
Set Transformer based clustering network
|
546 |
+
"""
|
547 |
+
|
548 |
+
class mySequential(nn.Sequential):
|
549 |
+
"""
|
550 |
+
Multiple inputs version of nn.Sequential customized for
|
551 |
+
multiple self-attention layers with a (same) mask
|
552 |
+
"""
|
553 |
+
def forward(self, *inputs):
|
554 |
+
|
555 |
+
X, mask = inputs[0], inputs[1]
|
556 |
+
for module in self._modules.values():
|
557 |
+
X = module(X,mask)
|
558 |
+
return X
|
559 |
+
|
560 |
+
def __init__(self, in_dim, h_dim, output_dim, nstack_dec=4,
|
561 |
+
MLP_enc={}, MAB_dec={}, SAB_dec={}, MLP_mask={}):
|
562 |
+
|
563 |
+
super().__init__()
|
564 |
+
|
565 |
+
# Encoder MLP
|
566 |
+
self.MLP_E = MLP([in_dim, in_dim, h_dim], **MLP_enc)
|
567 |
+
|
568 |
+
# Decoder
|
569 |
+
self.mab_D = MAB(dim_Q=h_dim, dim_K=h_dim, dim_V=h_dim, **MAB_dec)
|
570 |
+
|
571 |
+
# Decoder self-attention layers
|
572 |
+
self.sab_stack_D = self.mySequential(*[
|
573 |
+
self.mySequential(
|
574 |
+
SAB(dim_in=h_dim, dim_out=h_dim, **SAB_dec)
|
575 |
+
)
|
576 |
+
for i in range(nstack_dec)
|
577 |
+
])
|
578 |
+
|
579 |
+
# Final mask MLP
|
580 |
+
self.MLP_m = MLP([h_dim, h_dim//2, h_dim//4, output_dim], **MLP_mask)
|
581 |
+
|
582 |
+
def forward(self, X, X_pivot, X_mask = None, X_pivot_mask = None):
|
583 |
+
"""
|
584 |
+
X: input data vectors per row
|
585 |
+
X_pivot: pivotal data (at least one per batch)
|
586 |
+
X_mask: boolean (batch) mask for X (set 0 for zero-padded null elements)
|
587 |
+
X_pivot_mask: boolean (batch) mask for X_pivot
|
588 |
+
"""
|
589 |
+
|
590 |
+
# Simple encoder
|
591 |
+
G = self.MLP_E(X)
|
592 |
+
G_pivot = self.MLP_E(X_pivot)
|
593 |
+
|
594 |
+
# Compute cross-attention and self-attention
|
595 |
+
H_m = self.sab_stack_D(self.mab_D(Q=G, K=G_pivot, mask=(X_mask, X_pivot_mask)), X_mask)
|
596 |
+
|
597 |
+
# Decode logits
|
598 |
+
return self.MLP_m(H_m)
|
599 |
+
|
600 |
+
class TrackFlowNet(torch.nn.Module):
|
601 |
+
"""
|
602 |
+
Normalizing Flow Network [experimental]
|
603 |
+
"""
|
604 |
+
def __init__(self, in_dim, num_cond_inputs=None, h_dim=64, nblocks=4, act='tanh'):
|
605 |
+
"""
|
606 |
+
conv_aggr: 'mean' in GNN seems to work ok with Flow!
|
607 |
+
"""
|
608 |
+
super().__init__()
|
609 |
+
|
610 |
+
self.training_on = True
|
611 |
+
|
612 |
+
# -----------------------------------
|
613 |
+
# MAF density estimator
|
614 |
+
|
615 |
+
modules = []
|
616 |
+
for _ in range(nblocks):
|
617 |
+
modules += [
|
618 |
+
fnn.MADE(num_inputs=in_dim, num_hidden=h_dim, num_cond_inputs=num_cond_inputs, act=act),
|
619 |
+
#fnn.BatchNormFlow(in_dim), # May cause problems in recursive use
|
620 |
+
fnn.Reverse(in_dim)
|
621 |
+
]
|
622 |
+
|
623 |
+
self.track_pdf = fnn.FlowSequential(*modules)
|
624 |
+
# -----------------------------------
|
625 |
+
|
626 |
+
def set_model_param_grad(self, string_id='pdf', requires_grad=True):
|
627 |
+
"""
|
628 |
+
Freeze or unfreeze model parameters (for the gradient descent)
|
629 |
+
"""
|
630 |
+
|
631 |
+
for name, W in self.named_parameters():
|
632 |
+
if string_id in name:
|
633 |
+
W.requires_grad = requires_grad
|
634 |
+
#print(f'Setting requires_grad={W.requires_grad} of the parameter <{name}>')
|
635 |
+
return
|
models/tag_f-0p01-hyper-5/model_net_epoch_834169.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:57da49d764e4fee13e9556a63b856cc254b28cd99397a99787a841baab1ee3c7
|
3 |
+
size 143403488
|
models/tag_f-0p01-hyper-5/model_pdf_epoch_834169.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:365a51ad4bafce3aae0fb1f5442b1aa632d473ed866aca3d720bb0b4ba6401d1
|
3 |
+
size 144523939
|
models/tag_f-0p01-hyper-5/optimizer_net_epoch_834169.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:91c9573a6a4a193be373cb523dd814c6340f2bfcc3ace913154009928a352944
|
3 |
+
size 2548315
|
models/tag_f-0p01-hyper-5/optimizer_pdf_epoch_834169.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f9af437f10a548c60026915058d0202e947243d94638aa55569469e1e0ebf711
|
3 |
+
size 2461899
|
models/tag_f-0p01-hyper-5/scheduler_net_epoch_834169.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:66896ef7df452a76f6482bd1b62bca415ae69d4fe88131a1c742ad638816d33c
|
3 |
+
size 789
|
models/tag_f-0p01-hyper-5/scheduler_pdf_epoch_834169.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:80c31de2a96c990896e7a996c4226ad29b9b73d8230d0d7e26d51fee634479b7
|
3 |
+
size 789
|
models/tag_f-0p1-hyper-5/model_net_epoch_752039.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b504e40740146fab5b809ad76b1c7814751a6b3594889e6268e5762d15041a5c
|
3 |
+
size 129397024
|
models/tag_f-0p1-hyper-5/model_pdf_epoch_752039.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:57c7d332c93ed228129c93a115d727baae508835dc3179b2f37535f9f8f90bab
|
3 |
+
size 130517539
|
models/tag_f-0p1-hyper-5/optimizer_net_epoch_752039.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:71def7ec1d75381d14d37f6149c98a16ba6d1fee5581f2b234b24a05ba60e5be
|
3 |
+
size 2543579
|
models/tag_f-0p1-hyper-5/optimizer_pdf_epoch_752039.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c21ef4caf9621c3f385d00f08fc0a578e8a12ca02d015a67a1b0072c93a9f4a9
|
3 |
+
size 981
|
models/tag_f-0p1-hyper-5/scheduler_net_epoch_752039.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8089ea387f0a81570664dbeec2f0a7201b1adea1195a7459f943792b783f6bc1
|
3 |
+
size 789
|
models/tag_f-0p1-hyper-5/scheduler_pdf_epoch_752039.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4ec10834480a2d82e28eee0a60f05f078262b364b2587dd2d02caa51de97fda3
|
3 |
+
size 789
|
models/tag_f-0p3-hyper-5/model_net_epoch_484878.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:80c29152b67cd037a07bc0de75e0f15a024f270b1f6b90205196fc8a186f5a75
|
3 |
+
size 83839328
|
models/tag_f-0p3-hyper-5/model_pdf_epoch_484878.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8557c0ebffa2f335cdf7d94bd3025b5df220558b948d23484d706fb333d04913
|
3 |
+
size 84959843
|
models/tag_f-0p3-hyper-5/optimizer_net_epoch_484878.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:62204519b3e2b84bd91396b84919ba7c4c680171988c255599969d2f7f45a734
|
3 |
+
size 2543579
|
models/tag_f-0p3-hyper-5/optimizer_pdf_epoch_484878.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:bb83aca384ca99600390f01b559ecca04b78fc9b8f2222c5f3d5deb99c1527bf
|
3 |
+
size 981
|
models/tag_f-0p3-hyper-5/scheduler_net_epoch_484878.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:688e4074963740abec8f59b5f54c219377299feeabcea21c0f79015f47040e2a
|
3 |
+
size 789
|
models/tag_f-0p3-hyper-5/scheduler_pdf_epoch_484878.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7d2e1e0f5cc38306cbdb00ae64b1dec0532852790dd48da7503310dfa97e119c
|
3 |
+
size 789
|
models/voxdyn_node2node_hyper_ncell_131072.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7744317ce56a104d8d7357272e2b117a50d30db962fc7804fbd1e9dd658d7b15
|
3 |
+
size 696889444
|
models/voxdyn_node2node_hyper_ncell_262144.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b9689592db5542078f033c1efcbbf3bd61157c0ac6b717542717189ca4d77718
|
3 |
+
size 1248245246
|
models/voxdyn_node2node_hyper_ncell_32768.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c13e5c1574714a6e71c7c95c93574936bfa7ddef65938361facb6bc0639340ba
|
3 |
+
size 215949890
|
models/voxdyn_node2node_hyper_ncell_524288.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ccc5926bfef73cde831fffb8b5c86b562b3176353cb61153ca6aab58eb0fe61a
|
3 |
+
size 2135411740
|
models/voxdyn_node2node_hyper_ncell_65536.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c1d234b35eae47cbc7f99506986b4ed2b9c6f581ab9094d05f1bf4246fb86f04
|
3 |
+
size 388562698
|