|
import numpy as np |
|
import pandas as pd |
|
import torch |
|
from torch import Tensor |
|
from transformers import Pipeline |
|
|
|
from rutabert_pipeline.dataset import TableDataset |
|
from rutabert_pipeline.sem_types import TYPES_MAPPING |
|
|
|
|
|
class ColumnTypeAnnotationPipeline(Pipeline): |
|
|
|
def _sanitize_parameters(self, **kwargs): |
|
preprocess_kwargs, forward_kwargs, postprocess_kwargs = {}, {}, {} |
|
return preprocess_kwargs, forward_kwargs, postprocess_kwargs |
|
|
|
def preprocess(self, inputs, **preprocess_parameters): |
|
return TableDataset(tokenizer=self.tokenizer, dataframe=inputs) |
|
|
|
def _forward(self, model_inputs, **forward_parameters): |
|
self._set_rs(2024) |
|
|
|
result_df = [] |
|
self.model.eval() |
|
with torch.no_grad(): |
|
for sample in model_inputs: |
|
logits = [] |
|
data = sample["data"].to(torch.device("cpu")) |
|
|
|
seq = data.unsqueeze(0) |
|
attention_mask = torch.clone(seq != 0) |
|
probs = self.model(seq, attention_mask=attention_mask) |
|
if isinstance(probs, tuple): |
|
probs = probs[0] |
|
cls_probs = self._get_token_logits(torch.device("cpu"), seq, probs, self.tokenizer.cls_token_id) |
|
|
|
logits.append(cls_probs.argmax(1).cpu().detach().numpy().tolist()) |
|
|
|
result_df.append([sample["table_id"], logits]) |
|
|
|
return pd.DataFrame(result_df, columns=["table_id", "labels"]) |
|
|
|
def postprocess(self, model_outputs, **postprocess_parameters): |
|
model_outputs["labels"] = model_outputs["labels"].apply(lambda x: TYPES_MAPPING.get(x[0][0])) |
|
return model_outputs |
|
|
|
@staticmethod |
|
def _set_rs(seed: int = 13) -> None: |
|
""" |
|
Set random seed |
|
:param seed: random seed |
|
:return: None |
|
""" |
|
torch.manual_seed(seed) |
|
torch.backends.cudnn.deterministic = True |
|
torch.backends.cudnn.benchmark = False |
|
np.random.seed(seed) |
|
|
|
@staticmethod |
|
def _get_token_logits(device: torch.device, data: Tensor, logits: Tensor, token_id: int) -> Tensor: |
|
""" |
|
Get specific token logits in the data |
|
:param device: device (GPU or CPU) |
|
:param data: model input data |
|
:param logits: model logits |
|
:param token_id: token id |
|
:return: all specific token logits in data |
|
""" |
|
token_indexes = torch.nonzero(data == token_id) |
|
token_logits = torch.zeros(token_indexes.shape[0], logits.shape[2]).to(device) |
|
for i in range(token_indexes.shape[0]): |
|
j, k = token_indexes[i] |
|
logit_i = logits[j, k, :] |
|
token_logits[i] = logit_i |
|
return token_logits |
|
|