File size: 5,614 Bytes
bf1f674
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
""" Utilities for reading and writing data files.
"""
import multiprocessing as mp
import os
from pathlib import PosixPath
from typing import Callable, Dict, List, Optional, Tuple, Union
from datasets import load_dataset
from torch.utils.data import Dataset

from transformers import (
    DataCollatorForLanguageModeling,
    PreTrainedTokenizer,
    default_data_collator,
)

from . import config

# To avoid huggingface warning
os.environ["TOKENIZERS_PARALLELISM"] = "false"
UBUNTU_ROOT = str(config.root)

def load_datasets(
    tokenizer: PreTrainedTokenizer,
    train_data: Union[str, PosixPath],
    eval_data: Optional[Union[str, PosixPath]] = None,
    test_data: Union[str, PosixPath] = None,
    file_type: str = "csv",
    delimiter: str = "\t",
    seq_key: str = "sequence",
    shuffle: bool = True,
    filter_empty: bool = False,
    n_workers: int = mp.cpu_count(),
    **kwargs,
) -> Dataset:
    """Load and cache data using Huggingface datasets library

    Args:
        tokenizer (PreTrainedTokenizer): tokenizer to apply to the sequences
        train_data (Union[str, PosixPath]): location of training data
        eval_data (Union[str, PosixPath], optional): location of evaluation data. Defaults to None.
        test_data (Union[str, PosixPath], optional): location of test data. Defaults to None.
        file_type (str, optional): type of file. Possible values are 'text' and 'csv'. Defaults to 'csv'.
        delimiter (str, optional): Defaults to '\t'.
        seq_key (str, optional): Column name of sequence data Can be 'sequence', 'seq', or 'text'. Defaults to 'sequence'.
        shuffle (bool, optional): Whether to shuffle the dataset. Defaults to True.
        filter_empty (bool, optional): Whether to filter out empty sequences. Defaults to False.
            NOTE: This completes an additional iteration, which can be time-consuming.
            Only enable if you have reason to believe that preprocessing steps will
            result in empty sequences.
        transformation (str, optional): type of transformation to apply.
            Options are 'log', 'boxcox'. Defaults to None.
        log_offset (Union[float, int]): value to offset gene expression values
            by before log transforming. Defaults to 1.
        preprocessor (BaseEstimator): preprocessor Yeoh-Johnson transformation.
        tissue_subset (Union[str, int, list], optional): tissues to subset labels to.
            Defaults to None.
        nshards (int, optional): Number of shards to divide data into, only
            keeping the first. Defaults to None.
        threshold (float, optional): filter out rows where all labels are
            below `threshold`. OR if `discretize` is True, see `discretize`.
            Defaults to None.
        discretize (bool, optional): set gene expression values below
            `threshold` to 0, above `threshold` to 1.
        kmer (int, optional): whether to run the kmer flip experiment and if so,
            how large kmers to flip. Defaults to None.
        n_workers (int, optional): number of processes to use for preprocessing.
            Defaults to `mp.cpu_count()` (number of available CPUs).
        position_buckets (Tuple[int], optional): the different buckets for the bucketed
            positional importance experiment

    Returns:
        Dataset
    """
    data_files = {"train": str(train_data)}
    if eval_data:
        data_files["eval"] = str(eval_data)
    if test_data:
        data_files["test"] = str(test_data)
    if file_type == "csv":
        kwargs.update({"delimiter": delimiter})
    datasets = load_dataset(file_type, data_files=data_files, **kwargs)
    # Tokenizing
    preprocess_fn = make_preprocess_function(tokenizer, seq_key=seq_key)
    # print("Tokenizing...")
    datasets = datasets.map(preprocess_fn, batched=True, num_proc=n_workers)
    if filter_empty:
        datasets = datasets.filter(filter_empty_sequence)
    if shuffle:
        seed = config.settings["random_seed"]
        datasets = datasets.shuffle(seeds={"train": seed, "eval": seed, "test": seed})
    return datasets


def make_preprocess_function(tokenizer, seq_key: str = "sequence") -> callable:
    """Make a preprocessing function that selects the appropriate column and
    tokenizes it.

    Args:
        tokenizer (PreTrainedTokenizer): tokenizer to apply to each sequence
        seq_key (str, optional): column name of the text data. Defaults to 'sequence'.

    Returns:
        callable: preprocessing function
    """

    def preprocess_function(examples):
        if seq_key:
            seqs = examples[seq_key]
        else:
            seqs = examples
        return tokenizer(
            seqs,
            max_length=tokenizer.model_max_length,
            truncation=True,
            padding="max_length",
        )

    return preprocess_function

def filter_empty_sequence(example: dict) -> bool:
    """Filter out empty sequences."""
    # sum(example['attention_mask']) gives the number of tokens, including SOS and EOS
    return sum(example["attention_mask"]) > 2

def load_data_collator(model_type: str, tokenizer=None, mlm_prob=None):
    if model_type == "language-model":
        assert (
            tokenizer is not None
        ), "tokenizer must not be None if model is type language-model"
        assert (
            mlm_prob is not None
        ), "mlm_prob must not be None if model is type language-model"

        return DataCollatorForLanguageModeling(
            tokenizer=tokenizer, mlm=True, mlm_probability=mlm_prob
        )
    else:
        return default_data_collator