File size: 3,029 Bytes
2f9282b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
# -*- coding: utf-8 -*-

from __future__ import annotations

from dataclasses import dataclass, field
from typing import Optional

import transformers
from transformers import HfArgumentParser, TrainingArguments

from flame.logging import get_logger

logger = get_logger(__name__)


@dataclass
class TrainingArguments(TrainingArguments):

    model_name_or_path: str = field(
        default=None,
        metadata={
            "help": "Path to the model weight or identifier from huggingface.co/models or modelscope.cn/models."
        },
    )
    tokenizer: str = field(
        default="mistralai/Mistral-7B-v0.1",
        metadata={"help": "Name of the tokenizer to use."}
    )
    use_fast_tokenizer: bool = field(
        default=False,
        metadata={"help": "Whether or not to use one of the fast tokenizer (backed by the tokenizers library)."},
    )
    from_config: bool = field(
        default=True,
        metadata={"help": "Whether to initialize models from scratch."},
    )
    dataset: Optional[str] = field(
        default=None,
        metadata={"help": "The dataset(s) to use. Use commas to separate multiple datasets."},
    )
    dataset_name: Optional[str] = field(
        default=None,
        metadata={"help": "The name of provided dataset(s) to use."},
    )
    cache_dir: str = field(
        default=None,
        metadata={"help": "Path to the cached tokenized dataset."},
    )
    split: str = field(
        default="train",
        metadata={"help": "Which dataset split to use for training and evaluation."},
    )
    streaming: bool = field(
        default=False,
        metadata={"help": "Enable dataset streaming."},
    )
    hf_hub_token: Optional[str] = field(
        default=None,
        metadata={"help": "Auth token to log in with Hugging Face Hub."},
    )
    preprocessing_num_workers: Optional[int] = field(
        default=None,
        metadata={"help": "The number of processes to use for the pre-processing."},
    )
    buffer_size: int = field(
        default=2048,
        metadata={"help": "Size of the buffer to randomly sample examples from in dataset streaming."},
    )
    context_length: int = field(
        default=2048,
        metadata={"help": "The context length of the tokenized inputs in the dataset."},
    )


def get_train_args():
    parser = HfArgumentParser(TrainingArguments)
    args, unknown_args = parser.parse_args_into_dataclasses(return_remaining_strings=True)

    if unknown_args:
        print(parser.format_help())
        print("Got unknown args, potentially deprecated arguments: {}".format(unknown_args))
        raise ValueError("Some specified arguments are not used by the HfArgumentParser: {}".format(unknown_args))

    if args.should_log:
        transformers.utils.logging.set_verbosity(args.get_process_log_level())
        transformers.utils.logging.enable_default_handler()
        transformers.utils.logging.enable_explicit_format()
    # set seeds manually
    transformers.set_seed(args.seed)
    return args