|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
from torch.utils.data import Dataset, DataLoader |
|
from transformers import AutoModel, AutoConfig |
|
from transformers import AutoTokenizer |
|
from optimum.onnxruntime import ORTModel |
|
import pandas as pd |
|
import os |
|
|
|
AUTH_TOKEN = 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 |
|
|
|
tokenizer = AutoTokenizer.from_pretrained('nguyenvulebinh/vi-mrc-base', |
|
use_auth_token=AUTH_TOKEN) |
|
pad_token_id = tokenizer.pad_token_id |
|
|
|
|
|
class PairwiseModel(nn.Module): |
|
def __init__(self, model_name, max_length=384, batch_size=16, device="cpu"): |
|
super(PairwiseModel, self).__init__() |
|
self.max_length = max_length |
|
self.batch_size = batch_size |
|
self.device = device |
|
self.model = ORTModel.from_pretrained(model_name, use_auth_token=AUTH_TOKEN, from_transformers=True) |
|
self.model.to(self.device) |
|
self.model.eval() |
|
self.config = AutoConfig.from_pretrained(model_name, use_auth_token=AUTH_TOKEN) |
|
self.fc = nn.Linear(768, 1).to(self.device) |
|
|
|
def forward(self, ids, masks): |
|
out = self.model(input_ids=ids, |
|
attention_mask=masks, |
|
output_hidden_states=False).last_hidden_state |
|
out = out[:, 0] |
|
outputs = self.fc(out) |
|
return outputs |
|
|
|
def stage1_ranking(self, question, texts): |
|
tmp = pd.DataFrame() |
|
tmp["text"] = [" ".join(x.split()) for x in texts] |
|
tmp["question"] = question |
|
valid_dataset = SiameseDatasetStage1(tmp, tokenizer, self.max_length, is_test=True) |
|
valid_loader = DataLoader(valid_dataset, batch_size=self.batch_size, collate_fn=collate_fn, |
|
num_workers=0, shuffle=False, pin_memory=True) |
|
preds = [] |
|
with torch.no_grad(): |
|
bar = enumerate(valid_loader) |
|
for step, data in bar: |
|
ids = data["ids"].to(self.device) |
|
masks = data["masks"].to(self.device) |
|
preds.append(torch.sigmoid(self(ids, masks)).view(-1)) |
|
preds = torch.concat(preds) |
|
return preds.cpu().numpy() |
|
|
|
def stage2_ranking(self, question, answer, titles, texts): |
|
tmp = pd.DataFrame() |
|
tmp["candidate"] = texts |
|
tmp["question"] = question |
|
tmp["answer"] = answer |
|
tmp["title"] = titles |
|
valid_dataset = SiameseDatasetStage2(tmp, tokenizer, self.max_length, is_test=True) |
|
valid_loader = DataLoader(valid_dataset, batch_size=self.batch_size, collate_fn=collate_fn, |
|
num_workers=0, shuffle=False, pin_memory=True) |
|
preds = [] |
|
with torch.no_grad(): |
|
bar = enumerate(valid_loader) |
|
for step, data in bar: |
|
ids = data["ids"].to(self.device) |
|
masks = data["masks"].to(self.device) |
|
preds.append(torch.sigmoid(self(ids, masks)).view(-1)) |
|
preds = torch.concat(preds) |
|
return preds.cpu().numpy() |
|
|
|
|
|
class SiameseDatasetStage1(Dataset): |
|
|
|
def __init__(self, df, tokenizer, max_length, is_test=False): |
|
self.df = df |
|
self.max_length = max_length |
|
self.tokenizer = tokenizer |
|
self.is_test = is_test |
|
self.content1 = tokenizer.batch_encode_plus(list(df.question.values), max_length=max_length, truncation=True)[ |
|
"input_ids"] |
|
self.content2 = tokenizer.batch_encode_plus(list(df.text.values), max_length=max_length, truncation=True)[ |
|
"input_ids"] |
|
if not self.is_test: |
|
self.targets = self.df.label |
|
|
|
def __len__(self): |
|
return len(self.df) |
|
|
|
def __getitem__(self, index): |
|
return { |
|
'ids1': torch.tensor(self.content1[index], dtype=torch.long), |
|
'ids2': torch.tensor(self.content2[index][1:], dtype=torch.long), |
|
'target': torch.tensor(0) if self.is_test else torch.tensor(self.targets[index], dtype=torch.float) |
|
} |
|
|
|
|
|
class SiameseDatasetStage2(Dataset): |
|
|
|
def __init__(self, df, tokenizer, max_length, is_test=False): |
|
self.df = df |
|
self.max_length = max_length |
|
self.tokenizer = tokenizer |
|
self.is_test = is_test |
|
self.df["content1"] = self.df.apply(lambda row: row.question + f" {tokenizer.sep_token} " + row.answer, axis=1) |
|
self.df["content2"] = self.df.apply(lambda row: row.title + f" {tokenizer.sep_token} " + row.candidate, axis=1) |
|
self.content1 = tokenizer.batch_encode_plus(list(df.content1.values), max_length=max_length, truncation=True)[ |
|
"input_ids"] |
|
self.content2 = tokenizer.batch_encode_plus(list(df.content2.values), max_length=max_length, truncation=True)[ |
|
"input_ids"] |
|
if not self.is_test: |
|
self.targets = self.df.label |
|
|
|
def __len__(self): |
|
return len(self.df) |
|
|
|
def __getitem__(self, index): |
|
return { |
|
'ids1': torch.tensor(self.content1[index], dtype=torch.long), |
|
'ids2': torch.tensor(self.content2[index][1:], dtype=torch.long), |
|
'target': torch.tensor(0) if self.is_test else torch.tensor(self.targets[index], dtype=torch.float) |
|
} |
|
|
|
|
|
def collate_fn(batch): |
|
ids = [torch.cat([x["ids1"], x["ids2"]]) for x in batch] |
|
targets = [x["target"] for x in batch] |
|
max_len = np.max([len(x) for x in ids]) |
|
masks = [] |
|
for i in range(len(ids)): |
|
if len(ids[i]) < max_len: |
|
ids[i] = torch.cat((ids[i], torch.tensor([pad_token_id, ] * (max_len - len(ids[i])), dtype=torch.long))) |
|
masks.append(ids[i] != pad_token_id) |
|
|
|
outputs = { |
|
"ids": torch.vstack(ids), |
|
"masks": torch.vstack(masks), |
|
"target": torch.vstack(targets).view(-1) |
|
} |
|
return outputs |
|
|