File size: 2,599 Bytes
e9e1652 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
import os
import json
import shutil
from optimum.exporters.onnx import main_export
import onnx
from onnxconverter_common import float16
import onnxruntime as rt
from onnxruntime.tools.onnx_model_utils import *
from onnxruntime.quantization import quantize_dynamic, QuantType
with open('conversion_config.json') as json_file:
conversion_config = json.load(json_file)
model_id = conversion_config["model_id"]
number_of_generated_embeddings = conversion_config["number_of_generated_embeddings"]
precision_to_filename_map = conversion_config["precision_to_filename_map"]
opset = conversion_config["opset"]
IR = conversion_config["IR"]
op = onnx.OperatorSetIdProto()
op.version = opset
if not os.path.exists("onnx"):
os.makedirs("onnx")
print("Exporting the main model version")
main_export(model_name_or_path=model_id, output="./", opset=opset, trust_remote_code=True, task="feature-extraction", dtype="fp32")
if "fp32" in precision_to_filename_map:
print("Exporting the fp32 onnx file...")
shutil.copyfile('model.onnx', precision_to_filename_map["fp32"])
print("Done\n\n")
if "fp16" in precision_to_filename_map:
print("Exporting the fp16 onnx file...")
model_fp16 = float16.convert_float_to_float16(onnx.load('model.onnx'),\
min_positive_val=1e-7, \
max_finite_val=1e4, \
keep_io_types=True, \
disable_shape_infer=True, \
op_block_list=None, \
node_block_list=None)
model_fp16 = onnx.helper.make_model(model_fp16.graph, ir_version = IR, opset_imports = [op]) #to be sure that we have compatible opset and IR version
onnx.save(model_fp16, precision_to_filename_map["fp16"])
print("Done\n\n")
if "int8" in precision_to_filename_map:
print("Quantizing fp32 model to int8...")
quantize_dynamic("model.onnx", precision_to_filename_map["int8"], weight_type=QuantType.QInt8)
print("Done\n\n")
if "uint8" in precision_to_filename_map:
print("Quantizing fp32 model to uint8...")
quantize_dynamic("model.onnx", precision_to_filename_map["uint8"], weight_type=QuantType.QUInt8)
print("Done\n\n")
os.remove("model.onnx")
|