File size: 5,196 Bytes
e9650d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
# import logging
# import os
# import random
# import signal
# import sys
# from pathlib import Path
# import fire
# import torch
# import yaml
# from addict import Dict
# from peft import set_peft_model_state_dict, get_peft_model_state_dict
# # add src to the pythonpath so we don't need to pip install this
# project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
# src_dir = os.path.join(project_root, "src")
# sys.path.insert(0, src_dir)
# from axolotl.utils.data import load_prepare_datasets
# from axolotl.utils.models import load_model
# from axolotl.utils.trainer import setup_trainer
# from axolotl.utils.wandb import setup_wandb_env_vars
# logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO"))
# def choose_device(cfg):
# def get_device():
# if torch.cuda.is_available():
# return "cuda"
# else:
# try:
# if torch.backends.mps.is_available():
# return "mps"
# except:
# return "cpu"
# cfg.device = get_device()
# if cfg.device == "cuda":
# cfg.device_map = {"": cfg.local_rank}
# else:
# cfg.device_map = {"": cfg.device}
# def choose_config(path: Path):
# yaml_files = [file for file in path.glob("*.yml")]
# if not yaml_files:
# raise ValueError(
# "No YAML config files found in the specified directory. Are you using a .yml extension?"
# )
# print("Choose a YAML file:")
# for idx, file in enumerate(yaml_files):
# print(f"{idx + 1}. {file}")
# chosen_file = None
# while chosen_file is None:
# try:
# choice = int(input("Enter the number of your choice: "))
# if 1 <= choice <= len(yaml_files):
# chosen_file = yaml_files[choice - 1]
# else:
# print("Invalid choice. Please choose a number from the list.")
# except ValueError:
# print("Invalid input. Please enter a number.")
# return chosen_file
# def save_latest_checkpoint_as_lora(
# config: Path = Path("configs/"),
# prepare_ds_only: bool = False,
# **kwargs,
# ):
# if Path(config).is_dir():
# config = choose_config(config)
# # load the config from the yaml file
# with open(config, "r") as f:
# cfg: Dict = Dict(lambda: None, yaml.load(f, Loader=yaml.Loader))
# # if there are any options passed in the cli, if it is something that seems valid from the yaml,
# # then overwrite the value
# cfg_keys = dict(cfg).keys()
# for k in kwargs:
# if k in cfg_keys:
# # handle booleans
# if isinstance(cfg[k], bool):
# cfg[k] = bool(kwargs[k])
# else:
# cfg[k] = kwargs[k]
# # setup some derived config / hyperparams
# cfg.gradient_accumulation_steps = cfg.batch_size // cfg.micro_batch_size
# cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
# cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
# assert cfg.local_rank == 0, "Run this with only one device!"
# choose_device(cfg)
# cfg.ddp = False
# if cfg.device == "mps":
# cfg.load_in_8bit = False
# cfg.tf32 = False
# if cfg.bf16:
# cfg.fp16 = True
# cfg.bf16 = False
# # Load the model and tokenizer
# logging.info("loading model, tokenizer, and lora_config...")
# model, tokenizer, lora_config = load_model(
# cfg.base_model,
# cfg.base_model_config,
# cfg.model_type,
# cfg.tokenizer_type,
# cfg,
# adapter=cfg.adapter,
# inference=True,
# )
# model.config.use_cache = False
# if torch.__version__ >= "2" and sys.platform != "win32":
# logging.info("Compiling torch model")
# model = torch.compile(model)
# possible_checkpoints = [str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")]
# if len(possible_checkpoints) > 0:
# sorted_paths = sorted(
# possible_checkpoints, key=lambda path: int(path.split("-")[-1])
# )
# resume_from_checkpoint = sorted_paths[-1]
# else:
# raise FileNotFoundError("Checkpoints folder not found")
# pytorch_bin_path = os.path.join(resume_from_checkpoint, "pytorch_model.bin")
# assert os.path.exists(pytorch_bin_path), "Bin not found"
# logging.info(f"Loading {pytorch_bin_path}")
# adapters_weights = torch.load(pytorch_bin_path, map_location="cpu")
# # d = get_peft_model_state_dict(model)
# print(model.load_state_dict(adapters_weights))
# # with open('b.log', "w") as f:
# # f.write(str(d.keys()))
# assert False
# print((adapters_weights.keys()))
# with open("a.log", "w") as f:
# f.write(str(adapters_weights.keys()))
# assert False
# logging.info("Setting peft model state dict")
# set_peft_model_state_dict(model, adapters_weights)
# logging.info(f"Set Completed!!! Saving pre-trained model to {cfg.output_dir}")
# model.save_pretrained(cfg.output_dir)
# if __name__ == "__main__":
# fire.Fire(save_latest_checkpoint_as_lora)
|