Prove_KCL / utils /textual_entailment_module.py
Jongmo's picture
Upload 25 files
a5bbcdb verified
import json
import numpy as np
import pandas as pd
from pathlib import Path
import torch
import re
from transformers import BertTokenizer, BertForSequenceClassification
# Constants and paths
HOME = Path('/users/k2031554')
DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
MAX_LEN = 512
CLASSES = ['SUPPORTS','REFUTES','NOT ENOUGH INFO']
METHODS = ['WEIGHTED_SUM', 'MALON']
def process_sent(sentence):
sentence = re.sub("LSB.*?RSB", "", sentence)
sentence = re.sub("LRB\s*?RRB", "", sentence)
sentence = re.sub("(\s*?)LRB((\s*?))", "\\1(\\2", sentence)
sentence = re.sub("(\s*?)RRB((\s*?))", "\\1)\\2", sentence)
sentence = re.sub("--", "-", sentence)
sentence = re.sub("``", '"', sentence)
sentence = re.sub("''", '"', sentence)
return sentence
class TextualEntailmentModule():
def __init__(
self,
model_path = 'base/models/BERT_FEVER_v4_model_PBT',
tokenizer_path = 'base/models/BERT_FEVER_v4_tok_PBT'
):
self.tokenizer = BertTokenizer.from_pretrained(
tokenizer_path
)
self.model = BertForSequenceClassification.from_pretrained(
model_path
)
self.model.to(DEVICE)
#def get_pair_scores(self, claim, evidence):
#
# encodings = self.tokenizer(
# [claim, evidence],
# max_length= MAX_LEN,
# return_token_type_ids=False,
# padding='max_length',
# truncation=True,
# return_tensors='pt',
# ).to(DEVICE)
#
# self.model.eval()
# with torch.no_grad():
# probs = self.model(
# input_ids=encodings['input_ids'],
# attention_mask=encodings['attention_mask']
# )
#
# return torch.softmax(probs.logits,dim=1).cpu().numpy()
def get_batch_scores(self, claims, evidence):
inputs = list(zip(claims, evidence))
encodings = self.tokenizer(
inputs,
max_length= MAX_LEN,
return_token_type_ids=False,
padding='max_length',
truncation=True,
return_tensors='pt',
).to(DEVICE)
self.model.eval()
with torch.no_grad():
probs = self.model(
input_ids=encodings['input_ids'],
attention_mask=encodings['attention_mask']
)
return torch.softmax(probs.logits,dim=1).cpu().numpy()
def get_label_from_scores(self, scores):
return CLASSES[np.argmax(scores)]
def get_label_malon(self, score_set):
score_labels = [np.argmax(s) for s in score_set]
if 1 not in score_labels and 0 not in score_labels:
return CLASSES[2] #NOT ENOUGH INFO
elif 0 in score_labels:
return CLASSES[0] #SUPPORTS
elif 1 in score_labels:
return CLASSES[1] #REFUTES