LedZeppe1in
Added custom rutabert pipeline for column type annotation
1507360
import numpy as np
import pandas as pd
import torch
from torch import Tensor
from transformers import Pipeline
from rutabert_pipeline.dataset import TableDataset
from rutabert_pipeline.sem_types import TYPES_MAPPING
class ColumnTypeAnnotationPipeline(Pipeline):
def _sanitize_parameters(self, **kwargs):
preprocess_kwargs, forward_kwargs, postprocess_kwargs = {}, {}, {}
return preprocess_kwargs, forward_kwargs, postprocess_kwargs
def preprocess(self, inputs, **preprocess_parameters):
return TableDataset(tokenizer=self.tokenizer, dataframe=inputs)
def _forward(self, model_inputs, **forward_parameters):
self._set_rs(2024)
result_df = []
self.model.eval()
with torch.no_grad():
for sample in model_inputs:
logits = []
data = sample["data"].to(torch.device("cpu"))
seq = data.unsqueeze(0)
attention_mask = torch.clone(seq != 0)
probs = self.model(seq, attention_mask=attention_mask)
if isinstance(probs, tuple):
probs = probs[0]
cls_probs = self._get_token_logits(torch.device("cpu"), seq, probs, self.tokenizer.cls_token_id)
logits.append(cls_probs.argmax(1).cpu().detach().numpy().tolist())
result_df.append([sample["table_id"], logits])
return pd.DataFrame(result_df, columns=["table_id", "labels"])
def postprocess(self, model_outputs, **postprocess_parameters):
model_outputs["labels"] = model_outputs["labels"].apply(lambda x: TYPES_MAPPING.get(x[0][0]))
return model_outputs
@staticmethod
def _set_rs(seed: int = 13) -> None:
"""
Set random seed
:param seed: random seed
:return: None
"""
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(seed)
@staticmethod
def _get_token_logits(device: torch.device, data: Tensor, logits: Tensor, token_id: int) -> Tensor:
"""
Get specific token logits in the data
:param device: device (GPU or CPU)
:param data: model input data
:param logits: model logits
:param token_id: token id
:return: all specific token logits in data
"""
token_indexes = torch.nonzero(data == token_id)
token_logits = torch.zeros(token_indexes.shape[0], logits.shape[2]).to(device)
for i in range(token_indexes.shape[0]):
j, k = token_indexes[i]
logit_i = logits[j, k, :]
token_logits[i] = logit_i
return token_logits