Upload 39 files
Browse files- .gitattributes +4 -0
- calib-cocotest2017.tar +3 -0
- depth_anything_v2_vits.onnx +3 -0
- depth_anything_v2_vits_ax620e.axmodel +3 -0
- depth_anything_v2_vits_ax650.axmodel +3 -0
- python/axengine/__init__.py +22 -0
- python/axengine/_axclrt.py +372 -0
- python/axengine/_axclrt_capi.py +198 -0
- python/axengine/_axclrt_types.py +21 -0
- python/axengine/_axe.py +399 -0
- python/axengine/_axe_capi.py +323 -0
- python/axengine/_axe_types.py +29 -0
- python/axengine/_base_session.py +59 -0
- python/axengine/_node.py +13 -0
- python/axengine/_providers.py +31 -0
- python/axengine/_session.py +117 -0
- python/examples/demo01.jpg +0 -0
- python/examples/demo02.jpg +0 -0
- python/examples/demo03.jpg +0 -0
- python/examples/demo04.jpg +0 -0
- python/examples/demo05.jpg +0 -0
- python/examples/demo06.jpg +0 -0
- python/examples/demo07.jpg +0 -0
- python/examples/demo08.jpg +0 -0
- python/examples/demo09.jpg +0 -0
- python/examples/demo10.jpg +0 -0
- python/examples/demo11.jpg +0 -0
- python/examples/demo12.jpg +0 -0
- python/examples/demo13.jpg +0 -0
- python/examples/demo14.jpg +0 -0
- python/examples/demo15.jpg +0 -0
- python/examples/demo16.jpg +0 -0
- python/examples/demo17.jpg +0 -0
- python/examples/demo18.jpg +0 -0
- python/examples/demo19.jpg +3 -0
- python/examples/demo20.jpg +0 -0
- python/infer.py +50 -0
- python/infer_onnx.py +56 -0
- python/output.png +3 -0
- python/requirements.txt +3 -0
.gitattributes
CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
depth_anything_v2_vits_ax620e.axmodel filter=lfs diff=lfs merge=lfs -text
|
37 |
+
depth_anything_v2_vits_ax650.axmodel filter=lfs diff=lfs merge=lfs -text
|
38 |
+
python/examples/demo19.jpg filter=lfs diff=lfs merge=lfs -text
|
39 |
+
python/output.png filter=lfs diff=lfs merge=lfs -text
|
calib-cocotest2017.tar
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0fd1652bceab31a66a35e05cc26a2c7633e8c5108f4c07c21b5868b9605cc15a
|
3 |
+
size 20869120
|
depth_anything_v2_vits.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:443e95f17819f347f5f987384b8cb7d7d18ed6af6ac46dec9b0152748ba7dfd0
|
3 |
+
size 98985978
|
depth_anything_v2_vits_ax620e.axmodel
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f13d462a968e309354e23babdaf6a90b26841c58fd02f36531ba5d7bb545bea4
|
3 |
+
size 38448968
|
depth_anything_v2_vits_ax650.axmodel
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4520309cbefa63a4127aae75f7b7aa0dc5cc07fa02ef7a13a4219b88950499a1
|
3 |
+
size 27978862
|
python/axengine/__init__.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2019-2024 Axera Semiconductor Co., Ltd. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# This source file is the property of Axera Semiconductor Co., Ltd. and
|
4 |
+
# may not be copied or distributed in any isomorphic form without the prior
|
5 |
+
# written consent of Axera Semiconductor Co., Ltd.
|
6 |
+
#
|
7 |
+
|
8 |
+
# thanks to community contributors list below:
|
9 |
+
# zylo117: https://github.com/zylo117, first implementation of the axclrt backend
|
10 |
+
|
11 |
+
from ._providers import axengine_provider_name, axclrt_provider_name
|
12 |
+
from ._providers import get_all_providers, get_available_providers
|
13 |
+
|
14 |
+
# check if axclrt is installed, or is a supported chip(e.g. AX650, AX620E etc.)
|
15 |
+
_available_providers = get_available_providers()
|
16 |
+
if not _available_providers:
|
17 |
+
raise ImportError(
|
18 |
+
f"No providers found. Please make sure you have installed one of the following: {get_all_providers()}")
|
19 |
+
print("[INFO] Available providers: ", _available_providers)
|
20 |
+
|
21 |
+
from ._node import NodeArg
|
22 |
+
from ._session import SessionOptions, InferenceSession
|
python/axengine/_axclrt.py
ADDED
@@ -0,0 +1,372 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2019-2024 Axera Semiconductor Co., Ltd. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# This source file is the property of Axera Semiconductor Co., Ltd. and
|
4 |
+
# may not be copied or distributed in any isomorphic form without the prior
|
5 |
+
# written consent of Axera Semiconductor Co., Ltd.
|
6 |
+
#
|
7 |
+
# first implementation of AXCLRTSession contributed by zylo117
|
8 |
+
|
9 |
+
import atexit
|
10 |
+
import os
|
11 |
+
import time
|
12 |
+
from typing import Any, Sequence
|
13 |
+
|
14 |
+
import ml_dtypes as mldt
|
15 |
+
import numpy as np
|
16 |
+
|
17 |
+
from ._axclrt_capi import axclrt_cffi, axclrt_lib
|
18 |
+
from ._axclrt_types import VNPUType, ModelType
|
19 |
+
from ._base_session import Session, SessionOptions
|
20 |
+
from ._node import NodeArg
|
21 |
+
|
22 |
+
__all__: ["AXCLRTSession"]
|
23 |
+
|
24 |
+
_is_axclrt_initialized = False
|
25 |
+
_is_axclrt_engine_initialized = False
|
26 |
+
|
27 |
+
|
28 |
+
def _transform_dtype(dtype):
|
29 |
+
if dtype == axclrt_cffi.cast("axclrtEngineDataType", axclrt_lib.AXCL_DATA_TYPE_UINT8):
|
30 |
+
return np.dtype(np.uint8)
|
31 |
+
elif dtype == axclrt_cffi.cast("axclrtEngineDataType", axclrt_lib.AXCL_DATA_TYPE_INT8):
|
32 |
+
return np.dtype(np.int8)
|
33 |
+
elif dtype == axclrt_cffi.cast("axclrtEngineDataType", axclrt_lib.AXCL_DATA_TYPE_UINT16):
|
34 |
+
return np.dtype(np.uint16)
|
35 |
+
elif dtype == axclrt_cffi.cast("axclrtEngineDataType", axclrt_lib.AXCL_DATA_TYPE_INT16):
|
36 |
+
return np.dtype(np.int16)
|
37 |
+
elif dtype == axclrt_cffi.cast("axclrtEngineDataType", axclrt_lib.AXCL_DATA_TYPE_UINT32):
|
38 |
+
return np.dtype(np.uint32)
|
39 |
+
elif dtype == axclrt_cffi.cast("axclrtEngineDataType", axclrt_lib.AXCL_DATA_TYPE_INT32):
|
40 |
+
return np.dtype(np.int32)
|
41 |
+
elif dtype == axclrt_cffi.cast("axclrtEngineDataType", axclrt_lib.AXCL_DATA_TYPE_FP32):
|
42 |
+
return np.dtype(np.float32)
|
43 |
+
elif dtype == axclrt_cffi.cast("axclrtEngineDataType", axclrt_lib.AXCL_DATA_TYPE_BF16):
|
44 |
+
return np.dtype(mldt.bfloat16)
|
45 |
+
else:
|
46 |
+
raise ValueError(f"Unsupported data type '{dtype}'.")
|
47 |
+
|
48 |
+
def _initialize_axclrt():
|
49 |
+
global _is_axclrt_initialized
|
50 |
+
ret = axclrt_lib.axclInit([])
|
51 |
+
if ret != 0:
|
52 |
+
raise RuntimeError(f"Failed to initialize axcl runtime. {ret}.")
|
53 |
+
_is_axclrt_initialized = True
|
54 |
+
|
55 |
+
|
56 |
+
def _finalize_axclrt():
|
57 |
+
global _is_axclrt_initialized, _is_axclrt_engine_initialized
|
58 |
+
if _is_axclrt_engine_initialized:
|
59 |
+
axclrt_lib.axclrtEngineFinalize()
|
60 |
+
_is_axclrt_engine_initialized = False
|
61 |
+
if _is_axclrt_initialized:
|
62 |
+
axclrt_lib.axclFinalize()
|
63 |
+
_is_axclrt_initialized = False
|
64 |
+
|
65 |
+
|
66 |
+
_initialize_axclrt()
|
67 |
+
atexit.register(_finalize_axclrt)
|
68 |
+
|
69 |
+
|
70 |
+
def _get_vnpu_type() -> VNPUType:
|
71 |
+
vnpu_type = axclrt_cffi.new("axclrtEngineVNpuKind *")
|
72 |
+
ret = axclrt_lib.axclrtEngineGetVNpuKind(vnpu_type)
|
73 |
+
if ret != 0:
|
74 |
+
raise RuntimeError("Failed to get VNPU attribute.")
|
75 |
+
return VNPUType(vnpu_type[0])
|
76 |
+
|
77 |
+
|
78 |
+
def _get_version():
|
79 |
+
major, minor, patch = axclrt_cffi.new('int32_t *'), axclrt_cffi.new('int32_t *'), axclrt_cffi.new(
|
80 |
+
'int32_t *')
|
81 |
+
axclrt_lib.axclrtGetVersion(major, minor, patch)
|
82 |
+
return f'{major[0]}.{minor[0]}.{patch[0]}'
|
83 |
+
|
84 |
+
|
85 |
+
class AXCLRTSession(Session):
|
86 |
+
def __init__(
|
87 |
+
self,
|
88 |
+
path_or_bytes: str | bytes | os.PathLike,
|
89 |
+
sess_options: SessionOptions | None = None,
|
90 |
+
provider_options: dict[Any, Any] | None = None,
|
91 |
+
**kwargs,
|
92 |
+
) -> None:
|
93 |
+
super().__init__()
|
94 |
+
|
95 |
+
self._device_index = 0
|
96 |
+
|
97 |
+
if provider_options is not None and "device_id" in provider_options[0]:
|
98 |
+
self._device_index = provider_options[0].get("device_id", 0)
|
99 |
+
|
100 |
+
lst = axclrt_cffi.new("axclrtDeviceList *")
|
101 |
+
ret = axclrt_lib.axclrtGetDeviceList(lst)
|
102 |
+
if ret != 0 or lst.num == 0:
|
103 |
+
raise RuntimeError(f"Get AXCL device failed 0x{ret:08x}, find total {lst.num} device.")
|
104 |
+
|
105 |
+
if self._device_index >= lst.num:
|
106 |
+
raise RuntimeError(f"Device index {self._device_index} is out of range, total {lst.num} device.")
|
107 |
+
|
108 |
+
self._device_id = lst.devices[self._device_index]
|
109 |
+
ret = axclrt_lib.axclrtSetDevice(self._device_id)
|
110 |
+
if ret != 0 or lst.num == 0:
|
111 |
+
raise RuntimeError(f"Set AXCL device failed 0x{ret:08x}.")
|
112 |
+
|
113 |
+
global _is_axclrt_engine_initialized
|
114 |
+
vnpu_type = axclrt_cffi.cast(
|
115 |
+
"axclrtEngineVNpuKind", VNPUType.DISABLED.value
|
116 |
+
)
|
117 |
+
# try to initialize NPU as disabled
|
118 |
+
ret = axclrt_lib.axclrtEngineInit(vnpu_type)
|
119 |
+
# if failed, try to get vnpu type
|
120 |
+
if 0 != ret:
|
121 |
+
vnpu = axclrt_cffi.new("axclrtEngineVNpuKind *")
|
122 |
+
ret = axclrt_lib.axclrtEngineGetVNpuKind(vnpu)
|
123 |
+
# if failed, that means the NPU is not available
|
124 |
+
if ret != 0:
|
125 |
+
raise RuntimeError(f"axclrtEngineInit as {vnpu.value} failed 0x{ret:08x}.")
|
126 |
+
# if success, that means the NPU is already initialized as vnpu.value
|
127 |
+
# so the initialization is failed.
|
128 |
+
# this means the other users maybe uninitialized the NPU suddenly
|
129 |
+
# and the app would be terminated unexpectedly at that moment.
|
130 |
+
# but we can't do anything to fix this issue, just print a warning message.
|
131 |
+
# it because the api looks like onnxruntime, so there no window avoid this.
|
132 |
+
# such as the life.
|
133 |
+
else:
|
134 |
+
print(f"[WARNING] Failed to initialize NPU as {vnpu_type}, NPU is already initialized as {vnpu.value}.")
|
135 |
+
# initialize NPU successfully, mark the flag to ensure the engine will be finalized
|
136 |
+
else:
|
137 |
+
_is_axclrt_engine_initialized = True
|
138 |
+
|
139 |
+
self.soc_name = axclrt_cffi.string(axclrt_lib.axclrtGetSocName()).decode()
|
140 |
+
print(f"[INFO] SOC Name: {self.soc_name}")
|
141 |
+
|
142 |
+
# model handle, context, info, io
|
143 |
+
self._model_id = axclrt_cffi.new("uint64_t *")
|
144 |
+
self._context_id = axclrt_cffi.new("uint64_t *")
|
145 |
+
|
146 |
+
# get vnpu type
|
147 |
+
self._vnpu_type = _get_vnpu_type()
|
148 |
+
print(f"[INFO] VNPU type: {self._vnpu_type}")
|
149 |
+
|
150 |
+
# load model
|
151 |
+
ret = self._load(path_or_bytes)
|
152 |
+
if 0 != ret:
|
153 |
+
raise RuntimeError("Failed to load model.")
|
154 |
+
print(f"[INFO] Compiler version: {self._get_model_tool_version()}")
|
155 |
+
|
156 |
+
# get model info
|
157 |
+
self._info = self._get_info()
|
158 |
+
self._shape_count = self._get_shape_count()
|
159 |
+
self._inputs = self._get_inputs()
|
160 |
+
self._outputs = self._get_outputs()
|
161 |
+
|
162 |
+
# prepare io
|
163 |
+
self._io = self._prepare_io()
|
164 |
+
|
165 |
+
def __del__(self):
|
166 |
+
self._unload()
|
167 |
+
|
168 |
+
def _load(self, path_or_bytes):
|
169 |
+
# model buffer, almost copied from onnx runtime
|
170 |
+
if isinstance(path_or_bytes, (str, os.PathLike)):
|
171 |
+
_model_path = axclrt_cffi.new("char[]", path_or_bytes.encode('utf-8'))
|
172 |
+
ret = axclrt_lib.axclrtEngineLoadFromFile(_model_path, self._model_id)
|
173 |
+
if ret != 0:
|
174 |
+
raise RuntimeError("axclrtEngineLoadFromFile failed.")
|
175 |
+
elif isinstance(path_or_bytes, bytes):
|
176 |
+
_model_buffer = axclrt_cffi.new("char[]", path_or_bytes)
|
177 |
+
_model_buffer_size = len(path_or_bytes)
|
178 |
+
|
179 |
+
dev_mem_ptr = axclrt_cffi.new('void **', axclrt_cffi.NULL)
|
180 |
+
ret = axclrt_lib.axclrtMalloc(dev_mem_ptr, _model_buffer_size, axclrt_lib.AXCL_MEM_MALLOC_NORMAL_ONLY)
|
181 |
+
if ret != 0:
|
182 |
+
raise RuntimeError("axclrtMalloc failed.")
|
183 |
+
|
184 |
+
ret = axclrt_lib.axclrtMemcpy(dev_mem_ptr[0], _model_buffer, _model_buffer_size, axclrt_lib.AXCL_MEMCPY_HOST_TO_DEVICE)
|
185 |
+
if ret != 0:
|
186 |
+
axclrt_lib.axclrtFree(dev_mem_ptr[0])
|
187 |
+
raise RuntimeError("axclrtMemcpy failed.")
|
188 |
+
|
189 |
+
ret = axclrt_lib.axclrtEngineLoadFromMem(dev_mem_ptr[0], _model_buffer_size, self._model_id)
|
190 |
+
axclrt_lib.axclrtFree(dev_mem_ptr[0])
|
191 |
+
if ret != 0:
|
192 |
+
raise RuntimeError("axclrtEngineLoadFromMem failed.")
|
193 |
+
else:
|
194 |
+
raise TypeError(f"Unable to load model from type '{type(path_or_bytes)}'")
|
195 |
+
|
196 |
+
ret = axclrt_lib.axclrtEngineCreateContext(self._model_id[0], self._context_id)
|
197 |
+
if ret != 0:
|
198 |
+
raise RuntimeError("axclrtEngineCreateContext failed")
|
199 |
+
return ret
|
200 |
+
|
201 |
+
def _unload(self):
|
202 |
+
if self._io is not None:
|
203 |
+
dev_size = axclrt_cffi.new("uint64_t *")
|
204 |
+
dev_prt = axclrt_cffi.new("void **")
|
205 |
+
for i in range(axclrt_lib.axclrtEngineGetNumInputs(self._info[0])):
|
206 |
+
axclrt_lib.axclrtEngineGetInputBufferByIndex(self._io, i, dev_prt, dev_size)
|
207 |
+
axclrt_lib.axclrtFree(dev_prt[0])
|
208 |
+
for i in range(axclrt_lib.axclrtEngineGetNumOutputs(self._info[0])):
|
209 |
+
axclrt_lib.axclrtEngineGetOutputBufferByIndex(self._io, i, dev_prt, dev_size)
|
210 |
+
axclrt_lib.axclrtFree(dev_prt[0])
|
211 |
+
axclrt_lib.axclrtEngineDestroyIO(self._io)
|
212 |
+
self._io = None
|
213 |
+
if self._model_id[0] is not None:
|
214 |
+
axclrt_lib.axclrtEngineUnload(self._model_id[0])
|
215 |
+
self._model_id[0] = 0
|
216 |
+
|
217 |
+
def _get_model_tool_version(self):
|
218 |
+
model_tool_version = axclrt_lib.axclrtEngineGetModelCompilerVersion(self._model_id[0])
|
219 |
+
return axclrt_cffi.string(model_tool_version).decode()
|
220 |
+
|
221 |
+
def _get_info(self):
|
222 |
+
io_info = axclrt_cffi.new("axclrtEngineIOInfo *")
|
223 |
+
ret = axclrt_lib.axclrtEngineGetIOInfo(self._model_id[0], io_info)
|
224 |
+
if ret != 0:
|
225 |
+
raise RuntimeError("axclrtEngineGetIOInfo failed.")
|
226 |
+
return io_info
|
227 |
+
|
228 |
+
def _get_shape_count(self):
|
229 |
+
count = axclrt_cffi.new("int32_t *")
|
230 |
+
ret = axclrt_lib.axclrtEngineGetShapeGroupsCount(self._info[0], count)
|
231 |
+
if ret != 0:
|
232 |
+
axclrt_lib.axclrtEngineUnload(self._model_id[0])
|
233 |
+
raise RuntimeError("axclrtEngineGetShapeGroupsCount failed.")
|
234 |
+
return count[0]
|
235 |
+
|
236 |
+
def _get_inputs(self):
|
237 |
+
inputs = []
|
238 |
+
for group in range(self._shape_count):
|
239 |
+
one_group_io = []
|
240 |
+
for index in range(axclrt_lib.axclrtEngineGetNumInputs(self._info[0])):
|
241 |
+
cffi_name = axclrt_lib.axclrtEngineGetInputNameByIndex(self._info[0], index)
|
242 |
+
name = axclrt_cffi.string(cffi_name).decode("utf-8")
|
243 |
+
|
244 |
+
cffi_dtype = axclrt_cffi.new("axclrtEngineDataType *")
|
245 |
+
ret = axclrt_lib.axclrtEngineGetInputDataType(self._info[0], index, cffi_dtype)
|
246 |
+
if ret != 0:
|
247 |
+
raise RuntimeError("axclrtEngineGetInputDataType failed.")
|
248 |
+
dtype = _transform_dtype(cffi_dtype[0])
|
249 |
+
|
250 |
+
cffi_dims = axclrt_cffi.new("axclrtEngineIODims *")
|
251 |
+
ret = axclrt_lib.axclrtEngineGetInputDims(self._info[0], group, index, cffi_dims)
|
252 |
+
if ret != 0:
|
253 |
+
raise RuntimeError("axclrtEngineGetInputDims failed.")
|
254 |
+
shape = [cffi_dims.dims[i] for i in range(cffi_dims.dimCount)]
|
255 |
+
|
256 |
+
meta = NodeArg(name, dtype, shape)
|
257 |
+
one_group_io.append(meta)
|
258 |
+
inputs.append(one_group_io)
|
259 |
+
return inputs
|
260 |
+
|
261 |
+
def _get_outputs(self):
|
262 |
+
outputs = []
|
263 |
+
for group in range(self._shape_count):
|
264 |
+
one_group_io = []
|
265 |
+
for index in range(axclrt_lib.axclrtEngineGetNumOutputs(self._info[0])):
|
266 |
+
name = axclrt_lib.axclrtEngineGetOutputNameByIndex(self._info[0], index)
|
267 |
+
|
268 |
+
cffi_dtype = axclrt_cffi.new("axclrtEngineDataType *")
|
269 |
+
ret = axclrt_lib.axclrtEngineGetOutputDataType(self._info[0], index, cffi_dtype)
|
270 |
+
if ret != 0:
|
271 |
+
raise RuntimeError("axclrtEngineGetOutputDataType failed.")
|
272 |
+
dtype = _transform_dtype(cffi_dtype[0])
|
273 |
+
|
274 |
+
cffi_dims = axclrt_cffi.new("axclrtEngineIODims *")
|
275 |
+
ret = axclrt_lib.axclrtEngineGetOutputDims(self._info[0], group, index, cffi_dims)
|
276 |
+
if ret != 0:
|
277 |
+
raise RuntimeError("axclrtEngineGetOutputDims failed.")
|
278 |
+
shape = [cffi_dims.dims[i] for i in range(cffi_dims.dimCount)]
|
279 |
+
|
280 |
+
meta = NodeArg(name, dtype, shape)
|
281 |
+
one_group_io.append(meta)
|
282 |
+
outputs.append(one_group_io)
|
283 |
+
return outputs
|
284 |
+
|
285 |
+
def _prepare_io(self):
|
286 |
+
_io = axclrt_cffi.new("axclrtEngineIO *")
|
287 |
+
ret = axclrt_lib.axclrtEngineCreateIO(self._info[0], _io)
|
288 |
+
if ret != 0:
|
289 |
+
raise RuntimeError(f"axclrtEngineCreateIO failed 0x{ret:08x}.")
|
290 |
+
for i in range(axclrt_lib.axclrtEngineGetNumInputs(self._info[0])):
|
291 |
+
max_size = 0
|
292 |
+
for group in range(self._shape_count):
|
293 |
+
size = axclrt_lib.axclrtEngineGetInputSizeByIndex(self._info[0], group, i)
|
294 |
+
max_size = max(max_size, size)
|
295 |
+
dev_ptr = axclrt_cffi.new("void **")
|
296 |
+
ret = axclrt_lib.axclrtMalloc(dev_ptr, max_size, axclrt_lib.AXCL_MEM_MALLOC_NORMAL_ONLY)
|
297 |
+
if 0 != ret or dev_ptr[0] == axclrt_cffi.NULL:
|
298 |
+
raise RuntimeError(f"axclrtMalloc failed 0x{ret:08x} for input {i}.")
|
299 |
+
ret = axclrt_lib.axclrtEngineSetInputBufferByIndex(_io[0], i, dev_ptr[0], max_size)
|
300 |
+
if 0 != ret:
|
301 |
+
raise RuntimeError(f"axclrtEngineSetInputBufferByIndex failed 0x{ret:08x} for input {i}.")
|
302 |
+
for i in range(axclrt_lib.axclrtEngineGetNumOutputs(self._info[0])):
|
303 |
+
max_size = 0
|
304 |
+
for group in range(self._shape_count):
|
305 |
+
size = axclrt_lib.axclrtEngineGetOutputSizeByIndex(self._info[0], group, i)
|
306 |
+
max_size = max(max_size, size)
|
307 |
+
dev_ptr = axclrt_cffi.new("void **")
|
308 |
+
ret = axclrt_lib.axclrtMalloc(dev_ptr, max_size, axclrt_lib.AXCL_MEM_MALLOC_NORMAL_ONLY)
|
309 |
+
if 0 != ret or dev_ptr[0] == axclrt_cffi.NULL:
|
310 |
+
raise RuntimeError(f"axclrtMalloc failed 0x{ret:08x} for output {i}.")
|
311 |
+
ret = axclrt_lib.axclrtEngineSetOutputBufferByIndex(_io[0], i, dev_ptr[0], max_size)
|
312 |
+
if 0 != ret:
|
313 |
+
raise RuntimeError(f"axclrtEngineSetOutputBufferByIndex failed 0x{ret:08x} for output {i}.")
|
314 |
+
return _io[0]
|
315 |
+
|
316 |
+
def run(
|
317 |
+
self,
|
318 |
+
output_names: list[str],
|
319 |
+
input_feed: dict[str, np.ndarray],
|
320 |
+
run_options=None
|
321 |
+
):
|
322 |
+
self._validate_input(input_feed)
|
323 |
+
self._validate_output(output_names)
|
324 |
+
|
325 |
+
if None is output_names:
|
326 |
+
output_names = [o.name for o in self.get_outputs()]
|
327 |
+
|
328 |
+
# fill model io
|
329 |
+
dev_prt = axclrt_cffi.new("void **")
|
330 |
+
dev_size = axclrt_cffi.new("uint64_t *")
|
331 |
+
for key, npy in input_feed.items():
|
332 |
+
for i, one in enumerate(self.get_inputs()):
|
333 |
+
if one.name == key:
|
334 |
+
assert (
|
335 |
+
list(one.shape) == list(npy.shape) and one.dtype == npy.dtype
|
336 |
+
), f"model inputs({key}) expect shape {one.shape} and dtype {one.dtype}, howerver gets input with shape {npy.shape} and dtype {npy.dtype}"
|
337 |
+
|
338 |
+
if not (
|
339 |
+
not npy.flags.c_contiguous
|
340 |
+
and npy.flags.f_contiguous
|
341 |
+
and npy.flags.contiguous
|
342 |
+
):
|
343 |
+
npy = np.ascontiguousarray(npy)
|
344 |
+
npy_ptr = axclrt_cffi.cast("void *", npy.ctypes.data)
|
345 |
+
ret = axclrt_lib.axclrtEngineGetInputBufferByIndex(self._io, i, dev_prt, dev_size)
|
346 |
+
if 0 != ret:
|
347 |
+
raise RuntimeError(f"axclrtEngineGetInputBufferByIndex failed for input {i}.")
|
348 |
+
ret = axclrt_lib.axclrtMemcpy(dev_prt[0], npy_ptr, npy.nbytes, axclrt_lib.AXCL_MEMCPY_HOST_TO_DEVICE)
|
349 |
+
if 0 != ret:
|
350 |
+
raise RuntimeError(f"axclrtMemcpy failed for input {i}.")
|
351 |
+
|
352 |
+
# execute model
|
353 |
+
ret = axclrt_lib.axclrtEngineExecute(self._model_id[0], self._context_id[0], 0, self._io)
|
354 |
+
|
355 |
+
# get output
|
356 |
+
outputs = []
|
357 |
+
if 0 == ret:
|
358 |
+
for i in range(len(self.get_outputs())):
|
359 |
+
ret = axclrt_lib.axclrtEngineGetOutputBufferByIndex(self._io, i, dev_prt, dev_size)
|
360 |
+
if 0 != ret:
|
361 |
+
raise RuntimeError(f"axclrtEngineGetOutputBufferByIndex failed for output {i}.")
|
362 |
+
npy = np.zeros(self.get_outputs()[i].shape, dtype=self.get_outputs()[i].dtype)
|
363 |
+
npy_ptr = axclrt_cffi.cast("void *", npy.ctypes.data)
|
364 |
+
ret = axclrt_lib.axclrtMemcpy(npy_ptr, dev_prt[0], npy.nbytes, axclrt_lib.AXCL_MEMCPY_DEVICE_TO_HOST)
|
365 |
+
if 0 != ret:
|
366 |
+
raise RuntimeError(f"axclrtMemcpy failed for output {i}.")
|
367 |
+
name = self.get_outputs()[i].name
|
368 |
+
if name in output_names:
|
369 |
+
outputs.append(npy)
|
370 |
+
return outputs
|
371 |
+
else:
|
372 |
+
raise RuntimeError(f"axclrtEngineExecute failed 0x{ret:08x}")
|
python/axengine/_axclrt_capi.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2019-2024 Axera Semiconductor Co., Ltd. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# This source file is the property of Axera Semiconductor Co., Ltd. and
|
4 |
+
# may not be copied or distributed in any isomorphic form without the prior
|
5 |
+
# written consent of Axera Semiconductor Co., Ltd.
|
6 |
+
#
|
7 |
+
|
8 |
+
import ctypes.util
|
9 |
+
|
10 |
+
from cffi import FFI
|
11 |
+
|
12 |
+
__all__: ["axclrt_cffi", "axclrt_lib"]
|
13 |
+
|
14 |
+
axclrt_cffi = FFI()
|
15 |
+
|
16 |
+
# axcl_base.h
|
17 |
+
axclrt_cffi.cdef(
|
18 |
+
"""
|
19 |
+
#define AXCL_MAX_DEVICE_COUNT 256
|
20 |
+
typedef int32_t axclError;
|
21 |
+
typedef void *axclrtContext;
|
22 |
+
"""
|
23 |
+
)
|
24 |
+
|
25 |
+
# axcl_rt_type.h
|
26 |
+
axclrt_cffi.cdef(
|
27 |
+
"""
|
28 |
+
typedef struct axclrtDeviceList {
|
29 |
+
uint32_t num;
|
30 |
+
int32_t devices[AXCL_MAX_DEVICE_COUNT];
|
31 |
+
} axclrtDeviceList;
|
32 |
+
|
33 |
+
typedef enum axclrtMemMallocPolicy {
|
34 |
+
AXCL_MEM_MALLOC_HUGE_FIRST,
|
35 |
+
AXCL_MEM_MALLOC_HUGE_ONLY,
|
36 |
+
AXCL_MEM_MALLOC_NORMAL_ONLY
|
37 |
+
} axclrtMemMallocPolicy;
|
38 |
+
|
39 |
+
typedef enum axclrtMemcpyKind {
|
40 |
+
AXCL_MEMCPY_HOST_TO_HOST,
|
41 |
+
AXCL_MEMCPY_HOST_TO_DEVICE, //!< host vir -> device phy
|
42 |
+
AXCL_MEMCPY_DEVICE_TO_HOST, //!< host vir <- device phy
|
43 |
+
AXCL_MEMCPY_DEVICE_TO_DEVICE,
|
44 |
+
AXCL_MEMCPY_HOST_PHY_TO_DEVICE, //!< host phy -> device phy
|
45 |
+
AXCL_MEMCPY_DEVICE_TO_HOST_PHY, //!< host phy <- device phy
|
46 |
+
} axclrtMemcpyKind;
|
47 |
+
"""
|
48 |
+
)
|
49 |
+
|
50 |
+
# axcl_rt_engine_type.h
|
51 |
+
axclrt_cffi.cdef(
|
52 |
+
"""
|
53 |
+
#define AXCLRT_ENGINE_MAX_DIM_CNT 32
|
54 |
+
typedef void* axclrtEngineIOInfo;
|
55 |
+
typedef void* axclrtEngineIO;
|
56 |
+
|
57 |
+
typedef enum axclrtEngineVNpuKind {
|
58 |
+
AXCL_VNPU_DISABLE = 0,
|
59 |
+
AXCL_VNPU_ENABLE = 1,
|
60 |
+
AXCL_VNPU_BIG_LITTLE = 2,
|
61 |
+
AXCL_VNPU_LITTLE_BIG = 3,
|
62 |
+
} axclrtEngineVNpuKind;
|
63 |
+
|
64 |
+
typedef enum axclrtEngineDataType {
|
65 |
+
AXCL_DATA_TYPE_NONE = 0,
|
66 |
+
AXCL_DATA_TYPE_INT4 = 1,
|
67 |
+
AXCL_DATA_TYPE_UINT4 = 2,
|
68 |
+
AXCL_DATA_TYPE_INT8 = 3,
|
69 |
+
AXCL_DATA_TYPE_UINT8 = 4,
|
70 |
+
AXCL_DATA_TYPE_INT16 = 5,
|
71 |
+
AXCL_DATA_TYPE_UINT16 = 6,
|
72 |
+
AXCL_DATA_TYPE_INT32 = 7,
|
73 |
+
AXCL_DATA_TYPE_UINT32 = 8,
|
74 |
+
AXCL_DATA_TYPE_INT64 = 9,
|
75 |
+
AXCL_DATA_TYPE_UINT64 = 10,
|
76 |
+
AXCL_DATA_TYPE_FP4 = 11,
|
77 |
+
AXCL_DATA_TYPE_FP8 = 12,
|
78 |
+
AXCL_DATA_TYPE_FP16 = 13,
|
79 |
+
AXCL_DATA_TYPE_BF16 = 14,
|
80 |
+
AXCL_DATA_TYPE_FP32 = 15,
|
81 |
+
AXCL_DATA_TYPE_FP64 = 16,
|
82 |
+
} axclrtEngineDataType;
|
83 |
+
|
84 |
+
typedef enum axclrtEngineDataLayout {
|
85 |
+
AXCL_DATA_LAYOUT_NONE = 0,
|
86 |
+
AXCL_DATA_LAYOUT_NHWC = 0,
|
87 |
+
AXCL_DATA_LAYOUT_NCHW = 1,
|
88 |
+
} axclrtEngineDataLayout;
|
89 |
+
|
90 |
+
typedef struct axclrtEngineIODims {
|
91 |
+
int32_t dimCount;
|
92 |
+
int32_t dims[AXCLRT_ENGINE_MAX_DIM_CNT];
|
93 |
+
} axclrtEngineIODims;
|
94 |
+
"""
|
95 |
+
)
|
96 |
+
|
97 |
+
# axcl.h
|
98 |
+
axclrt_cffi.cdef(
|
99 |
+
"""
|
100 |
+
axclError axclInit(const char *config);
|
101 |
+
axclError axclFinalize();
|
102 |
+
"""
|
103 |
+
)
|
104 |
+
|
105 |
+
# axcl_rt.h
|
106 |
+
axclrt_cffi.cdef(
|
107 |
+
"""
|
108 |
+
axclError axclrtGetVersion(int32_t *major, int32_t *minor, int32_t *patch);
|
109 |
+
const char *axclrtGetSocName();
|
110 |
+
"""
|
111 |
+
)
|
112 |
+
|
113 |
+
# axcl_rt_device.h
|
114 |
+
axclrt_cffi.cdef(
|
115 |
+
"""
|
116 |
+
axclError axclrtGetDeviceList(axclrtDeviceList *deviceList);
|
117 |
+
axclError axclrtSetDevice(int32_t deviceId);
|
118 |
+
axclError axclrtResetDevice(int32_t deviceId);
|
119 |
+
"""
|
120 |
+
)
|
121 |
+
|
122 |
+
# axcl_rt_context.h
|
123 |
+
axclrt_cffi.cdef(
|
124 |
+
"""
|
125 |
+
axclError axclrtCreateContext(axclrtContext *context, int32_t deviceId);
|
126 |
+
axclError axclrtDestroyContext(axclrtContext context);
|
127 |
+
axclError axclrtSetCurrentContext(axclrtContext context);
|
128 |
+
axclError axclrtGetCurrentContext(axclrtContext *context);
|
129 |
+
axclError axclrtGetDefaultContext(axclrtContext *context, int32_t deviceId);
|
130 |
+
"""
|
131 |
+
)
|
132 |
+
|
133 |
+
# axcl_rt_engine.h
|
134 |
+
axclrt_cffi.cdef(
|
135 |
+
"""
|
136 |
+
axclError axclrtEngineInit(axclrtEngineVNpuKind npuKind);
|
137 |
+
axclError axclrtEngineGetVNpuKind(axclrtEngineVNpuKind *npuKind);
|
138 |
+
axclError axclrtEngineFinalize();
|
139 |
+
|
140 |
+
axclError axclrtEngineLoadFromFile(const char *modelPath, uint64_t *modelId);
|
141 |
+
axclError axclrtEngineLoadFromMem(const void *model, uint64_t modelSize, uint64_t *modelId);
|
142 |
+
const char* axclrtEngineGetModelCompilerVersion(uint64_t modelId);
|
143 |
+
axclError axclrtEngineUnload(uint64_t modelId);
|
144 |
+
|
145 |
+
axclError axclrtEngineGetIOInfo(uint64_t modelId, axclrtEngineIOInfo *ioInfo);
|
146 |
+
axclError axclrtEngineGetShapeGroupsCount(axclrtEngineIOInfo ioInfo, int32_t *count);
|
147 |
+
|
148 |
+
uint32_t axclrtEngineGetNumInputs(axclrtEngineIOInfo ioInfo);
|
149 |
+
uint32_t axclrtEngineGetNumOutputs(axclrtEngineIOInfo ioInfo);
|
150 |
+
|
151 |
+
uint64_t axclrtEngineGetInputSizeByIndex(axclrtEngineIOInfo ioInfo, uint32_t group, uint32_t index);
|
152 |
+
uint64_t axclrtEngineGetOutputSizeByIndex(axclrtEngineIOInfo ioInfo, uint32_t group, uint32_t index);
|
153 |
+
|
154 |
+
axclError axclrtEngineGetInputDims(axclrtEngineIOInfo ioInfo, uint32_t group, uint32_t index, axclrtEngineIODims *dims);
|
155 |
+
axclError axclrtEngineGetOutputDims(axclrtEngineIOInfo ioInfo, uint32_t group, uint32_t index, axclrtEngineIODims *dims);
|
156 |
+
|
157 |
+
const char *axclrtEngineGetInputNameByIndex(axclrtEngineIOInfo ioInfo, uint32_t index);
|
158 |
+
const char *axclrtEngineGetOutputNameByIndex(axclrtEngineIOInfo ioInfo, uint32_t index);
|
159 |
+
|
160 |
+
int32_t axclrtEngineGetInputDataType(axclrtEngineIOInfo ioInfo, uint32_t index, axclrtEngineDataType *type);
|
161 |
+
int32_t axclrtEngineGetOutputDataType(axclrtEngineIOInfo ioInfo, uint32_t index, axclrtEngineDataType *type);
|
162 |
+
|
163 |
+
int32_t axclrtEngineGetInputDataLayout(axclrtEngineIOInfo ioInfo, uint32_t index, axclrtEngineDataLayout *layout);
|
164 |
+
int32_t axclrtEngineGetOutputDataLayout(axclrtEngineIOInfo ioInfo, uint32_t index, axclrtEngineDataLayout *layout);
|
165 |
+
|
166 |
+
axclError axclrtEngineCreateIO(axclrtEngineIOInfo ioInfo, axclrtEngineIO *io);
|
167 |
+
axclError axclrtEngineDestroyIO(axclrtEngineIO io);
|
168 |
+
|
169 |
+
axclError axclrtEngineSetInputBufferByIndex(axclrtEngineIO io, uint32_t index, const void *dataBuffer, uint64_t size);
|
170 |
+
axclError axclrtEngineSetOutputBufferByIndex(axclrtEngineIO io, uint32_t index, const void *dataBuffer, uint64_t size);
|
171 |
+
axclError axclrtEngineGetInputBufferByIndex(axclrtEngineIO io, uint32_t index, void **dataBuffer, uint64_t *size);
|
172 |
+
axclError axclrtEngineGetOutputBufferByIndex(axclrtEngineIO io, uint32_t index, void **dataBuffer, uint64_t *size);
|
173 |
+
|
174 |
+
axclError axclrtEngineCreateContext(uint64_t modelId, uint64_t *contextId);
|
175 |
+
|
176 |
+
axclError axclrtEngineExecute(uint64_t modelId, uint64_t contextId, uint32_t group, axclrtEngineIO io);
|
177 |
+
"""
|
178 |
+
)
|
179 |
+
|
180 |
+
# axcl_rt_memory.h
|
181 |
+
axclrt_cffi.cdef(
|
182 |
+
"""
|
183 |
+
axclError axclrtMalloc(void **devPtr, size_t size, axclrtMemMallocPolicy policy);
|
184 |
+
axclError axclrtMallocCached(void **devPtr, size_t size, axclrtMemMallocPolicy policy);
|
185 |
+
axclError axclrtMemcpy(void *dstPtr, const void *srcPtr, size_t count, axclrtMemcpyKind kind);
|
186 |
+
axclError axclrtFree(void *devPtr);
|
187 |
+
axclError axclrtMemFlush(void *devPtr, size_t size);
|
188 |
+
"""
|
189 |
+
)
|
190 |
+
|
191 |
+
rt_name = "axcl_rt"
|
192 |
+
rt_path = ctypes.util.find_library(rt_name)
|
193 |
+
assert (
|
194 |
+
rt_path is not None
|
195 |
+
), f"Failed to find library {rt_name}. Please ensure it is installed and in the library path."
|
196 |
+
|
197 |
+
axclrt_lib = axclrt_cffi.dlopen(rt_path)
|
198 |
+
assert axclrt_lib is not None, f"Failed to load library {rt_path}. Please ensure it is installed and in the library path."
|
python/axengine/_axclrt_types.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2019-2024 Axera Semiconductor Co., Ltd. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# This source file is the property of Axera Semiconductor Co., Ltd. and
|
4 |
+
# may not be copied or distributed in any isomorphic form without the prior
|
5 |
+
# written consent of Axera Semiconductor Co., Ltd.
|
6 |
+
#
|
7 |
+
|
8 |
+
from enum import Enum
|
9 |
+
|
10 |
+
|
11 |
+
class VNPUType(Enum):
|
12 |
+
DISABLED = 0
|
13 |
+
ENABLED = 1
|
14 |
+
BIG_LITTLE = 2
|
15 |
+
LITTLE_BIG = 3
|
16 |
+
|
17 |
+
|
18 |
+
class ModelType(Enum):
|
19 |
+
SINGLE = 0
|
20 |
+
DUAL = 1
|
21 |
+
TRIPLE = 2
|
python/axengine/_axe.py
ADDED
@@ -0,0 +1,399 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2019-2024 Axera Semiconductor Co., Ltd. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# This source file is the property of Axera Semiconductor Co., Ltd. and
|
4 |
+
# may not be copied or distributed in any isomorphic form without the prior
|
5 |
+
# written consent of Axera Semiconductor Co., Ltd.
|
6 |
+
#
|
7 |
+
|
8 |
+
import atexit
|
9 |
+
import os
|
10 |
+
from typing import Any, Sequence
|
11 |
+
|
12 |
+
import ml_dtypes as mldt
|
13 |
+
import numpy as np
|
14 |
+
|
15 |
+
from ._axe_capi import sys_lib, engine_cffi, engine_lib
|
16 |
+
from ._axe_types import VNPUType, ModelType, ChipType
|
17 |
+
from ._base_session import Session, SessionOptions
|
18 |
+
from ._node import NodeArg
|
19 |
+
|
20 |
+
__all__: ["AXEngineSession"]
|
21 |
+
|
22 |
+
_is_sys_initialized = False
|
23 |
+
_is_engine_initialized = False
|
24 |
+
|
25 |
+
|
26 |
+
def _transform_dtype(dtype):
|
27 |
+
if dtype == engine_cffi.cast("AX_ENGINE_DATA_TYPE_T", engine_lib.AX_ENGINE_DT_UINT8):
|
28 |
+
return np.dtype(np.uint8)
|
29 |
+
elif dtype == engine_cffi.cast("AX_ENGINE_DATA_TYPE_T", engine_lib.AX_ENGINE_DT_SINT8):
|
30 |
+
return np.dtype(np.int8)
|
31 |
+
elif dtype == engine_cffi.cast("AX_ENGINE_DATA_TYPE_T", engine_lib.AX_ENGINE_DT_UINT16):
|
32 |
+
return np.dtype(np.uint16)
|
33 |
+
elif dtype == engine_cffi.cast("AX_ENGINE_DATA_TYPE_T", engine_lib.AX_ENGINE_DT_SINT16):
|
34 |
+
return np.dtype(np.int16)
|
35 |
+
elif dtype == engine_cffi.cast("AX_ENGINE_DATA_TYPE_T", engine_lib.AX_ENGINE_DT_UINT32):
|
36 |
+
return np.dtype(np.uint32)
|
37 |
+
elif dtype == engine_cffi.cast("AX_ENGINE_DATA_TYPE_T", engine_lib.AX_ENGINE_DT_SINT32):
|
38 |
+
return np.dtype(np.int32)
|
39 |
+
elif dtype == engine_cffi.cast("AX_ENGINE_DATA_TYPE_T", engine_lib.AX_ENGINE_DT_FLOAT32):
|
40 |
+
return np.dtype(np.float32)
|
41 |
+
elif dtype == engine_cffi.cast("AX_ENGINE_DATA_TYPE_T", engine_lib.AX_ENGINE_DT_BFLOAT16):
|
42 |
+
return np.dtype(mldt.bfloat16)
|
43 |
+
else:
|
44 |
+
raise ValueError(f"Unsupported data type '{dtype}'.")
|
45 |
+
|
46 |
+
|
47 |
+
def _check_cffi_func_exists(lib, func_name):
|
48 |
+
try:
|
49 |
+
getattr(lib, func_name)
|
50 |
+
return True
|
51 |
+
except AttributeError:
|
52 |
+
return False
|
53 |
+
|
54 |
+
|
55 |
+
def _get_chip_type():
|
56 |
+
if not _check_cffi_func_exists(engine_lib, "AX_ENGINE_SetAffinity"):
|
57 |
+
return ChipType.M57H
|
58 |
+
elif not _check_cffi_func_exists(engine_lib, "AX_ENGINE_GetTotalOps"):
|
59 |
+
return ChipType.MC50
|
60 |
+
else:
|
61 |
+
return ChipType.MC20E
|
62 |
+
|
63 |
+
|
64 |
+
def _get_version():
|
65 |
+
engine_version = engine_lib.AX_ENGINE_GetVersion()
|
66 |
+
return engine_cffi.string(engine_version).decode("utf-8")
|
67 |
+
|
68 |
+
|
69 |
+
def _get_vnpu_type() -> VNPUType:
|
70 |
+
vnpu_type = engine_cffi.new("AX_ENGINE_NPU_ATTR_T *")
|
71 |
+
ret = engine_lib.AX_ENGINE_GetVNPUAttr(vnpu_type)
|
72 |
+
if 0 != ret:
|
73 |
+
raise RuntimeError("Failed to get VNPU attribute.")
|
74 |
+
return VNPUType(vnpu_type.eHardMode)
|
75 |
+
|
76 |
+
|
77 |
+
def _initialize_engine():
|
78 |
+
global _is_sys_initialized, _is_engine_initialized
|
79 |
+
|
80 |
+
ret = sys_lib.AX_SYS_Init()
|
81 |
+
if ret != 0:
|
82 |
+
raise RuntimeError("Failed to initialize ax sys.")
|
83 |
+
_is_sys_initialized = True
|
84 |
+
|
85 |
+
# disabled mode by default
|
86 |
+
vnpu_type = engine_cffi.new("AX_ENGINE_NPU_ATTR_T *")
|
87 |
+
ret = engine_lib.AX_ENGINE_GetVNPUAttr(vnpu_type)
|
88 |
+
if 0 != ret:
|
89 |
+
# this means the NPU was not initialized
|
90 |
+
vnpu_type.eHardMode = engine_cffi.cast(
|
91 |
+
"AX_ENGINE_NPU_MODE_T", VNPUType.DISABLED.value
|
92 |
+
)
|
93 |
+
ret = engine_lib.AX_ENGINE_Init(vnpu_type)
|
94 |
+
if ret != 0:
|
95 |
+
raise RuntimeError("Failed to initialize ax sys engine.")
|
96 |
+
_is_engine_initialized = True
|
97 |
+
|
98 |
+
print(f"[INFO] Chip type: {_get_chip_type()}")
|
99 |
+
print(f"[INFO] VNPU type: {_get_vnpu_type()}")
|
100 |
+
print(f"[INFO] Engine version: {_get_version()}")
|
101 |
+
|
102 |
+
|
103 |
+
def _finalize_engine():
|
104 |
+
global _is_sys_initialized, _is_engine_initialized
|
105 |
+
|
106 |
+
if _is_engine_initialized:
|
107 |
+
engine_lib.AX_ENGINE_Deinit()
|
108 |
+
if _is_sys_initialized:
|
109 |
+
sys_lib.AX_SYS_Deinit()
|
110 |
+
|
111 |
+
|
112 |
+
_initialize_engine()
|
113 |
+
atexit.register(_finalize_engine)
|
114 |
+
|
115 |
+
|
116 |
+
class AXEngineSession(Session):
|
117 |
+
def __init__(
|
118 |
+
self,
|
119 |
+
path_or_bytes: str | bytes | os.PathLike,
|
120 |
+
sess_options: SessionOptions | None = None,
|
121 |
+
provider_options: dict[Any, Any] | None = None,
|
122 |
+
**kwargs,
|
123 |
+
) -> None:
|
124 |
+
super().__init__()
|
125 |
+
|
126 |
+
self._chip_type = _get_chip_type()
|
127 |
+
self._vnpu_type = _get_vnpu_type()
|
128 |
+
|
129 |
+
# handle, context, info, io
|
130 |
+
self._handle = engine_cffi.new("uint64_t **")
|
131 |
+
self._context = engine_cffi.new("uint64_t **")
|
132 |
+
self._io = engine_cffi.new("AX_ENGINE_IO_T *")
|
133 |
+
|
134 |
+
# model buffer, almost copied from onnx runtime
|
135 |
+
if isinstance(path_or_bytes, (str, os.PathLike)):
|
136 |
+
self._model_name = os.path.splitext(os.path.basename(path_or_bytes))[0]
|
137 |
+
with open(path_or_bytes, "rb") as f:
|
138 |
+
data = f.read()
|
139 |
+
self._model_buffer = engine_cffi.new("char[]", data)
|
140 |
+
self._model_buffer_size = len(data)
|
141 |
+
elif isinstance(path_or_bytes, bytes):
|
142 |
+
self._model_buffer = engine_cffi.new("char[]", path_or_bytes)
|
143 |
+
self._model_buffer_size = len(path_or_bytes)
|
144 |
+
else:
|
145 |
+
raise TypeError(f"Unable to load model from type '{type(path_or_bytes)}'")
|
146 |
+
|
147 |
+
# get model type
|
148 |
+
self._model_type = self._get_model_type()
|
149 |
+
if self._chip_type is ChipType.MC20E:
|
150 |
+
if self._model_type is ModelType.FULL:
|
151 |
+
print(f"[INFO] Model type: {self._model_type.value} (full core)")
|
152 |
+
if self._model_type is ModelType.HALF:
|
153 |
+
print(f"[INFO] Model type: {self._model_type.value} (half core)")
|
154 |
+
if self._chip_type is ChipType.MC50:
|
155 |
+
if self._model_type is ModelType.SINGLE:
|
156 |
+
print(f"[INFO] Model type: {self._model_type.value} (single core)")
|
157 |
+
if self._model_type is ModelType.DUAL:
|
158 |
+
print(f"[INFO] Model type: {self._model_type.value} (dual core)")
|
159 |
+
if self._model_type is ModelType.TRIPLE:
|
160 |
+
print(f"[INFO] Model type: {self._model_type.value} (triple core)")
|
161 |
+
if self._chip_type is ChipType.M57H:
|
162 |
+
print(f"[INFO] Model type: {self._model_type.value} (single core)")
|
163 |
+
|
164 |
+
# check model type
|
165 |
+
if self._chip_type is ChipType.MC50:
|
166 |
+
# all types (single or dual or triple) of model are allowed in vnpu mode disabled
|
167 |
+
# only single core model is allowed in vnpu mode enabled
|
168 |
+
# only triple core model is NOT allowed in vnpu mode big-little or little-big
|
169 |
+
if self._vnpu_type is VNPUType.ENABLED:
|
170 |
+
if self._model_type is not ModelType.SINGLE:
|
171 |
+
raise ValueError(
|
172 |
+
f"Model type '{self._model_type}' is not allowed when vnpu is inited as {self._vnpu_type}."
|
173 |
+
)
|
174 |
+
if (
|
175 |
+
self._vnpu_type is VNPUType.BIG_LITTLE
|
176 |
+
or self._vnpu_type is VNPUType.LITTLE_BIG
|
177 |
+
):
|
178 |
+
if self._model_type is ModelType.TRIPLE:
|
179 |
+
raise ValueError(
|
180 |
+
f"Model type '{self._model_type}' is not allowed when vnpu is inited as {self._vnpu_type}."
|
181 |
+
)
|
182 |
+
if self._chip_type is ChipType.MC20E:
|
183 |
+
# all types of full or half core model are allowed in vnpu mode disabled
|
184 |
+
# only half core model is allowed in vnpu mode enabled
|
185 |
+
if self._vnpu_type is VNPUType.ENABLED:
|
186 |
+
if self._model_type is ModelType.FULL:
|
187 |
+
raise ValueError(
|
188 |
+
f"Model type '{self._model_type}' is not allowed when vnpu is inited as {self._vnpu_type}."
|
189 |
+
)
|
190 |
+
# if self._chip_type is ChipType.M57H:
|
191 |
+
# there only one type of model will be compiled, so no need to check
|
192 |
+
|
193 |
+
# load model
|
194 |
+
ret = self._load()
|
195 |
+
if 0 != ret:
|
196 |
+
raise RuntimeError("Failed to load model.")
|
197 |
+
print(f"[INFO] Compiler version: {self._get_model_tool_version()}")
|
198 |
+
|
199 |
+
# get shape group count
|
200 |
+
try:
|
201 |
+
self._shape_count = self._get_shape_count()
|
202 |
+
except AttributeError as e:
|
203 |
+
print(f"[WARNING] {e}")
|
204 |
+
self._shape_count = 1
|
205 |
+
|
206 |
+
# get model shape
|
207 |
+
self._info = self._get_info()
|
208 |
+
self._inputs = self._get_inputs()
|
209 |
+
self._outputs = self._get_outputs()
|
210 |
+
|
211 |
+
# fill model io
|
212 |
+
self._align = 128
|
213 |
+
self._cmm_token = engine_cffi.new("AX_S8[]", b"PyEngine")
|
214 |
+
self._io[0].nInputSize = len(self.get_inputs())
|
215 |
+
self._io[0].nOutputSize = len(self.get_outputs())
|
216 |
+
self._io[0].pInputs = engine_cffi.new(
|
217 |
+
"AX_ENGINE_IO_BUFFER_T[{}]".format(self._io[0].nInputSize)
|
218 |
+
)
|
219 |
+
self._io[0].pOutputs = engine_cffi.new(
|
220 |
+
"AX_ENGINE_IO_BUFFER_T[{}]".format(self._io[0].nOutputSize)
|
221 |
+
)
|
222 |
+
for i in range(len(self.get_inputs())):
|
223 |
+
max_buf = 0
|
224 |
+
for j in range(self._shape_count):
|
225 |
+
max_buf = max(max_buf, self._info[j][0].pInputs[i].nSize)
|
226 |
+
self._io[0].pInputs[i].nSize = max_buf
|
227 |
+
phy = engine_cffi.new("AX_U64*")
|
228 |
+
vir = engine_cffi.new("AX_VOID**")
|
229 |
+
ret = sys_lib.AX_SYS_MemAllocCached(
|
230 |
+
phy, vir, self._io[0].pInputs[i].nSize, self._align, self._cmm_token
|
231 |
+
)
|
232 |
+
if 0 != ret:
|
233 |
+
raise RuntimeError("Failed to allocate memory for input.")
|
234 |
+
self._io[0].pInputs[i].phyAddr = phy[0]
|
235 |
+
self._io[0].pInputs[i].pVirAddr = vir[0]
|
236 |
+
for i in range(len(self.get_outputs())):
|
237 |
+
max_buf = 0
|
238 |
+
for j in range(self._shape_count):
|
239 |
+
max_buf = max(max_buf, self._info[j][0].pOutputs[i].nSize)
|
240 |
+
self._io[0].pOutputs[i].nSize = max_buf
|
241 |
+
phy = engine_cffi.new("AX_U64*")
|
242 |
+
vir = engine_cffi.new("AX_VOID**")
|
243 |
+
ret = sys_lib.AX_SYS_MemAllocCached(
|
244 |
+
phy, vir, self._io[0].pOutputs[i].nSize, self._align, self._cmm_token
|
245 |
+
)
|
246 |
+
if 0 != ret:
|
247 |
+
raise RuntimeError("Failed to allocate memory for output.")
|
248 |
+
self._io[0].pOutputs[i].phyAddr = phy[0]
|
249 |
+
self._io[0].pOutputs[i].pVirAddr = vir[0]
|
250 |
+
|
251 |
+
def __del__(self):
|
252 |
+
self._unload()
|
253 |
+
|
254 |
+
def _get_model_type(self) -> ModelType:
|
255 |
+
model_type = engine_cffi.new("AX_ENGINE_MODEL_TYPE_T *")
|
256 |
+
ret = engine_lib.AX_ENGINE_GetModelType(
|
257 |
+
self._model_buffer, self._model_buffer_size, model_type
|
258 |
+
)
|
259 |
+
if 0 != ret:
|
260 |
+
raise RuntimeError("Failed to get model type.")
|
261 |
+
return ModelType(model_type[0])
|
262 |
+
|
263 |
+
def _get_model_tool_version(self):
|
264 |
+
model_tool_version = engine_lib.AX_ENGINE_GetModelToolsVersion(
|
265 |
+
self._handle[0]
|
266 |
+
)
|
267 |
+
return engine_cffi.string(model_tool_version).decode("utf-8")
|
268 |
+
|
269 |
+
def _load(self):
|
270 |
+
extra = engine_cffi.new("AX_ENGINE_HANDLE_EXTRA_T *")
|
271 |
+
extra_name = engine_cffi.new("char[]", self._model_name.encode("utf-8"))
|
272 |
+
extra.pName = extra_name
|
273 |
+
|
274 |
+
# for onnx runtime do not support one model multiple context running in multi-thread as far as I know, so
|
275 |
+
# the engine handle and context will create only once
|
276 |
+
ret = engine_lib.AX_ENGINE_CreateHandleV2(
|
277 |
+
self._handle, self._model_buffer, self._model_buffer_size, extra
|
278 |
+
)
|
279 |
+
if 0 == ret:
|
280 |
+
ret = engine_lib.AX_ENGINE_CreateContextV2(
|
281 |
+
self._handle[0], self._context
|
282 |
+
)
|
283 |
+
return ret
|
284 |
+
|
285 |
+
def _get_info(self):
|
286 |
+
total_info = []
|
287 |
+
if 1 == self._shape_count:
|
288 |
+
info = engine_cffi.new("AX_ENGINE_IO_INFO_T **")
|
289 |
+
ret = engine_lib.AX_ENGINE_GetIOInfo(self._handle[0], info)
|
290 |
+
if 0 != ret:
|
291 |
+
raise RuntimeError("Failed to get model shape.")
|
292 |
+
total_info.append(info)
|
293 |
+
else:
|
294 |
+
for i in range(self._shape_count):
|
295 |
+
info = engine_cffi.new("AX_ENGINE_IO_INFO_T **")
|
296 |
+
ret = engine_lib.AX_ENGINE_GetGroupIOInfo(
|
297 |
+
self._handle[0], i, info
|
298 |
+
)
|
299 |
+
if 0 != ret:
|
300 |
+
raise RuntimeError(f"Failed to get model the {i}th shape.")
|
301 |
+
total_info.append(info)
|
302 |
+
return total_info
|
303 |
+
|
304 |
+
def _get_shape_count(self):
|
305 |
+
count = engine_cffi.new("AX_U32 *")
|
306 |
+
ret = engine_lib.AX_ENGINE_GetGroupIOInfoCount(self._handle[0], count)
|
307 |
+
if 0 != ret:
|
308 |
+
raise RuntimeError("Failed to get model shape group.")
|
309 |
+
return count[0]
|
310 |
+
|
311 |
+
def _unload(self):
|
312 |
+
if self._handle[0] is not None:
|
313 |
+
engine_lib.AX_ENGINE_DestroyHandle(self._handle[0])
|
314 |
+
self._handle[0] = engine_cffi.NULL
|
315 |
+
|
316 |
+
def _get_io(self, io_type: str):
|
317 |
+
io_info = []
|
318 |
+
for group in range(self._shape_count):
|
319 |
+
one_group_io = []
|
320 |
+
for index in range(getattr(self._info[group][0], f'n{io_type}Size')):
|
321 |
+
current_io = getattr(self._info[group][0], f'p{io_type}s')[index]
|
322 |
+
name = engine_cffi.string(current_io.pName).decode("utf-8")
|
323 |
+
shape = [current_io.pShape[i] for i in range(current_io.nShapeSize)]
|
324 |
+
dtype = _transform_dtype(current_io.eDataType)
|
325 |
+
meta = NodeArg(name, dtype, shape)
|
326 |
+
one_group_io.append(meta)
|
327 |
+
io_info.append(one_group_io)
|
328 |
+
return io_info
|
329 |
+
|
330 |
+
def _get_inputs(self):
|
331 |
+
return self._get_io('Input')
|
332 |
+
|
333 |
+
def _get_outputs(self):
|
334 |
+
return self._get_io('Output')
|
335 |
+
|
336 |
+
def run(
|
337 |
+
self,
|
338 |
+
output_names: list[str],
|
339 |
+
input_feed: dict[str, np.ndarray],
|
340 |
+
run_options=None
|
341 |
+
):
|
342 |
+
self._validate_input(input_feed)
|
343 |
+
self._validate_output(output_names)
|
344 |
+
|
345 |
+
if None is output_names:
|
346 |
+
output_names = [o.name for o in self.get_outputs()]
|
347 |
+
|
348 |
+
# fill model io
|
349 |
+
for key, npy in input_feed.items():
|
350 |
+
for i, one in enumerate(self.get_inputs()):
|
351 |
+
if one.name == key:
|
352 |
+
assert (
|
353 |
+
list(one.shape) == list(npy.shape) and one.dtype == npy.dtype
|
354 |
+
), f"model inputs({key}) expect shape {one.shape} and dtype {one.dtype}, however gets input with shape {npy.shape} and dtype {npy.dtype}"
|
355 |
+
|
356 |
+
if not (
|
357 |
+
not npy.flags.c_contiguous
|
358 |
+
and npy.flags.f_contiguous
|
359 |
+
and npy.flags.contiguous
|
360 |
+
):
|
361 |
+
npy = np.ascontiguousarray(npy)
|
362 |
+
npy_ptr = engine_cffi.cast("void *", npy.ctypes.data)
|
363 |
+
|
364 |
+
engine_cffi.memmove(
|
365 |
+
self._io[0].pInputs[i].pVirAddr, npy_ptr, npy.nbytes
|
366 |
+
)
|
367 |
+
sys_lib.AX_SYS_MflushCache(
|
368 |
+
self._io[0].pInputs[i].phyAddr,
|
369 |
+
self._io[0].pInputs[i].pVirAddr,
|
370 |
+
self._io[0].pInputs[i].nSize,
|
371 |
+
)
|
372 |
+
break
|
373 |
+
|
374 |
+
# execute model
|
375 |
+
ret = engine_lib.AX_ENGINE_RunSyncV2(
|
376 |
+
self._handle[0], self._context[0], self._io
|
377 |
+
)
|
378 |
+
|
379 |
+
# flush output
|
380 |
+
outputs = []
|
381 |
+
if 0 == ret:
|
382 |
+
for i in range(len(self.get_outputs())):
|
383 |
+
sys_lib.AX_SYS_MinvalidateCache(
|
384 |
+
self._io[0].pOutputs[i].phyAddr,
|
385 |
+
self._io[0].pOutputs[i].pVirAddr,
|
386 |
+
self._io[0].pOutputs[i].nSize,
|
387 |
+
)
|
388 |
+
npy = np.frombuffer(
|
389 |
+
engine_cffi.buffer(
|
390 |
+
self._io[0].pOutputs[i].pVirAddr, self._io[0].pOutputs[i].nSize
|
391 |
+
),
|
392 |
+
dtype=self.get_outputs()[i].dtype,
|
393 |
+
).reshape(self.get_outputs()[i].shape)
|
394 |
+
name = self.get_outputs()[i].name
|
395 |
+
if name in output_names:
|
396 |
+
outputs.append(npy)
|
397 |
+
return outputs
|
398 |
+
else:
|
399 |
+
raise RuntimeError("Failed to run model.")
|
python/axengine/_axe_capi.py
ADDED
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2019-2024 Axera Semiconductor Co., Ltd. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# This source file is the property of Axera Semiconductor Co., Ltd. and
|
4 |
+
# may not be copied or distributed in any isomorphic form without the prior
|
5 |
+
# written consent of Axera Semiconductor Co., Ltd.
|
6 |
+
#
|
7 |
+
|
8 |
+
import ctypes.util
|
9 |
+
import platform
|
10 |
+
|
11 |
+
from cffi import FFI
|
12 |
+
|
13 |
+
__all__: ["sys_lib", "sys_cffi", "engine_lib", "engine_cffi"]
|
14 |
+
|
15 |
+
sys_cffi = FFI()
|
16 |
+
|
17 |
+
# ax_base_type.h
|
18 |
+
sys_cffi.cdef(
|
19 |
+
"""
|
20 |
+
typedef int AX_S32;
|
21 |
+
typedef unsigned int AX_U32;
|
22 |
+
typedef unsigned long long int AX_U64;
|
23 |
+
typedef signed char AX_S8;
|
24 |
+
typedef void AX_VOID;
|
25 |
+
"""
|
26 |
+
)
|
27 |
+
|
28 |
+
# ax_sys_api.h
|
29 |
+
sys_cffi.cdef(
|
30 |
+
"""
|
31 |
+
AX_S32 AX_SYS_Init(AX_VOID);
|
32 |
+
AX_S32 AX_SYS_Deinit(AX_VOID);
|
33 |
+
AX_S32 AX_SYS_MemAllocCached(AX_U64 *phyaddr, AX_VOID **pviraddr, AX_U32 size, AX_U32 align, const AX_S8 *token);
|
34 |
+
AX_S32 AX_SYS_MemFree(AX_U64 phyaddr, AX_VOID *pviraddr);
|
35 |
+
AX_S32 AX_SYS_MflushCache(AX_U64 phyaddr, AX_VOID *pviraddr, AX_U32 size);
|
36 |
+
AX_S32 AX_SYS_MinvalidateCache(AX_U64 phyaddr, AX_VOID *pviraddr, AX_U32 size);
|
37 |
+
"""
|
38 |
+
)
|
39 |
+
|
40 |
+
sys_name = "ax_sys"
|
41 |
+
sys_path = ctypes.util.find_library(sys_name)
|
42 |
+
assert (
|
43 |
+
sys_path is not None
|
44 |
+
), f"Failed to find library {sys_name}. Please ensure it is installed and in the library path."
|
45 |
+
|
46 |
+
sys_lib = sys_cffi.dlopen(sys_path)
|
47 |
+
assert sys_lib is not None, f"Failed to load library {sys_path}. Please ensure it is installed and in the library path."
|
48 |
+
|
49 |
+
engine_cffi = FFI()
|
50 |
+
|
51 |
+
# ax_base_type.h
|
52 |
+
engine_cffi.cdef(
|
53 |
+
"""
|
54 |
+
typedef unsigned long long int AX_U64;
|
55 |
+
typedef unsigned int AX_U32;
|
56 |
+
typedef unsigned char AX_U8;
|
57 |
+
typedef int AX_S32;
|
58 |
+
typedef signed char AX_S8;
|
59 |
+
typedef char AX_CHAR;
|
60 |
+
typedef void AX_VOID;
|
61 |
+
|
62 |
+
typedef enum {
|
63 |
+
AX_FALSE = 0,
|
64 |
+
AX_TRUE = 1,
|
65 |
+
} AX_BOOL;
|
66 |
+
"""
|
67 |
+
)
|
68 |
+
|
69 |
+
# ax_engine_type.h, base type
|
70 |
+
engine_cffi.cdef(
|
71 |
+
"""
|
72 |
+
typedef AX_U32 AX_ENGINE_NPU_SET_T;
|
73 |
+
"""
|
74 |
+
)
|
75 |
+
|
76 |
+
# ax_engine_type.h, enum
|
77 |
+
engine_cffi.cdef(
|
78 |
+
"""
|
79 |
+
typedef enum _AX_ENGINE_TENSOR_LAYOUT_E
|
80 |
+
{
|
81 |
+
AX_ENGINE_TENSOR_LAYOUT_UNKNOWN = 0,
|
82 |
+
AX_ENGINE_TENSOR_LAYOUT_NHWC = 1,
|
83 |
+
AX_ENGINE_TENSOR_LAYOUT_NCHW = 2,
|
84 |
+
} AX_ENGINE_TENSOR_LAYOUT_T;
|
85 |
+
|
86 |
+
typedef enum
|
87 |
+
{
|
88 |
+
AX_ENGINE_MT_PHYSICAL = 0,
|
89 |
+
AX_ENGINE_MT_VIRTUAL = 1,
|
90 |
+
AX_ENGINE_MT_OCM = 2,
|
91 |
+
} AX_ENGINE_MEMORY_TYPE_T;
|
92 |
+
|
93 |
+
typedef enum
|
94 |
+
{
|
95 |
+
AX_ENGINE_DT_UNKNOWN = 0,
|
96 |
+
AX_ENGINE_DT_UINT8 = 1,
|
97 |
+
AX_ENGINE_DT_UINT16 = 2,
|
98 |
+
AX_ENGINE_DT_FLOAT32 = 3,
|
99 |
+
AX_ENGINE_DT_SINT16 = 4,
|
100 |
+
AX_ENGINE_DT_SINT8 = 5,
|
101 |
+
AX_ENGINE_DT_SINT32 = 6,
|
102 |
+
AX_ENGINE_DT_UINT32 = 7,
|
103 |
+
AX_ENGINE_DT_FLOAT64 = 8,
|
104 |
+
AX_ENGINE_DT_BFLOAT16 = 9,
|
105 |
+
AX_ENGINE_DT_UINT10_PACKED = 100,
|
106 |
+
AX_ENGINE_DT_UINT12_PACKED = 101,
|
107 |
+
AX_ENGINE_DT_UINT14_PACKED = 102,
|
108 |
+
AX_ENGINE_DT_UINT16_PACKED = 103,
|
109 |
+
} AX_ENGINE_DATA_TYPE_T;
|
110 |
+
|
111 |
+
typedef enum
|
112 |
+
{
|
113 |
+
AX_ENGINE_CS_FEATUREMAP = 0,
|
114 |
+
AX_ENGINE_CS_RAW8 = 12,
|
115 |
+
AX_ENGINE_CS_RAW10 = 1,
|
116 |
+
AX_ENGINE_CS_RAW12 = 2,
|
117 |
+
AX_ENGINE_CS_RAW14 = 11,
|
118 |
+
AX_ENGINE_CS_RAW16 = 3,
|
119 |
+
AX_ENGINE_CS_NV12 = 4,
|
120 |
+
AX_ENGINE_CS_NV21 = 5,
|
121 |
+
AX_ENGINE_CS_RGB = 6,
|
122 |
+
AX_ENGINE_CS_BGR = 7,
|
123 |
+
AX_ENGINE_CS_RGBA = 8,
|
124 |
+
AX_ENGINE_CS_GRAY = 9,
|
125 |
+
AX_ENGINE_CS_YUV444 = 10,
|
126 |
+
} AX_ENGINE_COLOR_SPACE_T;
|
127 |
+
"""
|
128 |
+
)
|
129 |
+
|
130 |
+
# ax_engine_type.h, architecturally agnostic struct
|
131 |
+
engine_cffi.cdef(
|
132 |
+
"""
|
133 |
+
typedef enum {
|
134 |
+
AX_ENGINE_VIRTUAL_NPU_DISABLE = 0,
|
135 |
+
} AX_ENGINE_NPU_MODE_T;
|
136 |
+
|
137 |
+
typedef enum {
|
138 |
+
AX_ENGINE_MODEL_TYPE0 = 0,
|
139 |
+
} AX_ENGINE_MODEL_TYPE_T;
|
140 |
+
|
141 |
+
typedef struct {
|
142 |
+
AX_ENGINE_NPU_MODE_T eHardMode;
|
143 |
+
AX_U32 reserve[8];
|
144 |
+
} AX_ENGINE_NPU_ATTR_T;
|
145 |
+
|
146 |
+
typedef struct _AX_ENGINE_IO_META_EX_T
|
147 |
+
{
|
148 |
+
AX_ENGINE_COLOR_SPACE_T eColorSpace;
|
149 |
+
AX_U64 u64Reserved[18];
|
150 |
+
} AX_ENGINE_IO_META_EX_T;
|
151 |
+
|
152 |
+
typedef struct {
|
153 |
+
AX_ENGINE_NPU_SET_T nNpuSet;
|
154 |
+
AX_S8* pName;
|
155 |
+
AX_U32 reserve[8];
|
156 |
+
} AX_ENGINE_HANDLE_EXTRA_T;
|
157 |
+
|
158 |
+
typedef struct _AX_ENGINE_CMM_INFO_T
|
159 |
+
{
|
160 |
+
AX_U32 nCMMSize;
|
161 |
+
} AX_ENGINE_CMM_INFO_T;
|
162 |
+
|
163 |
+
typedef struct _AX_ENGINE_IO_SETTING_T
|
164 |
+
{
|
165 |
+
AX_U32 nWbtIndex;
|
166 |
+
AX_U64 u64Reserved[7];
|
167 |
+
}AX_ENGINE_IO_SETTING_T;
|
168 |
+
"""
|
169 |
+
)
|
170 |
+
|
171 |
+
# check architecture, 32bit or 64bit
|
172 |
+
arch = platform.architecture()[0]
|
173 |
+
|
174 |
+
# ax_engine_type.h, struct
|
175 |
+
if arch == "64bit":
|
176 |
+
engine_cffi.cdef(
|
177 |
+
"""
|
178 |
+
typedef struct _AX_ENGINE_IO_META_T
|
179 |
+
{
|
180 |
+
AX_CHAR* pName;
|
181 |
+
AX_S32* pShape;
|
182 |
+
AX_U8 nShapeSize;
|
183 |
+
AX_ENGINE_TENSOR_LAYOUT_T eLayout;
|
184 |
+
AX_ENGINE_MEMORY_TYPE_T eMemoryType;
|
185 |
+
AX_ENGINE_DATA_TYPE_T eDataType;
|
186 |
+
AX_ENGINE_IO_META_EX_T* pExtraMeta;
|
187 |
+
AX_U32 nSize;
|
188 |
+
AX_U32 nQuantizationValue;
|
189 |
+
AX_S32* pStride;
|
190 |
+
AX_U64 u64Reserved[9];
|
191 |
+
} AX_ENGINE_IO_META_T;
|
192 |
+
|
193 |
+
typedef struct _AX_ENGINE_IO_INFO_T
|
194 |
+
{
|
195 |
+
AX_ENGINE_IO_META_T* pInputs;
|
196 |
+
AX_U32 nInputSize;
|
197 |
+
AX_ENGINE_IO_META_T* pOutputs;
|
198 |
+
AX_U32 nOutputSize;
|
199 |
+
AX_U32 nMaxBatchSize;
|
200 |
+
AX_BOOL bDynamicBatchSize;
|
201 |
+
AX_U64 u64Reserved[11];
|
202 |
+
} AX_ENGINE_IO_INFO_T;
|
203 |
+
|
204 |
+
typedef struct _AX_ENGINE_IO_BUFFER_T
|
205 |
+
{
|
206 |
+
AX_U64 phyAddr;
|
207 |
+
AX_VOID* pVirAddr;
|
208 |
+
AX_U32 nSize;
|
209 |
+
AX_S32* pStride;
|
210 |
+
AX_U8 nStrideSize;
|
211 |
+
AX_U64 u64Reserved[11];
|
212 |
+
} AX_ENGINE_IO_BUFFER_T;
|
213 |
+
|
214 |
+
typedef struct _AX_ENGINE_IO_T
|
215 |
+
{
|
216 |
+
AX_ENGINE_IO_BUFFER_T* pInputs;
|
217 |
+
AX_U32 nInputSize;
|
218 |
+
AX_ENGINE_IO_BUFFER_T* pOutputs;
|
219 |
+
AX_U32 nOutputSize;
|
220 |
+
AX_U32 nBatchSize;
|
221 |
+
AX_ENGINE_IO_SETTING_T* pIoSetting;
|
222 |
+
AX_U64 u64Reserved[10];
|
223 |
+
} AX_ENGINE_IO_T;
|
224 |
+
"""
|
225 |
+
)
|
226 |
+
else:
|
227 |
+
engine_cffi.cdef(
|
228 |
+
"""
|
229 |
+
typedef struct _AX_ENGINE_IO_META_T
|
230 |
+
{
|
231 |
+
AX_CHAR* pName;
|
232 |
+
AX_S32* pShape;
|
233 |
+
AX_U8 nShapeSize;
|
234 |
+
AX_ENGINE_TENSOR_LAYOUT_T eLayout;
|
235 |
+
AX_ENGINE_MEMORY_TYPE_T eMemoryType;
|
236 |
+
AX_ENGINE_DATA_TYPE_T eDataType;
|
237 |
+
AX_ENGINE_IO_META_EX_T* pExtraMeta;
|
238 |
+
AX_U32 nSize;
|
239 |
+
AX_U32 nQuantizationValue;
|
240 |
+
AX_S32* pStride;
|
241 |
+
AX_U64 u64Reserved[11];
|
242 |
+
} AX_ENGINE_IO_META_T;
|
243 |
+
|
244 |
+
typedef struct _AX_ENGINE_IO_INFO_T
|
245 |
+
{
|
246 |
+
AX_ENGINE_IO_META_T* pInputs;
|
247 |
+
AX_U32 nInputSize;
|
248 |
+
AX_ENGINE_IO_META_T* pOutputs;
|
249 |
+
AX_U32 nOutputSize;
|
250 |
+
AX_U32 nMaxBatchSize;
|
251 |
+
AX_BOOL bDynamicBatchSize;
|
252 |
+
AX_U64 u64Reserved[13];
|
253 |
+
} AX_ENGINE_IO_INFO_T;
|
254 |
+
|
255 |
+
typedef struct _AX_ENGINE_IO_BUFFER_T
|
256 |
+
{
|
257 |
+
AX_U64 phyAddr;
|
258 |
+
AX_VOID* pVirAddr;
|
259 |
+
AX_U32 nSize;
|
260 |
+
AX_S32* pStride;
|
261 |
+
AX_U8 nStrideSize;
|
262 |
+
AX_U64 u64Reserved[13];
|
263 |
+
} AX_ENGINE_IO_BUFFER_T;
|
264 |
+
|
265 |
+
typedef struct _AX_ENGINE_IO_T
|
266 |
+
{
|
267 |
+
AX_ENGINE_IO_BUFFER_T* pInputs;
|
268 |
+
AX_U32 nInputSize;
|
269 |
+
AX_ENGINE_IO_BUFFER_T* pOutputs;
|
270 |
+
AX_U32 nOutputSize;
|
271 |
+
AX_U32 nBatchSize;
|
272 |
+
AX_ENGINE_IO_SETTING_T* pIoSetting;
|
273 |
+
AX_U64 u64Reserved[12];
|
274 |
+
} AX_ENGINE_IO_T;
|
275 |
+
"""
|
276 |
+
)
|
277 |
+
|
278 |
+
# ax_engine_api.h
|
279 |
+
engine_cffi.cdef(
|
280 |
+
"""
|
281 |
+
const AX_CHAR* AX_ENGINE_GetVersion(AX_VOID);
|
282 |
+
|
283 |
+
AX_VOID AX_ENGINE_NPUReset(AX_VOID);
|
284 |
+
AX_S32 AX_ENGINE_Init(AX_ENGINE_NPU_ATTR_T* pNpuAttr);
|
285 |
+
AX_S32 AX_ENGINE_GetVNPUAttr(AX_ENGINE_NPU_ATTR_T* pNpuAttr);
|
286 |
+
AX_S32 AX_ENGINE_Deinit(AX_VOID);
|
287 |
+
|
288 |
+
AX_S32 AX_ENGINE_GetModelType(const AX_VOID* pData, AX_U32 nDataSize, AX_ENGINE_MODEL_TYPE_T* pModelType);
|
289 |
+
|
290 |
+
AX_S32 AX_ENGINE_CreateHandleV2(uint64_t** pHandle, const AX_VOID* pData, AX_U32 nDataSize, AX_ENGINE_HANDLE_EXTRA_T* pExtraParam);
|
291 |
+
AX_S32 AX_ENGINE_DestroyHandle(uint64_t* nHandle);
|
292 |
+
|
293 |
+
AX_S32 AX_ENGINE_GetIOInfo(uint64_t* nHandle, AX_ENGINE_IO_INFO_T** pIO);
|
294 |
+
AX_S32 AX_ENGINE_GetGroupIOInfoCount(uint64_t* nHandle, AX_U32* pCount);
|
295 |
+
AX_S32 AX_ENGINE_GetGroupIOInfo(uint64_t* nHandle, AX_U32 nIndex, AX_ENGINE_IO_INFO_T** pIO);
|
296 |
+
|
297 |
+
AX_S32 AX_ENGINE_GetHandleModelType(uint64_t* nHandle, AX_ENGINE_MODEL_TYPE_T* pModelType);
|
298 |
+
|
299 |
+
AX_S32 AX_ENGINE_CreateContextV2(uint64_t* nHandle, uint64_t** pContext);
|
300 |
+
|
301 |
+
AX_S32 AX_ENGINE_RunSyncV2(uint64_t* handle, uint64_t* context, AX_ENGINE_IO_T* pIO);
|
302 |
+
AX_S32 AX_ENGINE_RunGroupIOSync(uint64_t* handle, uint64_t* context, AX_U32 nIndex, AX_ENGINE_IO_T* pIO);
|
303 |
+
|
304 |
+
AX_S32 AX_ENGINE_SetAffinity(uint64_t* nHandle, AX_ENGINE_NPU_SET_T nNpuSet);
|
305 |
+
AX_S32 AX_ENGINE_GetAffinity(uint64_t* nHandle, AX_ENGINE_NPU_SET_T* pNpuSet);
|
306 |
+
|
307 |
+
AX_S32 AX_ENGINE_GetCMMUsage(uint64_t* nHandle, AX_ENGINE_CMM_INFO_T* pCMMInfo);
|
308 |
+
|
309 |
+
const AX_CHAR* AX_ENGINE_GetModelToolsVersion(uint64_t* nHandle);
|
310 |
+
|
311 |
+
// internal use api, remember no question
|
312 |
+
AX_S32 AX_ENGINE_GetTotalOps();
|
313 |
+
"""
|
314 |
+
)
|
315 |
+
|
316 |
+
engine_name = "ax_engine"
|
317 |
+
engine_path = ctypes.util.find_library(engine_name)
|
318 |
+
assert (
|
319 |
+
engine_path is not None
|
320 |
+
), f"Failed to find library {engine_name}. Please ensure it is installed and in the library path."
|
321 |
+
|
322 |
+
engine_lib = engine_cffi.dlopen(engine_path)
|
323 |
+
assert engine_lib is not None, f"Failed to load library {engine_path}. Please ensure it is installed and in the library path."
|
python/axengine/_axe_types.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2019-2024 Axera Semiconductor Co., Ltd. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# This source file is the property of Axera Semiconductor Co., Ltd. and
|
4 |
+
# may not be copied or distributed in any isomorphic form without the prior
|
5 |
+
# written consent of Axera Semiconductor Co., Ltd.
|
6 |
+
#
|
7 |
+
|
8 |
+
from enum import Enum
|
9 |
+
|
10 |
+
|
11 |
+
class VNPUType(Enum):
|
12 |
+
DISABLED = 0
|
13 |
+
ENABLED = 1
|
14 |
+
BIG_LITTLE = 2
|
15 |
+
LITTLE_BIG = 3
|
16 |
+
|
17 |
+
|
18 |
+
class ModelType(Enum):
|
19 |
+
HALF = 0 # for MC20E, which means chip is AX630C(x), or AX620Q(x)
|
20 |
+
FULL = 1 # for MC20E
|
21 |
+
SINGLE = 0 # for MC50, which means chip is AX650A or AX650N, and M57H
|
22 |
+
DUAL = 1 # for MC50
|
23 |
+
TRIPLE = 2 # for MC50
|
24 |
+
|
25 |
+
|
26 |
+
class ChipType(Enum):
|
27 |
+
MC20E = 0
|
28 |
+
MC50 = 1
|
29 |
+
M57H = 2
|
python/axengine/_base_session.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2019-2024 Axera Semiconductor Co., Ltd. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# This source file is the property of Axera Semiconductor Co., Ltd. and
|
4 |
+
# may not be copied or distributed in any isomorphic form without the prior
|
5 |
+
# written consent of Axera Semiconductor Co., Ltd.
|
6 |
+
#
|
7 |
+
|
8 |
+
from abc import ABC, abstractmethod
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
from ._node import NodeArg
|
13 |
+
|
14 |
+
|
15 |
+
class SessionOptions:
|
16 |
+
pass
|
17 |
+
|
18 |
+
|
19 |
+
class Session(ABC):
|
20 |
+
def __init__(self) -> None:
|
21 |
+
self._shape_count = 0
|
22 |
+
self._inputs = []
|
23 |
+
self._outputs = []
|
24 |
+
|
25 |
+
def _validate_input(self, feed_input_names: dict[str, np.ndarray]):
|
26 |
+
missing_input_names = []
|
27 |
+
for i in self.get_inputs():
|
28 |
+
if i.name not in feed_input_names:
|
29 |
+
missing_input_names.append(i.name)
|
30 |
+
if missing_input_names:
|
31 |
+
raise ValueError(
|
32 |
+
f"Required inputs ({missing_input_names}) are missing from input feed ({feed_input_names}).")
|
33 |
+
|
34 |
+
def _validate_output(self, output_names: list[str]):
|
35 |
+
if output_names is not None:
|
36 |
+
for name in output_names:
|
37 |
+
if name not in [o.name for o in self.get_outputs()]:
|
38 |
+
raise ValueError(f"Output name '{name}' is not in model outputs name list.")
|
39 |
+
|
40 |
+
def get_inputs(self, shape_group: int = 0) -> list[NodeArg]:
|
41 |
+
if shape_group > self._shape_count:
|
42 |
+
raise ValueError(f"Shape group '{shape_group}' is out of range, total {self._shape_count}.")
|
43 |
+
selected_info = self._inputs[shape_group]
|
44 |
+
return selected_info
|
45 |
+
|
46 |
+
def get_outputs(self, shape_group: int = 0) -> list[NodeArg]:
|
47 |
+
if shape_group > self._shape_count:
|
48 |
+
raise ValueError(f"Shape group '{shape_group}' is out of range, total {self._shape_count}.")
|
49 |
+
selected_info = self._outputs[shape_group]
|
50 |
+
return selected_info
|
51 |
+
|
52 |
+
@abstractmethod
|
53 |
+
def run(
|
54 |
+
self,
|
55 |
+
output_names: list[str] | None,
|
56 |
+
input_feed: dict[str, np.ndarray],
|
57 |
+
run_options=None
|
58 |
+
) -> list[np.ndarray]:
|
59 |
+
pass
|
python/axengine/_node.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2019-2024 Axera Semiconductor Co., Ltd. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# This source file is the property of Axera Semiconductor Co., Ltd. and
|
4 |
+
# may not be copied or distributed in any isomorphic form without the prior
|
5 |
+
# written consent of Axera Semiconductor Co., Ltd.
|
6 |
+
#
|
7 |
+
|
8 |
+
|
9 |
+
class NodeArg(object):
|
10 |
+
def __init__(self, name, dtype, shape):
|
11 |
+
self.name = name
|
12 |
+
self.dtype = dtype
|
13 |
+
self.shape = shape
|
python/axengine/_providers.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2019-2024 Axera Semiconductor Co., Ltd. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# This source file is the property of Axera Semiconductor Co., Ltd. and
|
4 |
+
# may not be copied or distributed in any isomorphic form without the prior
|
5 |
+
# written consent of Axera Semiconductor Co., Ltd.
|
6 |
+
#
|
7 |
+
|
8 |
+
import ctypes.util as cutil
|
9 |
+
|
10 |
+
providers = []
|
11 |
+
axengine_provider_name = 'AxEngineExecutionProvider'
|
12 |
+
axclrt_provider_name = 'AXCLRTExecutionProvider'
|
13 |
+
|
14 |
+
_axengine_lib_name = 'ax_engine'
|
15 |
+
_axclrt_lib_name = 'axcl_rt'
|
16 |
+
|
17 |
+
# check if axcl_rt is installed, so if available, it's the default provider
|
18 |
+
if cutil.find_library(_axclrt_lib_name) is not None:
|
19 |
+
providers.append(axclrt_provider_name)
|
20 |
+
|
21 |
+
# check if ax_engine is installed
|
22 |
+
if cutil.find_library(_axengine_lib_name) is not None:
|
23 |
+
providers.append(axengine_provider_name)
|
24 |
+
|
25 |
+
|
26 |
+
def get_all_providers():
|
27 |
+
return [axengine_provider_name, axclrt_provider_name]
|
28 |
+
|
29 |
+
|
30 |
+
def get_available_providers():
|
31 |
+
return providers
|
python/axengine/_session.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2019-2024 Axera Semiconductor Co., Ltd. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# This source file is the property of Axera Semiconductor Co., Ltd. and
|
4 |
+
# may not be copied or distributed in any isomorphic form without the prior
|
5 |
+
# written consent of Axera Semiconductor Co., Ltd.
|
6 |
+
#
|
7 |
+
|
8 |
+
import os
|
9 |
+
from typing import Any, Sequence
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
|
13 |
+
from ._base_session import SessionOptions
|
14 |
+
from ._node import NodeArg
|
15 |
+
from ._providers import axclrt_provider_name, axengine_provider_name
|
16 |
+
from ._providers import get_available_providers
|
17 |
+
|
18 |
+
|
19 |
+
class InferenceSession:
|
20 |
+
def __init__(
|
21 |
+
self,
|
22 |
+
path_or_bytes: str | bytes | os.PathLike,
|
23 |
+
sess_options: SessionOptions | None = None,
|
24 |
+
providers: Sequence[str | tuple[str, dict[Any, Any]]] | None = None,
|
25 |
+
provider_options: Sequence[dict[Any, Any]] | None = None, **kwargs,
|
26 |
+
) -> None:
|
27 |
+
self._sess = None
|
28 |
+
self._sess_options = sess_options
|
29 |
+
self._provider = None
|
30 |
+
self._provider_options = None
|
31 |
+
self._available_providers = get_available_providers()
|
32 |
+
|
33 |
+
# the providers should be available at least one, checked in __init__.py
|
34 |
+
if providers is None:
|
35 |
+
# using first available provider as default
|
36 |
+
_provider_name = self._available_providers[0]
|
37 |
+
self._provider = _provider_name
|
38 |
+
else:
|
39 |
+
# if only one provider is specified
|
40 |
+
if isinstance(providers, str):
|
41 |
+
if providers not in self._available_providers:
|
42 |
+
raise ValueError(f"Selected provider: '{providers}' is not available.")
|
43 |
+
self._provider = providers
|
44 |
+
# if multiple providers are specified, using the first one as default
|
45 |
+
elif isinstance(providers, list):
|
46 |
+
_unavailable_provider = []
|
47 |
+
for p in providers:
|
48 |
+
assert isinstance(p, str) or isinstance(p, tuple), \
|
49 |
+
f"Invalid provider type: {type(p)}. Must be str or tuple."
|
50 |
+
if isinstance(p, str):
|
51 |
+
if p not in self._available_providers:
|
52 |
+
_unavailable_provider.append(p)
|
53 |
+
elif self._provider is None:
|
54 |
+
self._provider = p
|
55 |
+
if isinstance(p, tuple):
|
56 |
+
assert len(p) == 2, f"Invalid provider type: {p}. Must be tuple with 2 elements."
|
57 |
+
assert isinstance(p[0], str), f"Invalid provider type: {type(p[0])}. Must be str."
|
58 |
+
assert isinstance(p[1], dict), f"Invalid provider type: {type(p[1])}. Must be dict."
|
59 |
+
if p[0] not in self._available_providers:
|
60 |
+
_unavailable_provider.append(p[0])
|
61 |
+
elif self._provider is None:
|
62 |
+
self._provider = p[0]
|
63 |
+
# FIXME: check provider options
|
64 |
+
self._provider_options = p[1]
|
65 |
+
if _unavailable_provider:
|
66 |
+
if self._provider is None:
|
67 |
+
raise ValueError(f"Selected provider(s): {_unavailable_provider} is(are) not available.")
|
68 |
+
else:
|
69 |
+
print(f"[WARNING] Selected provider(s): {_unavailable_provider} is(are) not available.")
|
70 |
+
|
71 |
+
# FIXME: can we remove this check?
|
72 |
+
if self._provider is None:
|
73 |
+
raise ValueError(f"No available provider found in {providers}.")
|
74 |
+
print(f"[INFO] Using provider: {self._provider}")
|
75 |
+
|
76 |
+
if self._provider == axclrt_provider_name:
|
77 |
+
from ._axclrt import AXCLRTSession
|
78 |
+
self._sess = AXCLRTSession(path_or_bytes, sess_options, provider_options, **kwargs)
|
79 |
+
if self._provider == axengine_provider_name:
|
80 |
+
from ._axe import AXEngineSession
|
81 |
+
self._sess = AXEngineSession(path_or_bytes, sess_options, provider_options, **kwargs)
|
82 |
+
if self._sess is None:
|
83 |
+
raise RuntimeError(f"Create session failed with provider: {self._provider}")
|
84 |
+
|
85 |
+
# add to support 'with' statement
|
86 |
+
def __enter__(self):
|
87 |
+
return self
|
88 |
+
|
89 |
+
def __exit__(self, exc_type, exc_value, traceback):
|
90 |
+
# not suppress exceptions
|
91 |
+
return False
|
92 |
+
|
93 |
+
def get_session_options(self):
|
94 |
+
"""
|
95 |
+
Return the session options. See :class:`axengine.SessionOptions`.
|
96 |
+
"""
|
97 |
+
return self._sess_options
|
98 |
+
|
99 |
+
def get_providers(self):
|
100 |
+
"""
|
101 |
+
Return list of registered execution providers.
|
102 |
+
"""
|
103 |
+
return self._provider
|
104 |
+
|
105 |
+
def get_inputs(self, shape_group: int = 0) -> list[NodeArg]:
|
106 |
+
return self._sess.get_inputs(shape_group)
|
107 |
+
|
108 |
+
def get_outputs(self, shape_group: int = 0) -> list[NodeArg]:
|
109 |
+
return self._sess.get_outputs(shape_group)
|
110 |
+
|
111 |
+
def run(
|
112 |
+
self,
|
113 |
+
output_names: list[str] | None,
|
114 |
+
input_feed: dict[str, np.ndarray],
|
115 |
+
run_options=None
|
116 |
+
) -> list[np.ndarray]:
|
117 |
+
return self._sess.run(output_names, input_feed, run_options)
|
python/examples/demo01.jpg
ADDED
![]() |
python/examples/demo02.jpg
ADDED
![]() |
python/examples/demo03.jpg
ADDED
![]() |
python/examples/demo04.jpg
ADDED
![]() |
python/examples/demo05.jpg
ADDED
![]() |
python/examples/demo06.jpg
ADDED
![]() |
python/examples/demo07.jpg
ADDED
![]() |
python/examples/demo08.jpg
ADDED
![]() |
python/examples/demo09.jpg
ADDED
![]() |
python/examples/demo10.jpg
ADDED
![]() |
python/examples/demo11.jpg
ADDED
![]() |
python/examples/demo12.jpg
ADDED
![]() |
python/examples/demo13.jpg
ADDED
![]() |
python/examples/demo14.jpg
ADDED
![]() |
python/examples/demo15.jpg
ADDED
![]() |
python/examples/demo16.jpg
ADDED
![]() |
python/examples/demo17.jpg
ADDED
![]() |
python/examples/demo18.jpg
ADDED
![]() |
python/examples/demo19.jpg
ADDED
![]() |
Git LFS Details
|
python/examples/demo20.jpg
ADDED
![]() |
python/infer.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
from axengine import InferenceSession
|
5 |
+
|
6 |
+
def parse_args() -> argparse.Namespace:
|
7 |
+
parser = argparse.ArgumentParser()
|
8 |
+
parser.add_argument(
|
9 |
+
"--img",
|
10 |
+
type=str,
|
11 |
+
required=True,
|
12 |
+
help="Path to input image.",
|
13 |
+
)
|
14 |
+
parser.add_argument(
|
15 |
+
"--model",
|
16 |
+
type=str,
|
17 |
+
required=True,
|
18 |
+
help="Path to ONNX model.",
|
19 |
+
)
|
20 |
+
|
21 |
+
return parser.parse_args()
|
22 |
+
|
23 |
+
|
24 |
+
def infer(img: str, model: str, viz: bool = False):
|
25 |
+
img_raw = cv2.imread(img)
|
26 |
+
image = cv2.cvtColor(img_raw, cv2.COLOR_BGR2RGB)
|
27 |
+
orig_h, orig_w = image.shape[:2]
|
28 |
+
image = cv2.resize(image, (518,518) )
|
29 |
+
image = image[None]
|
30 |
+
|
31 |
+
session = InferenceSession(path_or_bytes=model, providers= ['AxEngineExecutionProvider', 'AXCLRTExecutionProvider'])
|
32 |
+
|
33 |
+
depth = session.run(output_names=["output"], input_feed={"input":image})[0]
|
34 |
+
|
35 |
+
depth = cv2.resize(depth[0, 0], (orig_w, orig_h))
|
36 |
+
depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
|
37 |
+
depth = depth.astype(np.uint8)
|
38 |
+
|
39 |
+
depth_color = cv2.applyColorMap(depth, cv2.COLORMAP_INFERNO)
|
40 |
+
|
41 |
+
combined_result = cv2.hconcat([img_raw, depth_color])
|
42 |
+
|
43 |
+
cv2.imwrite("output-ax.png", combined_result)
|
44 |
+
|
45 |
+
return depth
|
46 |
+
|
47 |
+
|
48 |
+
if __name__ == "__main__":
|
49 |
+
args = parse_args()
|
50 |
+
infer(**vars(args))
|
python/infer_onnx.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
import onnxruntime as ort
|
5 |
+
|
6 |
+
def parse_args() -> argparse.Namespace:
|
7 |
+
parser = argparse.ArgumentParser()
|
8 |
+
parser.add_argument(
|
9 |
+
"--img",
|
10 |
+
type=str,
|
11 |
+
required=True,
|
12 |
+
help="Path to input image.",
|
13 |
+
)
|
14 |
+
parser.add_argument(
|
15 |
+
"--model",
|
16 |
+
type=str,
|
17 |
+
required=True,
|
18 |
+
help="Path to ONNX model.",
|
19 |
+
)
|
20 |
+
|
21 |
+
return parser.parse_args()
|
22 |
+
|
23 |
+
|
24 |
+
def infer(img: str, model: str, viz: bool = False):
|
25 |
+
img_raw = cv2.imread(img)
|
26 |
+
image = cv2.cvtColor(img_raw, cv2.COLOR_BGR2RGB)
|
27 |
+
orig_h, orig_w = image.shape[:2]
|
28 |
+
image = cv2.resize(image, (518,518) )
|
29 |
+
mean = np.array([123.675, 116.28, 103.53],dtype=np.float32).reshape(1,1,3)
|
30 |
+
std = np.array([58.395, 57.12, 57.375],dtype=np.float32).reshape(1,1,3)
|
31 |
+
|
32 |
+
image = (image-mean)/std
|
33 |
+
image = image.transpose(2,0,1)
|
34 |
+
image = image[None]
|
35 |
+
|
36 |
+
session = ort.InferenceSession(
|
37 |
+
model, providers=["CUDAExecutionProvider", "CPUExecutionProvider"]
|
38 |
+
)
|
39 |
+
depth = session.run(None, {"input": image})[0]
|
40 |
+
|
41 |
+
depth = cv2.resize(depth[0, 0], (orig_w, orig_h))
|
42 |
+
depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
|
43 |
+
depth = depth.astype(np.uint8)
|
44 |
+
|
45 |
+
depth_color = cv2.applyColorMap(depth, cv2.COLORMAP_INFERNO)
|
46 |
+
|
47 |
+
combined_result = cv2.hconcat([img_raw, depth_color])
|
48 |
+
|
49 |
+
cv2.imwrite("output-onnx.png", combined_result)
|
50 |
+
|
51 |
+
return depth
|
52 |
+
|
53 |
+
|
54 |
+
if __name__ == "__main__":
|
55 |
+
args = parse_args()
|
56 |
+
infer(**vars(args))
|
python/output.png
ADDED
![]() |
Git LFS Details
|
python/requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
onnx
|
2 |
+
onnxruntime
|
3 |
+
opencv-python
|