|
from __future__ import absolute_import |
|
import os |
|
from statistics import mean |
|
import sys |
|
from xml.sax.handler import feature_external_ges |
|
|
|
import pickle |
|
import torch |
|
import csv |
|
import json |
|
import random |
|
import time |
|
import logging |
|
import argparse |
|
|
|
import numpy as np |
|
from io import open |
|
from itertools import cycle |
|
import torch.nn as nn |
|
from model_gen import Seq2Seq |
|
from tqdm import tqdm, trange |
|
from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler, TensorDataset |
|
from torch.utils.data.distributed import DistributedSampler |
|
|
|
from transformers import (WEIGHTS_NAME, AdamW, get_linear_schedule_with_warmup, |
|
RobertaConfig, RobertaModel, RobertaTokenizer) |
|
|
|
import pathlib |
|
|
|
folder = str(pathlib.Path(__file__).parent.resolve()) |
|
|
|
|
|
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', |
|
datefmt='%m/%d/%Y %H:%M:%S', |
|
level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
divide_number = 6 |
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3,4,5,6,7" |
|
|
|
class Example(object): |
|
"""A single training/test example.""" |
|
|
|
def __init__(self, |
|
idx, |
|
source, |
|
target, |
|
cpuname, |
|
funcname, |
|
filename, |
|
property, |
|
vec, |
|
exist, |
|
module |
|
): |
|
self.idx = idx |
|
self.source = source |
|
self.target = target |
|
self.cpuname = cpuname |
|
self.funcname = funcname |
|
self.filename = filename |
|
self.property = property |
|
self.vec = vec |
|
self.exist = exist |
|
self.module = module |
|
|
|
|
|
def read_examples_no_bracket(filename, is_function_test): |
|
"""Read examples from filename.""" |
|
examples = [] |
|
with open(filename, encoding="utf-8") as f: |
|
for idx, line in enumerate(f): |
|
if is_function_test: |
|
if idx > 212: |
|
break |
|
line = line.strip() |
|
js = json.loads(line) |
|
if js["Stmt"].strip()[0] == "}": |
|
continue |
|
if js["Value"].strip().lower() == "nothing" and '#' in js['FIR']: |
|
continue |
|
if '1' in js['Vector'][-97:] and '#' not in js['FIR']: |
|
continue |
|
if 'idx' not in js: |
|
js['idx'] = idx |
|
code = ' '.join(js['FIR_token']).replace('\n', ' ') |
|
code = ' '.join(code.strip().split()) |
|
nl = ' '.join(js['Stmt_token']).replace('\n', ' ') |
|
nl = ' '.join(nl.strip().split()) |
|
if str(js['Exist']).lower() != "true" and str(js['Exist']).lower() != "false": |
|
if int(round(float(js['Exist']))) == 1: |
|
exist = 1 |
|
elif js["Value"].strip().lower() != "nothing": |
|
exist = 1 |
|
else: |
|
exist = 0 |
|
else: |
|
if js['Exist'].lower() == "true": |
|
exist = 1 |
|
else: |
|
exist = 0 |
|
tem = list(js['Vector'].replace("|zm|","")) |
|
vec = [] |
|
for t in tem: |
|
if int(t) == 1: |
|
vec.append(1) |
|
else: |
|
vec.append(0) |
|
pro = ' '.join(js['Value_token']).replace('\n', ' ') |
|
pro = ' '.join(pro.strip().split()) |
|
|
|
cpu = js['Target'] |
|
func = js['Func'] |
|
file = js['File'] |
|
mod = "" |
|
if "Module" in js.keys(): |
|
mod = js["Module"] |
|
examples.append( |
|
Example( |
|
idx=idx, |
|
source=code, |
|
target=nl, |
|
cpuname=cpu, |
|
funcname=func, |
|
filename=file, |
|
property=pro, |
|
vec=vec, |
|
exist=exist, |
|
module = mod |
|
|
|
) |
|
) |
|
return examples |
|
|
|
|
|
def read_examples(filename, is_function_test): |
|
"""Read examples from filename.""" |
|
examples = [] |
|
with open(filename, encoding="utf-8") as f: |
|
for idx, line in enumerate(f): |
|
if is_function_test: |
|
if idx > 212: |
|
break |
|
line = line.strip() |
|
js = json.loads(line) |
|
if 'idx' not in js: |
|
js['idx'] = idx |
|
code = ' '.join(js['FIR_token']).replace('\n', ' ') |
|
code = ' '.join(code.strip().split()) |
|
nl = ' '.join(js['Stmt_token']).replace('\n', ' ') |
|
nl = ' '.join(nl.strip().split()) |
|
if str(js['Exist']).lower() != "true" and str(js['Exist']).lower() != "false": |
|
if int(round(float(js['Exist']))) == 1: |
|
exist = 1 |
|
elif js["Value"].strip().lower() != "nothing": |
|
exist = 1 |
|
else: |
|
exist = 0 |
|
else: |
|
if js['Exist'].lower() == "true": |
|
exist = 1 |
|
else: |
|
exist = 0 |
|
tem = list(js['Vector'].replace("|zm|","")) |
|
vec = [] |
|
for t in tem: |
|
if int(t) == 1: |
|
vec.append(1) |
|
else: |
|
vec.append(0) |
|
pro = ' '.join(js['Value_token']).replace('\n', ' ') |
|
pro = ' '.join(pro.strip().split()) |
|
|
|
cpu = js['Target'] |
|
func = js['Func'] |
|
file = js['File'] |
|
mod = "" |
|
if "Module" in js.keys(): |
|
mod = js["Module"] |
|
examples.append( |
|
Example( |
|
idx=idx, |
|
source=code, |
|
target=nl, |
|
cpuname=cpu, |
|
funcname=func, |
|
filename=file, |
|
property=pro, |
|
vec=vec, |
|
exist=exist, |
|
module = mod |
|
|
|
) |
|
) |
|
return examples |
|
|
|
|
|
class InputFeatures(object): |
|
"""A single training/test features for a example.""" |
|
|
|
def __init__(self, |
|
example_id, |
|
source_ids, |
|
exist, |
|
target_ids, |
|
): |
|
self.example_id = example_id |
|
self.source_ids = source_ids |
|
self.exist = exist |
|
self.target_ids = target_ids |
|
|
|
|
|
def convert_examples_to_features(examples, tokenizer, args, stage=None): |
|
"""convert examples to token ids""" |
|
features = [] |
|
for example_index, example in enumerate(examples): |
|
|
|
func_tokens = tokenizer.tokenize(example.funcname) |
|
source_tokens = tokenizer.tokenize( |
|
example.source) |
|
pro_tokens = tokenizer.tokenize(example.property) |
|
vec_tokens = example.vec |
|
source_tokens = [tokenizer.cls_token, "<encoder-decoder>", tokenizer.sep_token, "<mask0>"] + func_tokens + [tokenizer.cls_token] + \ |
|
source_tokens + [tokenizer.cls_token] + pro_tokens + \ |
|
[tokenizer.cls_token] + vec_tokens + [tokenizer.sep_token] |
|
source_ids = tokenizer.convert_tokens_to_ids(source_tokens) |
|
padding_length = args.max_source_length - len(source_ids) |
|
source_ids += [tokenizer.pad_token_id] * padding_length |
|
|
|
target_tokens = tokenizer.tokenize(example.target) |
|
exist = [example.exist] |
|
target_tokens = [tokenizer.cls_token, "<mask0>"] + \ |
|
target_tokens + [tokenizer.sep_token] |
|
target_ids = tokenizer.convert_tokens_to_ids(target_tokens) |
|
padding_length = args.max_target_length - len(target_ids) |
|
target_ids += [tokenizer.pad_token_id] * padding_length |
|
|
|
features.append( |
|
InputFeatures( |
|
example_index, |
|
source_ids, |
|
exist, |
|
target_ids, |
|
) |
|
) |
|
return features |
|
|
|
|
|
def set_seed(seed=991105): |
|
random.seed(seed) |
|
os.environ['PYHTONHASHSEED'] = str(seed) |
|
np.random.seed(seed) |
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed(seed) |
|
torch.backends.cudnn.deterministic = True |
|
|
|
def is_valid_parentheses(s): |
|
cnt_bracket_small = 0 |
|
cnt_bracket_mid = 0 |
|
cnt_bracket_large = 0 |
|
new_s = "" |
|
for p in s: |
|
new_s += p |
|
if p == "(": |
|
cnt_bracket_small += 1 |
|
if p == ")": |
|
cnt_bracket_small -= 1 |
|
if p == "[": |
|
cnt_bracket_mid += 1 |
|
if p == "]": |
|
cnt_bracket_mid -= 1 |
|
if p == "{": |
|
cnt_bracket_large += 1 |
|
if p == "}": |
|
cnt_bracket_large -= 1 |
|
if cnt_bracket_small < 0: |
|
cnt_bracket_small = 0 |
|
new_s = new_s[:-1] |
|
|
|
if cnt_bracket_mid < 0: |
|
cnt_bracket_mid = 0 |
|
new_s = new_s[:-1] |
|
|
|
if cnt_bracket_large < 0: |
|
cnt_bracket_large = 0 |
|
new_s = new_s[:-1] |
|
|
|
return new_s |
|
|
|
|
|
def rewrite_pred(pred, gt_pred, gt_source, gt_value): |
|
re_pred = pred |
|
if is_valid_parentheses(pred).replace(" ", "") == gt_pred.replace(" ", ""): |
|
return True, is_valid_parentheses(re_pred) |
|
if "zmtarzm" in gt_value and gt_source.replace("#", gt_value).replace(" ", "") == gt_pred.replace(" ", ""): |
|
return True, gt_source.replace("#", gt_value) |
|
return False, re_pred |
|
|
|
|
|
def vega_train_main(): |
|
parser = argparse.ArgumentParser() |
|
|
|
|
|
parser.add_argument("--model_name_or_path", default=None, type=str, required=True, |
|
help="Path to pre-trained model: e.g. roberta-base") |
|
parser.add_argument("--output_dir", default=None, type=str, required=True, |
|
help="The output directory where the model predictions and checkpoints will be written.") |
|
|
|
|
|
parser.add_argument("--train_filename", default=None, type=str, |
|
help="The train filename. Should contain the .jsonl files for this task.") |
|
parser.add_argument("--dev_filename", default=None, type=str, |
|
help="The dev filename. Should contain the .jsonl files for this task.") |
|
parser.add_argument("--test_filename", default=None, type=str, |
|
help="The test filename. Should contain the .jsonl files for this task.") |
|
parser.add_argument("--max_source_length", default=590, type=int, |
|
help="The maximum total source sequence length after tokenization. Sequences longer " |
|
"than this will be truncated, sequences shorter will be padded.") |
|
parser.add_argument("--max_target_length", default=240, type=int, |
|
help="The maximum total target sequence length after tokenization. Sequences longer " |
|
"than this will be truncated, sequences shorter will be padded.") |
|
parser.add_argument("--do_train", action='store_true', |
|
help="Whether to run training.") |
|
parser.add_argument("--do_eval", action='store_true', |
|
help="Whether to run eval on the dev set.") |
|
parser.add_argument("--do_test", action='store_true', |
|
help="Whether to run eval on the dev set.") |
|
parser.add_argument("--do_function_test", action='store_true', |
|
help="Whether to run eval on the subset of the dev set.") |
|
parser.add_argument("--no_cuda", action='store_true', |
|
help="Avoid using CUDA when available") |
|
|
|
parser.add_argument("--train_batch_size", default=8, type=int, |
|
help="Batch size per GPU/CPU for training.") |
|
parser.add_argument("--eval_batch_size", default=8, type=int, |
|
help="Batch size per GPU/CPU for evaluation.") |
|
parser.add_argument('--gradient_accumulation_steps', type=int, default=1, |
|
help="Number of updates steps to accumulate before performing a backward/update pass.") |
|
parser.add_argument("--learning_rate", default=6e-5, type=float, |
|
help="The initial learning rate for Adam.") |
|
parser.add_argument("--beam_size", default=1, type=int, |
|
help="beam size for beam search") |
|
parser.add_argument("--weight_decay", default=0.0, type=float, |
|
help="Weight deay if we apply some.") |
|
parser.add_argument("--adam_epsilon", default=1e-8, type=float, |
|
help="Epsilon for Adam optimizer.") |
|
parser.add_argument("--max_grad_norm", default=1.0, type=float, |
|
help="Max gradient norm.") |
|
parser.add_argument("--num_train_epochs", default=30, type=int, |
|
help="Total number of training epochs to perform.") |
|
parser.add_argument('--seed', type=int, default=20230420, |
|
help="random seed for initialization") |
|
|
|
parser.add_argument("--mse_loss_weight", default=0.9, type=float, |
|
help="Weight of Mean Square Error Loss.") |
|
parser.add_argument("--ce_loss_weight", default=0.1, type=float, |
|
help="Weight of Cross Entropy Loss.") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
args.n_gpu = torch.cuda.device_count() |
|
args.device = device |
|
logger.info("device: %s, n_gpu: %s", device, args.n_gpu) |
|
|
|
|
|
set_seed(args.seed) |
|
|
|
|
|
args.output_dir = folder + "/" + args.output_dir |
|
if os.path.exists(args.output_dir) is False: |
|
os.makedirs(args.output_dir) |
|
args.model_name_or_path = folder + "/" + args.model_name_or_path |
|
if args.train_filename: |
|
args.train_filename = folder + "/" + args.train_filename |
|
if args.dev_filename: |
|
args.dev_filename = folder + "/" + args.dev_filename |
|
if args.test_filename: |
|
args.test_filename = folder + "/" + args.test_filename |
|
|
|
tokenizer = RobertaTokenizer.from_pretrained(args.model_name_or_path) |
|
config = RobertaConfig.from_pretrained(args.model_name_or_path) |
|
|
|
config.is_decoder = True |
|
encoder = RobertaModel.from_pretrained( |
|
args.model_name_or_path, config=config) |
|
|
|
model = Seq2Seq(encoder=encoder, decoder=encoder, config=config, |
|
mse_loss_weight=args.mse_loss_weight, ce_loss_weight=args.ce_loss_weight, |
|
beam_size=args.beam_size, max_length=args.max_target_length, |
|
sos_id=tokenizer.convert_tokens_to_ids(["<mask0>"])[0], eos_id=tokenizer.sep_token_id) |
|
|
|
model.to(args.device) |
|
|
|
if args.n_gpu > 1: |
|
|
|
model = torch.nn.DataParallel(model) |
|
|
|
if args.do_train: |
|
|
|
all_examples = read_examples(args.train_filename, False) |
|
train_examples = read_examples_no_bracket(args.train_filename, False) |
|
train_features = convert_examples_to_features( |
|
train_examples, tokenizer, args, stage='train') |
|
all_source_ids = torch.tensor( |
|
[f.source_ids for f in train_features], dtype=torch.long) |
|
all_exists = torch.tensor( |
|
[f.exist for f in train_features], dtype=torch.float32) |
|
all_target_ids = torch.tensor( |
|
[f.target_ids for f in train_features], dtype=torch.long) |
|
train_data = TensorDataset(all_source_ids, all_exists, all_target_ids) |
|
train_sampler = RandomSampler(train_data) |
|
train_dataloader = DataLoader(train_data, sampler=train_sampler, |
|
batch_size=args.train_batch_size // args.gradient_accumulation_steps) |
|
|
|
|
|
no_decay = ['bias', 'LayerNorm.weight'] |
|
optimizer_grouped_parameters = [ |
|
{'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], |
|
'weight_decay': args.weight_decay}, |
|
{'params': [p for n, p in model.named_parameters() if any( |
|
nd in n for nd in no_decay)], 'weight_decay': 0.0} |
|
] |
|
optimizer = AdamW(optimizer_grouped_parameters, |
|
lr=args.learning_rate, eps=args.adam_epsilon) |
|
scheduler = get_linear_schedule_with_warmup(optimizer, |
|
num_warmup_steps=int( |
|
len(train_dataloader)*args.num_train_epochs*0.1), |
|
num_training_steps=len(train_dataloader)*args.num_train_epochs) |
|
|
|
|
|
logger.info("***** Running training *****") |
|
logger.info(" Num examples = %d", len(all_examples)) |
|
logger.info(" Batch size = %d", args.train_batch_size * |
|
args.gradient_accumulation_steps) |
|
logger.info(" Num epoch = %d", args.num_train_epochs) |
|
|
|
model.train() |
|
eval_examples_all = read_examples(args.dev_filename, False) |
|
total_eval_all = len(eval_examples_all) |
|
patience, best_acc, losses, dev_dataset = 0, 0, [], {} |
|
for epoch in tqdm(range(args.num_train_epochs)): |
|
for idx, batch in enumerate(train_dataloader): |
|
batch = tuple(t.to(device) for t in batch) |
|
source_ids, exist, target_ids = batch |
|
loss, _, _, mse_loss, ce_loss = model( |
|
source_ids=source_ids, exist=exist, target_ids=target_ids) |
|
|
|
if args.n_gpu > 1: |
|
loss = loss.mean() |
|
if args.gradient_accumulation_steps > 1: |
|
loss = loss / args.gradient_accumulation_steps |
|
|
|
losses.append(loss.item()) |
|
loss.backward() |
|
if len(losses) % args.gradient_accumulation_steps == 0: |
|
|
|
optimizer.step() |
|
optimizer.zero_grad() |
|
scheduler.step() |
|
if len(losses) // args.gradient_accumulation_steps % 100 == 0: |
|
logger.info("epoch {} step {} loss {}".format(epoch, |
|
len( |
|
losses)//args.gradient_accumulation_steps, |
|
round(np.mean(losses[-100*args.gradient_accumulation_steps:]), 4))) |
|
if args.do_eval: |
|
|
|
if 'dev_loss' in dev_dataset: |
|
eval_examples, eval_data = dev_dataset['dev_loss'] |
|
else: |
|
eval_examples = read_examples_no_bracket(args.dev_filename, False) |
|
eval_features = convert_examples_to_features( |
|
eval_examples, tokenizer, args, stage='dev') |
|
all_source_ids = torch.tensor( |
|
[f.source_ids for f in eval_features], dtype=torch.long) |
|
all_exists = torch.tensor( |
|
[f.exist for f in eval_features], dtype=torch.float32) |
|
all_target_ids = torch.tensor( |
|
[f.target_ids for f in eval_features], dtype=torch.long) |
|
eval_data = TensorDataset( |
|
all_source_ids, all_exists, all_target_ids) |
|
dev_dataset['dev_loss'] = eval_examples, eval_data |
|
eval_sampler = SequentialSampler(eval_data) |
|
eval_dataloader = DataLoader( |
|
eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size) |
|
|
|
logger.info("***** Running evaluation *****") |
|
logger.info(" Num examples = %d", total_eval_all) |
|
logger.info(" Batch size = %d", args.eval_batch_size) |
|
|
|
|
|
model.eval() |
|
eval_loss, tokens_num = 0, 0 |
|
for batch in eval_dataloader: |
|
batch = tuple(t.to(device) for t in batch) |
|
source_ids, exist, target_ids = batch |
|
|
|
with torch.no_grad(): |
|
_, loss, num, _, _ = model( |
|
source_ids=source_ids, exist=exist, target_ids=target_ids) |
|
eval_loss += loss.sum().item() |
|
tokens_num += num.sum().item() |
|
|
|
model.train() |
|
eval_loss = eval_loss / tokens_num |
|
result = {'eval_ppl': round(np.exp(eval_loss), 5)} |
|
for key in sorted(result.keys()): |
|
logger.info(" %s = %s", key, str(result[key])) |
|
logger.info(" " + "*" * 20) |
|
|
|
|
|
if 'dev_acc' in dev_dataset: |
|
eval_examples, eval_data = dev_dataset['dev_acc'] |
|
else: |
|
eval_examples = read_examples_no_bracket(args.dev_filename, False) |
|
eval_examples = random.sample(eval_examples, int(len(eval_examples) / divide_number)) |
|
eval_features = convert_examples_to_features( |
|
eval_examples, tokenizer, args, stage='test') |
|
all_source_ids = torch.tensor( |
|
[f.source_ids for f in eval_features], dtype=torch.long) |
|
eval_data = TensorDataset(all_source_ids) |
|
dev_dataset['dev_acc'] = eval_examples, eval_data |
|
|
|
eval_sampler = SequentialSampler(eval_data) |
|
eval_dataloader = DataLoader( |
|
eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size) |
|
model.eval() |
|
pp = [] |
|
pr = [] |
|
for batch in eval_dataloader: |
|
batch = tuple(t.to(device) for t in batch) |
|
source_ids = batch[0] |
|
with torch.no_grad(): |
|
preds, predicates = model(source_ids) |
|
|
|
for pred, predicate in zip(preds, predicates): |
|
t = pred[0].cpu().numpy() |
|
p = predicate.float().item() |
|
t = list(t) |
|
|
|
tem_i = 0 |
|
if 0 in t: |
|
for my_i in range(len(t) - 1, 0, -1): |
|
if t[my_i] != 0: |
|
break |
|
tem_i -= 1 |
|
if tem_i < 0: |
|
t = t[:tem_i] |
|
text = tokenizer.decode( |
|
t, clean_up_tokenization_spaces=False) |
|
pp.append(text) |
|
pr.append(p) |
|
model.train() |
|
|
|
p_wrong_list = [] |
|
v_wrong_list = [] |
|
model_predicate = [] |
|
groundtruth_predicate = [] |
|
|
|
total = int(total_eval_all / divide_number) |
|
base_num = total - len(eval_examples) |
|
EM = float(base_num) |
|
EM_V = float(base_num) |
|
EM_P = float(base_num) |
|
cnt_v = 0 |
|
cnt_p = 0 |
|
cnt_iteration = 0 |
|
for ref, gold in zip(zip(pp, pr), eval_examples): |
|
cnt_iteration += 1 |
|
pred = ref[0].strip() |
|
predicate = ref[1] |
|
if gold.property.strip().lower() != "nothing": |
|
predicate = 1.0 |
|
else: |
|
pred = gold.source.strip() |
|
if 1 not in gold.vec: |
|
predicate = 0.0 |
|
if 1 in gold.vec and gold.source.strip()[0] == '}': |
|
predicate = 1.0 |
|
if '#' in gold.source: |
|
predicate = 0.0 |
|
if 1 in gold.vec[-97:]: |
|
predicate = 1.0 |
|
gt_pred = gold.target.strip() |
|
gt_predicate = gold.exist |
|
|
|
|
|
if pred == gt_pred and int(round(predicate)) == int(round(gt_predicate)): |
|
EM = EM + 1.0 |
|
EM_V = EM_V + 1.0 |
|
EM_P = EM_P + 1.0 |
|
else: |
|
if pred == gt_pred: |
|
EM_V = EM_V + 1.0 |
|
else: |
|
v_wrong_list.append([gold.filename, gold.funcname, gold.cpuname,\ |
|
round(predicate), gt_predicate, pred, gt_pred]) |
|
cnt_v += 1 |
|
if int(round(predicate)) == int(round(gt_predicate)): |
|
EM_P = EM_P + 1.0 |
|
else: |
|
cnt_p += 1 |
|
p_wrong_list.append([gold.filename, gold.funcname, gold.cpuname,\ |
|
round(predicate), gt_predicate, pred, gt_pred]) |
|
|
|
model_predicate.append(predicate) |
|
groundtruth_predicate.append(gt_predicate) |
|
dev_acc = round((100*EM/total), 2) |
|
dev_acc_v = round((100*EM_V/total), 2) |
|
dev_acc_p = round((100*EM_P/total), 2) |
|
logger.info(" %s = %s " % ("Current Acc", str(dev_acc))) |
|
logger.info(" "+"*"*20) |
|
logger.info(" %s = %s " % ("Current Acc V", str(dev_acc_v))) |
|
logger.info(" "+"*"*20) |
|
logger.info(" %s = %s " % ("Current Acc P", str(dev_acc_p))) |
|
logger.info(" "+"*"*20) |
|
if dev_acc > best_acc: |
|
best_acc = dev_acc |
|
|
|
output_dir = os.path.join( |
|
args.output_dir, 'checkpoint-best-acc') |
|
if not os.path.exists(output_dir): |
|
os.makedirs(output_dir) |
|
model_to_save = model.module if hasattr( |
|
model, 'module') else model |
|
output_model_file = os.path.join( |
|
output_dir, "pytorch_model.bin") |
|
torch.save(model_to_save.state_dict(), output_model_file) |
|
logger.info(" Best acc:%s", best_acc) |
|
logger.info(" " + "*" * 20) |
|
|
|
|
|
if args.do_test or args.do_function_test: |
|
if os.path.exists(args.output_dir+"/result.jsonl"): |
|
os.unlink(args.output_dir+"/result.jsonl") |
|
checkpoint_prefix = 'checkpoint-best-acc/pytorch_model.bin' |
|
output_dir = os.path.join(args.output_dir, checkpoint_prefix) |
|
model_to_load = model.module if hasattr(model, 'module') else model |
|
model_to_load.load_state_dict(torch.load(output_dir), strict=False) |
|
|
|
eval_examples_all = read_examples(args.test_filename, args.do_function_test) |
|
eval_examples = read_examples_no_bracket(args.test_filename, args.do_function_test) |
|
|
|
|
|
total_all = len(eval_examples_all) |
|
base_test = total_all - len(eval_examples) |
|
|
|
|
|
eval_features = convert_examples_to_features( |
|
eval_examples, tokenizer, args, stage='test') |
|
all_source_ids = torch.tensor( |
|
[f.source_ids for f in eval_features], dtype=torch.long) |
|
eval_data = TensorDataset(all_source_ids) |
|
|
|
eval_examples_idx_lis = [] |
|
for ee in eval_examples: |
|
eval_examples_idx_lis.append(ee.idx) |
|
|
|
eval_sampler = SequentialSampler(eval_data) |
|
eval_dataloader = DataLoader( |
|
eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size) |
|
|
|
model.eval() |
|
pp = [] |
|
pr = [] |
|
if not args.do_function_test: |
|
print("Start Inferencing!") |
|
else: |
|
print("Start Function Test Inferencing!") |
|
for batch in eval_dataloader: |
|
batch = tuple(t.to(device) for t in batch) |
|
source_ids = batch[0] |
|
with torch.no_grad(): |
|
preds, predicates = model(source_ids) |
|
|
|
for pred, predicate in zip(preds, predicates): |
|
t = pred[0].cpu().numpy() |
|
p = predicate.float().item() |
|
t = list(t) |
|
tem_i = 0 |
|
if 0 in t: |
|
for my_i in range(len(t)-1, 0, -1): |
|
if t[my_i] != 0: |
|
break |
|
tem_i -= 1 |
|
if tem_i < 0: |
|
t = t[:tem_i] |
|
text = tokenizer.decode( |
|
t, clean_up_tokenization_spaces=False) |
|
pp.append(text) |
|
pr.append(p) |
|
if not args.do_function_test: |
|
print("Finished Inferencing.") |
|
else: |
|
print("Finished Function Test Inferencing.") |
|
model.train() |
|
EM = float(base_test) |
|
EM_P = float(base_test) |
|
EM_V = float(base_test) |
|
p_wrong_list = [] |
|
v_wrong_list = [] |
|
edit_sim = 0.0 |
|
total = total_all |
|
res_dic = {} |
|
|
|
model_predicate = [] |
|
groundtruth_predicate = [] |
|
|
|
for ref, gold in zip(zip(pp, pr), eval_examples): |
|
pred = ref[0].strip() |
|
predicate = ref[1] |
|
if gold.property.strip().lower() != "nothing": |
|
predicate = 1.0 |
|
else: |
|
pred = gold.source.strip() |
|
if 1 not in gold.vec: |
|
predicate = 0.0 |
|
if 1 in gold.vec and gold.source.strip()[0] == '}': |
|
predicate = 1.0 |
|
if '#' in gold.source: |
|
predicate = 0.0 |
|
if 1 in gold.vec[-97:]: |
|
predicate = 1.0 |
|
gt_pred = gold.target.strip() |
|
gt_predicate = gold.exist |
|
is_re = False |
|
gt_value = gold.property |
|
gt_source = gold.source |
|
if pred == gt_pred and round(predicate) == gt_predicate: |
|
EM += 1 |
|
if pred == gt_pred and round(predicate) != gt_predicate: |
|
p_wrong_list.append([gold.filename, gold.funcname, gold.cpuname, gold.idx, |
|
round(predicate), gt_predicate, pred, gt_pred]) |
|
if pred != gt_pred and round(predicate) == gt_predicate: |
|
is_re, re_pred = rewrite_pred(pred, gt_pred, gt_source, gt_value) |
|
if not is_re: |
|
v_wrong_list.append([gold.filename, gold.funcname, gold.cpuname, gold.idx, |
|
round(predicate), gt_predicate, pred, gt_pred]) |
|
else: |
|
pred = re_pred |
|
EM += 1 |
|
if pred != gt_pred and round(predicate) != gt_predicate: |
|
v_wrong_list.append([gold.filename, gold.funcname, gold.cpuname, gold.idx, |
|
round(predicate), gt_predicate, pred, gt_pred]) |
|
p_wrong_list.append([gold.filename, gold.funcname, gold.cpuname, gold.idx, |
|
round(predicate), gt_predicate, pred, gt_pred]) |
|
tem_dic = {} |
|
tem_dic["idx"] = gold.idx |
|
tem_dic["vega_code"] = pred |
|
tem_dic["ans_code"] = gt_pred |
|
tem_dic["vega_pre"] = round(predicate) |
|
tem_dic["ans_pre"] = gt_predicate |
|
tem_dic["File"] = gold.filename |
|
tem_dic["Func"] = gold.funcname |
|
tem_dic["Module"] = gold.module |
|
tem_dic["Target"] = gold.cpuname |
|
res_dic[gold.idx] = tem_dic |
|
|
|
if pred == gt_pred: |
|
EM_V += 1 |
|
if round(predicate) == gt_predicate: |
|
EM_P += 1 |
|
model_predicate.append(predicate) |
|
groundtruth_predicate.append(gt_predicate) |
|
dev_acc = round((100 * EM / total), 2) |
|
dev_acc_v = round((100 * EM_V / total), 2) |
|
dev_acc_p = round((100 * EM_P / total), 2) |
|
predictions = [] |
|
|
|
|
|
with open(args.output_dir+"/result.jsonl", 'a') as f2: |
|
for ee in eval_examples_all: |
|
if ee.idx not in eval_examples_idx_lis: |
|
dic = {} |
|
dic["idx"] = ee.idx |
|
dic["vega_code"] = ee.source.replace("zmtarzm", ee.cpuname) |
|
dic["ans_code"] = ee.source.replace("zmtarzm", ee.cpuname) |
|
dic["vega_pre"] = ee.exist |
|
dic["ans_pre"] = ee.exist |
|
dic["File"] = ee.filename |
|
dic["Func"] = ee.funcname |
|
dic["Module"] = ee.module |
|
dic["Target"] = ee.cpuname |
|
dic["Stable"] = "True" |
|
else: |
|
dic = {} |
|
dic["idx"] = res_dic[ee.idx]["idx"] |
|
dic["vega_code"] = res_dic[ee.idx]["vega_code"].replace("zmtarzm", ee.cpuname) |
|
dic["ans_code"] = res_dic[ee.idx]["ans_code"].replace("zmtarzm", ee.cpuname) |
|
dic["vega_pre"] = res_dic[ee.idx]["vega_pre"] |
|
dic["ans_pre"] = res_dic[ee.idx]["ans_pre"] |
|
dic["File"] = res_dic[ee.idx]["File"] |
|
dic["Func"] = res_dic[ee.idx]["Func"] |
|
dic["Module"] = res_dic[ee.idx]["Module"] |
|
dic["Target"] = res_dic[ee.idx]["Target"] |
|
dic["Stable"] = "False" |
|
|
|
json.dump(dic, f2) |
|
f2.write('\n') |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
vega_train_main() |