Upload pipeline.py with huggingface_hub
Browse files- pipeline.py +43 -0
pipeline.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from diffusers import DiffusionPipeline
|
2 |
+
import torch
|
3 |
+
from sklearn.neighbors import KNeighborsRegressor
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
class inDelphiPipeline(DiffusionPipeline):
|
7 |
+
def __init__(self, inDelphi_model, onebp_features, insert_probabilities, m654):
|
8 |
+
super().__init__()
|
9 |
+
|
10 |
+
self.register_modules(inDelphi_model=inDelphi_model)
|
11 |
+
self.onebp_feature_mean = onebp_features.mean(axis=0)
|
12 |
+
self.onebp_feature_std = onebp_features.std(axis=0)
|
13 |
+
self.insertion_model = KNeighborsRegressor(weights='distance').fit((onebp_features - self.onebp_feature_mean) / self.onebp_feature_std, insert_probabilities)
|
14 |
+
self.m654 = m654 / np.maximum(np.linalg.norm(m654, ord=1, axis=1, keepdims=True), 1e-6)
|
15 |
+
self.m4 = m654.reshape(16, 4, 4).sum(axis=0)
|
16 |
+
self.m4 = self.m4 / np.maximum(np.linalg.norm(self.m4, ord=1, axis=1, keepdims=True), 1e-6)
|
17 |
+
|
18 |
+
@torch.no_grad()
|
19 |
+
def __call__(self, batch, use_m654=False):
|
20 |
+
mh_weights, mhless_weights, total_del_len_weights = self.inDelphi_model(
|
21 |
+
batch["mh_input"].to(self.inDelphi_model.device),
|
22 |
+
batch["mh_del_len"].to(self.inDelphi_model.device)
|
23 |
+
).values()
|
24 |
+
mX = self.m654 if use_m654 else self.m4
|
25 |
+
log_total_weights = total_del_len_weights.sum(dim=1, keepdim=True).log()
|
26 |
+
precisions = 1 - torch.distributions.Categorical(total_del_len_weights[:,:28]).entropy() / torch.log(torch.tensor(28))
|
27 |
+
onebp_features = torch.cat([
|
28 |
+
batch["onebp_feature"],
|
29 |
+
precisions[:, None].cpu(),
|
30 |
+
log_total_weights.cpu()
|
31 |
+
], dim=1).cpu().numpy()
|
32 |
+
pre_insert_probabilities = self.insertion_model.predict((onebp_features - self.onebp_feature_mean) / self.onebp_feature_std)
|
33 |
+
pre_insert_1bps = mX[batch['m654'] % 4] if mX.shape[0] == 4 else mX[batch['m654']]
|
34 |
+
return {
|
35 |
+
"mh_weight": [
|
36 |
+
mh_weights[i, batch["mh_del_len"][i] < self.inDelphi_model.config.DELLEN_LIMIT]
|
37 |
+
for i in range(len(batch["mh_del_len"]))
|
38 |
+
],
|
39 |
+
"mhless_weight": mhless_weights,
|
40 |
+
"total_del_len_weight": total_del_len_weights,
|
41 |
+
"pre_insert_probability": pre_insert_probabilities,
|
42 |
+
"pre_insert_1bp": pre_insert_1bps
|
43 |
+
}
|