Spaces:
Sleeping
Sleeping
File size: 1,316 Bytes
36ed92b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
import argparse
import os
from safetensors import safe_open
from safetensors.torch import save_file
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--asylora_path', type=str, required=True, help="Path to the input asylora file.")
parser.add_argument('--output_path', type=str, required=True, help="Path to save the modified safetensors file.")
parser.add_argument('--lora_up', type=int, required=True, help="The target lora_up value.")
args = parser.parse_args()
output_dir = os.path.dirname(args.output_path)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
with safe_open(args.asylora_path, framework="pt") as f:
tensor_dict = {key: f.get_tensor(key) for key in f.keys()}
modified_dict = {}
for key, tensor in tensor_dict.items():
if 'lora_ups' in key:
lora_up_index = int(key.split('.')[2])
if lora_up_index != args.lora_up - 1:
continue
else:
new_key = key.replace(f'lora_ups.{lora_up_index}.', 'lora_up.')
modified_dict[new_key] = tensor
else:
modified_dict[key] = tensor
save_file(modified_dict, args.output_path)
if __name__ == "__main__":
main()
|