Spaces:
Running
Running
sunheycho
commited on
Commit
·
00ccfb9
1
Parent(s):
56b6ee6
fix(llama-lora): force safetensors for HF base and PEFT adapter loads to avoid torch.load CVE; update TinyLlama load too
Browse files
api.py
CHANGED
@@ -513,6 +513,7 @@ try:
|
|
513 |
llm_model = AutoModelForCausalLM.from_pretrained(
|
514 |
model_name,
|
515 |
torch_dtype=torch.float16,
|
|
|
516 |
# Removing options that require accelerate package
|
517 |
# device_map="auto",
|
518 |
# load_in_8bit=True
|
@@ -806,7 +807,7 @@ def load_hf_base_and_tokenizer(base_id: str, tok_id: str = None):
|
|
806 |
hf_base_models[base_id] = AutoModelForCausalLM.from_pretrained(
|
807 |
base_id,
|
808 |
torch_dtype=_preferred_dtype(),
|
809 |
-
use_safetensors=
|
810 |
).to(device)
|
811 |
return hf_tokenizers[tok_key], hf_base_models[base_id]
|
812 |
|
@@ -821,9 +822,15 @@ def load_hf_lora_model(base_id: str, adapter_id: str):
|
|
821 |
base = AutoModelForCausalLM.from_pretrained(
|
822 |
base_id,
|
823 |
torch_dtype=_preferred_dtype(),
|
824 |
-
use_safetensors=
|
825 |
).to(device)
|
826 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
827 |
hf_lora_models[key] = lora_model
|
828 |
return lora_model
|
829 |
|
|
|
513 |
llm_model = AutoModelForCausalLM.from_pretrained(
|
514 |
model_name,
|
515 |
torch_dtype=torch.float16,
|
516 |
+
use_safetensors=True,
|
517 |
# Removing options that require accelerate package
|
518 |
# device_map="auto",
|
519 |
# load_in_8bit=True
|
|
|
807 |
hf_base_models[base_id] = AutoModelForCausalLM.from_pretrained(
|
808 |
base_id,
|
809 |
torch_dtype=_preferred_dtype(),
|
810 |
+
use_safetensors=True,
|
811 |
).to(device)
|
812 |
return hf_tokenizers[tok_key], hf_base_models[base_id]
|
813 |
|
|
|
822 |
base = AutoModelForCausalLM.from_pretrained(
|
823 |
base_id,
|
824 |
torch_dtype=_preferred_dtype(),
|
825 |
+
use_safetensors=True,
|
826 |
).to(device)
|
827 |
+
# Prefer safetensors for adapter weights to avoid torch.load vulnerability
|
828 |
+
try:
|
829 |
+
lora_model = PeftModel.from_pretrained(base, adapter_id, use_safetensors=True)
|
830 |
+
except TypeError:
|
831 |
+
# Older PEFT versions may not support use_safetensors flag
|
832 |
+
lora_model = PeftModel.from_pretrained(base, adapter_id)
|
833 |
+
lora_model = lora_model.eval().to(device)
|
834 |
hf_lora_models[key] = lora_model
|
835 |
return lora_model
|
836 |
|