Spaces:
Paused
Paused
yjwtheonly
commited on
Commit
·
fce1f4b
1
Parent(s):
8ae6390
Agnostic
Browse files- DiseaseAgnostic/KG_extractor.py +473 -0
- DiseaseAgnostic/edge_to_abstract.py +652 -0
- DiseaseAgnostic/evaluation.py +219 -0
- DiseaseAgnostic/generate_target_and_attack.py +371 -0
- DiseaseAgnostic/model.py +520 -0
- DiseaseAgnostic/utils.py +187 -0
DiseaseAgnostic/KG_extractor.py
ADDED
|
@@ -0,0 +1,473 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#%%
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
from torch.autograd import Variable
|
| 5 |
+
from sklearn import metrics
|
| 6 |
+
|
| 7 |
+
import datetime
|
| 8 |
+
from typing import Dict, Tuple, List
|
| 9 |
+
import logging
|
| 10 |
+
import os
|
| 11 |
+
import utils
|
| 12 |
+
import pickle as pkl
|
| 13 |
+
import json
|
| 14 |
+
import torch.backends.cudnn as cudnn
|
| 15 |
+
|
| 16 |
+
from tqdm import tqdm
|
| 17 |
+
|
| 18 |
+
import sys
|
| 19 |
+
sys.path.append("..")
|
| 20 |
+
import Parameters
|
| 21 |
+
|
| 22 |
+
parser = utils.get_argument_parser()
|
| 23 |
+
parser.add_argument('--reasonable-rate', type = float, default=0.7, help = 'The added edge\'s existance rank prob greater than this rate')
|
| 24 |
+
parser.add_argument('--mode', type=str, default='sentence', help='sentence, finetune, biogpt, bioBART')
|
| 25 |
+
parser.add_argument('--action', type=str, default='parse', help='parse or extract')
|
| 26 |
+
parser.add_argument('--init-mode', type = str, default='random', help = 'How to select target nodes')
|
| 27 |
+
parser.add_argument('--ratio', type = str, default='', help='ratio of the number of changed words')
|
| 28 |
+
args = parser.parse_args()
|
| 29 |
+
args = utils.set_hyperparams(args)
|
| 30 |
+
|
| 31 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 32 |
+
|
| 33 |
+
utils.seed_all(args.seed)
|
| 34 |
+
np.set_printoptions(precision=5)
|
| 35 |
+
cudnn.benchmark = False
|
| 36 |
+
|
| 37 |
+
data_path = '../DiseaseSpecific/processed_data/GNBR'
|
| 38 |
+
target_path = f'processed_data/target_{args.reasonable_rate}{args.init_mode}.pkl'
|
| 39 |
+
attack_path = f'processed_data/attack_edge_{args.model}_{args.reasonable_rate}{args.init_mode}.pkl'
|
| 40 |
+
modified_attack_path = f'processed_data/attack_edge_{args.model}_{args.reasonable_rate}{args.init_mode}{args.mode}.pkl'
|
| 41 |
+
|
| 42 |
+
with open(attack_path, 'rb') as fl:
|
| 43 |
+
Attack_edge_list = pkl.load(fl)
|
| 44 |
+
attack_data = np.array(Attack_edge_list).reshape(-1, 3)
|
| 45 |
+
#%%
|
| 46 |
+
with open(os.path.join(data_path, 'entities_reverse_dict.json')) as fl:
|
| 47 |
+
id_to_meshid = json.load(fl)
|
| 48 |
+
with open(os.path.join(data_path, 'entities_dict.json'), 'r') as fl:
|
| 49 |
+
meshid_to_id = json.load(fl)
|
| 50 |
+
with open(Parameters.GNBRfile+'entity_raw_name', 'rb') as fl:
|
| 51 |
+
entity_raw_name = pkl.load(fl)
|
| 52 |
+
with open(Parameters.GNBRfile+'retieve_sentence_through_edgetype', 'rb') as fl:
|
| 53 |
+
retieve_sentence_through_edgetype = pkl.load(fl)
|
| 54 |
+
with open(Parameters.GNBRfile+'raw_text_of_each_sentence', 'rb') as fl:
|
| 55 |
+
raw_text_sen = pkl.load(fl)
|
| 56 |
+
with open(Parameters.GNBRfile+'original_entity_raw_name', 'rb') as fl:
|
| 57 |
+
full_entity_raw_name = pkl.load(fl)
|
| 58 |
+
for k, v in entity_raw_name.items():
|
| 59 |
+
assert v in full_entity_raw_name[k]
|
| 60 |
+
|
| 61 |
+
#find unique
|
| 62 |
+
once_set = set()
|
| 63 |
+
twice_set = set()
|
| 64 |
+
|
| 65 |
+
with open('../DiseaseSpecific/generate_abstract/valid_entity.json', 'r') as fl:
|
| 66 |
+
valid_entity = json.load(fl)
|
| 67 |
+
valid_entity = set(valid_entity)
|
| 68 |
+
|
| 69 |
+
good_name = set()
|
| 70 |
+
for k, v, in full_entity_raw_name.items():
|
| 71 |
+
names = list(v)
|
| 72 |
+
for name in names:
|
| 73 |
+
# if name == 'in a':
|
| 74 |
+
# print(names)
|
| 75 |
+
good_name.add(name)
|
| 76 |
+
# if name not in once_set:
|
| 77 |
+
# once_set.add(name)
|
| 78 |
+
# else:
|
| 79 |
+
# twice_set.add(name)
|
| 80 |
+
# assert 'WNK4' in once_set
|
| 81 |
+
# good_name = set.difference(once_set, twice_set)
|
| 82 |
+
# assert 'in a' not in good_name
|
| 83 |
+
# assert 'STE20' not in good_name
|
| 84 |
+
# assert 'STE20' not in valid_entity
|
| 85 |
+
# assert 'STE20-related proline-alanine-rich kinase' not in good_name
|
| 86 |
+
# assert 'STE20-related proline-alanine-rich kinase' not in valid_entity
|
| 87 |
+
# raise Exception
|
| 88 |
+
|
| 89 |
+
name_to_type = {}
|
| 90 |
+
name_to_meshid = {}
|
| 91 |
+
|
| 92 |
+
for k, v, in full_entity_raw_name.items():
|
| 93 |
+
names = list(v)
|
| 94 |
+
for name in names:
|
| 95 |
+
if name in good_name:
|
| 96 |
+
name_to_type[name] = k.split('_')[0]
|
| 97 |
+
name_to_meshid[name] = k
|
| 98 |
+
|
| 99 |
+
import spacy
|
| 100 |
+
import networkx as nx
|
| 101 |
+
import pprint
|
| 102 |
+
|
| 103 |
+
def check(p, s):
|
| 104 |
+
|
| 105 |
+
if p < 1 or p >= len(s):
|
| 106 |
+
return True
|
| 107 |
+
return not((s[p]>='a' and s[p]<='z') or (s[p]>='A' and s[p]<='Z') or (s[p]>='0' and s[p]<='9'))
|
| 108 |
+
|
| 109 |
+
def raw_to_format(sen):
|
| 110 |
+
|
| 111 |
+
text = sen
|
| 112 |
+
l = 0
|
| 113 |
+
ret = []
|
| 114 |
+
while(l < len(text)):
|
| 115 |
+
bo =False
|
| 116 |
+
if text[l] != ' ':
|
| 117 |
+
for i in range(len(text), l, -1): # reversing is important !!!
|
| 118 |
+
cc = text[l:i]
|
| 119 |
+
if (cc in good_name or cc in valid_entity) and check(l-1, text) and check(i, text):
|
| 120 |
+
ret.append(cc.replace(' ', '_'))
|
| 121 |
+
l = i
|
| 122 |
+
bo = True
|
| 123 |
+
break
|
| 124 |
+
if not bo:
|
| 125 |
+
ret.append(text[l])
|
| 126 |
+
l += 1
|
| 127 |
+
return ''.join(ret)
|
| 128 |
+
|
| 129 |
+
if args.mode == 'sentence':
|
| 130 |
+
with open(f'generate_abstract/{args.init_mode}{args.reasonable_rate}_chat.json', 'r') as fl:
|
| 131 |
+
draft = json.load(fl)
|
| 132 |
+
elif args.mode == 'finetune':
|
| 133 |
+
with open(f'generate_abstract/{args.init_mode}{args.reasonable_rate}_sentence_finetune.json', 'r') as fl:
|
| 134 |
+
draft = json.load(fl)
|
| 135 |
+
elif args.mode == 'bioBART':
|
| 136 |
+
with open(f'generate_abstract/{args.init_mode}{args.reasonable_rate}{args.ratio}_bioBART_finetune.json', 'r') as fl:
|
| 137 |
+
draft = json.load(fl)
|
| 138 |
+
elif args.mode == 'biogpt':
|
| 139 |
+
with open(f'generate_abstract/{args.init_mode}{args.reasonable_rate}_biogpt.json', 'r') as fl:
|
| 140 |
+
draft = json.load(fl)
|
| 141 |
+
else:
|
| 142 |
+
raise Exception('No!!!')
|
| 143 |
+
|
| 144 |
+
nlp = spacy.load("en_core_web_sm")
|
| 145 |
+
|
| 146 |
+
type_set = set()
|
| 147 |
+
for aa in range(36):
|
| 148 |
+
dependency_sen_dict = retieve_sentence_through_edgetype[aa]['manual']
|
| 149 |
+
tmp_dict = retieve_sentence_through_edgetype[aa]['auto']
|
| 150 |
+
dependencys = list(dependency_sen_dict.keys()) + list(tmp_dict.keys())
|
| 151 |
+
for dependency in dependencys:
|
| 152 |
+
dep_list = dependency.split(' ')
|
| 153 |
+
for sub_dep in dep_list:
|
| 154 |
+
sub_dep_list = sub_dep.split('|')
|
| 155 |
+
assert(len(sub_dep_list) == 3)
|
| 156 |
+
type_set.add(sub_dep_list[1])
|
| 157 |
+
# print('Type:', type_set)
|
| 158 |
+
|
| 159 |
+
if args.action == 'parse':
|
| 160 |
+
# dp_path, sen_list = list(dependency_sen_dict.items())[0]
|
| 161 |
+
# check
|
| 162 |
+
# paper_id, sen_id = sen_list[0]
|
| 163 |
+
# sen = raw_text_sen[paper_id][sen_id]
|
| 164 |
+
# doc = nlp(sen['text'])
|
| 165 |
+
# print(dp_path, '\n')
|
| 166 |
+
# pprint.pprint(sen)
|
| 167 |
+
# print()
|
| 168 |
+
# for token in doc:
|
| 169 |
+
# print((token.head.text, token.text, token.dep_))
|
| 170 |
+
|
| 171 |
+
out = ''
|
| 172 |
+
for k, v_dict in draft.items():
|
| 173 |
+
input = v_dict['in']
|
| 174 |
+
output = v_dict['out']
|
| 175 |
+
if input == '':
|
| 176 |
+
continue
|
| 177 |
+
output = output.replace('\n', ' ')
|
| 178 |
+
doc = nlp(output)
|
| 179 |
+
for sen in doc.sents:
|
| 180 |
+
out += raw_to_format(sen.text) + '\n'
|
| 181 |
+
with open(f'generate_abstract/{args.init_mode}{args.reasonable_rate}_{args.mode}_parsein.txt', 'w') as fl:
|
| 182 |
+
fl.write(out)
|
| 183 |
+
elif args.action == 'extract':
|
| 184 |
+
|
| 185 |
+
# dependency_to_type_id = {}
|
| 186 |
+
# for k, v in Parameters.edge_type_to_id.items():
|
| 187 |
+
# dependency_to_type_id[k] = {}
|
| 188 |
+
# for type in v:
|
| 189 |
+
# LL = list(retieve_sentence_through_edgetype[type]['manual'].keys()) + list(retieve_sentence_through_edgetype[type]['auto'].keys())
|
| 190 |
+
# for dp in LL:
|
| 191 |
+
# dependency_to_type_id[k][dp] = type
|
| 192 |
+
if os.path.exists('generate_abstract/dependency_to_type_id.pickle'):
|
| 193 |
+
with open('generate_abstract/dependency_to_type_id.pickle', 'rb') as fl:
|
| 194 |
+
dependency_to_type_id = pkl.load(fl)
|
| 195 |
+
else:
|
| 196 |
+
dependency_to_type_id = {}
|
| 197 |
+
print('Loading path data ...')
|
| 198 |
+
for k in Parameters.edge_type_to_id.keys():
|
| 199 |
+
start, end = k.split('-')
|
| 200 |
+
dependency_to_type_id[k] = {}
|
| 201 |
+
inner_edge_type_to_id = Parameters.edge_type_to_id[k]
|
| 202 |
+
inner_edge_type_dict = Parameters.edge_type_dict[k]
|
| 203 |
+
cal_manual_num = [0] * len(inner_edge_type_to_id)
|
| 204 |
+
with open('../GNBRdata/part-i-'+start+'-'+end+'-path-theme-distributions.txt', 'r') as fl:
|
| 205 |
+
for i, line in tqdm(list(enumerate(fl.readlines()))):
|
| 206 |
+
tmp = line.split('\t')
|
| 207 |
+
if i == 0:
|
| 208 |
+
head = [tmp[i] for i in range(1, len(tmp), 2)]
|
| 209 |
+
assert ' '.join(head) == ' '.join(inner_edge_type_dict[0])
|
| 210 |
+
continue
|
| 211 |
+
probability = [float(tmp[i]) for i in range(1, len(tmp), 2)]
|
| 212 |
+
flag_list = [int(tmp[i]) for i in range(2, len(tmp), 2)]
|
| 213 |
+
indices = np.where(np.asarray(flag_list) == 1)[0]
|
| 214 |
+
if len(indices) >= 1:
|
| 215 |
+
tmp_p = [cal_manual_num[i] for i in indices]
|
| 216 |
+
p = indices[np.argmin(tmp_p)]
|
| 217 |
+
cal_manual_num[p] += 1
|
| 218 |
+
else:
|
| 219 |
+
p = np.argmax(probability)
|
| 220 |
+
assert tmp[0].lower() not in dependency_to_type_id.keys()
|
| 221 |
+
dependency_to_type_id[k][tmp[0].lower()] = inner_edge_type_to_id[p]
|
| 222 |
+
with open('generate_abstract/dependency_to_type_id.pickle', 'wb') as fl:
|
| 223 |
+
pkl.dump(dependency_to_type_id, fl)
|
| 224 |
+
|
| 225 |
+
# record = []
|
| 226 |
+
# with open(f'generate_abstract/par_parseout.txt', 'r') as fl:
|
| 227 |
+
# Tmp = []
|
| 228 |
+
# tmp = []
|
| 229 |
+
# for i,line in enumerate(fl.readlines()):
|
| 230 |
+
# # print(len(line), line)
|
| 231 |
+
# line = line.replace('\n', '')
|
| 232 |
+
# if len(line) > 1:
|
| 233 |
+
# tmp.append(line)
|
| 234 |
+
# else:
|
| 235 |
+
# Tmp.append(tmp)
|
| 236 |
+
# tmp = []
|
| 237 |
+
# if len(Tmp) == 3:
|
| 238 |
+
# record.append(Tmp)
|
| 239 |
+
# Tmp = []
|
| 240 |
+
|
| 241 |
+
# print(len(record))
|
| 242 |
+
# record_index = 0
|
| 243 |
+
# add = 0
|
| 244 |
+
# Attack = []
|
| 245 |
+
# for ii in range(100):
|
| 246 |
+
|
| 247 |
+
# # input = v_dict['in']
|
| 248 |
+
# # output = v_dict['out']
|
| 249 |
+
# # output = output.replace('\n', ' ')
|
| 250 |
+
# s, r, o = attack_data[ii]
|
| 251 |
+
# dependency_sen_dict = retieve_sentence_through_edgetype[int(r)]['manual']
|
| 252 |
+
|
| 253 |
+
# target_dp = set()
|
| 254 |
+
# for dp_path, sen_list in dependency_sen_dict.items():
|
| 255 |
+
# target_dp.add(dp_path)
|
| 256 |
+
# DP_list = []
|
| 257 |
+
# for _ in range(1):
|
| 258 |
+
# dp_dict = {}
|
| 259 |
+
# data = record[record_index]
|
| 260 |
+
# record_index += 1
|
| 261 |
+
# dp_paths = data[2]
|
| 262 |
+
# nodes_list = []
|
| 263 |
+
# edges_list = []
|
| 264 |
+
# for line in dp_paths:
|
| 265 |
+
# ttp, tmp = line.split('(')
|
| 266 |
+
# assert tmp[-1] == ')'
|
| 267 |
+
# tmp = tmp[:-1]
|
| 268 |
+
# e1, e2 = tmp.split(', ')
|
| 269 |
+
# if not ttp in type_set and ':' in ttp:
|
| 270 |
+
# ttp = ttp.split(':')[0]
|
| 271 |
+
# dp_dict[f'{e1}_x_{e2}'] = [e1, ttp, e2]
|
| 272 |
+
# dp_dict[f'{e2}_x_{e1}'] = [e1, ttp, e2]
|
| 273 |
+
# nodes_list.append(e1)
|
| 274 |
+
# nodes_list.append(e2)
|
| 275 |
+
# edges_list.append((e1, e2))
|
| 276 |
+
# nodes_list = list(set(nodes_list))
|
| 277 |
+
# pure_name = [('-'.join(name.split('-')[:-1])).replace('_', ' ') for name in nodes_list]
|
| 278 |
+
# graph = nx.Graph(edges_list)
|
| 279 |
+
|
| 280 |
+
# type_list = [name_to_type[name] if name in good_name else '' for name in pure_name]
|
| 281 |
+
# # print(type_list)
|
| 282 |
+
# # for i in range(len(type_list)):
|
| 283 |
+
# # print(pure_name[i], type_list[i])
|
| 284 |
+
# for i in range(len(nodes_list)):
|
| 285 |
+
# if type_list[i] != '':
|
| 286 |
+
# for j in range(len(nodes_list)):
|
| 287 |
+
# if i != j and type_list[j] != '':
|
| 288 |
+
# if f'{type_list[i]}-{type_list[j]}' in Parameters.edge_type_to_id.keys():
|
| 289 |
+
# # print(f'{type_list[i]}_{type_list[j]}')
|
| 290 |
+
# ret_path = []
|
| 291 |
+
# sp = nx.shortest_path(graph, source=nodes_list[i], target=nodes_list[j])
|
| 292 |
+
# start = sp[0]
|
| 293 |
+
# end = sp[-1]
|
| 294 |
+
# for k in range(len(sp)-1):
|
| 295 |
+
# e1, ttp, e2 = dp_dict[f'{sp[k]}_x_{sp[k+1]}']
|
| 296 |
+
# if e1 == start:
|
| 297 |
+
# e1 = 'start_entity-x'
|
| 298 |
+
# if e2 == start:
|
| 299 |
+
# e2 = 'start_entity-x'
|
| 300 |
+
# if e1 == end:
|
| 301 |
+
# e1 = 'end_entity-x'
|
| 302 |
+
# if e2 == end:
|
| 303 |
+
# e2 = 'end_entity-x'
|
| 304 |
+
# ret_path.append(f'{"-".join(e1.split("-")[:-1])}|{ttp}|{"-".join(e2.split("-")[:-1])}'.lower())
|
| 305 |
+
# dependency_P = ' '.join(ret_path)
|
| 306 |
+
# DP_list.append((f'{type_list[i]}-{type_list[j]}',
|
| 307 |
+
# name_to_meshid[pure_name[i]],
|
| 308 |
+
# name_to_meshid[pure_name[j]],
|
| 309 |
+
# dependency_P))
|
| 310 |
+
|
| 311 |
+
# boo = False
|
| 312 |
+
# modified_attack = []
|
| 313 |
+
# for k, ss, tt, dp in DP_list:
|
| 314 |
+
# if dp in dependency_to_type_id[k].keys():
|
| 315 |
+
# tp = str(dependency_to_type_id[k][dp])
|
| 316 |
+
# id_ss = str(meshid_to_id[ss])
|
| 317 |
+
# id_tt = str(meshid_to_id[tt])
|
| 318 |
+
# modified_attack.append(f'{id_ss}*{tp}*{id_tt}')
|
| 319 |
+
# if int(dependency_to_type_id[k][dp]) == int(r):
|
| 320 |
+
# # if id_to_meshid[s] == ss and id_to_meshid[o] == tt:
|
| 321 |
+
# boo = True
|
| 322 |
+
# modified_attack = list(set(modified_attack))
|
| 323 |
+
# modified_attack = [k.split('*') for k in modified_attack]
|
| 324 |
+
# if boo:
|
| 325 |
+
# add += 1
|
| 326 |
+
# # else:
|
| 327 |
+
# # print(ii)
|
| 328 |
+
|
| 329 |
+
# # for i in range(len(type_list)):
|
| 330 |
+
# # if type_list[i]:
|
| 331 |
+
# # print(pure_name[i], type_list[i])
|
| 332 |
+
# # for k, ss, tt, dp in DP_list:
|
| 333 |
+
# # print(k, dp)
|
| 334 |
+
# # print(record[record_index - 1])
|
| 335 |
+
# # raise Exception('No!!')
|
| 336 |
+
# Attack.append(modified_attack)
|
| 337 |
+
|
| 338 |
+
record = []
|
| 339 |
+
with open(f'generate_abstract/{args.init_mode}{args.reasonable_rate}_{args.mode}_parseout.txt', 'r') as fl:
|
| 340 |
+
Tmp = []
|
| 341 |
+
tmp = []
|
| 342 |
+
for i,line in enumerate(fl.readlines()):
|
| 343 |
+
# print(len(line), line)
|
| 344 |
+
line = line.replace('\n', '')
|
| 345 |
+
if len(line) > 1:
|
| 346 |
+
tmp.append(line)
|
| 347 |
+
else:
|
| 348 |
+
if len(Tmp) == 2:
|
| 349 |
+
if len(tmp) == 1 and '/' in tmp[0].split(' ')[0]:
|
| 350 |
+
Tmp.append([])
|
| 351 |
+
record.append(Tmp)
|
| 352 |
+
Tmp = []
|
| 353 |
+
Tmp.append(tmp)
|
| 354 |
+
if len(Tmp) == 2 and tmp[0][:5] != '(ROOT':
|
| 355 |
+
print(record[-1][2])
|
| 356 |
+
raise Exception('??')
|
| 357 |
+
tmp = []
|
| 358 |
+
if len(Tmp) == 3:
|
| 359 |
+
record.append(Tmp)
|
| 360 |
+
Tmp = []
|
| 361 |
+
with open(f'generate_abstract/{args.init_mode}{args.reasonable_rate}_{args.mode}_parsein.txt', 'r') as fl:
|
| 362 |
+
parsin = fl.readlines()
|
| 363 |
+
|
| 364 |
+
print('Record len', len(record), 'Parsin len:', len(parsin))
|
| 365 |
+
record_index = 0
|
| 366 |
+
add = 0
|
| 367 |
+
|
| 368 |
+
Attack = []
|
| 369 |
+
for ii, (k, v_dict) in enumerate(tqdm(draft.items())):
|
| 370 |
+
|
| 371 |
+
input = v_dict['in']
|
| 372 |
+
output = v_dict['out']
|
| 373 |
+
output = output.replace('\n', ' ')
|
| 374 |
+
s, r, o = attack_data[ii]
|
| 375 |
+
s = str(s)
|
| 376 |
+
r = str(r)
|
| 377 |
+
o = str(o)
|
| 378 |
+
assert ii == int(k.split('_')[-1])
|
| 379 |
+
|
| 380 |
+
DP_list = []
|
| 381 |
+
if input != '':
|
| 382 |
+
|
| 383 |
+
dependency_sen_dict = retieve_sentence_through_edgetype[int(r)]['manual']
|
| 384 |
+
target_dp = set()
|
| 385 |
+
for dp_path, sen_list in dependency_sen_dict.items():
|
| 386 |
+
target_dp.add(dp_path)
|
| 387 |
+
doc = nlp(output)
|
| 388 |
+
|
| 389 |
+
for sen in doc.sents:
|
| 390 |
+
dp_dict = {}
|
| 391 |
+
if record_index >= len(record):
|
| 392 |
+
break
|
| 393 |
+
data = record[record_index]
|
| 394 |
+
record_index += 1
|
| 395 |
+
dp_paths = data[2]
|
| 396 |
+
nodes_list = []
|
| 397 |
+
edges_list = []
|
| 398 |
+
for line in dp_paths:
|
| 399 |
+
aa = line.split('(')
|
| 400 |
+
if len(aa) == 1:
|
| 401 |
+
print(ii)
|
| 402 |
+
print(sen)
|
| 403 |
+
print(data)
|
| 404 |
+
raise Exception
|
| 405 |
+
ttp, tmp = aa[0], aa[1]
|
| 406 |
+
assert tmp[-1] == ')'
|
| 407 |
+
tmp = tmp[:-1]
|
| 408 |
+
e1, e2 = tmp.split(', ')
|
| 409 |
+
if not ttp in type_set and ':' in ttp:
|
| 410 |
+
ttp = ttp.split(':')[0]
|
| 411 |
+
dp_dict[f'{e1}_x_{e2}'] = [e1, ttp, e2]
|
| 412 |
+
dp_dict[f'{e2}_x_{e1}'] = [e1, ttp, e2]
|
| 413 |
+
nodes_list.append(e1)
|
| 414 |
+
nodes_list.append(e2)
|
| 415 |
+
edges_list.append((e1, e2))
|
| 416 |
+
nodes_list = list(set(nodes_list))
|
| 417 |
+
pure_name = [('-'.join(name.split('-')[:-1])).replace('_', ' ') for name in nodes_list]
|
| 418 |
+
graph = nx.Graph(edges_list)
|
| 419 |
+
|
| 420 |
+
type_list = [name_to_type[name] if name in good_name else '' for name in pure_name]
|
| 421 |
+
# print(type_list)
|
| 422 |
+
for i in range(len(nodes_list)):
|
| 423 |
+
if type_list[i] != '':
|
| 424 |
+
for j in range(len(nodes_list)):
|
| 425 |
+
if i != j and type_list[j] != '':
|
| 426 |
+
if f'{type_list[i]}-{type_list[j]}' in Parameters.edge_type_to_id.keys():
|
| 427 |
+
# print(f'{type_list[i]}_{type_list[j]}')
|
| 428 |
+
ret_path = []
|
| 429 |
+
sp = nx.shortest_path(graph, source=nodes_list[i], target=nodes_list[j])
|
| 430 |
+
start = sp[0]
|
| 431 |
+
end = sp[-1]
|
| 432 |
+
for k in range(len(sp)-1):
|
| 433 |
+
e1, ttp, e2 = dp_dict[f'{sp[k]}_x_{sp[k+1]}']
|
| 434 |
+
if e1 == start:
|
| 435 |
+
e1 = 'start_entity-x'
|
| 436 |
+
if e2 == start:
|
| 437 |
+
e2 = 'start_entity-x'
|
| 438 |
+
if e1 == end:
|
| 439 |
+
e1 = 'end_entity-x'
|
| 440 |
+
if e2 == end:
|
| 441 |
+
e2 = 'end_entity-x'
|
| 442 |
+
ret_path.append(f'{"-".join(e1.split("-")[:-1])}|{ttp}|{"-".join(e2.split("-")[:-1])}'.lower())
|
| 443 |
+
dependency_P = ' '.join(ret_path)
|
| 444 |
+
DP_list.append((f'{type_list[i]}-{type_list[j]}',
|
| 445 |
+
name_to_meshid[pure_name[i]],
|
| 446 |
+
name_to_meshid[pure_name[j]],
|
| 447 |
+
dependency_P))
|
| 448 |
+
|
| 449 |
+
boo = False
|
| 450 |
+
modified_attack = []
|
| 451 |
+
for k, ss, tt, dp in DP_list:
|
| 452 |
+
if dp in dependency_to_type_id[k].keys():
|
| 453 |
+
tp = str(dependency_to_type_id[k][dp])
|
| 454 |
+
id_ss = str(meshid_to_id[ss])
|
| 455 |
+
id_tt = str(meshid_to_id[tt])
|
| 456 |
+
modified_attack.append(f'{id_ss}*{tp}*{id_tt}')
|
| 457 |
+
if int(dependency_to_type_id[k][dp]) == int(r):
|
| 458 |
+
if id_to_meshid[s] == ss and id_to_meshid[o] == tt:
|
| 459 |
+
boo = True
|
| 460 |
+
modified_attack = list(set(modified_attack))
|
| 461 |
+
modified_attack = [k.split('*') for k in modified_attack]
|
| 462 |
+
if boo:
|
| 463 |
+
# print(DP_list)
|
| 464 |
+
add += 1
|
| 465 |
+
Attack.append(modified_attack)
|
| 466 |
+
print(add)
|
| 467 |
+
print('End record_index:', record_index)
|
| 468 |
+
final_Attack = Attack
|
| 469 |
+
print('Len of Attack:', len(Attack))
|
| 470 |
+
with open(modified_attack_path, 'wb') as fl:
|
| 471 |
+
pkl.dump(final_Attack, fl)
|
| 472 |
+
else:
|
| 473 |
+
raise Exception('Wrong action !!')
|
DiseaseAgnostic/edge_to_abstract.py
ADDED
|
@@ -0,0 +1,652 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#%%
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
from torch.autograd import Variable
|
| 5 |
+
from sklearn import metrics
|
| 6 |
+
|
| 7 |
+
import datetime
|
| 8 |
+
from typing import Dict, Tuple, List
|
| 9 |
+
import logging
|
| 10 |
+
import os
|
| 11 |
+
import utils
|
| 12 |
+
import pickle as pkl
|
| 13 |
+
import json
|
| 14 |
+
import torch.backends.cudnn as cudnn
|
| 15 |
+
|
| 16 |
+
from tqdm import tqdm
|
| 17 |
+
|
| 18 |
+
import sys
|
| 19 |
+
sys.path.append("..")
|
| 20 |
+
import Parameters
|
| 21 |
+
|
| 22 |
+
parser = utils.get_argument_parser()
|
| 23 |
+
parser.add_argument('--reasonable-rate', type = float, default=0.7, help = 'The added edge\'s existance rank prob greater than this rate')
|
| 24 |
+
parser.add_argument('--mode', type=str, default='sentence', help='sentence, biogpt or finetune')
|
| 25 |
+
parser.add_argument('--init-mode', type = str, default='random', help = 'How to select target nodes')
|
| 26 |
+
parser.add_argument('--ratio', type = str, default='', help='ratio of the number of changed words')
|
| 27 |
+
args = parser.parse_args()
|
| 28 |
+
args = utils.set_hyperparams(args)
|
| 29 |
+
|
| 30 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 31 |
+
|
| 32 |
+
utils.seed_all(args.seed)
|
| 33 |
+
np.set_printoptions(precision=5)
|
| 34 |
+
cudnn.benchmark = False
|
| 35 |
+
|
| 36 |
+
data_path = '../DiseaseSpecific/processed_data/GNBR'
|
| 37 |
+
target_path = f'processed_data/target_{args.reasonable_rate}{args.init_mode}.pkl'
|
| 38 |
+
attack_path = f'processed_data/attack_edge_{args.model}_{args.reasonable_rate}{args.init_mode}.pkl'
|
| 39 |
+
|
| 40 |
+
# target_data = utils.load_data(target_path)
|
| 41 |
+
with open(target_path, 'rb') as fl:
|
| 42 |
+
Target_node_list = pkl.load(fl)
|
| 43 |
+
with open(attack_path, 'rb') as fl:
|
| 44 |
+
Attack_edge_list = pkl.load(fl)
|
| 45 |
+
attack_data = np.array(Attack_edge_list).reshape(-1, 3)
|
| 46 |
+
# assert target_data.shape == attack_data.shape
|
| 47 |
+
#%%
|
| 48 |
+
|
| 49 |
+
with open('../DiseaseSpecific/processed_data/GNBR/entities_reverse_dict.json') as fl:
|
| 50 |
+
id_to_meshid = json.load(fl)
|
| 51 |
+
with open(Parameters.GNBRfile+'entity_raw_name', 'rb') as fl:
|
| 52 |
+
entity_raw_name = pkl.load(fl)
|
| 53 |
+
with open(Parameters.GNBRfile+'retieve_sentence_through_edgetype', 'rb') as fl:
|
| 54 |
+
retieve_sentence_through_edgetype = pkl.load(fl)
|
| 55 |
+
with open(Parameters.GNBRfile+'raw_text_of_each_sentence', 'rb') as fl:
|
| 56 |
+
raw_text_sen = pkl.load(fl)
|
| 57 |
+
|
| 58 |
+
if args.mode == 'sentence':
|
| 59 |
+
import torch
|
| 60 |
+
from torch.nn.modules.loss import CrossEntropyLoss
|
| 61 |
+
from transformers import AutoTokenizer
|
| 62 |
+
from transformers import BioGptForCausalLM
|
| 63 |
+
criterion = CrossEntropyLoss(reduction="none")
|
| 64 |
+
|
| 65 |
+
print('Generating GPT input ...')
|
| 66 |
+
|
| 67 |
+
tokenizer = AutoTokenizer.from_pretrained('microsoft/biogpt')
|
| 68 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 69 |
+
model = BioGptForCausalLM.from_pretrained('microsoft/biogpt', pad_token_id=tokenizer.eos_token_id)
|
| 70 |
+
model.to(device)
|
| 71 |
+
model.eval()
|
| 72 |
+
GPT_batch_size = 24
|
| 73 |
+
single_sentence = {}
|
| 74 |
+
test_text = []
|
| 75 |
+
test_dp = []
|
| 76 |
+
test_parse = []
|
| 77 |
+
for i, (s, r, o) in enumerate(tqdm(attack_data)):
|
| 78 |
+
|
| 79 |
+
s = str(s)
|
| 80 |
+
r = str(r)
|
| 81 |
+
o = str(o)
|
| 82 |
+
if int(s) != -1:
|
| 83 |
+
|
| 84 |
+
dependency_sen_dict = retieve_sentence_through_edgetype[int(r)]['manual']
|
| 85 |
+
candidate_sen = []
|
| 86 |
+
Dp_path = []
|
| 87 |
+
L = len(dependency_sen_dict.keys())
|
| 88 |
+
bound = 500 // L
|
| 89 |
+
if bound == 0:
|
| 90 |
+
bound = 1
|
| 91 |
+
for dp_path, sen_list in dependency_sen_dict.items():
|
| 92 |
+
if len(sen_list) > bound:
|
| 93 |
+
index = np.random.choice(np.array(range(len(sen_list))), bound, replace=False)
|
| 94 |
+
sen_list = [sen_list[aa] for aa in index]
|
| 95 |
+
candidate_sen += sen_list
|
| 96 |
+
Dp_path += [dp_path] * len(sen_list)
|
| 97 |
+
|
| 98 |
+
text_s = entity_raw_name[id_to_meshid[s]]
|
| 99 |
+
text_o = entity_raw_name[id_to_meshid[o]]
|
| 100 |
+
candidate_text_sen = []
|
| 101 |
+
candidate_ori_sen = []
|
| 102 |
+
candidate_parse_sen = []
|
| 103 |
+
|
| 104 |
+
for paper_id, sen_id in candidate_sen:
|
| 105 |
+
sen = raw_text_sen[paper_id][sen_id]
|
| 106 |
+
text = sen['text']
|
| 107 |
+
candidate_ori_sen.append(text)
|
| 108 |
+
ss = sen['start_formatted']
|
| 109 |
+
oo = sen['end_formatted']
|
| 110 |
+
text = text.replace('-LRB-', '(')
|
| 111 |
+
text = text.replace('-RRB-', ')')
|
| 112 |
+
text = text.replace('-LSB-', '[')
|
| 113 |
+
text = text.replace('-RSB-', ']')
|
| 114 |
+
text = text.replace('-LCB-', '{')
|
| 115 |
+
text = text.replace('-RCB-', '}')
|
| 116 |
+
parse_text = text
|
| 117 |
+
parse_text = parse_text.replace(ss, text_s.replace(' ', '_'))
|
| 118 |
+
parse_text = parse_text.replace(oo, text_o.replace(' ', '_'))
|
| 119 |
+
text = text.replace(ss, text_s)
|
| 120 |
+
text = text.replace(oo, text_o)
|
| 121 |
+
text = text.replace('_', ' ')
|
| 122 |
+
candidate_text_sen.append(text)
|
| 123 |
+
candidate_parse_sen.append(parse_text)
|
| 124 |
+
tokens = tokenizer( candidate_text_sen,
|
| 125 |
+
truncation = True,
|
| 126 |
+
padding = True,
|
| 127 |
+
max_length = 300,
|
| 128 |
+
return_tensors="pt")
|
| 129 |
+
target_ids = tokens['input_ids'].to(device)
|
| 130 |
+
attention_mask = tokens['attention_mask'].to(device)
|
| 131 |
+
|
| 132 |
+
L = len(candidate_text_sen)
|
| 133 |
+
assert L > 0
|
| 134 |
+
ret_log_L = []
|
| 135 |
+
for l in range(0, L, GPT_batch_size):
|
| 136 |
+
R = min(L, l + GPT_batch_size)
|
| 137 |
+
target = target_ids[l:R, :]
|
| 138 |
+
attention = attention_mask[l:R, :]
|
| 139 |
+
outputs = model(input_ids = target,
|
| 140 |
+
attention_mask = attention,
|
| 141 |
+
labels = target)
|
| 142 |
+
logits = outputs.logits
|
| 143 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
| 144 |
+
shift_labels = target[..., 1:].contiguous()
|
| 145 |
+
Loss = criterion(shift_logits.view(-1, shift_logits.shape[-1]), shift_labels.view(-1))
|
| 146 |
+
Loss = Loss.view(-1, shift_logits.shape[1])
|
| 147 |
+
attention = attention[..., 1:].contiguous()
|
| 148 |
+
log_Loss = (torch.mean(Loss * attention.float(), dim = 1) / torch.mean(attention.float(), dim = 1))
|
| 149 |
+
ret_log_L.append(log_Loss.detach())
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
ret_log_L = list(torch.cat(ret_log_L, -1).cpu().numpy())
|
| 153 |
+
sen_score = list(zip(candidate_text_sen, ret_log_L, candidate_ori_sen, Dp_path, candidate_parse_sen))
|
| 154 |
+
sen_score.sort(key = lambda x: x[1])
|
| 155 |
+
test_text.append(sen_score[0][2])
|
| 156 |
+
test_dp.append(sen_score[0][3])
|
| 157 |
+
test_parse.append(sen_score[0][4])
|
| 158 |
+
single_sentence.update({f'{s}_{r}_{o}_{i}': sen_score[0][0]})
|
| 159 |
+
|
| 160 |
+
else:
|
| 161 |
+
single_sentence.update({f'{s}_{r}_{o}_{i}': ''})
|
| 162 |
+
|
| 163 |
+
with open(f'generate_abstract/{args.init_mode}{args.reasonable_rate}_sentence.json', 'w') as fl:
|
| 164 |
+
json.dump(single_sentence, fl, indent=4)
|
| 165 |
+
# with open('generate_abstract/test.txt', 'w') as fl:
|
| 166 |
+
# fl.write('\n'.join(test_text))
|
| 167 |
+
# with open('generate_abstract/dp.txt', 'w') as fl:
|
| 168 |
+
# fl.write('\n'.join(test_dp))
|
| 169 |
+
with open (f'generate_abstract/path/{args.init_mode}{args.reasonable_rate}_path.json', 'w') as fl:
|
| 170 |
+
fl.write('\n'.join(test_dp))
|
| 171 |
+
with open (f'generate_abstract/path/{args.init_mode}{args.reasonable_rate}_temp.json', 'w') as fl:
|
| 172 |
+
fl.write('\n'.join(test_text))
|
| 173 |
+
|
| 174 |
+
elif args.mode == 'biogpt':
|
| 175 |
+
pass
|
| 176 |
+
# from biogpt_generate import GPT_eval
|
| 177 |
+
# import spacy
|
| 178 |
+
|
| 179 |
+
# model = GPT_eval(args.seed)
|
| 180 |
+
|
| 181 |
+
# nlp = spacy.load("en_core_web_sm")
|
| 182 |
+
# with open(f'generate_abstract/{args.target_split}_{args.reasonable_rate}_sentence.json', 'r') as fl:
|
| 183 |
+
# data = json.load(fl)
|
| 184 |
+
|
| 185 |
+
# KK = []
|
| 186 |
+
# input = []
|
| 187 |
+
# for i,(k, v) in enumerate(data.items()):
|
| 188 |
+
# KK.append(k)
|
| 189 |
+
# input.append(v)
|
| 190 |
+
# output = model.eval(input)
|
| 191 |
+
|
| 192 |
+
# ret = {}
|
| 193 |
+
# for i, o in enumerate(output):
|
| 194 |
+
|
| 195 |
+
# o = o.replace('<|abstract|>', '')
|
| 196 |
+
# doc = nlp(o)
|
| 197 |
+
# sen_list = []
|
| 198 |
+
# sen_set = set()
|
| 199 |
+
# for sen in doc.sents:
|
| 200 |
+
# txt = sen.text
|
| 201 |
+
# if not (txt.lower() in sen_set):
|
| 202 |
+
# sen_set.add(txt.lower())
|
| 203 |
+
# sen_list.append(txt)
|
| 204 |
+
# O = ' '.join(sen_list)
|
| 205 |
+
# ret[KK[i]] = {'in' : input[i], 'out' : O}
|
| 206 |
+
|
| 207 |
+
# with open(f'generate_abstract/{args.target_split}_{args.reasonable_rate}_biogpt.json', 'w') as fl:
|
| 208 |
+
# json.dump(ret, fl, indent=4)
|
| 209 |
+
|
| 210 |
+
elif args.mode == 'finetune':
|
| 211 |
+
|
| 212 |
+
import spacy
|
| 213 |
+
import pprint
|
| 214 |
+
from transformers import AutoModel, AutoTokenizer,BartForConditionalGeneration
|
| 215 |
+
|
| 216 |
+
print('Finetuning ...')
|
| 217 |
+
|
| 218 |
+
with open(f'generate_abstract/{args.init_mode}{args.reasonable_rate}_chat.json', 'r') as fl:
|
| 219 |
+
draft = json.load(fl)
|
| 220 |
+
with open (f'generate_abstract/path/{args.init_mode}{args.reasonable_rate}_path.json', 'r') as fl:
|
| 221 |
+
dpath = fl.readlines()
|
| 222 |
+
|
| 223 |
+
nlp = spacy.load("en_core_web_sm")
|
| 224 |
+
if os.path.exists(f'generate_abstract/bioBART/{args.init_mode}{args.reasonable_rate}{args.ratio}_candidates.json'):
|
| 225 |
+
with open(f'generate_abstract/bioBART/{args.init_mode}{args.reasonable_rate}{args.ratio}_candidates.json', 'r') as fl:
|
| 226 |
+
ret_candidates = json.load(fl)
|
| 227 |
+
else:
|
| 228 |
+
|
| 229 |
+
def find_mini_span(vec, words, check_set):
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def cal(text, sset):
|
| 233 |
+
add = 0
|
| 234 |
+
for tt in sset:
|
| 235 |
+
if tt in text:
|
| 236 |
+
add += 1
|
| 237 |
+
return add
|
| 238 |
+
text = ' '.join(words)
|
| 239 |
+
max_add = cal(text, check_set)
|
| 240 |
+
|
| 241 |
+
minn = 10000000
|
| 242 |
+
span = ''
|
| 243 |
+
rc = None
|
| 244 |
+
for i in range(len(vec)):
|
| 245 |
+
if vec[i] == True:
|
| 246 |
+
p = -1
|
| 247 |
+
for j in range(i+1, len(vec)+1):
|
| 248 |
+
if vec[j-1] == True:
|
| 249 |
+
text = ' '.join(words[i:j])
|
| 250 |
+
if cal(text, check_set) == max_add:
|
| 251 |
+
p = j
|
| 252 |
+
break
|
| 253 |
+
if p > 0:
|
| 254 |
+
if (p-i) < minn:
|
| 255 |
+
minn = p-i
|
| 256 |
+
span = ' '.join(words[i:p])
|
| 257 |
+
rc = (i, p)
|
| 258 |
+
if rc:
|
| 259 |
+
for i in range(rc[0], rc[1]):
|
| 260 |
+
vec[i] = True
|
| 261 |
+
return vec, span
|
| 262 |
+
|
| 263 |
+
# def mask_func(tokenized_sen, position):
|
| 264 |
+
|
| 265 |
+
# if len(tokenized_sen) == 0:
|
| 266 |
+
# return []
|
| 267 |
+
# token_list = []
|
| 268 |
+
# # for sen in tokenized_sen:
|
| 269 |
+
# # for token in sen:
|
| 270 |
+
# # token_list.append(token)
|
| 271 |
+
# for sen in tokenized_sen:
|
| 272 |
+
# token_list += sen.text.split(' ')
|
| 273 |
+
# l_p = 0
|
| 274 |
+
# r_p = 1
|
| 275 |
+
# assert position == 'front' or position == 'back'
|
| 276 |
+
# if position == 'back':
|
| 277 |
+
# l_p, r_p = r_p, l_p
|
| 278 |
+
# P = np.linspace(start = l_p, stop = r_p, num = len(token_list))
|
| 279 |
+
# P = (P ** 3) * 0.4
|
| 280 |
+
|
| 281 |
+
# ret_list = []
|
| 282 |
+
# for t, p in zip(token_list, list(P)):
|
| 283 |
+
# if '.' in t or '(' in t or ')' in t or '[' in t or ']' in t:
|
| 284 |
+
# ret_list.append(t)
|
| 285 |
+
# else:
|
| 286 |
+
# if np.random.rand() < p:
|
| 287 |
+
# ret_list.append('<mask>')
|
| 288 |
+
# else:
|
| 289 |
+
# ret_list.append(t)
|
| 290 |
+
# return [' '.join(ret_list)]
|
| 291 |
+
def mask_func(tokenized_sen):
|
| 292 |
+
|
| 293 |
+
if len(tokenized_sen) == 0:
|
| 294 |
+
return []
|
| 295 |
+
token_list = []
|
| 296 |
+
# for sen in tokenized_sen:
|
| 297 |
+
# for token in sen:
|
| 298 |
+
# token_list.append(token)
|
| 299 |
+
for sen in tokenized_sen:
|
| 300 |
+
token_list += sen.text.split(' ')
|
| 301 |
+
if args.ratio == '':
|
| 302 |
+
P = 0.3
|
| 303 |
+
else:
|
| 304 |
+
P = float(args.ratio)
|
| 305 |
+
|
| 306 |
+
ret_list = []
|
| 307 |
+
i = 0
|
| 308 |
+
mask_num = 0
|
| 309 |
+
while i < len(token_list):
|
| 310 |
+
t = token_list[i]
|
| 311 |
+
if '.' in t or '(' in t or ')' in t or '[' in t or ']' in t:
|
| 312 |
+
ret_list.append(t)
|
| 313 |
+
i += 1
|
| 314 |
+
mask_num = 0
|
| 315 |
+
else:
|
| 316 |
+
length = np.random.poisson(3)
|
| 317 |
+
if np.random.rand() < P and length > 0:
|
| 318 |
+
if mask_num < 8:
|
| 319 |
+
ret_list.append('<mask>')
|
| 320 |
+
mask_num += 1
|
| 321 |
+
i += length
|
| 322 |
+
else:
|
| 323 |
+
ret_list.append(t)
|
| 324 |
+
i += 1
|
| 325 |
+
mask_num = 0
|
| 326 |
+
return [' '.join(ret_list)]
|
| 327 |
+
|
| 328 |
+
model = BartForConditionalGeneration.from_pretrained('GanjinZero/biobart-large')
|
| 329 |
+
model.eval()
|
| 330 |
+
model.to(device)
|
| 331 |
+
tokenizer = AutoTokenizer.from_pretrained('GanjinZero/biobart-large')
|
| 332 |
+
|
| 333 |
+
ret_candidates = {}
|
| 334 |
+
dpath_i = 0
|
| 335 |
+
|
| 336 |
+
for i,(k, v) in enumerate(tqdm(draft.items())):
|
| 337 |
+
|
| 338 |
+
input = v['in'].replace('\n', '')
|
| 339 |
+
output = v['out'].replace('\n', '')
|
| 340 |
+
s, r, o = attack_data[i]
|
| 341 |
+
s = str(s)
|
| 342 |
+
o = str(o)
|
| 343 |
+
r = str(r)
|
| 344 |
+
|
| 345 |
+
if int(s) == -1:
|
| 346 |
+
ret_candidates[str(i)] = {'span': '', 'prompt' : '', 'out' : [], 'in': [], 'assist': []}
|
| 347 |
+
continue
|
| 348 |
+
|
| 349 |
+
path_text = dpath[dpath_i].replace('\n', '')
|
| 350 |
+
dpath_i += 1
|
| 351 |
+
text_s = entity_raw_name[id_to_meshid[s]]
|
| 352 |
+
text_o = entity_raw_name[id_to_meshid[o]]
|
| 353 |
+
|
| 354 |
+
doc = nlp(output)
|
| 355 |
+
words= input.split(' ')
|
| 356 |
+
tokenized_sens = [sen for sen in doc.sents]
|
| 357 |
+
sens = np.array([sen.text for sen in doc.sents])
|
| 358 |
+
|
| 359 |
+
checkset = set([text_s, text_o])
|
| 360 |
+
e_entity = set(['start_entity', 'end_entity'])
|
| 361 |
+
for path in path_text.split(' '):
|
| 362 |
+
a, b, c = path.split('|')
|
| 363 |
+
if a not in e_entity:
|
| 364 |
+
checkset.add(a)
|
| 365 |
+
if c not in e_entity:
|
| 366 |
+
checkset.add(c)
|
| 367 |
+
vec = []
|
| 368 |
+
l = 0
|
| 369 |
+
while(l < len(words)):
|
| 370 |
+
bo =False
|
| 371 |
+
for j in range(len(words), l, -1): # reversing is important !!!
|
| 372 |
+
cc = ' '.join(words[l:j])
|
| 373 |
+
if (cc in checkset):
|
| 374 |
+
vec += [True] * (j-l)
|
| 375 |
+
l = j
|
| 376 |
+
bo = True
|
| 377 |
+
break
|
| 378 |
+
if not bo:
|
| 379 |
+
vec.append(False)
|
| 380 |
+
l += 1
|
| 381 |
+
vec, span = find_mini_span(vec, words, checkset)
|
| 382 |
+
# vec = np.vectorize(lambda x: x in checkset)(words)
|
| 383 |
+
vec[-1] = True
|
| 384 |
+
prompt = []
|
| 385 |
+
mask_num = 0
|
| 386 |
+
for j, bo in enumerate(vec):
|
| 387 |
+
if not bo:
|
| 388 |
+
mask_num += 1
|
| 389 |
+
else:
|
| 390 |
+
if mask_num > 0:
|
| 391 |
+
# mask_num = mask_num // 3 # span length ~ poisson distribution (lambda = 3)
|
| 392 |
+
mask_num = max(mask_num, 1)
|
| 393 |
+
mask_num= min(8, mask_num)
|
| 394 |
+
prompt += ['<mask>'] * mask_num
|
| 395 |
+
prompt.append(words[j])
|
| 396 |
+
mask_num = 0
|
| 397 |
+
prompt = ' '.join(prompt)
|
| 398 |
+
Text = []
|
| 399 |
+
Assist = []
|
| 400 |
+
|
| 401 |
+
for j in range(len(sens)):
|
| 402 |
+
Bart_input = list(sens[:j]) + [prompt] +list(sens[j+1:])
|
| 403 |
+
assist = list(sens[:j]) + [input] +list(sens[j+1:])
|
| 404 |
+
Text.append(' '.join(Bart_input))
|
| 405 |
+
Assist.append(' '.join(assist))
|
| 406 |
+
|
| 407 |
+
for j in range(len(sens)):
|
| 408 |
+
Bart_input = mask_func(tokenized_sens[:j]) + [input] + mask_func(tokenized_sens[j+1:])
|
| 409 |
+
assist = list(sens[:j]) + [input] +list(sens[j+1:])
|
| 410 |
+
Text.append(' '.join(Bart_input))
|
| 411 |
+
Assist.append(' '.join(assist))
|
| 412 |
+
|
| 413 |
+
batch_size = len(Text) // 2
|
| 414 |
+
Outs = []
|
| 415 |
+
for l in range(2):
|
| 416 |
+
A = tokenizer(Text[batch_size * l:batch_size * (l+1)],
|
| 417 |
+
truncation = True,
|
| 418 |
+
padding = True,
|
| 419 |
+
max_length = 1024,
|
| 420 |
+
return_tensors="pt")
|
| 421 |
+
input_ids = A['input_ids'].to(device)
|
| 422 |
+
attention_mask = A['attention_mask'].to(device)
|
| 423 |
+
aaid = model.generate(input_ids, num_beams = 5, max_length = 1024)
|
| 424 |
+
outs = tokenizer.batch_decode(aaid, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
| 425 |
+
Outs += outs
|
| 426 |
+
ret_candidates[str(i)] = {'span': span, 'prompt' : prompt, 'out' : Outs, 'in': Text, 'assist': Assist}
|
| 427 |
+
with open(f'generate_abstract/bioBART/{args.init_mode}{args.reasonable_rate}{args.ratio}_candidates.json', 'w') as fl:
|
| 428 |
+
json.dump(ret_candidates, fl, indent = 4)
|
| 429 |
+
|
| 430 |
+
from torch.nn.modules.loss import CrossEntropyLoss
|
| 431 |
+
from transformers import BioGptForCausalLM
|
| 432 |
+
criterion = CrossEntropyLoss(reduction="none")
|
| 433 |
+
|
| 434 |
+
tokenizer = AutoTokenizer.from_pretrained('microsoft/biogpt')
|
| 435 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 436 |
+
model = BioGptForCausalLM.from_pretrained('microsoft/biogpt', pad_token_id=tokenizer.eos_token_id)
|
| 437 |
+
model.to(device)
|
| 438 |
+
model.eval()
|
| 439 |
+
|
| 440 |
+
scored = {}
|
| 441 |
+
ret = {}
|
| 442 |
+
case_study = {}
|
| 443 |
+
p_ret = {}
|
| 444 |
+
add = 0
|
| 445 |
+
dpath_i = 0
|
| 446 |
+
inner_better = 0
|
| 447 |
+
outter_better = 0
|
| 448 |
+
better_than_gpt = 0
|
| 449 |
+
for i,(k, v) in enumerate(tqdm(draft.items())):
|
| 450 |
+
|
| 451 |
+
span = ret_candidates[str(i)]['span']
|
| 452 |
+
prompt = ret_candidates[str(i)]['prompt']
|
| 453 |
+
sen_list = ret_candidates[str(i)]['out']
|
| 454 |
+
BART_in = ret_candidates[str(i)]['in']
|
| 455 |
+
Assist = ret_candidates[str(i)]['assist']
|
| 456 |
+
|
| 457 |
+
s, r, o = attack_data[i]
|
| 458 |
+
s = str(s)
|
| 459 |
+
r = str(r)
|
| 460 |
+
o = str(o)
|
| 461 |
+
|
| 462 |
+
if int(s) == -1:
|
| 463 |
+
ret[k] = {'prompt': '', 'in':'', 'out': ''}
|
| 464 |
+
p_ret[k] = {'prompt': '', 'in':'', 'out': ''}
|
| 465 |
+
continue
|
| 466 |
+
|
| 467 |
+
text_s = entity_raw_name[id_to_meshid[s]]
|
| 468 |
+
text_o = entity_raw_name[id_to_meshid[o]]
|
| 469 |
+
|
| 470 |
+
def process(text):
|
| 471 |
+
|
| 472 |
+
for i in range(ord('A'), ord('Z')+1):
|
| 473 |
+
text = text.replace(f'.{chr(i)}', f'. {chr(i)}')
|
| 474 |
+
return text
|
| 475 |
+
|
| 476 |
+
sen_list = [process(text) for text in sen_list]
|
| 477 |
+
path_text = dpath[dpath_i].replace('\n', '')
|
| 478 |
+
dpath_i += 1
|
| 479 |
+
|
| 480 |
+
checkset = set([text_s, text_o])
|
| 481 |
+
e_entity = set(['start_entity', 'end_entity'])
|
| 482 |
+
for path in path_text.split(' '):
|
| 483 |
+
a, b, c = path.split('|')
|
| 484 |
+
if a not in e_entity:
|
| 485 |
+
checkset.add(a)
|
| 486 |
+
if c not in e_entity:
|
| 487 |
+
checkset.add(c)
|
| 488 |
+
|
| 489 |
+
input = v['in'].replace('\n', '')
|
| 490 |
+
output = v['out'].replace('\n', '')
|
| 491 |
+
|
| 492 |
+
doc = nlp(output)
|
| 493 |
+
gpt_sens = [sen.text for sen in doc.sents]
|
| 494 |
+
assert len(gpt_sens) == len(sen_list) // 2
|
| 495 |
+
|
| 496 |
+
word_sets = []
|
| 497 |
+
for sen in gpt_sens:
|
| 498 |
+
word_sets.append(set(sen.split(' ')))
|
| 499 |
+
|
| 500 |
+
def sen_align(word_sets, modified_word_sets):
|
| 501 |
+
|
| 502 |
+
l = 0
|
| 503 |
+
while(l < len(modified_word_sets)):
|
| 504 |
+
if len(word_sets[l].intersection(modified_word_sets[l])) > len(word_sets[l]) * 0.8:
|
| 505 |
+
l += 1
|
| 506 |
+
else:
|
| 507 |
+
break
|
| 508 |
+
if l == len(modified_word_sets):
|
| 509 |
+
return -1, -1, -1, -1
|
| 510 |
+
r = l + 1
|
| 511 |
+
r1 = None
|
| 512 |
+
r2 = None
|
| 513 |
+
for pos1 in range(r, len(word_sets)):
|
| 514 |
+
for pos2 in range(r, len(modified_word_sets)):
|
| 515 |
+
if len(word_sets[pos1].intersection(modified_word_sets[pos2])) > len(word_sets[pos1]) * 0.8:
|
| 516 |
+
r1 = pos1
|
| 517 |
+
r2 = pos2
|
| 518 |
+
break
|
| 519 |
+
if r1 is not None:
|
| 520 |
+
break
|
| 521 |
+
if r1 is None:
|
| 522 |
+
r1 = len(word_sets)
|
| 523 |
+
r2 = len(modified_word_sets)
|
| 524 |
+
return l, r1, l, r2
|
| 525 |
+
|
| 526 |
+
replace_sen_list = []
|
| 527 |
+
boundary = []
|
| 528 |
+
assert len(sen_list) % 2 == 0
|
| 529 |
+
for j in range(len(sen_list) // 2):
|
| 530 |
+
doc = nlp(sen_list[j])
|
| 531 |
+
sens = [sen.text for sen in doc.sents]
|
| 532 |
+
modified_word_sets = [set(sen.split(' ')) for sen in sens]
|
| 533 |
+
l1, r1, l2, r2 = sen_align(word_sets, modified_word_sets)
|
| 534 |
+
boundary.append((l1, r1, l2, r2))
|
| 535 |
+
if l1 == -1:
|
| 536 |
+
replace_sen_list.append(sen_list[j])
|
| 537 |
+
continue
|
| 538 |
+
check_text = ' '.join(sens[l2: r2])
|
| 539 |
+
replace_sen_list.append(' '.join(gpt_sens[:l1] + [check_text] + gpt_sens[r1:]))
|
| 540 |
+
sen_list = replace_sen_list + sen_list[len(sen_list) // 2:]
|
| 541 |
+
|
| 542 |
+
old_L = len(sen_list)
|
| 543 |
+
sen_list.append(output)
|
| 544 |
+
sen_list += Assist
|
| 545 |
+
tokens = tokenizer( sen_list,
|
| 546 |
+
truncation = True,
|
| 547 |
+
padding = True,
|
| 548 |
+
max_length = 1024,
|
| 549 |
+
return_tensors="pt")
|
| 550 |
+
target_ids = tokens['input_ids'].to(device)
|
| 551 |
+
attention_mask = tokens['attention_mask'].to(device)
|
| 552 |
+
L = len(sen_list)
|
| 553 |
+
ret_log_L = []
|
| 554 |
+
for l in range(0, L, 5):
|
| 555 |
+
R = min(L, l + 5)
|
| 556 |
+
target = target_ids[l:R, :]
|
| 557 |
+
attention = attention_mask[l:R, :]
|
| 558 |
+
outputs = model(input_ids = target,
|
| 559 |
+
attention_mask = attention,
|
| 560 |
+
labels = target)
|
| 561 |
+
logits = outputs.logits
|
| 562 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
| 563 |
+
shift_labels = target[..., 1:].contiguous()
|
| 564 |
+
Loss = criterion(shift_logits.view(-1, shift_logits.shape[-1]), shift_labels.view(-1))
|
| 565 |
+
Loss = Loss.view(-1, shift_logits.shape[1])
|
| 566 |
+
attention = attention[..., 1:].contiguous()
|
| 567 |
+
log_Loss = (torch.mean(Loss * attention.float(), dim = 1) / torch.mean(attention.float(), dim = 1))
|
| 568 |
+
ret_log_L.append(log_Loss.detach())
|
| 569 |
+
log_Loss = torch.cat(ret_log_L, -1).cpu().numpy()
|
| 570 |
+
|
| 571 |
+
real_log_Loss = log_Loss.copy()
|
| 572 |
+
|
| 573 |
+
log_Loss = log_Loss[:old_L]
|
| 574 |
+
# sen_list = sen_list[:old_L]
|
| 575 |
+
|
| 576 |
+
# mini_span should be preserved
|
| 577 |
+
# for j in range(len(log_Loss)):
|
| 578 |
+
# doc = nlp(sen_list[j])
|
| 579 |
+
# sens = [sen.text for sen in doc.sents]
|
| 580 |
+
# Len = len(sen_list)
|
| 581 |
+
# check_text = ' '.join(sens[j : max(0,len(sens) - Len) + j + 1])
|
| 582 |
+
# if span not in check_text:
|
| 583 |
+
# log_Loss[j] += 1
|
| 584 |
+
|
| 585 |
+
p = np.argmin(log_Loss)
|
| 586 |
+
if p < old_L // 2:
|
| 587 |
+
inner_better += 1
|
| 588 |
+
else:
|
| 589 |
+
outter_better += 1
|
| 590 |
+
content = []
|
| 591 |
+
for i in range(len(real_log_Loss)):
|
| 592 |
+
content.append([sen_list[i], str(real_log_Loss[i])])
|
| 593 |
+
scored[k] = {'path':path_text, 'prompt': prompt, 'in':input, 's':text_s, 'o':text_o, 'out': content, 'bound': boundary}
|
| 594 |
+
p_p = p
|
| 595 |
+
# print('Old_L:', old_L)
|
| 596 |
+
|
| 597 |
+
if real_log_Loss[p] > real_log_Loss[p+1+old_L]:
|
| 598 |
+
p_p = p+1+old_L
|
| 599 |
+
if real_log_Loss[p] > real_log_Loss[p+1+old_L]:
|
| 600 |
+
add += 1
|
| 601 |
+
|
| 602 |
+
if real_log_Loss[p] < real_log_Loss[old_L]:
|
| 603 |
+
better_than_gpt += 1
|
| 604 |
+
else:
|
| 605 |
+
if real_log_Loss[p] > real_log_Loss[p+1+old_L]:
|
| 606 |
+
p = p+1+old_L
|
| 607 |
+
# case_study[k] = {'path':path_text, 'entity_0': text_s, 'entity_1': text_o, 'GPT_in': input, 'Prompt': prompt, 'GPT_out': {'text': output, 'perplexity': str(np.exp(real_log_Loss[old_L]))}, 'BART_in': BART_in[p], 'BART_out': {'text': sen_list[p], 'perplexity': str(np.exp(real_log_Loss[p]))}, 'Assist': {'text': Assist[p], 'perplexity': str(np.exp(real_log_Loss[p+1+old_L]))}}
|
| 608 |
+
ret[k] = {'prompt': prompt, 'in':input, 'out': sen_list[p]}
|
| 609 |
+
p_ret[k] = {'prompt': prompt, 'in':input, 'out': sen_list[p_p]}
|
| 610 |
+
print(add)
|
| 611 |
+
print('inner_better:', inner_better)
|
| 612 |
+
print('outter_better:', outter_better)
|
| 613 |
+
print('better_than_gpt:', better_than_gpt)
|
| 614 |
+
print('better_than_replace', add)
|
| 615 |
+
with open(f'generate_abstract/{args.init_mode}{args.reasonable_rate}{args.ratio}_bioBART_finetune.json', 'w') as fl:
|
| 616 |
+
json.dump(ret, fl, indent=4)
|
| 617 |
+
# with open(f'generate_abstract/bioBART/case_{args.target_split}_{args.reasonable_rate}_bioBART_finetune.json', 'w') as fl:
|
| 618 |
+
# json.dump(case_study, fl, indent=4)
|
| 619 |
+
with open(f'generate_abstract/bioBART/{args.init_mode}{args.reasonable_rate}{args.ratio}_scored.json', 'w') as fl:
|
| 620 |
+
json.dump(scored, fl, indent=4)
|
| 621 |
+
with open(f'generate_abstract/bioBART/{args.init_mode}{args.reasonable_rate}{args.ratio}_perplexity.json', 'w') as fl:
|
| 622 |
+
json.dump(p_ret, fl, indent=4)
|
| 623 |
+
|
| 624 |
+
# with open(Parameters.GNBRfile+'original_entity_raw_name', 'rb') as fl:
|
| 625 |
+
# full_entity_raw_name = pkl.load(fl)
|
| 626 |
+
# for k, v in entity_raw_name.items():
|
| 627 |
+
# assert v in full_entity_raw_name[k]
|
| 628 |
+
|
| 629 |
+
# nlp = spacy.load("en_core_web_sm")
|
| 630 |
+
# type_set = set()
|
| 631 |
+
# for aa in range(36):
|
| 632 |
+
# dependency_sen_dict = retieve_sentence_through_edgetype[aa]['manual']
|
| 633 |
+
# tmp_dict = retieve_sentence_through_edgetype[aa]['auto']
|
| 634 |
+
# dependencys = list(dependency_sen_dict.keys()) + list(tmp_dict.keys())
|
| 635 |
+
# for dependency in dependencys:
|
| 636 |
+
# dep_list = dependency.split(' ')
|
| 637 |
+
# for sub_dep in dep_list:
|
| 638 |
+
# sub_dep_list = sub_dep.split('|')
|
| 639 |
+
# assert(len(sub_dep_list) == 3)
|
| 640 |
+
# type_set.add(sub_dep_list[1])
|
| 641 |
+
|
| 642 |
+
# fine_dict = {}
|
| 643 |
+
# for k, v_dict in draft.items():
|
| 644 |
+
|
| 645 |
+
# input = v_dict['in']
|
| 646 |
+
# output = v_dict['out']
|
| 647 |
+
# fine_dict[k] = {'in':input, 'out': input + ' ' + output}
|
| 648 |
+
|
| 649 |
+
# with open(f'generate_abstract/{args.target_split}_{args.reasonable_rate}_sentence_finetune.json', 'w') as fl:
|
| 650 |
+
# json.dump(fine_dict, fl, indent=4)
|
| 651 |
+
else:
|
| 652 |
+
raise Exception('Wrong mode !!')
|
DiseaseAgnostic/evaluation.py
ADDED
|
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#%%
|
| 2 |
+
import logging
|
| 3 |
+
from symbol import parameters
|
| 4 |
+
from textwrap import indent
|
| 5 |
+
import os
|
| 6 |
+
import tempfile
|
| 7 |
+
import sys
|
| 8 |
+
from matplotlib import collections
|
| 9 |
+
import pandas as pd
|
| 10 |
+
import json
|
| 11 |
+
from glob import glob
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
import numpy as np
|
| 14 |
+
from pprint import pprint
|
| 15 |
+
import torch
|
| 16 |
+
import pickle as pkl
|
| 17 |
+
from collections import Counter
|
| 18 |
+
# print(dir(collections))
|
| 19 |
+
import networkx as nx
|
| 20 |
+
from collections import Counter
|
| 21 |
+
import utils
|
| 22 |
+
from torch.nn import functional as F
|
| 23 |
+
sys.path.append("..")
|
| 24 |
+
import Parameters
|
| 25 |
+
from DiseaseSpecific.attack import calculate_edge_bound, get_model_loss_without_softmax
|
| 26 |
+
|
| 27 |
+
#%%
|
| 28 |
+
def load_data(file_name):
|
| 29 |
+
df = pd.read_csv(file_name, sep='\t', header=None, names=None, dtype=str)
|
| 30 |
+
df = df.drop_duplicates()
|
| 31 |
+
return df.values
|
| 32 |
+
|
| 33 |
+
parser = utils.get_argument_parser()
|
| 34 |
+
parser.add_argument('--reasonable-rate', type = float, default=0.7, help = 'The added edge\'s existance rank prob greater than this rate')
|
| 35 |
+
parser.add_argument('--mode', type = str, default='', help = ' "" or chat or bioBART')
|
| 36 |
+
parser.add_argument('--init-mode', type = str, default='random', help = 'How to select target nodes') # 'single' for case study
|
| 37 |
+
parser.add_argument('--added-edge-num', type = str, default = '', help = 'Added edge num')
|
| 38 |
+
|
| 39 |
+
args = parser.parse_args()
|
| 40 |
+
args = utils.set_hyperparams(args)
|
| 41 |
+
utils.seed_all(args.seed)
|
| 42 |
+
graph_edge_path = '../DiseaseSpecific/processed_data/GNBR/all.txt'
|
| 43 |
+
idtomeshid_path = '../DiseaseSpecific/processed_data/GNBR/entities_reverse_dict.json'
|
| 44 |
+
model_path = f'../DiseaseSpecific/saved_models/GNBR_{args.model}_128_0.2_0.3_0.3.model'
|
| 45 |
+
data_path = '../DiseaseSpecific/processed_data/GNBR'
|
| 46 |
+
target_path = f'processed_data/target_{args.reasonable_rate}{args.init_mode}.pkl'
|
| 47 |
+
attack_path = f'processed_data/attack_edge_{args.model}_{args.reasonable_rate}{args.init_mode}{args.added_edge_num}{args.mode}.pkl'
|
| 48 |
+
|
| 49 |
+
with open(Parameters.GNBRfile+'original_entity_raw_name', 'rb') as fl:
|
| 50 |
+
full_entity_raw_name = pkl.load(fl)
|
| 51 |
+
|
| 52 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 53 |
+
|
| 54 |
+
# device = torch.device("cpu")
|
| 55 |
+
|
| 56 |
+
args.device = device
|
| 57 |
+
n_ent, n_rel, ent_to_id, rel_to_id = utils.generate_dicts(data_path)
|
| 58 |
+
model = utils.load_model(model_path, args, n_ent, n_rel, args.device)
|
| 59 |
+
|
| 60 |
+
graph_edge = utils.load_data(graph_edge_path)
|
| 61 |
+
with open(idtomeshid_path, 'r') as fl:
|
| 62 |
+
idtomeshid = json.load(fl)
|
| 63 |
+
print(graph_edge.shape, len(idtomeshid))
|
| 64 |
+
|
| 65 |
+
divide_bound, data_mean, data_std = calculate_edge_bound(graph_edge, model, args.device, n_ent)
|
| 66 |
+
print('Defender ...')
|
| 67 |
+
print(divide_bound, data_mean, data_std)
|
| 68 |
+
|
| 69 |
+
meshids = list(idtomeshid.values())
|
| 70 |
+
cal = {
|
| 71 |
+
'chemical' : 0,
|
| 72 |
+
'disease' : 0,
|
| 73 |
+
'gene' : 0
|
| 74 |
+
}
|
| 75 |
+
for meshid in meshids:
|
| 76 |
+
cal[meshid.split('_')[0]] += 1
|
| 77 |
+
# pprint(cal)
|
| 78 |
+
|
| 79 |
+
def check_reasonable(s, r, o):
|
| 80 |
+
|
| 81 |
+
train_trip = np.asarray([[s, r, o]])
|
| 82 |
+
train_trip = torch.from_numpy(train_trip.astype('int64')).to(device)
|
| 83 |
+
edge_loss = get_model_loss_without_softmax(train_trip, model, device).squeeze()
|
| 84 |
+
# edge_losse_log_prob = torch.log(F.softmax(-edge_loss, dim = -1))
|
| 85 |
+
|
| 86 |
+
edge_loss = edge_loss.item()
|
| 87 |
+
edge_loss = (edge_loss - data_mean) / data_std
|
| 88 |
+
edge_losses_prob = 1 / ( 1 + np.exp(edge_loss - divide_bound) )
|
| 89 |
+
bound = 1 - args.reasonable_rate
|
| 90 |
+
|
| 91 |
+
return (edge_losses_prob > bound), edge_losses_prob
|
| 92 |
+
|
| 93 |
+
edgeid_to_edgetype = {}
|
| 94 |
+
edgeid_to_reversemask = {}
|
| 95 |
+
for k, id_list in Parameters.edge_type_to_id.items():
|
| 96 |
+
for iid, mask in zip(id_list, Parameters.reverse_mask[k]):
|
| 97 |
+
edgeid_to_edgetype[str(iid)] = k
|
| 98 |
+
edgeid_to_reversemask[str(iid)] = mask
|
| 99 |
+
|
| 100 |
+
with open(target_path, 'rb') as fl:
|
| 101 |
+
Target_node_list = pkl.load(fl)
|
| 102 |
+
with open(attack_path, 'rb') as fl:
|
| 103 |
+
Attack_edge_list = pkl.load(fl)
|
| 104 |
+
with open(Parameters.UMLSfile+'drug_term', 'rb') as fl:
|
| 105 |
+
drug_term = pkl.load(fl)
|
| 106 |
+
with open(Parameters.GNBRfile+'entity_raw_name', 'rb') as fl:
|
| 107 |
+
entity_raw_name = pkl.load(fl)
|
| 108 |
+
drug_meshid = []
|
| 109 |
+
for meshid, nm in entity_raw_name.items():
|
| 110 |
+
if nm.lower() in drug_term and meshid.split('_')[0] == 'chemical':
|
| 111 |
+
drug_meshid.append(meshid)
|
| 112 |
+
drug_meshid = set(drug_meshid)
|
| 113 |
+
|
| 114 |
+
if args.init_mode == 'single':
|
| 115 |
+
name_list = []
|
| 116 |
+
for target in Target_node_list:
|
| 117 |
+
name = entity_raw_name[idtomeshid[str(target)]]
|
| 118 |
+
name_list.append(name)
|
| 119 |
+
with open(f'results/name_list_{args.reasonable_rate}{args.init_mode}.txt', 'w') as fl:
|
| 120 |
+
fl.write('\n'.join(name_list))
|
| 121 |
+
# print(Target_node_list)
|
| 122 |
+
# # print(Attack_edge_list)
|
| 123 |
+
# addset = set()
|
| 124 |
+
# if args.added_edge_num == 1:
|
| 125 |
+
# for edge in Attack_edge_list:
|
| 126 |
+
# addset.add(edge[2])
|
| 127 |
+
# else:
|
| 128 |
+
# for edge_list in Attack_edge_list:
|
| 129 |
+
# for edge in edge_list:
|
| 130 |
+
# addset.add(edge[2])
|
| 131 |
+
# print(addset)
|
| 132 |
+
# print(len(addset))
|
| 133 |
+
# typeset = set()
|
| 134 |
+
# for iid in addset:
|
| 135 |
+
# typeset.add(idtomeshid[str(iid)].split('_')[0])
|
| 136 |
+
# print(typeset)
|
| 137 |
+
# raise Exception('done')
|
| 138 |
+
|
| 139 |
+
if args.init_mode == 'single':
|
| 140 |
+
Target_node_list = [[Target_node_list[i]] for i in range(len(Target_node_list))]
|
| 141 |
+
Attack_edge_list = [[Attack_edge_list[i]] for i in range(len(Attack_edge_list))]
|
| 142 |
+
else:
|
| 143 |
+
print(len(Attack_edge_list), len(Target_node_list))
|
| 144 |
+
tmp_target_node_list = []
|
| 145 |
+
tmp_attack_edge_list = []
|
| 146 |
+
for l in range(0,len(Target_node_list), 50):
|
| 147 |
+
r = min(l+50, len(Target_node_list))
|
| 148 |
+
tmp_target_node_list.append(Target_node_list[l:r])
|
| 149 |
+
tmp_attack_edge_list.append(Attack_edge_list[l:r])
|
| 150 |
+
Target_node_list = tmp_target_node_list
|
| 151 |
+
Attack_edge_list = tmp_attack_edge_list
|
| 152 |
+
|
| 153 |
+
# for i, init_p in enumerate([0.1, 0.3, 0.5, 0.7, 0.9]):
|
| 154 |
+
|
| 155 |
+
# target_node_list = Target_node_list[i]
|
| 156 |
+
# attack_edge_list = Attack_edge_list[i]
|
| 157 |
+
Init = []
|
| 158 |
+
After = []
|
| 159 |
+
# final_init = []
|
| 160 |
+
# final_after = []
|
| 161 |
+
for i, (target_node_list, attack_edge_list) in enumerate(zip(Target_node_list, Attack_edge_list)):
|
| 162 |
+
|
| 163 |
+
G = nx.DiGraph()
|
| 164 |
+
for s, r, o in graph_edge:
|
| 165 |
+
assert idtomeshid[s].split('_')[0] == edgeid_to_edgetype[r].split('-')[0]
|
| 166 |
+
if edgeid_to_reversemask[r] == 1:
|
| 167 |
+
G.add_edge(int(o), int(s))
|
| 168 |
+
else:
|
| 169 |
+
G.add_edge(int(s), int(o))
|
| 170 |
+
|
| 171 |
+
pagerank_value_1 = nx.pagerank(G, max_iter = 200, tol=1.0e-7)
|
| 172 |
+
|
| 173 |
+
for target, attack_list in tqdm(list(zip(target_node_list, attack_edge_list))):
|
| 174 |
+
pr = list(pagerank_value_1.items())
|
| 175 |
+
pr.sort(key = lambda x: x[1])
|
| 176 |
+
list_iid = []
|
| 177 |
+
for iid, score in pr:
|
| 178 |
+
tp = idtomeshid[str(iid)].split('_')[0]
|
| 179 |
+
if tp == 'chemical':
|
| 180 |
+
# if idtomeshid[str(iid)] in drug_meshid:
|
| 181 |
+
list_iid.append(iid)
|
| 182 |
+
init_rank = len(list_iid) - list_iid.index(target)
|
| 183 |
+
# init_rank = 1 - list_iid.index(target) / len(list_iid)
|
| 184 |
+
Init.append(init_rank)
|
| 185 |
+
|
| 186 |
+
for target, attack_list in tqdm(list(zip(target_node_list, attack_edge_list))):
|
| 187 |
+
|
| 188 |
+
if args.mode == '' and (args.added_edge_num == '' or int(args.added_edge_num) == 1):
|
| 189 |
+
if int(attack_list[0]) == -1:
|
| 190 |
+
attack_list = []
|
| 191 |
+
else:
|
| 192 |
+
attack_list = [attack_list]
|
| 193 |
+
if len(attack_list) > 0:
|
| 194 |
+
for s, r, o in attack_list:
|
| 195 |
+
bo, prob = check_reasonable(s, r, o)
|
| 196 |
+
if bo:
|
| 197 |
+
if edgeid_to_reversemask[str(r)] == 1:
|
| 198 |
+
G.add_edge(int(o), int(s))
|
| 199 |
+
else:
|
| 200 |
+
G.add_edge(int(s), int(o))
|
| 201 |
+
pagerank_value_1 = nx.pagerank(G, max_iter = 200, tol=1.0e-7)
|
| 202 |
+
for target, attack_list in tqdm(list(zip(target_node_list, attack_edge_list))):
|
| 203 |
+
pr = list(pagerank_value_1.items())
|
| 204 |
+
pr.sort(key = lambda x: x[1])
|
| 205 |
+
list_iid = []
|
| 206 |
+
for iid, score in pr:
|
| 207 |
+
tp = idtomeshid[str(iid)].split('_')[0]
|
| 208 |
+
if tp == 'chemical':
|
| 209 |
+
# if idtomeshid[str(iid)] in drug_meshid:
|
| 210 |
+
list_iid.append(iid)
|
| 211 |
+
after_rank = len(list_iid) - list_iid.index(target)
|
| 212 |
+
# after_rank = 1 - list_iid.index(target) / len(list_iid)
|
| 213 |
+
After.append(after_rank)
|
| 214 |
+
with open(f'results/Init_{args.reasonable_rate}{args.init_mode}.pkl', 'wb') as fl:
|
| 215 |
+
pkl.dump(Init, fl)
|
| 216 |
+
with open(f'results/After_{args.model}_{args.reasonable_rate}{args.init_mode}{args.added_edge_num}{args.mode}.pkl', 'wb') as fl:
|
| 217 |
+
pkl.dump(After, fl)
|
| 218 |
+
print(np.mean(Init), np.std(Init))
|
| 219 |
+
print(np.mean(After), np.std(After))
|
DiseaseAgnostic/generate_target_and_attack.py
ADDED
|
@@ -0,0 +1,371 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#%%
|
| 2 |
+
import logging
|
| 3 |
+
from symbol import parameters
|
| 4 |
+
from textwrap import indent
|
| 5 |
+
import os
|
| 6 |
+
import tempfile
|
| 7 |
+
import sys
|
| 8 |
+
from matplotlib import collections
|
| 9 |
+
import pandas as pd
|
| 10 |
+
import json
|
| 11 |
+
from glob import glob
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
import numpy as np
|
| 14 |
+
from pprint import pprint
|
| 15 |
+
import torch
|
| 16 |
+
import pickle as pkl
|
| 17 |
+
from collections import Counter
|
| 18 |
+
# print(dir(collections))
|
| 19 |
+
import networkx as nx
|
| 20 |
+
from collections import Counter
|
| 21 |
+
import utils
|
| 22 |
+
from torch.nn import functional as F
|
| 23 |
+
sys.path.append("..")
|
| 24 |
+
import Parameters
|
| 25 |
+
from DiseaseSpecific.attack import calculate_edge_bound, get_model_loss_without_softmax
|
| 26 |
+
|
| 27 |
+
#%%
|
| 28 |
+
def load_data(file_name):
|
| 29 |
+
df = pd.read_csv(file_name, sep='\t', header=None, names=None, dtype=str)
|
| 30 |
+
df = df.drop_duplicates()
|
| 31 |
+
return df.values
|
| 32 |
+
|
| 33 |
+
parser = utils.get_argument_parser()
|
| 34 |
+
parser.add_argument('--reasonable-rate', type = float, default=0.7, help = 'The added edge\'s existance rank prob greater than this rate')
|
| 35 |
+
parser.add_argument('--init-mode', type = str, default='single', help = 'How to select target nodes') # 'single' for case study
|
| 36 |
+
parser.add_argument('--added-edge-num', type = str, default = '', help = 'Added edge num')
|
| 37 |
+
|
| 38 |
+
args = parser.parse_args()
|
| 39 |
+
args = utils.set_hyperparams(args)
|
| 40 |
+
utils.seed_all(args.seed)
|
| 41 |
+
graph_edge_path = '../DiseaseSpecific/processed_data/GNBR/all.txt'
|
| 42 |
+
idtomeshid_path = '../DiseaseSpecific/processed_data/GNBR/entities_reverse_dict.json'
|
| 43 |
+
model_path = f'../DiseaseSpecific/saved_models/GNBR_{args.model}_128_0.2_0.3_0.3.model'
|
| 44 |
+
data_path = '../DiseaseSpecific/processed_data/GNBR'
|
| 45 |
+
with open(Parameters.GNBRfile+'original_entity_raw_name', 'rb') as fl:
|
| 46 |
+
full_entity_raw_name = pkl.load(fl)
|
| 47 |
+
|
| 48 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 49 |
+
args.device = device
|
| 50 |
+
n_ent, n_rel, ent_to_id, rel_to_id = utils.generate_dicts(data_path)
|
| 51 |
+
model = utils.load_model(model_path, args, n_ent, n_rel, args.device)
|
| 52 |
+
print(device)
|
| 53 |
+
|
| 54 |
+
graph_edge = utils.load_data(graph_edge_path)
|
| 55 |
+
with open(idtomeshid_path, 'r') as fl:
|
| 56 |
+
idtomeshid = json.load(fl)
|
| 57 |
+
print(graph_edge.shape, len(idtomeshid))
|
| 58 |
+
|
| 59 |
+
divide_bound, data_mean, data_std = calculate_edge_bound(graph_edge, model, args.device, n_ent)
|
| 60 |
+
print('Defender ...')
|
| 61 |
+
print(divide_bound, data_mean, data_std)
|
| 62 |
+
|
| 63 |
+
meshids = list(idtomeshid.values())
|
| 64 |
+
cal = {
|
| 65 |
+
'chemical' : 0,
|
| 66 |
+
'disease' : 0,
|
| 67 |
+
'gene' : 0
|
| 68 |
+
}
|
| 69 |
+
for meshid in meshids:
|
| 70 |
+
cal[meshid.split('_')[0]] += 1
|
| 71 |
+
# pprint(cal)
|
| 72 |
+
|
| 73 |
+
def check_reasonable(s, r, o):
|
| 74 |
+
|
| 75 |
+
train_trip = np.asarray([[s, r, o]])
|
| 76 |
+
train_trip = torch.from_numpy(train_trip.astype('int64')).to(device)
|
| 77 |
+
edge_loss = get_model_loss_without_softmax(train_trip, model, device).squeeze()
|
| 78 |
+
# edge_losse_log_prob = torch.log(F.softmax(-edge_loss, dim = -1))
|
| 79 |
+
|
| 80 |
+
edge_loss = edge_loss.item()
|
| 81 |
+
edge_loss = (edge_loss - data_mean) / data_std
|
| 82 |
+
edge_losses_prob = 1 / ( 1 + np.exp(edge_loss - divide_bound) )
|
| 83 |
+
bound = 1 - args.reasonable_rate
|
| 84 |
+
|
| 85 |
+
return (edge_losses_prob > bound), edge_losses_prob
|
| 86 |
+
|
| 87 |
+
edgeid_to_edgetype = {}
|
| 88 |
+
edgeid_to_reversemask = {}
|
| 89 |
+
for k, id_list in Parameters.edge_type_to_id.items():
|
| 90 |
+
for iid, mask in zip(id_list, Parameters.reverse_mask[k]):
|
| 91 |
+
edgeid_to_edgetype[str(iid)] = k
|
| 92 |
+
edgeid_to_reversemask[str(iid)] = mask
|
| 93 |
+
reverse_tot = 0
|
| 94 |
+
G = nx.DiGraph()
|
| 95 |
+
for s, r, o in graph_edge:
|
| 96 |
+
assert idtomeshid[s].split('_')[0] == edgeid_to_edgetype[r].split('-')[0]
|
| 97 |
+
if edgeid_to_reversemask[r] == 1:
|
| 98 |
+
reverse_tot += 1
|
| 99 |
+
G.add_edge(int(o), int(s))
|
| 100 |
+
else:
|
| 101 |
+
G.add_edge(int(s), int(o))
|
| 102 |
+
# print(reverse_tot)
|
| 103 |
+
print('Edge num:', G.number_of_edges(), 'Node num:', G.number_of_nodes())
|
| 104 |
+
pagerank_value_1 = nx.pagerank(G, max_iter = 200, tol=1.0e-7)
|
| 105 |
+
|
| 106 |
+
#%%
|
| 107 |
+
with open(Parameters.UMLSfile+'drug_term', 'rb') as fl:
|
| 108 |
+
drug_term = pkl.load(fl)
|
| 109 |
+
with open(Parameters.GNBRfile+'entity_raw_name', 'rb') as fl:
|
| 110 |
+
entity_raw_name = pkl.load(fl)
|
| 111 |
+
drug_meshid = []
|
| 112 |
+
for meshid, nm in entity_raw_name.items():
|
| 113 |
+
if nm.lower() in drug_term and meshid.split('_')[0] == 'chemical':
|
| 114 |
+
drug_meshid.append(meshid)
|
| 115 |
+
drug_meshid = set(drug_meshid)
|
| 116 |
+
pr = list(pagerank_value_1.items())
|
| 117 |
+
pr.sort(key = lambda x: x[1])
|
| 118 |
+
sorted_rank = { 'chemical' : [],
|
| 119 |
+
'gene' : [],
|
| 120 |
+
'disease': [],
|
| 121 |
+
'merged' : []}
|
| 122 |
+
for iid, score in pr:
|
| 123 |
+
tp = idtomeshid[str(iid)].split('_')[0]
|
| 124 |
+
if tp == 'chemical':
|
| 125 |
+
if idtomeshid[str(iid)] in drug_meshid:
|
| 126 |
+
sorted_rank[tp].append((iid, score))
|
| 127 |
+
else:
|
| 128 |
+
sorted_rank[tp].append((iid, score))
|
| 129 |
+
sorted_rank['merged'].append((iid, score))
|
| 130 |
+
llen = len(sorted_rank['merged'])
|
| 131 |
+
sorted_rank['merged'] = sorted_rank['merged'][llen * 3 // 4 : ]
|
| 132 |
+
print(len(sorted_rank['chemical']))
|
| 133 |
+
print(len(sorted_rank['gene']), len(sorted_rank['disease']), len(sorted_rank['merged']))
|
| 134 |
+
|
| 135 |
+
#%%
|
| 136 |
+
Target_node_list = []
|
| 137 |
+
Attack_edge_list = []
|
| 138 |
+
if args.init_mode == '':
|
| 139 |
+
|
| 140 |
+
if args.added_edge_num != '' and args.added_edge_num != '1':
|
| 141 |
+
raise Exception('added_edge_num must be 1 when init_mode=='' ')
|
| 142 |
+
for init_p in [0.1, 0.3, 0.5, 0.7, 0.9]:
|
| 143 |
+
|
| 144 |
+
p = len(sorted_rank['chemical']) * init_p
|
| 145 |
+
print('Init p:', init_p)
|
| 146 |
+
target_node_list = []
|
| 147 |
+
attack_edge_list = []
|
| 148 |
+
num_max_eq = 0
|
| 149 |
+
mean_rank_of_total_max = 0
|
| 150 |
+
for pp in tqdm(range(int(p)-10, int(p)+10)):
|
| 151 |
+
target = sorted_rank['chemical'][pp][0]
|
| 152 |
+
target_node_list.append(target)
|
| 153 |
+
|
| 154 |
+
candidate_list = []
|
| 155 |
+
score_list = []
|
| 156 |
+
loss_list = []
|
| 157 |
+
for iid, score in sorted_rank['merged']:
|
| 158 |
+
a = G.number_of_edges(iid, target) + 1
|
| 159 |
+
if a != 1:
|
| 160 |
+
continue
|
| 161 |
+
b = G.out_degree(iid) + 1
|
| 162 |
+
tp = idtomeshid[str(iid)].split('_')[0]
|
| 163 |
+
edge_losses = []
|
| 164 |
+
r_list = []
|
| 165 |
+
for r in range(len(edgeid_to_edgetype)):
|
| 166 |
+
r_tp = edgeid_to_edgetype[str(r)]
|
| 167 |
+
if (edgeid_to_reversemask[str(r)] == 0 and r_tp.split('-')[0] == tp and r_tp.split('-')[1] == 'chemical'):
|
| 168 |
+
train_trip = np.array([[iid, r, target]])
|
| 169 |
+
train_trip = torch.from_numpy(train_trip.astype('int64')).to(device)
|
| 170 |
+
edge_loss = get_model_loss_without_softmax(train_trip, model, device).squeeze()
|
| 171 |
+
edge_losses.append(edge_loss.unsqueeze(0).detach())
|
| 172 |
+
r_list.append(r)
|
| 173 |
+
elif(edgeid_to_reversemask[str(r)] == 1 and r_tp.split('-')[0] == 'chemical' and r_tp.split('-')[1] == tp):
|
| 174 |
+
train_trip = np.array([[iid, r, target]]) # add batch dim
|
| 175 |
+
train_trip = torch.from_numpy(train_trip.astype('int64')).to(device)
|
| 176 |
+
edge_loss = get_model_loss_without_softmax(train_trip, model, device).squeeze()
|
| 177 |
+
edge_losses.append(edge_loss.unsqueeze(0).detach())
|
| 178 |
+
r_list.append(r)
|
| 179 |
+
if len(edge_losses)==0:
|
| 180 |
+
continue
|
| 181 |
+
min_index = torch.argmin(torch.cat(edge_losses, dim = 0))
|
| 182 |
+
r = r_list[min_index]
|
| 183 |
+
r_tp = edgeid_to_edgetype[str(r)]
|
| 184 |
+
|
| 185 |
+
if (edgeid_to_reversemask[str(r)] == 0):
|
| 186 |
+
bo, prob = check_reasonable(iid, r, target)
|
| 187 |
+
if bo:
|
| 188 |
+
candidate_list.append((iid, r, target))
|
| 189 |
+
score_list.append(score * a / b)
|
| 190 |
+
loss_list.append(edge_losses[min_index].item())
|
| 191 |
+
if (edgeid_to_reversemask[str(r)] == 1):
|
| 192 |
+
bo, prob = check_reasonable(target, r, iid)
|
| 193 |
+
if bo:
|
| 194 |
+
candidate_list.append((target, r, iid))
|
| 195 |
+
score_list.append(score * a / b)
|
| 196 |
+
loss_list.append(edge_losses[min_index].item())
|
| 197 |
+
|
| 198 |
+
if len(candidate_list) == 0:
|
| 199 |
+
attack_edge_list.append((-1, -1, -1))
|
| 200 |
+
continue
|
| 201 |
+
norm_score = np.array(score_list) / np.sum(score_list)
|
| 202 |
+
norm_loss = np.exp(-np.array(loss_list)) / np.sum(np.exp(-np.array(loss_list)))
|
| 203 |
+
|
| 204 |
+
total_score = norm_score * norm_loss
|
| 205 |
+
max_index = np.argmax(total_score)
|
| 206 |
+
attack_edge_list.append(candidate_list[max_index])
|
| 207 |
+
|
| 208 |
+
score_max_index = np.argmax(norm_score)
|
| 209 |
+
if score_max_index == max_index:
|
| 210 |
+
num_max_eq += 1
|
| 211 |
+
|
| 212 |
+
score_index_list = list(zip(list(range(len(norm_score))), norm_score))
|
| 213 |
+
score_index_list.sort(key = lambda x: x[1], reverse = True)
|
| 214 |
+
max_index_in_score = score_index_list.index((max_index, norm_score[max_index]))
|
| 215 |
+
mean_rank_of_total_max += max_index_in_score / len(norm_score)
|
| 216 |
+
print('num_max_eq:', num_max_eq)
|
| 217 |
+
print('mean_rank_of_total_max:', mean_rank_of_total_max / 20)
|
| 218 |
+
Target_node_list.append(target_node_list)
|
| 219 |
+
Attack_edge_list.append(attack_edge_list)
|
| 220 |
+
else:
|
| 221 |
+
assert args.init_mode == 'random' or args.init_mode == 'single'
|
| 222 |
+
print(f'Init mode : {args.init_mode}')
|
| 223 |
+
utils.seed_all(args.seed)
|
| 224 |
+
|
| 225 |
+
if args.init_mode == 'random':
|
| 226 |
+
index = np.random.choice(len(sorted_rank['chemical']), 400, replace = False)
|
| 227 |
+
else:
|
| 228 |
+
# index = [5807, 6314, 5799, 5831, 3954, 5654, 5649, 5624, 2412, 2407]
|
| 229 |
+
|
| 230 |
+
index = np.random.choice(len(sorted_rank['chemical']), 400, replace = False)
|
| 231 |
+
with open(f'../pagerank/results/After_distmult_0.7random10.pkl', 'rb') as fl:
|
| 232 |
+
edge = pkl.load(fl)
|
| 233 |
+
with open('../pagerank/results/Init_0.7random.pkl', 'rb') as fl:
|
| 234 |
+
init = pkl.load(fl)
|
| 235 |
+
increase = (np.array(init) - np.array(edge)) / np.array(init)
|
| 236 |
+
increase = increase.reshape(-1)
|
| 237 |
+
selected_index = np.argsort(increase)[::-1][:10]
|
| 238 |
+
# print(selected_index)
|
| 239 |
+
# print(increase[selected_index])
|
| 240 |
+
# print(np.array(init)[selected_index])
|
| 241 |
+
# print(np.array(edge)[selected_index])
|
| 242 |
+
index = [index[i] for i in selected_index]
|
| 243 |
+
# llen = len(sorted_rank['chemical'])
|
| 244 |
+
# index = np.random.choice(range(llen//4, llen), 4, replace = False)
|
| 245 |
+
# index = selected_index + list(index)
|
| 246 |
+
# for i in index:
|
| 247 |
+
# ii = str(sorted_rank['chemical'][i][0])
|
| 248 |
+
# nm = entity_raw_name[idtomeshid[ii]]
|
| 249 |
+
# nmset = full_entity_raw_name[idtomeshid[ii]]
|
| 250 |
+
# print('**'*10)
|
| 251 |
+
# print(i)
|
| 252 |
+
# print(nm)
|
| 253 |
+
# print(nmset)
|
| 254 |
+
# raise Exception('stop')
|
| 255 |
+
target_node_list = []
|
| 256 |
+
attack_edge_list = []
|
| 257 |
+
num_max_eq = 0
|
| 258 |
+
mean_rank_of_total_max = 0
|
| 259 |
+
|
| 260 |
+
for pp in tqdm(index):
|
| 261 |
+
target = sorted_rank['chemical'][pp][0]
|
| 262 |
+
target_node_list.append(target)
|
| 263 |
+
|
| 264 |
+
print('Target:', entity_raw_name[idtomeshid[str(target)]])
|
| 265 |
+
|
| 266 |
+
candidate_list = []
|
| 267 |
+
score_list = []
|
| 268 |
+
loss_list = []
|
| 269 |
+
main_dict = {}
|
| 270 |
+
for iid, score in sorted_rank['merged']:
|
| 271 |
+
a = G.number_of_edges(iid, target) + 1
|
| 272 |
+
if a != 1:
|
| 273 |
+
continue
|
| 274 |
+
b = G.out_degree(iid) + 1
|
| 275 |
+
tp = idtomeshid[str(iid)].split('_')[0]
|
| 276 |
+
edge_losses = []
|
| 277 |
+
r_list = []
|
| 278 |
+
for r in range(len(edgeid_to_edgetype)):
|
| 279 |
+
r_tp = edgeid_to_edgetype[str(r)]
|
| 280 |
+
if (edgeid_to_reversemask[str(r)] == 0 and r_tp.split('-')[0] == tp and r_tp.split('-')[1] == 'chemical'):
|
| 281 |
+
train_trip = np.array([[iid, r, target]])
|
| 282 |
+
train_trip = torch.from_numpy(train_trip.astype('int64')).to(device)
|
| 283 |
+
edge_loss = get_model_loss_without_softmax(train_trip, model, device).squeeze()
|
| 284 |
+
edge_losses.append(edge_loss.unsqueeze(0).detach())
|
| 285 |
+
r_list.append(r)
|
| 286 |
+
elif(edgeid_to_reversemask[str(r)] == 1 and r_tp.split('-')[0] == 'chemical' and r_tp.split('-')[1] == tp):
|
| 287 |
+
train_trip = np.array([[iid, r, target]]) # add batch dim
|
| 288 |
+
train_trip = torch.from_numpy(train_trip.astype('int64')).to(device)
|
| 289 |
+
edge_loss = get_model_loss_without_softmax(train_trip, model, device).squeeze()
|
| 290 |
+
edge_losses.append(edge_loss.unsqueeze(0).detach())
|
| 291 |
+
r_list.append(r)
|
| 292 |
+
if len(edge_losses)==0:
|
| 293 |
+
continue
|
| 294 |
+
min_index = torch.argmin(torch.cat(edge_losses, dim = 0))
|
| 295 |
+
r = r_list[min_index]
|
| 296 |
+
r_tp = edgeid_to_edgetype[str(r)]
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
old_len = len(candidate_list)
|
| 300 |
+
if (edgeid_to_reversemask[str(r)] == 0):
|
| 301 |
+
bo, prob = check_reasonable(iid, r, target)
|
| 302 |
+
if bo:
|
| 303 |
+
candidate_list.append((iid, r, target))
|
| 304 |
+
score_list.append(score * a / b)
|
| 305 |
+
loss_list.append(edge_losses[min_index].item())
|
| 306 |
+
if (edgeid_to_reversemask[str(r)] == 1):
|
| 307 |
+
bo, prob = check_reasonable(target, r, iid)
|
| 308 |
+
if bo:
|
| 309 |
+
candidate_list.append((target, r, iid))
|
| 310 |
+
score_list.append(score * a / b)
|
| 311 |
+
loss_list.append(edge_losses[min_index].item())
|
| 312 |
+
|
| 313 |
+
if len(candidate_list) != old_len:
|
| 314 |
+
if int(iid) in main_iid:
|
| 315 |
+
main_dict[iid] = len(candidate_list) - 1
|
| 316 |
+
|
| 317 |
+
if len(candidate_list) == 0:
|
| 318 |
+
if args.added_edge_num == '' or int(args.added_edge_num) == 1:
|
| 319 |
+
attack_edge_list.append((-1,-1,-1))
|
| 320 |
+
else:
|
| 321 |
+
attack_edge_list.append([])
|
| 322 |
+
continue
|
| 323 |
+
norm_score = np.array(score_list) / np.sum(score_list)
|
| 324 |
+
norm_loss = np.exp(-np.array(loss_list)) / np.sum(np.exp(-np.array(loss_list)))
|
| 325 |
+
|
| 326 |
+
total_score = norm_score * norm_loss
|
| 327 |
+
total_score_index = list(zip(range(len(total_score)), total_score))
|
| 328 |
+
total_score_index.sort(key = lambda x: x[1], reverse = True)
|
| 329 |
+
|
| 330 |
+
norm_score_index = np.argsort(norm_score)[::-1]
|
| 331 |
+
norm_loss_index = np.argsort(norm_loss)[::-1]
|
| 332 |
+
total_index = np.argsort(total_score)[::-1]
|
| 333 |
+
assert total_index[0] == total_score_index[0][0]
|
| 334 |
+
# find rank of main index
|
| 335 |
+
for k, v in main_dict.items():
|
| 336 |
+
k = int(k)
|
| 337 |
+
index = v
|
| 338 |
+
print(f'score rank of {entity_raw_name[idtomeshid[str(k)]]}: ', norm_score_index.tolist().index(index))
|
| 339 |
+
print(f'loss rank of {entity_raw_name[idtomeshid[str(k)]]}: ', norm_loss_index.tolist().index(index))
|
| 340 |
+
print(f'total rank of {entity_raw_name[idtomeshid[str(k)]]}: ', total_index.tolist().index(index))
|
| 341 |
+
|
| 342 |
+
max_index = np.argmax(total_score)
|
| 343 |
+
assert max_index == total_score_index[0][0]
|
| 344 |
+
|
| 345 |
+
tmp_add = []
|
| 346 |
+
add_num = 1
|
| 347 |
+
if args.added_edge_num == '' or int(args.added_edge_num) == 1:
|
| 348 |
+
attack_edge_list.append(candidate_list[max_index])
|
| 349 |
+
else:
|
| 350 |
+
add_num = int(args.added_edge_num)
|
| 351 |
+
for i in range(add_num):
|
| 352 |
+
tmp_add.append(candidate_list[total_score_index[i][0]])
|
| 353 |
+
attack_edge_list.append(tmp_add)
|
| 354 |
+
|
| 355 |
+
score_max_index = np.argmax(norm_score)
|
| 356 |
+
if score_max_index == max_index:
|
| 357 |
+
num_max_eq += 1
|
| 358 |
+
score_index_list = list(zip(list(range(len(norm_score))), norm_score))
|
| 359 |
+
score_index_list.sort(key = lambda x: x[1], reverse = True)
|
| 360 |
+
max_index_in_score = score_index_list.index((max_index, norm_score[max_index]))
|
| 361 |
+
mean_rank_of_total_max += max_index_in_score / len(norm_score)
|
| 362 |
+
print('num_max_eq:', num_max_eq)
|
| 363 |
+
print('mean_rank_of_total_max:', mean_rank_of_total_max / 400)
|
| 364 |
+
Target_node_list = target_node_list
|
| 365 |
+
Attack_edge_list = attack_edge_list
|
| 366 |
+
print(np.array(Target_node_list).shape)
|
| 367 |
+
print(np.array(Attack_edge_list).shape)
|
| 368 |
+
# with open(f'processed_data/target_{args.reasonable_rate}{args.init_mode}.pkl', 'wb') as fl:
|
| 369 |
+
# pkl.dump(Target_node_list, fl)
|
| 370 |
+
# with open(f'processed_data/attack_edge_{args.model}_{args.reasonable_rate}{args.init_mode}{args.added_edge_num}.pkl', 'wb') as fl:
|
| 371 |
+
# pkl.dump(Attack_edge_list, fl)
|
DiseaseAgnostic/model.py
ADDED
|
@@ -0,0 +1,520 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.nn import functional as F, Parameter
|
| 3 |
+
from torch.autograd import Variable
|
| 4 |
+
from torch.nn.init import xavier_normal_, xavier_uniform_
|
| 5 |
+
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
|
| 6 |
+
|
| 7 |
+
class Distmult(torch.nn.Module):
|
| 8 |
+
def __init__(self, args, num_entities, num_relations):
|
| 9 |
+
super(Distmult, self).__init__()
|
| 10 |
+
|
| 11 |
+
if args.max_norm:
|
| 12 |
+
self.emb_e = torch.nn.Embedding(num_entities, args.embedding_dim, max_norm=1.0)
|
| 13 |
+
self.emb_rel = torch.nn.Embedding(num_relations, args.embedding_dim)
|
| 14 |
+
else:
|
| 15 |
+
self.emb_e = torch.nn.Embedding(num_entities, args.embedding_dim, padding_idx=None)
|
| 16 |
+
self.emb_rel = torch.nn.Embedding(num_relations, args.embedding_dim, padding_idx=None)
|
| 17 |
+
|
| 18 |
+
self.inp_drop = torch.nn.Dropout(args.input_drop)
|
| 19 |
+
self.loss = torch.nn.CrossEntropyLoss()
|
| 20 |
+
|
| 21 |
+
self.init()
|
| 22 |
+
|
| 23 |
+
def init(self):
|
| 24 |
+
xavier_normal_(self.emb_e.weight)
|
| 25 |
+
xavier_normal_(self.emb_rel.weight)
|
| 26 |
+
|
| 27 |
+
def score_sr(self, sub, rel, sigmoid = False):
|
| 28 |
+
sub_emb = self.emb_e(sub).squeeze(dim=1)
|
| 29 |
+
rel_emb = self.emb_rel(rel).squeeze(dim=1)
|
| 30 |
+
|
| 31 |
+
#sub_emb = self.inp_drop(sub_emb)
|
| 32 |
+
#rel_emb = self.inp_drop(rel_emb)
|
| 33 |
+
|
| 34 |
+
pred = torch.mm(sub_emb*rel_emb, self.emb_e.weight.transpose(1,0))
|
| 35 |
+
if sigmoid:
|
| 36 |
+
pred = torch.sigmoid(pred)
|
| 37 |
+
return pred
|
| 38 |
+
|
| 39 |
+
def score_or(self, obj, rel, sigmoid = False):
|
| 40 |
+
obj_emb = self.emb_e(obj).squeeze(dim=1)
|
| 41 |
+
rel_emb = self.emb_rel(rel).squeeze(dim=1)
|
| 42 |
+
|
| 43 |
+
#obj_emb = self.inp_drop(obj_emb)
|
| 44 |
+
#rel_emb = self.inp_drop(rel_emb)
|
| 45 |
+
|
| 46 |
+
pred = torch.mm(obj_emb*rel_emb, self.emb_e.weight.transpose(1,0))
|
| 47 |
+
if sigmoid:
|
| 48 |
+
pred = torch.sigmoid(pred)
|
| 49 |
+
return pred
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def forward(self, sub_emb, rel_emb, mode='rhs', sigmoid=False):
|
| 53 |
+
'''
|
| 54 |
+
When mode is 'rhs' we expect (s,r); for 'lhs', we expect (o,r)
|
| 55 |
+
For distmult, computations for both modes are equivalent, so we do not need if-else block
|
| 56 |
+
'''
|
| 57 |
+
sub_emb = self.inp_drop(sub_emb)
|
| 58 |
+
rel_emb = self.inp_drop(rel_emb)
|
| 59 |
+
|
| 60 |
+
pred = torch.mm(sub_emb*rel_emb, self.emb_e.weight.transpose(1,0))
|
| 61 |
+
|
| 62 |
+
if sigmoid:
|
| 63 |
+
pred = torch.sigmoid(pred)
|
| 64 |
+
|
| 65 |
+
return pred
|
| 66 |
+
|
| 67 |
+
def score_triples(self, sub, rel, obj, sigmoid=False):
|
| 68 |
+
'''
|
| 69 |
+
Inputs - subject, relation, object
|
| 70 |
+
Return - score
|
| 71 |
+
'''
|
| 72 |
+
sub_emb = self.emb_e(sub).squeeze(dim=1)
|
| 73 |
+
rel_emb = self.emb_rel(rel).squeeze(dim=1)
|
| 74 |
+
obj_emb = self.emb_e(obj).squeeze(dim=1)
|
| 75 |
+
|
| 76 |
+
pred = torch.sum(sub_emb*rel_emb*obj_emb, dim=-1)
|
| 77 |
+
|
| 78 |
+
if sigmoid:
|
| 79 |
+
pred = torch.sigmoid(pred)
|
| 80 |
+
|
| 81 |
+
return pred
|
| 82 |
+
|
| 83 |
+
def score_emb(self, emb_s, emb_r, emb_o, sigmoid=False):
|
| 84 |
+
'''
|
| 85 |
+
Inputs - embeddings of subject, relation, object
|
| 86 |
+
Return - score
|
| 87 |
+
'''
|
| 88 |
+
pred = torch.sum(emb_s*emb_r*emb_o, dim=-1)
|
| 89 |
+
|
| 90 |
+
if sigmoid:
|
| 91 |
+
pred = torch.sigmoid(pred)
|
| 92 |
+
|
| 93 |
+
return pred
|
| 94 |
+
|
| 95 |
+
def score_triples_vec(self, sub, rel, obj, sigmoid=False):
|
| 96 |
+
'''
|
| 97 |
+
Inputs - subject, relation, object
|
| 98 |
+
Return - a vector score for the triple instead of reducing over the embedding dimension
|
| 99 |
+
'''
|
| 100 |
+
sub_emb = self.emb_e(sub).squeeze(dim=1)
|
| 101 |
+
rel_emb = self.emb_rel(rel).squeeze(dim=1)
|
| 102 |
+
obj_emb = self.emb_e(obj).squeeze(dim=1)
|
| 103 |
+
|
| 104 |
+
pred = sub_emb*rel_emb*obj_emb
|
| 105 |
+
|
| 106 |
+
if sigmoid:
|
| 107 |
+
pred = torch.sigmoid(pred)
|
| 108 |
+
|
| 109 |
+
return pred
|
| 110 |
+
|
| 111 |
+
class Complex(torch.nn.Module):
|
| 112 |
+
def __init__(self, args, num_entities, num_relations):
|
| 113 |
+
super(Complex, self).__init__()
|
| 114 |
+
|
| 115 |
+
if args.max_norm:
|
| 116 |
+
self.emb_e = torch.nn.Embedding(num_entities, 2*args.embedding_dim, max_norm=1.0)
|
| 117 |
+
self.emb_rel = torch.nn.Embedding(num_relations, 2*args.embedding_dim)
|
| 118 |
+
else:
|
| 119 |
+
self.emb_e = torch.nn.Embedding(num_entities, 2*args.embedding_dim, padding_idx=None)
|
| 120 |
+
self.emb_rel = torch.nn.Embedding(num_relations, 2*args.embedding_dim, padding_idx=None)
|
| 121 |
+
|
| 122 |
+
self.inp_drop = torch.nn.Dropout(args.input_drop)
|
| 123 |
+
self.loss = torch.nn.CrossEntropyLoss()
|
| 124 |
+
|
| 125 |
+
self.init()
|
| 126 |
+
|
| 127 |
+
def init(self):
|
| 128 |
+
xavier_normal_(self.emb_e.weight)
|
| 129 |
+
xavier_normal_(self.emb_rel.weight)
|
| 130 |
+
|
| 131 |
+
def score_sr(self, sub, rel, sigmoid = False):
|
| 132 |
+
sub_emb = self.emb_e(sub).squeeze(dim=1)
|
| 133 |
+
rel_emb = self.emb_rel(rel).squeeze(dim=1)
|
| 134 |
+
|
| 135 |
+
s_real, s_img = torch.chunk(rel_emb, 2, dim=-1)
|
| 136 |
+
rel_real, rel_img = torch.chunk(sub_emb, 2, dim=-1)
|
| 137 |
+
emb_e_real, emb_e_img = torch.chunk(self.emb_e.weight, 2, dim=-1)
|
| 138 |
+
|
| 139 |
+
#s_real = self.inp_drop(s_real)
|
| 140 |
+
#s_img = self.inp_drop(s_img)
|
| 141 |
+
#rel_real = self.inp_drop(rel_real)
|
| 142 |
+
#rel_img = self.inp_drop(rel_img)
|
| 143 |
+
|
| 144 |
+
# complex space bilinear product (equivalent to HolE)
|
| 145 |
+
# realrealreal = torch.mm(s_real*rel_real, emb_e_real.transpose(1,0))
|
| 146 |
+
# realimgimg = torch.mm(s_real*rel_img, emb_e_img.transpose(1,0))
|
| 147 |
+
# imgrealimg = torch.mm(s_img*rel_real, emb_e_img.transpose(1,0))
|
| 148 |
+
# imgimgreal = torch.mm(s_img*rel_img, emb_e_real.transpose(1,0))
|
| 149 |
+
# pred = realrealreal + realimgimg + imgrealimg - imgimgreal
|
| 150 |
+
|
| 151 |
+
realo_realreal = s_real*rel_real
|
| 152 |
+
realo_imgimg = s_img*rel_img
|
| 153 |
+
realo = realo_realreal - realo_imgimg
|
| 154 |
+
real = torch.mm(realo, emb_e_real.transpose(1,0))
|
| 155 |
+
|
| 156 |
+
imgo_realimg = s_real*rel_img
|
| 157 |
+
imgo_imgreal = s_img*rel_real
|
| 158 |
+
imgo = imgo_realimg + imgo_imgreal
|
| 159 |
+
img = torch.mm(imgo, emb_e_img.transpose(1,0))
|
| 160 |
+
|
| 161 |
+
pred = real + img
|
| 162 |
+
|
| 163 |
+
if sigmoid:
|
| 164 |
+
pred = torch.sigmoid(pred)
|
| 165 |
+
return pred
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def score_or(self, obj, rel, sigmoid = False):
|
| 169 |
+
obj_emb = self.emb_e(obj).squeeze(dim=1)
|
| 170 |
+
rel_emb = self.emb_rel(rel).squeeze(dim=1)
|
| 171 |
+
|
| 172 |
+
rel_real, rel_img = torch.chunk(rel_emb, 2, dim=-1)
|
| 173 |
+
o_real, o_img = torch.chunk(obj_emb, 2, dim=-1)
|
| 174 |
+
emb_e_real, emb_e_img = torch.chunk(self.emb_e.weight, 2, dim=-1)
|
| 175 |
+
|
| 176 |
+
#rel_real = self.inp_drop(rel_real)
|
| 177 |
+
#rel_img = self.inp_drop(rel_img)
|
| 178 |
+
#o_real = self.inp_drop(o_real)
|
| 179 |
+
#o_img = self.inp_drop(o_img)
|
| 180 |
+
|
| 181 |
+
# complex space bilinear product (equivalent to HolE)
|
| 182 |
+
# realrealreal = torch.mm(rel_real*o_real, emb_e_real.transpose(1,0))
|
| 183 |
+
# realimgimg = torch.mm(rel_img*o_img, emb_e_real.transpose(1,0))
|
| 184 |
+
# imgrealimg = torch.mm(rel_real*o_img, emb_e_img.transpose(1,0))
|
| 185 |
+
# imgimgreal = torch.mm(rel_img*o_real, emb_e_img.transpose(1,0))
|
| 186 |
+
# pred = realrealreal + realimgimg + imgrealimg - imgimgreal
|
| 187 |
+
|
| 188 |
+
reals_realreal = rel_real*o_real
|
| 189 |
+
reals_imgimg = rel_img*o_img
|
| 190 |
+
reals = reals_realreal + reals_imgimg
|
| 191 |
+
real = torch.mm(reals, emb_e_real.transpose(1,0))
|
| 192 |
+
|
| 193 |
+
imgs_realimg = rel_real*o_img
|
| 194 |
+
imgs_imgreal = rel_img*o_real
|
| 195 |
+
imgs = imgs_realimg - imgs_imgreal
|
| 196 |
+
img = torch.mm(imgs, emb_e_img.transpose(1,0))
|
| 197 |
+
|
| 198 |
+
pred = real + img
|
| 199 |
+
|
| 200 |
+
if sigmoid:
|
| 201 |
+
pred = torch.sigmoid(pred)
|
| 202 |
+
return pred
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def forward(self, sub_emb, rel_emb, mode='rhs', sigmoid=False):
|
| 206 |
+
'''
|
| 207 |
+
When mode is 'rhs' we expect (s,r); for 'lhs', we expect (o,r)
|
| 208 |
+
|
| 209 |
+
'''
|
| 210 |
+
if mode == 'lhs':
|
| 211 |
+
rel_real, rel_img = torch.chunk(rel_emb, 2, dim=-1)
|
| 212 |
+
o_real, o_img = torch.chunk(sub_emb, 2, dim=-1)
|
| 213 |
+
emb_e_real, emb_e_img = torch.chunk(self.emb_e.weight, 2, dim=-1)
|
| 214 |
+
|
| 215 |
+
rel_real = self.inp_drop(rel_real)
|
| 216 |
+
rel_img = self.inp_drop(rel_img)
|
| 217 |
+
o_real = self.inp_drop(o_real)
|
| 218 |
+
o_img = self.inp_drop(o_img)
|
| 219 |
+
|
| 220 |
+
# complex space bilinear product (equivalent to HolE)
|
| 221 |
+
# realrealreal = torch.mm(rel_real*o_real, emb_e_real.transpose(1,0))
|
| 222 |
+
# realimgimg = torch.mm(rel_img*o_img, emb_e_real.transpose(1,0))
|
| 223 |
+
# imgrealimg = torch.mm(rel_real*o_img, emb_e_img.transpose(1,0))
|
| 224 |
+
# imgimgreal = torch.mm(rel_img*o_real, emb_e_img.transpose(1,0))
|
| 225 |
+
# pred = realrealreal + realimgimg + imgrealimg - imgimgreal
|
| 226 |
+
reals_realreal = rel_real*o_real
|
| 227 |
+
reals_imgimg = rel_img*o_img
|
| 228 |
+
reals = reals_realreal + reals_imgimg
|
| 229 |
+
real = torch.mm(reals, emb_e_real.transpose(1,0))
|
| 230 |
+
|
| 231 |
+
imgs_realimg = rel_real*o_img
|
| 232 |
+
imgs_imgreal = rel_img*o_real
|
| 233 |
+
imgs = imgs_realimg - imgs_imgreal
|
| 234 |
+
img = torch.mm(imgs, emb_e_img.transpose(1,0))
|
| 235 |
+
|
| 236 |
+
pred = real + img
|
| 237 |
+
|
| 238 |
+
else:
|
| 239 |
+
s_real, s_img = torch.chunk(rel_emb, 2, dim=-1)
|
| 240 |
+
rel_real, rel_img = torch.chunk(sub_emb, 2, dim=-1)
|
| 241 |
+
emb_e_real, emb_e_img = torch.chunk(self.emb_e.weight, 2, dim=-1)
|
| 242 |
+
|
| 243 |
+
s_real = self.inp_drop(s_real)
|
| 244 |
+
s_img = self.inp_drop(s_img)
|
| 245 |
+
rel_real = self.inp_drop(rel_real)
|
| 246 |
+
rel_img = self.inp_drop(rel_img)
|
| 247 |
+
|
| 248 |
+
# complex space bilinear product (equivalent to HolE)
|
| 249 |
+
# realrealreal = torch.mm(s_real*rel_real, emb_e_real.transpose(1,0))
|
| 250 |
+
# realimgimg = torch.mm(s_real*rel_img, emb_e_img.transpose(1,0))
|
| 251 |
+
# imgrealimg = torch.mm(s_img*rel_real, emb_e_img.transpose(1,0))
|
| 252 |
+
# imgimgreal = torch.mm(s_img*rel_img, emb_e_real.transpose(1,0))
|
| 253 |
+
# pred = realrealreal + realimgimg + imgrealimg - imgimgreal
|
| 254 |
+
|
| 255 |
+
realo_realreal = s_real*rel_real
|
| 256 |
+
realo_imgimg = s_img*rel_img
|
| 257 |
+
realo = realo_realreal - realo_imgimg
|
| 258 |
+
real = torch.mm(realo, emb_e_real.transpose(1,0))
|
| 259 |
+
|
| 260 |
+
imgo_realimg = s_real*rel_img
|
| 261 |
+
imgo_imgreal = s_img*rel_real
|
| 262 |
+
imgo = imgo_realimg + imgo_imgreal
|
| 263 |
+
img = torch.mm(imgo, emb_e_img.transpose(1,0))
|
| 264 |
+
|
| 265 |
+
pred = real + img
|
| 266 |
+
|
| 267 |
+
if sigmoid:
|
| 268 |
+
pred = torch.sigmoid(pred)
|
| 269 |
+
|
| 270 |
+
return pred
|
| 271 |
+
|
| 272 |
+
def score_triples(self, sub, rel, obj, sigmoid=False):
|
| 273 |
+
'''
|
| 274 |
+
Inputs - subject, relation, object
|
| 275 |
+
Return - score
|
| 276 |
+
'''
|
| 277 |
+
sub_emb = self.emb_e(sub).squeeze(dim=1)
|
| 278 |
+
rel_emb = self.emb_rel(rel).squeeze(dim=1)
|
| 279 |
+
obj_emb = self.emb_e(obj).squeeze(dim=1)
|
| 280 |
+
|
| 281 |
+
s_real, s_img = torch.chunk(sub_emb, 2, dim=-1)
|
| 282 |
+
rel_real, rel_img = torch.chunk(rel_emb, 2, dim=-1)
|
| 283 |
+
o_real, o_img = torch.chunk(obj_emb, 2, dim=-1)
|
| 284 |
+
|
| 285 |
+
realrealreal = torch.sum(s_real*rel_real*o_real, dim=-1)
|
| 286 |
+
realimgimg = torch.sum(s_real*rel_img*o_img, axis=-1)
|
| 287 |
+
imgrealimg = torch.sum(s_img*rel_real*o_img, axis=-1)
|
| 288 |
+
imgimgreal = torch.sum(s_img*rel_img*o_real, axis=-1)
|
| 289 |
+
|
| 290 |
+
pred = realrealreal + realimgimg + imgrealimg - imgimgreal
|
| 291 |
+
|
| 292 |
+
if sigmoid:
|
| 293 |
+
pred = torch.sigmoid(pred)
|
| 294 |
+
|
| 295 |
+
return pred
|
| 296 |
+
|
| 297 |
+
def score_emb(self, emb_s, emb_r, emb_o, sigmoid=False):
|
| 298 |
+
'''
|
| 299 |
+
Inputs - embeddings of subject, relation, object
|
| 300 |
+
Return - score
|
| 301 |
+
'''
|
| 302 |
+
|
| 303 |
+
s_real, s_img = torch.chunk(emb_s, 2, dim=-1)
|
| 304 |
+
rel_real, rel_img = torch.chunk(emb_r, 2, dim=-1)
|
| 305 |
+
o_real, o_img = torch.chunk(emb_o, 2, dim=-1)
|
| 306 |
+
|
| 307 |
+
realrealreal = torch.sum(s_real*rel_real*o_real, dim=-1)
|
| 308 |
+
realimgimg = torch.sum(s_real*rel_img*o_img, axis=-1)
|
| 309 |
+
imgrealimg = torch.sum(s_img*rel_real*o_img, axis=-1)
|
| 310 |
+
imgimgreal = torch.sum(s_img*rel_img*o_real, axis=-1)
|
| 311 |
+
|
| 312 |
+
pred = realrealreal + realimgimg + imgrealimg - imgimgreal
|
| 313 |
+
|
| 314 |
+
if sigmoid:
|
| 315 |
+
pred = torch.sigmoid(pred)
|
| 316 |
+
|
| 317 |
+
return pred
|
| 318 |
+
|
| 319 |
+
def score_triples_vec(self, sub, rel, obj, sigmoid=False):
|
| 320 |
+
'''
|
| 321 |
+
Inputs - subject, relation, object
|
| 322 |
+
Return - a vector score for the triple instead of reducing over the embedding dimension
|
| 323 |
+
'''
|
| 324 |
+
sub_emb = self.emb_e(sub).squeeze(dim=1)
|
| 325 |
+
rel_emb = self.emb_rel(rel).squeeze(dim=1)
|
| 326 |
+
obj_emb = self.emb_e(obj).squeeze(dim=1)
|
| 327 |
+
|
| 328 |
+
s_real, s_img = torch.chunk(sub_emb, 2, dim=-1)
|
| 329 |
+
rel_real, rel_img = torch.chunk(rel_emb, 2, dim=-1)
|
| 330 |
+
o_real, o_img = torch.chunk(obj_emb, 2, dim=-1)
|
| 331 |
+
|
| 332 |
+
realrealreal = s_real*rel_real*o_real
|
| 333 |
+
realimgimg = s_real*rel_img*o_img
|
| 334 |
+
imgrealimg = s_img*rel_real*o_img
|
| 335 |
+
imgimgreal = s_img*rel_img*o_real
|
| 336 |
+
|
| 337 |
+
pred = realrealreal + realimgimg + imgrealimg - imgimgreal
|
| 338 |
+
|
| 339 |
+
if sigmoid:
|
| 340 |
+
pred = torch.sigmoid(pred)
|
| 341 |
+
|
| 342 |
+
return pred
|
| 343 |
+
|
| 344 |
+
class Conve(torch.nn.Module):
|
| 345 |
+
|
| 346 |
+
#Too slow !!!!
|
| 347 |
+
|
| 348 |
+
def __init__(self, args, num_entities, num_relations):
|
| 349 |
+
super(Conve, self).__init__()
|
| 350 |
+
|
| 351 |
+
if args.max_norm:
|
| 352 |
+
self.emb_e = torch.nn.Embedding(num_entities, args.embedding_dim, max_norm=1.0)
|
| 353 |
+
self.emb_rel = torch.nn.Embedding(num_relations, args.embedding_dim)
|
| 354 |
+
else:
|
| 355 |
+
self.emb_e = torch.nn.Embedding(num_entities, args.embedding_dim, padding_idx=None)
|
| 356 |
+
self.emb_rel = torch.nn.Embedding(num_relations, args.embedding_dim, padding_idx=None)
|
| 357 |
+
|
| 358 |
+
self.inp_drop = torch.nn.Dropout(args.input_drop)
|
| 359 |
+
self.hidden_drop = torch.nn.Dropout(args.hidden_drop)
|
| 360 |
+
self.feature_drop = torch.nn.Dropout2d(args.feat_drop)
|
| 361 |
+
|
| 362 |
+
self.embedding_dim = args.embedding_dim #default is 200
|
| 363 |
+
self.num_filters = args.num_filters # default is 32
|
| 364 |
+
self.kernel_size = args.kernel_size # default is 3
|
| 365 |
+
self.stack_width = args.stack_width # default is 20
|
| 366 |
+
self.stack_height = args.embedding_dim // self.stack_width
|
| 367 |
+
|
| 368 |
+
self.bn0 = torch.nn.BatchNorm2d(1)
|
| 369 |
+
self.bn1 = torch.nn.BatchNorm2d(self.num_filters)
|
| 370 |
+
self.bn2 = torch.nn.BatchNorm1d(args.embedding_dim)
|
| 371 |
+
|
| 372 |
+
self.conv1 = torch.nn.Conv2d(1, out_channels=self.num_filters,
|
| 373 |
+
kernel_size=(self.kernel_size, self.kernel_size),
|
| 374 |
+
stride=1, padding=0, bias=args.use_bias)
|
| 375 |
+
#self.conv1 = torch.nn.Conv2d(1, 32, (3, 3), 1, 0, bias=args.use_bias) # <-- default
|
| 376 |
+
|
| 377 |
+
flat_sz_h = int(2*self.stack_width) - self.kernel_size + 1
|
| 378 |
+
flat_sz_w = self.stack_height - self.kernel_size + 1
|
| 379 |
+
self.flat_sz = flat_sz_h*flat_sz_w*self.num_filters
|
| 380 |
+
self.fc = torch.nn.Linear(self.flat_sz, args.embedding_dim)
|
| 381 |
+
|
| 382 |
+
self.register_parameter('b', Parameter(torch.zeros(num_entities)))
|
| 383 |
+
self.loss = torch.nn.CrossEntropyLoss()
|
| 384 |
+
|
| 385 |
+
self.init()
|
| 386 |
+
|
| 387 |
+
def init(self):
|
| 388 |
+
xavier_normal_(self.emb_e.weight)
|
| 389 |
+
xavier_normal_(self.emb_rel.weight)
|
| 390 |
+
|
| 391 |
+
def concat(self, e1_embed, rel_embed, form='plain'):
|
| 392 |
+
if form == 'plain':
|
| 393 |
+
e1_embed = e1_embed. view(-1, 1, self.stack_width, self.stack_height)
|
| 394 |
+
rel_embed = rel_embed.view(-1, 1, self.stack_width, self.stack_height)
|
| 395 |
+
stack_inp = torch.cat([e1_embed, rel_embed], 2)
|
| 396 |
+
|
| 397 |
+
elif form == 'alternate':
|
| 398 |
+
e1_embed = e1_embed. view(-1, 1, self.embedding_dim)
|
| 399 |
+
rel_embed = rel_embed.view(-1, 1, self.embedding_dim)
|
| 400 |
+
stack_inp = torch.cat([e1_embed, rel_embed], 1)
|
| 401 |
+
stack_inp = torch.transpose(stack_inp, 2, 1).reshape((-1, 1, 2*self.stack_width, self.stack_height))
|
| 402 |
+
|
| 403 |
+
else: raise NotImplementedError
|
| 404 |
+
return stack_inp
|
| 405 |
+
|
| 406 |
+
def conve_architecture(self, sub_emb, rel_emb):
|
| 407 |
+
stacked_inputs = self.concat(sub_emb, rel_emb)
|
| 408 |
+
stacked_inputs = self.bn0(stacked_inputs)
|
| 409 |
+
x = self.inp_drop(stacked_inputs)
|
| 410 |
+
x = self.conv1(x)
|
| 411 |
+
x = self.bn1(x)
|
| 412 |
+
x = F.relu(x)
|
| 413 |
+
x = self.feature_drop(x)
|
| 414 |
+
#x = x.view(x.shape[0], -1)
|
| 415 |
+
x = x.view(-1, self.flat_sz)
|
| 416 |
+
x = self.fc(x)
|
| 417 |
+
x = self.hidden_drop(x)
|
| 418 |
+
x = self.bn2(x)
|
| 419 |
+
x = F.relu(x)
|
| 420 |
+
|
| 421 |
+
return x
|
| 422 |
+
|
| 423 |
+
def score_sr(self, sub, rel, sigmoid = False):
|
| 424 |
+
sub_emb = self.emb_e(sub)
|
| 425 |
+
rel_emb = self.emb_rel(rel)
|
| 426 |
+
|
| 427 |
+
x = self.conve_architecture(sub_emb, rel_emb)
|
| 428 |
+
|
| 429 |
+
pred = torch.mm(x, self.emb_e.weight.transpose(1,0))
|
| 430 |
+
pred += self.b.expand_as(pred)
|
| 431 |
+
|
| 432 |
+
if sigmoid:
|
| 433 |
+
pred = torch.sigmoid(pred)
|
| 434 |
+
return pred
|
| 435 |
+
|
| 436 |
+
def score_or(self, obj, rel, sigmoid = False):
|
| 437 |
+
obj_emb = self.emb_e(obj)
|
| 438 |
+
rel_emb = self.emb_rel(rel)
|
| 439 |
+
|
| 440 |
+
x = self.conve_architecture(obj_emb, rel_emb)
|
| 441 |
+
pred = torch.mm(x, self.emb_e.weight.transpose(1,0))
|
| 442 |
+
pred += self.b.expand_as(pred)
|
| 443 |
+
|
| 444 |
+
if sigmoid:
|
| 445 |
+
pred = torch.sigmoid(pred)
|
| 446 |
+
return pred
|
| 447 |
+
|
| 448 |
+
|
| 449 |
+
def forward(self, sub_emb, rel_emb, mode='rhs', sigmoid=False):
|
| 450 |
+
'''
|
| 451 |
+
When mode is 'rhs' we expect (s,r); for 'lhs', we expect (o,r)
|
| 452 |
+
For conve, computations for both modes are equivalent, so we do not need if-else block
|
| 453 |
+
'''
|
| 454 |
+
x = self.conve_architecture(sub_emb, rel_emb)
|
| 455 |
+
|
| 456 |
+
pred = torch.mm(x, self.emb_e.weight.transpose(1,0))
|
| 457 |
+
pred += self.b.expand_as(pred)
|
| 458 |
+
|
| 459 |
+
if sigmoid:
|
| 460 |
+
pred = torch.sigmoid(pred)
|
| 461 |
+
|
| 462 |
+
return pred
|
| 463 |
+
|
| 464 |
+
def score_triples(self, sub, rel, obj, sigmoid=False):
|
| 465 |
+
'''
|
| 466 |
+
Inputs - subject, relation, object
|
| 467 |
+
Return - score
|
| 468 |
+
'''
|
| 469 |
+
sub_emb = self.emb_e(sub)
|
| 470 |
+
rel_emb = self.emb_rel(rel)
|
| 471 |
+
obj_emb = self.emb_e(obj)
|
| 472 |
+
x = self.conve_architecture(sub_emb, rel_emb)
|
| 473 |
+
|
| 474 |
+
pred = torch.mm(x, obj_emb.transpose(1,0))
|
| 475 |
+
#print(pred.shape)
|
| 476 |
+
pred += self.b[obj].expand_as(pred) #taking the bias value for object embedding
|
| 477 |
+
# above works fine for single input triples;
|
| 478 |
+
# but if input is batch of triples, then this is a matrix of (num_trip x num_trip) where diagonal is scores
|
| 479 |
+
# so use torch.diagonal() after calling this function
|
| 480 |
+
pred = torch.diagonal(pred)
|
| 481 |
+
# or could have used : pred= torch.sum(x*obj_emb, dim=-1)
|
| 482 |
+
|
| 483 |
+
if sigmoid:
|
| 484 |
+
pred = torch.sigmoid(pred)
|
| 485 |
+
|
| 486 |
+
return pred
|
| 487 |
+
|
| 488 |
+
def score_emb(self, emb_s, emb_r, emb_o, sigmoid=False):
|
| 489 |
+
'''
|
| 490 |
+
Inputs - embeddings of subject, relation, object
|
| 491 |
+
Return - score
|
| 492 |
+
'''
|
| 493 |
+
x = self.conve_architecture(emb_s, emb_r)
|
| 494 |
+
|
| 495 |
+
pred = torch.mm(x, emb_o.transpose(1,0))
|
| 496 |
+
|
| 497 |
+
pred = torch.diagonal(pred)
|
| 498 |
+
|
| 499 |
+
if sigmoid:
|
| 500 |
+
pred = torch.sigmoid(pred)
|
| 501 |
+
|
| 502 |
+
return pred
|
| 503 |
+
|
| 504 |
+
def score_triples_vec(self, sub, rel, obj, sigmoid=False):
|
| 505 |
+
'''
|
| 506 |
+
Inputs - subject, relation, object
|
| 507 |
+
Return - a vector score for the triple instead of reducing over the embedding dimension
|
| 508 |
+
'''
|
| 509 |
+
sub_emb = self.emb_e(sub)
|
| 510 |
+
rel_emb = self.emb_rel(rel)
|
| 511 |
+
obj_emb = self.emb_e(obj)
|
| 512 |
+
|
| 513 |
+
x = self.conve_architecture(sub_emb, rel_emb)
|
| 514 |
+
|
| 515 |
+
pred = x*obj_emb
|
| 516 |
+
|
| 517 |
+
if sigmoid:
|
| 518 |
+
pred = torch.sigmoid(pred)
|
| 519 |
+
|
| 520 |
+
return pred
|
DiseaseAgnostic/utils.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''
|
| 2 |
+
A file modified on https://github.com/PeruBhardwaj/AttributionAttack/blob/main/KGEAttack/ConvE/utils.py
|
| 3 |
+
'''
|
| 4 |
+
#%%
|
| 5 |
+
import logging
|
| 6 |
+
import time
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
import io
|
| 9 |
+
import pandas as pd
|
| 10 |
+
import numpy as np
|
| 11 |
+
import os
|
| 12 |
+
import json
|
| 13 |
+
|
| 14 |
+
import argparse
|
| 15 |
+
import torch
|
| 16 |
+
import random
|
| 17 |
+
|
| 18 |
+
from yaml import parse
|
| 19 |
+
|
| 20 |
+
from model import Conve, Distmult, Complex
|
| 21 |
+
|
| 22 |
+
logger = logging.getLogger(__name__)
|
| 23 |
+
#%%
|
| 24 |
+
def generate_dicts(data_path):
|
| 25 |
+
with open (os.path.join(data_path, 'entities_dict.json'), 'r') as f:
|
| 26 |
+
ent_to_id = json.load(f)
|
| 27 |
+
with open (os.path.join(data_path, 'relations_dict.json'), 'r') as f:
|
| 28 |
+
rel_to_id = json.load(f)
|
| 29 |
+
n_ent = len(list(ent_to_id.keys()))
|
| 30 |
+
n_rel = len(list(rel_to_id.keys()))
|
| 31 |
+
|
| 32 |
+
return n_ent, n_rel, ent_to_id, rel_to_id
|
| 33 |
+
|
| 34 |
+
def save_data(file_name, data):
|
| 35 |
+
with open(file_name, 'w') as fl:
|
| 36 |
+
for item in data:
|
| 37 |
+
fl.write("%s\n" % "\t".join(map(str, item)))
|
| 38 |
+
|
| 39 |
+
def load_data(file_name):
|
| 40 |
+
df = pd.read_csv(file_name, sep='\t', header=None, names=None, dtype=str)
|
| 41 |
+
df = df.drop_duplicates()
|
| 42 |
+
return df.values
|
| 43 |
+
|
| 44 |
+
def seed_all(seed=1):
|
| 45 |
+
random.seed(seed)
|
| 46 |
+
np.random.seed(seed)
|
| 47 |
+
torch.manual_seed(seed)
|
| 48 |
+
torch.cuda.manual_seed_all(seed)
|
| 49 |
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
| 50 |
+
torch.backends.cudnn.deterministic = True
|
| 51 |
+
|
| 52 |
+
def add_model(args, n_ent, n_rel):
|
| 53 |
+
if args.model is None:
|
| 54 |
+
model = Distmult(args, n_ent, n_rel)
|
| 55 |
+
elif args.model == 'distmult':
|
| 56 |
+
model = Distmult(args, n_ent, n_rel)
|
| 57 |
+
elif args.model == 'complex':
|
| 58 |
+
model = Complex(args, n_ent, n_rel)
|
| 59 |
+
elif args.model == 'conve':
|
| 60 |
+
model = Conve(args, n_ent, n_rel)
|
| 61 |
+
else:
|
| 62 |
+
raise Exception("Unknown model!")
|
| 63 |
+
|
| 64 |
+
return model
|
| 65 |
+
|
| 66 |
+
def load_model(model_path, args, n_ent, n_rel, device):
|
| 67 |
+
# add a model and load the pre-trained params
|
| 68 |
+
model = add_model(args, n_ent, n_rel)
|
| 69 |
+
model.to(device)
|
| 70 |
+
logger.info('Loading saved model from {0}'.format(model_path))
|
| 71 |
+
state = torch.load(model_path)
|
| 72 |
+
model_params = state['state_dict']
|
| 73 |
+
params = [(key, value.size(), value.numel()) for key, value in model_params.items()]
|
| 74 |
+
for key, size, count in params:
|
| 75 |
+
logger.info('Key:{0}, Size:{1}, Count:{2}'.format(key, size, count))
|
| 76 |
+
|
| 77 |
+
model.load_state_dict(model_params)
|
| 78 |
+
model.eval()
|
| 79 |
+
logger.info(model)
|
| 80 |
+
|
| 81 |
+
return model
|
| 82 |
+
|
| 83 |
+
def add_eval_parameters(parser):
|
| 84 |
+
|
| 85 |
+
parser.add_argument('--eval-mode', type = str, default = 'all', help = 'Method to evaluate the attack performance. Default: all. (all or single)')
|
| 86 |
+
parser.add_argument('--cuda-name', type = str, required = True, help = 'Start a main thread on each cuda.')
|
| 87 |
+
parser.add_argument('--direct', action='store_true', help = 'Directly add edge or not.')
|
| 88 |
+
parser.add_argument('--seperate', action='store_true', help = 'Evaluate seperatly or not')
|
| 89 |
+
return parser
|
| 90 |
+
|
| 91 |
+
def add_attack_parameters(parser):
|
| 92 |
+
|
| 93 |
+
# parser.add_argument('--target-split', type=str, default='0_100_1', help='Ranks to use for target set. Values are 0 for ranks==1; 1 for ranks <=10; 2 for ranks>10 and ranks<=100. Default: 1')
|
| 94 |
+
parser.add_argument('--target-split', type=str, default='min', help='Methods for target triple selection. Default: min. (min or top_?, top means top_0.1)')
|
| 95 |
+
parser.add_argument('--target-size', type=int, default=50, help='Number of target triples. Default: 50')
|
| 96 |
+
parser.add_argument('--target-existed', action='store_true', help='Whether the targeted s_?_o already exists.')
|
| 97 |
+
|
| 98 |
+
# parser.add_argument('--budget', type=int, default=1, help='Budget for each target triple for each corruption side')
|
| 99 |
+
|
| 100 |
+
parser.add_argument('--attack-goal', type = str, default='single', help='Attack goal. Default: single. (single or global)')
|
| 101 |
+
parser.add_argument('--neighbor-num', type = int, default=20, help='Max neighbor num for each side. Default: 20')
|
| 102 |
+
parser.add_argument('--candidate-mode', type = str, default='quadratic', help = 'The method to generate candidate edge. Default: quadratic. (quadratic or linear)')
|
| 103 |
+
parser.add_argument('--reasonable-rate', type = float, default=0.7, help = 'The added edge\'s existance rank prob greater than this rate')
|
| 104 |
+
# parser.add_argument('--neighbor-num', type = int, default=200, help='Max neighbor num for each side. Default: 200')
|
| 105 |
+
# parser.add_argument('--candidate-mode', type = str, default='linear', help = 'The method to generate candidate edge. Default: quadratic. (quadratic or linear)')
|
| 106 |
+
parser.add_argument('--attack-batch-size', type=int, default=256, help='Batch size for processing neighbours of target')
|
| 107 |
+
parser.add_argument('--template-mode', type=str, default = 'manual', help = 'Template mode for transforming edge to single sentense. Default: manual. (manual or auto)')
|
| 108 |
+
|
| 109 |
+
parser.add_argument('--update-lissa', action='store_true', help = 'Update lissa cache or not.')
|
| 110 |
+
|
| 111 |
+
parser.add_argument('--GPT-batch-size', type=int, default = 64, help = 'Batch size for GPT2 when calculating LM score. Default: 64')
|
| 112 |
+
parser.add_argument('--LM-softmax', action='store_true', help = 'Use a softmax head on LM prob or not.')
|
| 113 |
+
parser.add_argument('--LMprob-mode', type=str, default='relative', help = 'Use the absolute LM score or calculate the destruction score when target word is replaced. Default: absolute. (absolute or relative)')
|
| 114 |
+
|
| 115 |
+
return parser
|
| 116 |
+
|
| 117 |
+
def get_argument_parser():
|
| 118 |
+
'''Generate an argument parser'''
|
| 119 |
+
parser = argparse.ArgumentParser(description='Graph embedding')
|
| 120 |
+
|
| 121 |
+
parser.add_argument('--seed', type=int, default=1, metavar='S', help='Random seed (default: 1)')
|
| 122 |
+
|
| 123 |
+
parser.add_argument('--data', type=str, default='GNBR', help='Dataset to use: { GNBR }')
|
| 124 |
+
parser.add_argument('--model', type=str, default='distmult', help='Choose from: {distmult, complex, transe, conve}')
|
| 125 |
+
|
| 126 |
+
parser.add_argument('--transe-margin', type=float, default=0.0, help='Margin value for TransE scoring function. Default:0.0')
|
| 127 |
+
parser.add_argument('--transe-norm', type=int, default=2, help='P-norm value for TransE scoring function. Default:2')
|
| 128 |
+
|
| 129 |
+
parser.add_argument('--epochs', type=int, default=100, help='Number of epochs to train (default: 100)')
|
| 130 |
+
parser.add_argument('--lr', type=float, default=0.001, help='Learning rate (default: 0.001)')
|
| 131 |
+
parser.add_argument('--lr-decay', type=float, default=0.0, help='Weight decay value to use in the optimizer. Default: 0.0')
|
| 132 |
+
parser.add_argument('--max-norm', action='store_true', help='Option to add unit max norm constraint to entity embeddings')
|
| 133 |
+
|
| 134 |
+
parser.add_argument('--train-batch-size', type=int, default=64, help='Batch size for train split (default: 128)')
|
| 135 |
+
parser.add_argument('--test-batch-size', type=int, default=128, help='Batch size for test split (default: 128)')
|
| 136 |
+
parser.add_argument('--valid-batch-size', type=int, default=128, help='Batch size for valid split (default: 128)')
|
| 137 |
+
parser.add_argument('--KG-valid-rate', type = float, default=0.1, help='Validation rate during KG embedding training. (default: 0.1)')
|
| 138 |
+
|
| 139 |
+
parser.add_argument('--save-influence-map', action='store_true', help='Save the influence map during training for gradient rollback.')
|
| 140 |
+
parser.add_argument('--add-reciprocals', action='store_true')
|
| 141 |
+
|
| 142 |
+
parser.add_argument('--embedding-dim', type=int, default=128, help='The embedding dimension (1D). Default: 128')
|
| 143 |
+
parser.add_argument('--stack-width', type=int, default=16, help='The first dimension of the reshaped/stacked 2D embedding. Second dimension is inferred. Default: 20')
|
| 144 |
+
#parser.add_argument('--stack_height', type=int, default=10, help='The second dimension of the reshaped/stacked 2D embedding. Default: 10')
|
| 145 |
+
parser.add_argument('--hidden-drop', type=float, default=0.3, help='Dropout for the hidden layer. Default: 0.3.')
|
| 146 |
+
parser.add_argument('--input-drop', type=float, default=0.2, help='Dropout for the input embeddings. Default: 0.2.')
|
| 147 |
+
parser.add_argument('--feat-drop', type=float, default=0.3, help='Dropout for the convolutional features. Default: 0.2.')
|
| 148 |
+
parser.add_argument('-num-filters', default=32, type=int, help='Number of filters for convolution')
|
| 149 |
+
parser.add_argument('-kernel-size', default=3, type=int, help='Kernel Size for convolution')
|
| 150 |
+
|
| 151 |
+
parser.add_argument('--use-bias', action='store_true', help='Use a bias in the convolutional layer. Default: True')
|
| 152 |
+
|
| 153 |
+
parser.add_argument('--reg-weight', type=float, default=5e-2, help='Weight for regularization. Default: 5e-2')
|
| 154 |
+
parser.add_argument('--reg-norm', type=int, default=3, help='Norm for regularization. Default: 2')
|
| 155 |
+
# parser.add_argument('--resume', action='store_true', help='Restore a saved model.')
|
| 156 |
+
# parser.add_argument('--resume-split', type=str, default='test', help='Split to evaluate a restored model')
|
| 157 |
+
# parser.add_argument('--reproduce-results', action='store_true', help='Use the hyperparameters to reproduce the results.')
|
| 158 |
+
# parser.add_argument('--original-data', type=str, default='FB15k-237', help='Dataset to use; this option is needed to set the hyperparams to reproduce the results for training after attack, default: FB15k-237')
|
| 159 |
+
return parser
|
| 160 |
+
|
| 161 |
+
def set_hyperparams(args):
|
| 162 |
+
if args.model == 'distmult':
|
| 163 |
+
args.lr = 0.005
|
| 164 |
+
args.train_batch_size = 1024
|
| 165 |
+
args.reg_norm = 3
|
| 166 |
+
elif args.model == 'complex':
|
| 167 |
+
args.lr = 0.005
|
| 168 |
+
args.reg_norm = 3
|
| 169 |
+
args.input_drop = 0.4
|
| 170 |
+
args.train_batch_size = 1024
|
| 171 |
+
elif args.model == 'conve':
|
| 172 |
+
args.lr = 0.005
|
| 173 |
+
args.train_batch_size = 1024
|
| 174 |
+
args.reg_weight = 0.0
|
| 175 |
+
|
| 176 |
+
# args.damping = 0.01
|
| 177 |
+
# args.lissa_repeat = 1
|
| 178 |
+
# args.lissa_depth = 1
|
| 179 |
+
# args.scale = 500
|
| 180 |
+
# args.lissa_batch_size = 100
|
| 181 |
+
|
| 182 |
+
args.damping = 0.01
|
| 183 |
+
args.lissa_repeat = 1
|
| 184 |
+
args.lissa_depth = 1
|
| 185 |
+
args.scale = 400
|
| 186 |
+
args.lissa_batch_size = 300
|
| 187 |
+
return args
|