|
""" Utilities for reading and writing data files. |
|
""" |
|
import multiprocessing as mp |
|
import os |
|
from pathlib import PosixPath |
|
from typing import Callable, Dict, List, Optional, Tuple, Union |
|
from datasets import load_dataset |
|
from torch.utils.data import Dataset |
|
|
|
from transformers import ( |
|
DataCollatorForLanguageModeling, |
|
PreTrainedTokenizer, |
|
default_data_collator, |
|
) |
|
|
|
from . import config |
|
|
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
UBUNTU_ROOT = str(config.root) |
|
|
|
def load_datasets( |
|
tokenizer: PreTrainedTokenizer, |
|
train_data: Union[str, PosixPath], |
|
eval_data: Optional[Union[str, PosixPath]] = None, |
|
test_data: Union[str, PosixPath] = None, |
|
file_type: str = "csv", |
|
delimiter: str = "\t", |
|
seq_key: str = "sequence", |
|
shuffle: bool = True, |
|
filter_empty: bool = False, |
|
n_workers: int = mp.cpu_count(), |
|
**kwargs, |
|
) -> Dataset: |
|
"""Load and cache data using Huggingface datasets library |
|
|
|
Args: |
|
tokenizer (PreTrainedTokenizer): tokenizer to apply to the sequences |
|
train_data (Union[str, PosixPath]): location of training data |
|
eval_data (Union[str, PosixPath], optional): location of evaluation data. Defaults to None. |
|
test_data (Union[str, PosixPath], optional): location of test data. Defaults to None. |
|
file_type (str, optional): type of file. Possible values are 'text' and 'csv'. Defaults to 'csv'. |
|
delimiter (str, optional): Defaults to '\t'. |
|
seq_key (str, optional): Column name of sequence data Can be 'sequence', 'seq', or 'text'. Defaults to 'sequence'. |
|
shuffle (bool, optional): Whether to shuffle the dataset. Defaults to True. |
|
filter_empty (bool, optional): Whether to filter out empty sequences. Defaults to False. |
|
NOTE: This completes an additional iteration, which can be time-consuming. |
|
Only enable if you have reason to believe that preprocessing steps will |
|
result in empty sequences. |
|
transformation (str, optional): type of transformation to apply. |
|
Options are 'log', 'boxcox'. Defaults to None. |
|
log_offset (Union[float, int]): value to offset gene expression values |
|
by before log transforming. Defaults to 1. |
|
preprocessor (BaseEstimator): preprocessor Yeoh-Johnson transformation. |
|
tissue_subset (Union[str, int, list], optional): tissues to subset labels to. |
|
Defaults to None. |
|
nshards (int, optional): Number of shards to divide data into, only |
|
keeping the first. Defaults to None. |
|
threshold (float, optional): filter out rows where all labels are |
|
below `threshold`. OR if `discretize` is True, see `discretize`. |
|
Defaults to None. |
|
discretize (bool, optional): set gene expression values below |
|
`threshold` to 0, above `threshold` to 1. |
|
kmer (int, optional): whether to run the kmer flip experiment and if so, |
|
how large kmers to flip. Defaults to None. |
|
n_workers (int, optional): number of processes to use for preprocessing. |
|
Defaults to `mp.cpu_count()` (number of available CPUs). |
|
position_buckets (Tuple[int], optional): the different buckets for the bucketed |
|
positional importance experiment |
|
|
|
Returns: |
|
Dataset |
|
""" |
|
data_files = {"train": str(train_data)} |
|
if eval_data: |
|
data_files["eval"] = str(eval_data) |
|
if test_data: |
|
data_files["test"] = str(test_data) |
|
if file_type == "csv": |
|
kwargs.update({"delimiter": delimiter}) |
|
datasets = load_dataset(file_type, data_files=data_files, **kwargs) |
|
|
|
preprocess_fn = make_preprocess_function(tokenizer, seq_key=seq_key) |
|
|
|
datasets = datasets.map(preprocess_fn, batched=True, num_proc=n_workers) |
|
if filter_empty: |
|
datasets = datasets.filter(filter_empty_sequence) |
|
if shuffle: |
|
seed = config.settings["random_seed"] |
|
datasets = datasets.shuffle(seeds={"train": seed, "eval": seed, "test": seed}) |
|
return datasets |
|
|
|
|
|
def make_preprocess_function(tokenizer, seq_key: str = "sequence") -> callable: |
|
"""Make a preprocessing function that selects the appropriate column and |
|
tokenizes it. |
|
|
|
Args: |
|
tokenizer (PreTrainedTokenizer): tokenizer to apply to each sequence |
|
seq_key (str, optional): column name of the text data. Defaults to 'sequence'. |
|
|
|
Returns: |
|
callable: preprocessing function |
|
""" |
|
|
|
def preprocess_function(examples): |
|
if seq_key: |
|
seqs = examples[seq_key] |
|
else: |
|
seqs = examples |
|
return tokenizer( |
|
seqs, |
|
max_length=tokenizer.model_max_length, |
|
truncation=True, |
|
padding="max_length", |
|
) |
|
|
|
return preprocess_function |
|
|
|
def filter_empty_sequence(example: dict) -> bool: |
|
"""Filter out empty sequences.""" |
|
|
|
return sum(example["attention_mask"]) > 2 |
|
|
|
def load_data_collator(model_type: str, tokenizer=None, mlm_prob=None): |
|
if model_type == "language-model": |
|
assert ( |
|
tokenizer is not None |
|
), "tokenizer must not be None if model is type language-model" |
|
assert ( |
|
mlm_prob is not None |
|
), "mlm_prob must not be None if model is type language-model" |
|
|
|
return DataCollatorForLanguageModeling( |
|
tokenizer=tokenizer, mlm=True, mlm_probability=mlm_prob |
|
) |
|
else: |
|
return default_data_collator |
|
|