ljw20180420 commited on
Commit
66376de
·
verified ·
1 Parent(s): 79024d4

Upload pipeline.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. pipeline.py +43 -0
pipeline.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import DiffusionPipeline
2
+ import torch
3
+ from sklearn.neighbors import KNeighborsRegressor
4
+ import numpy as np
5
+
6
+ class inDelphiPipeline(DiffusionPipeline):
7
+ def __init__(self, inDelphi_model, onebp_features, insert_probabilities, m654):
8
+ super().__init__()
9
+
10
+ self.register_modules(inDelphi_model=inDelphi_model)
11
+ self.onebp_feature_mean = onebp_features.mean(axis=0)
12
+ self.onebp_feature_std = onebp_features.std(axis=0)
13
+ self.insertion_model = KNeighborsRegressor(weights='distance').fit((onebp_features - self.onebp_feature_mean) / self.onebp_feature_std, insert_probabilities)
14
+ self.m654 = m654 / np.maximum(np.linalg.norm(m654, ord=1, axis=1, keepdims=True), 1e-6)
15
+ self.m4 = m654.reshape(16, 4, 4).sum(axis=0)
16
+ self.m4 = self.m4 / np.maximum(np.linalg.norm(self.m4, ord=1, axis=1, keepdims=True), 1e-6)
17
+
18
+ @torch.no_grad()
19
+ def __call__(self, batch, use_m654=False):
20
+ mh_weights, mhless_weights, total_del_len_weights = self.inDelphi_model(
21
+ batch["mh_input"].to(self.inDelphi_model.device),
22
+ batch["mh_del_len"].to(self.inDelphi_model.device)
23
+ ).values()
24
+ mX = self.m654 if use_m654 else self.m4
25
+ log_total_weights = total_del_len_weights.sum(dim=1, keepdim=True).log()
26
+ precisions = 1 - torch.distributions.Categorical(total_del_len_weights[:,:28]).entropy() / torch.log(torch.tensor(28))
27
+ onebp_features = torch.cat([
28
+ batch["onebp_feature"],
29
+ precisions[:, None].cpu(),
30
+ log_total_weights.cpu()
31
+ ], dim=1).cpu().numpy()
32
+ pre_insert_probabilities = self.insertion_model.predict((onebp_features - self.onebp_feature_mean) / self.onebp_feature_std)
33
+ pre_insert_1bps = mX[batch['m654'] % 4] if mX.shape[0] == 4 else mX[batch['m654']]
34
+ return {
35
+ "mh_weight": [
36
+ mh_weights[i, batch["mh_del_len"][i] < self.inDelphi_model.config.DELLEN_LIMIT]
37
+ for i in range(len(batch["mh_del_len"]))
38
+ ],
39
+ "mhless_weight": mhless_weights,
40
+ "total_del_len_weight": total_del_len_weights,
41
+ "pre_insert_probability": pre_insert_probabilities,
42
+ "pre_insert_1bp": pre_insert_1bps
43
+ }