Spaces:
Paused
Paused
| #%% | |
| import torch | |
| import numpy as np | |
| from torch.autograd import Variable | |
| from sklearn import metrics | |
| import datetime | |
| from typing import Dict, Tuple, List | |
| import logging | |
| import os | |
| import utils | |
| import pickle as pkl | |
| import json | |
| import torch.backends.cudnn as cudnn | |
| from tqdm import tqdm | |
| import sys | |
| sys.path.append("..") | |
| import Parameters | |
| parser = utils.get_argument_parser() | |
| parser.add_argument('--reasonable-rate', type = float, default=0.7, help = 'The added edge\'s existance rank prob greater than this rate') | |
| parser.add_argument('--mode', type=str, default='sentence', help='sentence, biogpt or finetune') | |
| parser.add_argument('--init-mode', type = str, default='random', help = 'How to select target nodes') | |
| parser.add_argument('--ratio', type = str, default='', help='ratio of the number of changed words') | |
| args = parser.parse_args() | |
| args = utils.set_hyperparams(args) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| utils.seed_all(args.seed) | |
| np.set_printoptions(precision=5) | |
| cudnn.benchmark = False | |
| data_path = '../DiseaseSpecific/processed_data/GNBR' | |
| target_path = f'processed_data/target_{args.reasonable_rate}{args.init_mode}.pkl' | |
| attack_path = f'processed_data/attack_edge_{args.model}_{args.reasonable_rate}{args.init_mode}.pkl' | |
| # target_data = utils.load_data(target_path) | |
| with open(target_path, 'rb') as fl: | |
| Target_node_list = pkl.load(fl) | |
| with open(attack_path, 'rb') as fl: | |
| Attack_edge_list = pkl.load(fl) | |
| attack_data = np.array(Attack_edge_list).reshape(-1, 3) | |
| # assert target_data.shape == attack_data.shape | |
| #%% | |
| with open('../DiseaseSpecific/processed_data/GNBR/entities_reverse_dict.json') as fl: | |
| id_to_meshid = json.load(fl) | |
| with open(Parameters.GNBRfile+'entity_raw_name', 'rb') as fl: | |
| entity_raw_name = pkl.load(fl) | |
| with open(Parameters.GNBRfile+'retieve_sentence_through_edgetype', 'rb') as fl: | |
| retieve_sentence_through_edgetype = pkl.load(fl) | |
| with open(Parameters.GNBRfile+'raw_text_of_each_sentence', 'rb') as fl: | |
| raw_text_sen = pkl.load(fl) | |
| if args.mode == 'sentence': | |
| import torch | |
| from torch.nn.modules.loss import CrossEntropyLoss | |
| from transformers import AutoTokenizer | |
| from transformers import BioGptForCausalLM | |
| criterion = CrossEntropyLoss(reduction="none") | |
| print('Generating GPT input ...') | |
| tokenizer = AutoTokenizer.from_pretrained('microsoft/biogpt') | |
| tokenizer.pad_token = tokenizer.eos_token | |
| model = BioGptForCausalLM.from_pretrained('microsoft/biogpt', pad_token_id=tokenizer.eos_token_id) | |
| model.to(device) | |
| model.eval() | |
| GPT_batch_size = 24 | |
| single_sentence = {} | |
| test_text = [] | |
| test_dp = [] | |
| test_parse = [] | |
| for i, (s, r, o) in enumerate(tqdm(attack_data)): | |
| s = str(s) | |
| r = str(r) | |
| o = str(o) | |
| if int(s) != -1: | |
| dependency_sen_dict = retieve_sentence_through_edgetype[int(r)]['manual'] | |
| candidate_sen = [] | |
| Dp_path = [] | |
| L = len(dependency_sen_dict.keys()) | |
| bound = 500 // L | |
| if bound == 0: | |
| bound = 1 | |
| for dp_path, sen_list in dependency_sen_dict.items(): | |
| if len(sen_list) > bound: | |
| index = np.random.choice(np.array(range(len(sen_list))), bound, replace=False) | |
| sen_list = [sen_list[aa] for aa in index] | |
| candidate_sen += sen_list | |
| Dp_path += [dp_path] * len(sen_list) | |
| text_s = entity_raw_name[id_to_meshid[s]] | |
| text_o = entity_raw_name[id_to_meshid[o]] | |
| candidate_text_sen = [] | |
| candidate_ori_sen = [] | |
| candidate_parse_sen = [] | |
| for paper_id, sen_id in candidate_sen: | |
| sen = raw_text_sen[paper_id][sen_id] | |
| text = sen['text'] | |
| candidate_ori_sen.append(text) | |
| ss = sen['start_formatted'] | |
| oo = sen['end_formatted'] | |
| text = text.replace('-LRB-', '(') | |
| text = text.replace('-RRB-', ')') | |
| text = text.replace('-LSB-', '[') | |
| text = text.replace('-RSB-', ']') | |
| text = text.replace('-LCB-', '{') | |
| text = text.replace('-RCB-', '}') | |
| parse_text = text | |
| parse_text = parse_text.replace(ss, text_s.replace(' ', '_')) | |
| parse_text = parse_text.replace(oo, text_o.replace(' ', '_')) | |
| text = text.replace(ss, text_s) | |
| text = text.replace(oo, text_o) | |
| text = text.replace('_', ' ') | |
| candidate_text_sen.append(text) | |
| candidate_parse_sen.append(parse_text) | |
| tokens = tokenizer( candidate_text_sen, | |
| truncation = True, | |
| padding = True, | |
| max_length = 300, | |
| return_tensors="pt") | |
| target_ids = tokens['input_ids'].to(device) | |
| attention_mask = tokens['attention_mask'].to(device) | |
| L = len(candidate_text_sen) | |
| assert L > 0 | |
| ret_log_L = [] | |
| for l in range(0, L, GPT_batch_size): | |
| R = min(L, l + GPT_batch_size) | |
| target = target_ids[l:R, :] | |
| attention = attention_mask[l:R, :] | |
| outputs = model(input_ids = target, | |
| attention_mask = attention, | |
| labels = target) | |
| logits = outputs.logits | |
| shift_logits = logits[..., :-1, :].contiguous() | |
| shift_labels = target[..., 1:].contiguous() | |
| Loss = criterion(shift_logits.view(-1, shift_logits.shape[-1]), shift_labels.view(-1)) | |
| Loss = Loss.view(-1, shift_logits.shape[1]) | |
| attention = attention[..., 1:].contiguous() | |
| log_Loss = (torch.mean(Loss * attention.float(), dim = 1) / torch.mean(attention.float(), dim = 1)) | |
| ret_log_L.append(log_Loss.detach()) | |
| ret_log_L = list(torch.cat(ret_log_L, -1).cpu().numpy()) | |
| sen_score = list(zip(candidate_text_sen, ret_log_L, candidate_ori_sen, Dp_path, candidate_parse_sen)) | |
| sen_score.sort(key = lambda x: x[1]) | |
| test_text.append(sen_score[0][2]) | |
| test_dp.append(sen_score[0][3]) | |
| test_parse.append(sen_score[0][4]) | |
| single_sentence.update({f'{s}_{r}_{o}_{i}': sen_score[0][0]}) | |
| else: | |
| single_sentence.update({f'{s}_{r}_{o}_{i}': ''}) | |
| with open(f'generate_abstract/{args.init_mode}{args.reasonable_rate}_sentence.json', 'w') as fl: | |
| json.dump(single_sentence, fl, indent=4) | |
| with open (f'generate_abstract/path/{args.init_mode}{args.reasonable_rate}_path.json', 'w') as fl: | |
| fl.write('\n'.join(test_dp)) | |
| with open (f'generate_abstract/path/{args.init_mode}{args.reasonable_rate}_temp.json', 'w') as fl: | |
| fl.write('\n'.join(test_text)) | |
| elif args.mode == 'finetune': | |
| import spacy | |
| import pprint | |
| from transformers import AutoModel, AutoTokenizer,BartForConditionalGeneration | |
| print('Finetuning ...') | |
| with open(f'generate_abstract/{args.init_mode}{args.reasonable_rate}_chat.json', 'r') as fl: | |
| draft = json.load(fl) | |
| with open (f'generate_abstract/path/{args.init_mode}{args.reasonable_rate}_path.json', 'r') as fl: | |
| dpath = fl.readlines() | |
| nlp = spacy.load("en_core_web_sm") | |
| if os.path.exists(f'generate_abstract/bioBART/{args.init_mode}{args.reasonable_rate}{args.ratio}_candidates.json'): | |
| with open(f'generate_abstract/bioBART/{args.init_mode}{args.reasonable_rate}{args.ratio}_candidates.json', 'r') as fl: | |
| ret_candidates = json.load(fl) | |
| else: | |
| def find_mini_span(vec, words, check_set): | |
| def cal(text, sset): | |
| add = 0 | |
| for tt in sset: | |
| if tt in text: | |
| add += 1 | |
| return add | |
| text = ' '.join(words) | |
| max_add = cal(text, check_set) | |
| minn = 10000000 | |
| span = '' | |
| rc = None | |
| for i in range(len(vec)): | |
| if vec[i] == True: | |
| p = -1 | |
| for j in range(i+1, len(vec)+1): | |
| if vec[j-1] == True: | |
| text = ' '.join(words[i:j]) | |
| if cal(text, check_set) == max_add: | |
| p = j | |
| break | |
| if p > 0: | |
| if (p-i) < minn: | |
| minn = p-i | |
| span = ' '.join(words[i:p]) | |
| rc = (i, p) | |
| if rc: | |
| for i in range(rc[0], rc[1]): | |
| vec[i] = True | |
| return vec, span | |
| def mask_func(tokenized_sen): | |
| if len(tokenized_sen) == 0: | |
| return [] | |
| token_list = [] | |
| # for sen in tokenized_sen: | |
| # for token in sen: | |
| # token_list.append(token) | |
| for sen in tokenized_sen: | |
| token_list += sen.text.split(' ') | |
| if args.ratio == '': | |
| P = 0.3 | |
| else: | |
| P = float(args.ratio) | |
| ret_list = [] | |
| i = 0 | |
| mask_num = 0 | |
| while i < len(token_list): | |
| t = token_list[i] | |
| if '.' in t or '(' in t or ')' in t or '[' in t or ']' in t: | |
| ret_list.append(t) | |
| i += 1 | |
| mask_num = 0 | |
| else: | |
| length = np.random.poisson(3) | |
| if np.random.rand() < P and length > 0: | |
| if mask_num < 8: | |
| ret_list.append('<mask>') | |
| mask_num += 1 | |
| i += length | |
| else: | |
| ret_list.append(t) | |
| i += 1 | |
| mask_num = 0 | |
| return [' '.join(ret_list)] | |
| model = BartForConditionalGeneration.from_pretrained('GanjinZero/biobart-large') | |
| model.eval() | |
| model.to(device) | |
| tokenizer = AutoTokenizer.from_pretrained('GanjinZero/biobart-large') | |
| ret_candidates = {} | |
| dpath_i = 0 | |
| for i,(k, v) in enumerate(tqdm(draft.items())): | |
| input = v['in'].replace('\n', '') | |
| output = v['out'].replace('\n', '') | |
| s, r, o = attack_data[i] | |
| s = str(s) | |
| o = str(o) | |
| r = str(r) | |
| if int(s) == -1: | |
| ret_candidates[str(i)] = {'span': '', 'prompt' : '', 'out' : [], 'in': [], 'assist': []} | |
| continue | |
| path_text = dpath[dpath_i].replace('\n', '') | |
| dpath_i += 1 | |
| text_s = entity_raw_name[id_to_meshid[s]] | |
| text_o = entity_raw_name[id_to_meshid[o]] | |
| doc = nlp(output) | |
| words= input.split(' ') | |
| tokenized_sens = [sen for sen in doc.sents] | |
| sens = np.array([sen.text for sen in doc.sents]) | |
| checkset = set([text_s, text_o]) | |
| e_entity = set(['start_entity', 'end_entity']) | |
| for path in path_text.split(' '): | |
| a, b, c = path.split('|') | |
| if a not in e_entity: | |
| checkset.add(a) | |
| if c not in e_entity: | |
| checkset.add(c) | |
| vec = [] | |
| l = 0 | |
| while(l < len(words)): | |
| bo =False | |
| for j in range(len(words), l, -1): # reversing is important !!! | |
| cc = ' '.join(words[l:j]) | |
| if (cc in checkset): | |
| vec += [True] * (j-l) | |
| l = j | |
| bo = True | |
| break | |
| if not bo: | |
| vec.append(False) | |
| l += 1 | |
| vec, span = find_mini_span(vec, words, checkset) | |
| # vec = np.vectorize(lambda x: x in checkset)(words) | |
| vec[-1] = True | |
| prompt = [] | |
| mask_num = 0 | |
| for j, bo in enumerate(vec): | |
| if not bo: | |
| mask_num += 1 | |
| else: | |
| if mask_num > 0: | |
| # mask_num = mask_num // 3 # span length ~ poisson distribution (lambda = 3) | |
| mask_num = max(mask_num, 1) | |
| mask_num= min(8, mask_num) | |
| prompt += ['<mask>'] * mask_num | |
| prompt.append(words[j]) | |
| mask_num = 0 | |
| prompt = ' '.join(prompt) | |
| Text = [] | |
| Assist = [] | |
| for j in range(len(sens)): | |
| Bart_input = list(sens[:j]) + [prompt] +list(sens[j+1:]) | |
| assist = list(sens[:j]) + [input] +list(sens[j+1:]) | |
| Text.append(' '.join(Bart_input)) | |
| Assist.append(' '.join(assist)) | |
| for j in range(len(sens)): | |
| Bart_input = mask_func(tokenized_sens[:j]) + [input] + mask_func(tokenized_sens[j+1:]) | |
| assist = list(sens[:j]) + [input] +list(sens[j+1:]) | |
| Text.append(' '.join(Bart_input)) | |
| Assist.append(' '.join(assist)) | |
| batch_size = len(Text) // 2 | |
| Outs = [] | |
| for l in range(2): | |
| A = tokenizer(Text[batch_size * l:batch_size * (l+1)], | |
| truncation = True, | |
| padding = True, | |
| max_length = 1024, | |
| return_tensors="pt") | |
| input_ids = A['input_ids'].to(device) | |
| attention_mask = A['attention_mask'].to(device) | |
| aaid = model.generate(input_ids, num_beams = 5, max_length = 1024) | |
| outs = tokenizer.batch_decode(aaid, skip_special_tokens=True, clean_up_tokenization_spaces=False) | |
| Outs += outs | |
| ret_candidates[str(i)] = {'span': span, 'prompt' : prompt, 'out' : Outs, 'in': Text, 'assist': Assist} | |
| with open(f'generate_abstract/bioBART/{args.init_mode}{args.reasonable_rate}{args.ratio}_candidates.json', 'w') as fl: | |
| json.dump(ret_candidates, fl, indent = 4) | |
| from torch.nn.modules.loss import CrossEntropyLoss | |
| from transformers import BioGptForCausalLM | |
| criterion = CrossEntropyLoss(reduction="none") | |
| tokenizer = AutoTokenizer.from_pretrained('microsoft/biogpt') | |
| tokenizer.pad_token = tokenizer.eos_token | |
| model = BioGptForCausalLM.from_pretrained('microsoft/biogpt', pad_token_id=tokenizer.eos_token_id) | |
| model.to(device) | |
| model.eval() | |
| scored = {} | |
| ret = {} | |
| case_study = {} | |
| p_ret = {} | |
| dpath_i = 0 | |
| for i,(k, v) in enumerate(tqdm(draft.items())): | |
| span = ret_candidates[str(i)]['span'] | |
| prompt = ret_candidates[str(i)]['prompt'] | |
| sen_list = ret_candidates[str(i)]['out'] | |
| BART_in = ret_candidates[str(i)]['in'] | |
| Assist = ret_candidates[str(i)]['assist'] | |
| s, r, o = attack_data[i] | |
| s = str(s) | |
| r = str(r) | |
| o = str(o) | |
| if int(s) == -1: | |
| ret[k] = {'prompt': '', 'in':'', 'out': ''} | |
| p_ret[k] = {'prompt': '', 'in':'', 'out': ''} | |
| continue | |
| text_s = entity_raw_name[id_to_meshid[s]] | |
| text_o = entity_raw_name[id_to_meshid[o]] | |
| def process(text): | |
| for i in range(ord('A'), ord('Z')+1): | |
| text = text.replace(f'.{chr(i)}', f'. {chr(i)}') | |
| return text | |
| sen_list = [process(text) for text in sen_list] | |
| path_text = dpath[dpath_i].replace('\n', '') | |
| dpath_i += 1 | |
| checkset = set([text_s, text_o]) | |
| e_entity = set(['start_entity', 'end_entity']) | |
| for path in path_text.split(' '): | |
| a, b, c = path.split('|') | |
| if a not in e_entity: | |
| checkset.add(a) | |
| if c not in e_entity: | |
| checkset.add(c) | |
| input = v['in'].replace('\n', '') | |
| output = v['out'].replace('\n', '') | |
| doc = nlp(output) | |
| gpt_sens = [sen.text for sen in doc.sents] | |
| assert len(gpt_sens) == len(sen_list) // 2 | |
| word_sets = [] | |
| for sen in gpt_sens: | |
| word_sets.append(set(sen.split(' '))) | |
| def sen_align(word_sets, modified_word_sets): | |
| l = 0 | |
| while(l < len(modified_word_sets)): | |
| if len(word_sets[l].intersection(modified_word_sets[l])) > len(word_sets[l]) * 0.8: | |
| l += 1 | |
| else: | |
| break | |
| if l == len(modified_word_sets): | |
| return -1, -1, -1, -1 | |
| r = l + 1 | |
| r1 = None | |
| r2 = None | |
| for pos1 in range(r, len(word_sets)): | |
| for pos2 in range(r, len(modified_word_sets)): | |
| if len(word_sets[pos1].intersection(modified_word_sets[pos2])) > len(word_sets[pos1]) * 0.8: | |
| r1 = pos1 | |
| r2 = pos2 | |
| break | |
| if r1 is not None: | |
| break | |
| if r1 is None: | |
| r1 = len(word_sets) | |
| r2 = len(modified_word_sets) | |
| return l, r1, l, r2 | |
| replace_sen_list = [] | |
| boundary = [] | |
| assert len(sen_list) % 2 == 0 | |
| for j in range(len(sen_list) // 2): | |
| doc = nlp(sen_list[j]) | |
| sens = [sen.text for sen in doc.sents] | |
| modified_word_sets = [set(sen.split(' ')) for sen in sens] | |
| l1, r1, l2, r2 = sen_align(word_sets, modified_word_sets) | |
| boundary.append((l1, r1, l2, r2)) | |
| if l1 == -1: | |
| replace_sen_list.append(sen_list[j]) | |
| continue | |
| check_text = ' '.join(sens[l2: r2]) | |
| replace_sen_list.append(' '.join(gpt_sens[:l1] + [check_text] + gpt_sens[r1:])) | |
| sen_list = replace_sen_list + sen_list[len(sen_list) // 2:] | |
| old_L = len(sen_list) | |
| sen_list.append(output) | |
| sen_list += Assist | |
| tokens = tokenizer( sen_list, | |
| truncation = True, | |
| padding = True, | |
| max_length = 1024, | |
| return_tensors="pt") | |
| target_ids = tokens['input_ids'].to(device) | |
| attention_mask = tokens['attention_mask'].to(device) | |
| L = len(sen_list) | |
| ret_log_L = [] | |
| for l in range(0, L, 5): | |
| R = min(L, l + 5) | |
| target = target_ids[l:R, :] | |
| attention = attention_mask[l:R, :] | |
| outputs = model(input_ids = target, | |
| attention_mask = attention, | |
| labels = target) | |
| logits = outputs.logits | |
| shift_logits = logits[..., :-1, :].contiguous() | |
| shift_labels = target[..., 1:].contiguous() | |
| Loss = criterion(shift_logits.view(-1, shift_logits.shape[-1]), shift_labels.view(-1)) | |
| Loss = Loss.view(-1, shift_logits.shape[1]) | |
| attention = attention[..., 1:].contiguous() | |
| log_Loss = (torch.mean(Loss * attention.float(), dim = 1) / torch.mean(attention.float(), dim = 1)) | |
| ret_log_L.append(log_Loss.detach()) | |
| log_Loss = torch.cat(ret_log_L, -1).cpu().numpy() | |
| real_log_Loss = log_Loss.copy() | |
| log_Loss = log_Loss[:old_L] | |
| # sen_list = sen_list[:old_L] | |
| p = np.argmin(log_Loss) | |
| content = [] | |
| for i in range(len(real_log_Loss)): | |
| content.append([sen_list[i], str(real_log_Loss[i])]) | |
| scored[k] = {'path':path_text, 'prompt': prompt, 'in':input, 's':text_s, 'o':text_o, 'out': content, 'bound': boundary} | |
| p_p = p | |
| if real_log_Loss[p] > real_log_Loss[p+1+old_L]: | |
| p_p = p+1+old_L | |
| if real_log_Loss[p] > real_log_Loss[old_L]: | |
| if real_log_Loss[p] > real_log_Loss[p+1+old_L]: | |
| p = p+1+old_L | |
| # case_study[k] = {'path':path_text, 'entity_0': text_s, 'entity_1': text_o, 'GPT_in': input, 'Prompt': prompt, 'GPT_out': {'text': output, 'perplexity': str(np.exp(real_log_Loss[old_L]))}, 'BART_in': BART_in[p], 'BART_out': {'text': sen_list[p], 'perplexity': str(np.exp(real_log_Loss[p]))}, 'Assist': {'text': Assist[p], 'perplexity': str(np.exp(real_log_Loss[p+1+old_L]))}} | |
| ret[k] = {'prompt': prompt, 'in':input, 'out': sen_list[p]} | |
| with open(f'generate_abstract/{args.init_mode}{args.reasonable_rate}{args.ratio}_bioBART_finetune.json', 'w') as fl: | |
| json.dump(ret, fl, indent=4) | |
| with open(f'generate_abstract/bioBART/{args.init_mode}{args.reasonable_rate}{args.ratio}_scored.json', 'w') as fl: | |
| json.dump(scored, fl, indent=4) | |
| else: | |
| raise Exception('Wrong mode !!') |