3v324v23's picture
bf16 to fp16
45be9e6
raw
history blame contribute delete
503 Bytes
import torch
from safetensors import safe_open
from safetensors.torch import save_file
# 加载 safetensors 文件
model_path = "model.safetensors"
with safe_open(model_path, framework="pt", device="cpu") as f:
state_dict = {key: f.get_tensor(key) for key in f.keys()}
# 将 BF16 转换为 FP16
fp16_state_dict = {key: value.to(torch.float16) for key, value in state_dict.items()}
# 保存为新的 safetensors 文件
output_path = "model_fp16.safetensors"
save_file(fp16_state_dict, output_path)