sunheycho commited on
Commit
00ccfb9
·
1 Parent(s): 56b6ee6

fix(llama-lora): force safetensors for HF base and PEFT adapter loads to avoid torch.load CVE; update TinyLlama load too

Browse files
Files changed (1) hide show
  1. api.py +10 -3
api.py CHANGED
@@ -513,6 +513,7 @@ try:
513
  llm_model = AutoModelForCausalLM.from_pretrained(
514
  model_name,
515
  torch_dtype=torch.float16,
 
516
  # Removing options that require accelerate package
517
  # device_map="auto",
518
  # load_in_8bit=True
@@ -806,7 +807,7 @@ def load_hf_base_and_tokenizer(base_id: str, tok_id: str = None):
806
  hf_base_models[base_id] = AutoModelForCausalLM.from_pretrained(
807
  base_id,
808
  torch_dtype=_preferred_dtype(),
809
- use_safetensors=False,
810
  ).to(device)
811
  return hf_tokenizers[tok_key], hf_base_models[base_id]
812
 
@@ -821,9 +822,15 @@ def load_hf_lora_model(base_id: str, adapter_id: str):
821
  base = AutoModelForCausalLM.from_pretrained(
822
  base_id,
823
  torch_dtype=_preferred_dtype(),
824
- use_safetensors=False,
825
  ).to(device)
826
- lora_model = PeftModel.from_pretrained(base, adapter_id).eval().to(device)
 
 
 
 
 
 
827
  hf_lora_models[key] = lora_model
828
  return lora_model
829
 
 
513
  llm_model = AutoModelForCausalLM.from_pretrained(
514
  model_name,
515
  torch_dtype=torch.float16,
516
+ use_safetensors=True,
517
  # Removing options that require accelerate package
518
  # device_map="auto",
519
  # load_in_8bit=True
 
807
  hf_base_models[base_id] = AutoModelForCausalLM.from_pretrained(
808
  base_id,
809
  torch_dtype=_preferred_dtype(),
810
+ use_safetensors=True,
811
  ).to(device)
812
  return hf_tokenizers[tok_key], hf_base_models[base_id]
813
 
 
822
  base = AutoModelForCausalLM.from_pretrained(
823
  base_id,
824
  torch_dtype=_preferred_dtype(),
825
+ use_safetensors=True,
826
  ).to(device)
827
+ # Prefer safetensors for adapter weights to avoid torch.load vulnerability
828
+ try:
829
+ lora_model = PeftModel.from_pretrained(base, adapter_id, use_safetensors=True)
830
+ except TypeError:
831
+ # Older PEFT versions may not support use_safetensors flag
832
+ lora_model = PeftModel.from_pretrained(base, adapter_id)
833
+ lora_model = lora_model.eval().to(device)
834
  hf_lora_models[key] = lora_model
835
  return lora_model
836