|
from itertools import chain |
|
|
|
import torch |
|
from pandas import DataFrame |
|
from torch.utils.data import Dataset |
|
from tqdm import tqdm |
|
from transformers import PreTrainedTokenizerBase |
|
|
|
|
|
class TableDataset(Dataset): |
|
|
|
def __init__(self, tokenizer: PreTrainedTokenizerBase, dataframe: DataFrame = None): |
|
self.dataframe = self._create_dataset(tokenizer, dataframe) |
|
|
|
def __len__(self): |
|
return len(self.dataframe) |
|
|
|
def __getitem__(self, idx): |
|
return { |
|
"data": self.dataframe.iloc[idx]["data"], |
|
"labels": self.dataframe.iloc[idx]["labels"], |
|
"table_id": self.dataframe.iloc[idx]["table_id"] |
|
} |
|
|
|
@staticmethod |
|
def _create_dataset(tokenizer: PreTrainedTokenizerBase, dataframe: DataFrame) -> DataFrame: |
|
data_list = [] |
|
for table_id, table in tqdm(dataframe.groupby("table_id")): |
|
num_cols = len(table) |
|
|
|
|
|
tokenized_table_columns = table["column_data"].apply( |
|
lambda x: tokenizer.encode(x, add_special_tokens=False, max_length=(512 // num_cols) - 2, truncation=True) |
|
).tolist() |
|
|
|
labels = table["label_id"].values |
|
for i in range(num_cols): |
|
tail = [tokenized_table_columns[j] + [tokenizer.sep_token_id] if j != i else [] for j in range(num_cols)] |
|
head = [tokenizer.cls_token_id, *tokenized_table_columns[i][:], tokenizer.sep_token_id] |
|
tokenized_columns_seq = torch.LongTensor(head + list(chain.from_iterable(tail))) |
|
label = torch.LongTensor([labels[i]]) |
|
data_list.append([table_id, num_cols, tokenized_columns_seq, label]) |
|
return DataFrame(data_list, columns=["table_id", "n_cols", "data", "labels"]) |
|
|