e2eqa-wiki / src /models /pairwise_model.py
foxxy-hm's picture
Update src/models/pairwise_model.py
b2f8848
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModel, AutoConfig
from transformers import AutoTokenizer
import pandas as pd
AUTH_TOKEN = "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"
tokenizer = AutoTokenizer.from_pretrained('nguyenvulebinh/vi-mrc-base',
use_auth_token=AUTH_TOKEN)
pad_token_id = tokenizer.pad_token_id
class PairwiseModel(nn.Module):
def __init__(self, model_name, max_length=384, batch_size=16, device="cpu"):
super(PairwiseModel, self).__init__()
self.max_length = max_length
self.batch_size = batch_size
self.device = device
self.model = AutoModel.from_pretrained(model_name, use_auth_token=AUTH_TOKEN)
self.model.to(self.device)
self.model.eval()
self.config = AutoConfig.from_pretrained(model_name, use_auth_token=AUTH_TOKEN)
self.fc = nn.Linear(768, 1).to(self.device)
def forward(self, ids, masks):
out = self.model(input_ids=ids,
attention_mask=masks,
output_hidden_states=False).last_hidden_state
out = out[:, 0]
outputs = self.fc(out)
return outputs
def stage1_ranking(self, question, texts):
tmp = pd.DataFrame()
tmp["text"] = [" ".join(x.split()) for x in texts]
tmp["question"] = question
valid_dataset = SiameseDatasetStage1(tmp, tokenizer, self.max_length, is_test=True)
valid_loader = DataLoader(valid_dataset, batch_size=self.batch_size, collate_fn=collate_fn,
num_workers=0, shuffle=False, pin_memory=True)
preds = []
with torch.no_grad():
bar = enumerate(valid_loader)
for step, data in bar:
ids = data["ids"].to(self.device)
masks = data["masks"].to(self.device)
preds.append(torch.sigmoid(self(ids, masks)).view(-1))
preds = torch.concat(preds)
return preds.cpu().numpy()
def stage2_ranking(self, question, answer, titles, texts):
tmp = pd.DataFrame()
tmp["candidate"] = texts
tmp["question"] = question
tmp["answer"] = answer
tmp["title"] = titles
valid_dataset = SiameseDatasetStage2(tmp, tokenizer, self.max_length, is_test=True)
valid_loader = DataLoader(valid_dataset, batch_size=self.batch_size, collate_fn=collate_fn,
num_workers=0, shuffle=False, pin_memory=True)
preds = []
with torch.no_grad():
bar = enumerate(valid_loader)
for step, data in bar:
ids = data["ids"].to(self.device)
masks = data["masks"].to(self.device)
preds.append(torch.sigmoid(self(ids, masks)).view(-1))
preds = torch.concat(preds)
return preds.cpu().numpy()
class SiameseDatasetStage1(Dataset):
def __init__(self, df, tokenizer, max_length, is_test=False):
self.df = df
self.max_length = max_length
self.tokenizer = tokenizer
self.is_test = is_test
self.content1 = tokenizer.batch_encode_plus(list(df.question.values), max_length=max_length, truncation=True)[
"input_ids"]
self.content2 = tokenizer.batch_encode_plus(list(df.text.values), max_length=max_length, truncation=True)[
"input_ids"]
if not self.is_test:
self.targets = self.df.label
def __len__(self):
return len(self.df)
def __getitem__(self, index):
return {
'ids1': torch.tensor(self.content1[index], dtype=torch.long),
'ids2': torch.tensor(self.content2[index][1:], dtype=torch.long),
'target': torch.tensor(0) if self.is_test else torch.tensor(self.targets[index], dtype=torch.float)
}
class SiameseDatasetStage2(Dataset):
def __init__(self, df, tokenizer, max_length, is_test=False):
self.df = df
self.max_length = max_length
self.tokenizer = tokenizer
self.is_test = is_test
self.df["content1"] = self.df.apply(lambda row: row.question + f" {tokenizer.sep_token} " + row.answer, axis=1)
self.df["content2"] = self.df.apply(lambda row: row.title + f" {tokenizer.sep_token} " + row.candidate, axis=1)
self.content1 = tokenizer.batch_encode_plus(list(df.content1.values), max_length=max_length, truncation=True)[
"input_ids"]
self.content2 = tokenizer.batch_encode_plus(list(df.content2.values), max_length=max_length, truncation=True)[
"input_ids"]
if not self.is_test:
self.targets = self.df.label
def __len__(self):
return len(self.df)
def __getitem__(self, index):
return {
'ids1': torch.tensor(self.content1[index], dtype=torch.long),
'ids2': torch.tensor(self.content2[index][1:], dtype=torch.long),
'target': torch.tensor(0) if self.is_test else torch.tensor(self.targets[index], dtype=torch.float)
}
def collate_fn(batch):
ids = [torch.cat([x["ids1"], x["ids2"]]) for x in batch]
targets = [x["target"] for x in batch]
max_len = np.max([len(x) for x in ids])
masks = []
for i in range(len(ids)):
if len(ids[i]) < max_len:
ids[i] = torch.cat((ids[i], torch.tensor([pad_token_id, ] * (max_len - len(ids[i])), dtype=torch.long)))
masks.append(ids[i] != pad_token_id)
# print(tokenizer.decode(ids[0]))
outputs = {
"ids": torch.vstack(ids),
"masks": torch.vstack(masks),
"target": torch.vstack(targets).view(-1)
}
return outputs