LedZeppe1in
Added custom rutabert pipeline for column type annotation
1507360
from itertools import chain
import torch
from pandas import DataFrame
from torch.utils.data import Dataset
from tqdm import tqdm
from transformers import PreTrainedTokenizerBase
class TableDataset(Dataset):
def __init__(self, tokenizer: PreTrainedTokenizerBase, dataframe: DataFrame = None):
self.dataframe = self._create_dataset(tokenizer, dataframe)
def __len__(self):
return len(self.dataframe)
def __getitem__(self, idx):
return {
"data": self.dataframe.iloc[idx]["data"],
"labels": self.dataframe.iloc[idx]["labels"],
"table_id": self.dataframe.iloc[idx]["table_id"]
}
@staticmethod
def _create_dataset(tokenizer: PreTrainedTokenizerBase, dataframe: DataFrame) -> DataFrame:
data_list = []
for table_id, table in tqdm(dataframe.groupby("table_id")):
num_cols = len(table)
# Tokenize table columns
tokenized_table_columns = table["column_data"].apply(
lambda x: tokenizer.encode(x, add_special_tokens=False, max_length=(512 // num_cols) - 2, truncation=True)
).tolist()
labels = table["label_id"].values
for i in range(num_cols):
tail = [tokenized_table_columns[j] + [tokenizer.sep_token_id] if j != i else [] for j in range(num_cols)]
head = [tokenizer.cls_token_id, *tokenized_table_columns[i][:], tokenizer.sep_token_id]
tokenized_columns_seq = torch.LongTensor(head + list(chain.from_iterable(tail)))
label = torch.LongTensor([labels[i]])
data_list.append([table_id, num_cols, tokenized_columns_seq, label])
return DataFrame(data_list, columns=["table_id", "n_cols", "data", "labels"])