File size: 1,560 Bytes
			
			| b21e4a2 00568c1 b21e4a2 48434be b21e4a2 e50ab07 b21e4a2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 | """
shared module for cli specific things
"""
import logging
from dataclasses import dataclass, field
from typing import Optional
import axolotl.monkeypatch.data.batch_dataset_fetcher  # pylint: disable=unused-import  # noqa: F401
from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
configure_logging()
LOG = logging.getLogger("axolotl.common.cli")
@dataclass
class TrainerCliArgs:
    """
    dataclass representing the various non-training arguments
    """
    debug: bool = field(default=False)
    debug_text_only: bool = field(default=False)
    debug_num_examples: int = field(default=5)
    inference: bool = field(default=False)
    merge_lora: bool = field(default=False)
    prompter: Optional[str] = field(default=None)
    shard: bool = field(default=False)
@dataclass
class PreprocessCliArgs:
    """
    dataclass representing arguments for preprocessing only
    """
    debug: bool = field(default=False)
    debug_text_only: bool = field(default=False)
    debug_num_examples: int = field(default=1)
    prompter: Optional[str] = field(default=None)
def load_model_and_tokenizer(
    *,
    cfg: DictDefault,
    cli_args: TrainerCliArgs,
):
    LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
    tokenizer = load_tokenizer(cfg)
    LOG.info("loading model and (optionally) peft_config...")
    model, _ = load_model(cfg, tokenizer, inference=cli_args.inference)
    return model, tokenizer
 |