FloraBERT / module /dataio.py
Gurveer05's picture
Added pred func
""" Utilities for reading and writing data files.
import multiprocessing as mp
import os
from pathlib import PosixPath
from typing import Callable, Dict, List, Optional, Tuple, Union
from datasets import load_dataset
from torch.utils.data import Dataset
from transformers import (
from . import config
# To avoid huggingface warning
os.environ["TOKENIZERS_PARALLELISM"] = "false"
UBUNTU_ROOT = str(config.root)
def load_datasets(
tokenizer: PreTrainedTokenizer,
train_data: Union[str, PosixPath],
eval_data: Optional[Union[str, PosixPath]] = None,
test_data: Union[str, PosixPath] = None,
file_type: str = "csv",
delimiter: str = "\t",
seq_key: str = "sequence",
shuffle: bool = True,
filter_empty: bool = False,
n_workers: int = mp.cpu_count(),
) -> Dataset:
"""Load and cache data using Huggingface datasets library
tokenizer (PreTrainedTokenizer): tokenizer to apply to the sequences
train_data (Union[str, PosixPath]): location of training data
eval_data (Union[str, PosixPath], optional): location of evaluation data. Defaults to None.
test_data (Union[str, PosixPath], optional): location of test data. Defaults to None.
file_type (str, optional): type of file. Possible values are 'text' and 'csv'. Defaults to 'csv'.
delimiter (str, optional): Defaults to '\t'.
seq_key (str, optional): Column name of sequence data Can be 'sequence', 'seq', or 'text'. Defaults to 'sequence'.
shuffle (bool, optional): Whether to shuffle the dataset. Defaults to True.
filter_empty (bool, optional): Whether to filter out empty sequences. Defaults to False.
NOTE: This completes an additional iteration, which can be time-consuming.
Only enable if you have reason to believe that preprocessing steps will
result in empty sequences.
transformation (str, optional): type of transformation to apply.
Options are 'log', 'boxcox'. Defaults to None.
log_offset (Union[float, int]): value to offset gene expression values
by before log transforming. Defaults to 1.
preprocessor (BaseEstimator): preprocessor Yeoh-Johnson transformation.
tissue_subset (Union[str, int, list], optional): tissues to subset labels to.
Defaults to None.
nshards (int, optional): Number of shards to divide data into, only
keeping the first. Defaults to None.
threshold (float, optional): filter out rows where all labels are
below `threshold`. OR if `discretize` is True, see `discretize`.
Defaults to None.
discretize (bool, optional): set gene expression values below
`threshold` to 0, above `threshold` to 1.
kmer (int, optional): whether to run the kmer flip experiment and if so,
how large kmers to flip. Defaults to None.
n_workers (int, optional): number of processes to use for preprocessing.
Defaults to `mp.cpu_count()` (number of available CPUs).
position_buckets (Tuple[int], optional): the different buckets for the bucketed
positional importance experiment
data_files = {"train": str(train_data)}
if eval_data:
data_files["eval"] = str(eval_data)
if test_data:
data_files["test"] = str(test_data)
if file_type == "csv":
kwargs.update({"delimiter": delimiter})
datasets = load_dataset(file_type, data_files=data_files, **kwargs)
# Tokenizing
preprocess_fn = make_preprocess_function(tokenizer, seq_key=seq_key)
# print("Tokenizing...")
datasets = datasets.map(preprocess_fn, batched=True, num_proc=n_workers)
if filter_empty:
datasets = datasets.filter(filter_empty_sequence)
if shuffle:
seed = config.settings["random_seed"]
datasets = datasets.shuffle(seeds={"train": seed, "eval": seed, "test": seed})
return datasets
def make_preprocess_function(tokenizer, seq_key: str = "sequence") -> callable:
"""Make a preprocessing function that selects the appropriate column and
tokenizes it.
tokenizer (PreTrainedTokenizer): tokenizer to apply to each sequence
seq_key (str, optional): column name of the text data. Defaults to 'sequence'.
callable: preprocessing function
def preprocess_function(examples):
if seq_key:
seqs = examples[seq_key]
seqs = examples
return tokenizer(
return preprocess_function
def filter_empty_sequence(example: dict) -> bool:
"""Filter out empty sequences."""
# sum(example['attention_mask']) gives the number of tokens, including SOS and EOS
return sum(example["attention_mask"]) > 2
def load_data_collator(model_type: str, tokenizer=None, mlm_prob=None):
if model_type == "language-model":
assert (
tokenizer is not None
), "tokenizer must not be None if model is type language-model"
assert (
mlm_prob is not None
), "mlm_prob must not be None if model is type language-model"
return DataCollatorForLanguageModeling(
tokenizer=tokenizer, mlm=True, mlm_probability=mlm_prob
return default_data_collator