from diffusers import DiffusionPipeline import torch from sklearn.neighbors import KNeighborsRegressor import numpy as np class inDelphiPipeline(DiffusionPipeline): def __init__(self, inDelphi_model, onebp_features, insert_probabilities, m654): super().__init__() self.register_modules(inDelphi_model=inDelphi_model) self.onebp_feature_mean = onebp_features.mean(axis=0) self.onebp_feature_std = onebp_features.std(axis=0) self.insertion_model = KNeighborsRegressor(weights='distance').fit((onebp_features - self.onebp_feature_mean) / self.onebp_feature_std, insert_probabilities) self.m654 = m654 / np.maximum(np.linalg.norm(m654, ord=1, axis=1, keepdims=True), 1e-6) self.m4 = m654.reshape(16, 4, 4).sum(axis=0) self.m4 = self.m4 / np.maximum(np.linalg.norm(self.m4, ord=1, axis=1, keepdims=True), 1e-6) @torch.no_grad() def __call__(self, batch, use_m654=False): mh_weights, mhless_weights, total_del_len_weights = self.inDelphi_model( batch["mh_input"].to(self.inDelphi_model.device), batch["mh_del_len"].to(self.inDelphi_model.device) ).values() mX = self.m654 if use_m654 else self.m4 log_total_weights = total_del_len_weights.sum(dim=1, keepdim=True).log() precisions = 1 - torch.distributions.Categorical(total_del_len_weights[:,:28]).entropy() / torch.log(torch.tensor(28)) onebp_features = torch.cat([ batch["onebp_feature"], precisions[:, None].cpu(), log_total_weights.cpu() ], dim=1).cpu().numpy() pre_insert_probabilities = self.insertion_model.predict((onebp_features - self.onebp_feature_mean) / self.onebp_feature_std) pre_insert_1bps = mX[batch['m654'] % 4] if mX.shape[0] == 4 else mX[batch['m654']] return { "mh_weight": [ mh_weights[i, batch["mh_del_len"][i] < self.inDelphi_model.config.DELLEN_LIMIT] for i in range(len(batch["mh_del_len"])) ], "mhless_weight": mhless_weights, "total_del_len_weight": total_del_len_weights, "pre_insert_probability": pre_insert_probabilities, "pre_insert_1bp": pre_insert_1bps }