File size: 3,029 Bytes
2f9282b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
# -*- coding: utf-8 -*-
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Optional
import transformers
from transformers import HfArgumentParser, TrainingArguments
from flame.logging import get_logger
logger = get_logger(__name__)
@dataclass
class TrainingArguments(TrainingArguments):
model_name_or_path: str = field(
default=None,
metadata={
"help": "Path to the model weight or identifier from huggingface.co/models or modelscope.cn/models."
},
)
tokenizer: str = field(
default="mistralai/Mistral-7B-v0.1",
metadata={"help": "Name of the tokenizer to use."}
)
use_fast_tokenizer: bool = field(
default=False,
metadata={"help": "Whether or not to use one of the fast tokenizer (backed by the tokenizers library)."},
)
from_config: bool = field(
default=True,
metadata={"help": "Whether to initialize models from scratch."},
)
dataset: Optional[str] = field(
default=None,
metadata={"help": "The dataset(s) to use. Use commas to separate multiple datasets."},
)
dataset_name: Optional[str] = field(
default=None,
metadata={"help": "The name of provided dataset(s) to use."},
)
cache_dir: str = field(
default=None,
metadata={"help": "Path to the cached tokenized dataset."},
)
split: str = field(
default="train",
metadata={"help": "Which dataset split to use for training and evaluation."},
)
streaming: bool = field(
default=False,
metadata={"help": "Enable dataset streaming."},
)
hf_hub_token: Optional[str] = field(
default=None,
metadata={"help": "Auth token to log in with Hugging Face Hub."},
)
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={"help": "The number of processes to use for the pre-processing."},
)
buffer_size: int = field(
default=2048,
metadata={"help": "Size of the buffer to randomly sample examples from in dataset streaming."},
)
context_length: int = field(
default=2048,
metadata={"help": "The context length of the tokenized inputs in the dataset."},
)
def get_train_args():
parser = HfArgumentParser(TrainingArguments)
args, unknown_args = parser.parse_args_into_dataclasses(return_remaining_strings=True)
if unknown_args:
print(parser.format_help())
print("Got unknown args, potentially deprecated arguments: {}".format(unknown_args))
raise ValueError("Some specified arguments are not used by the HfArgumentParser: {}".format(unknown_args))
if args.should_log:
transformers.utils.logging.set_verbosity(args.get_process_log_level())
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
# set seeds manually
transformers.set_seed(args.seed)
return args
|