qqc1989 commited on
Commit
8df2c09
·
verified ·
1 Parent(s): 60db486

Upload 39 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ depth_anything_v2_vits_ax620e.axmodel filter=lfs diff=lfs merge=lfs -text
37
+ depth_anything_v2_vits_ax650.axmodel filter=lfs diff=lfs merge=lfs -text
38
+ python/examples/demo19.jpg filter=lfs diff=lfs merge=lfs -text
39
+ python/output.png filter=lfs diff=lfs merge=lfs -text
calib-cocotest2017.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0fd1652bceab31a66a35e05cc26a2c7633e8c5108f4c07c21b5868b9605cc15a
3
+ size 20869120
depth_anything_v2_vits.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:443e95f17819f347f5f987384b8cb7d7d18ed6af6ac46dec9b0152748ba7dfd0
3
+ size 98985978
depth_anything_v2_vits_ax620e.axmodel ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f13d462a968e309354e23babdaf6a90b26841c58fd02f36531ba5d7bb545bea4
3
+ size 38448968
depth_anything_v2_vits_ax650.axmodel ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4520309cbefa63a4127aae75f7b7aa0dc5cc07fa02ef7a13a4219b88950499a1
3
+ size 27978862
python/axengine/__init__.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019-2024 Axera Semiconductor Co., Ltd. All Rights Reserved.
2
+ #
3
+ # This source file is the property of Axera Semiconductor Co., Ltd. and
4
+ # may not be copied or distributed in any isomorphic form without the prior
5
+ # written consent of Axera Semiconductor Co., Ltd.
6
+ #
7
+
8
+ # thanks to community contributors list below:
9
+ # zylo117: https://github.com/zylo117, first implementation of the axclrt backend
10
+
11
+ from ._providers import axengine_provider_name, axclrt_provider_name
12
+ from ._providers import get_all_providers, get_available_providers
13
+
14
+ # check if axclrt is installed, or is a supported chip(e.g. AX650, AX620E etc.)
15
+ _available_providers = get_available_providers()
16
+ if not _available_providers:
17
+ raise ImportError(
18
+ f"No providers found. Please make sure you have installed one of the following: {get_all_providers()}")
19
+ print("[INFO] Available providers: ", _available_providers)
20
+
21
+ from ._node import NodeArg
22
+ from ._session import SessionOptions, InferenceSession
python/axengine/_axclrt.py ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019-2024 Axera Semiconductor Co., Ltd. All Rights Reserved.
2
+ #
3
+ # This source file is the property of Axera Semiconductor Co., Ltd. and
4
+ # may not be copied or distributed in any isomorphic form without the prior
5
+ # written consent of Axera Semiconductor Co., Ltd.
6
+ #
7
+ # first implementation of AXCLRTSession contributed by zylo117
8
+
9
+ import atexit
10
+ import os
11
+ import time
12
+ from typing import Any, Sequence
13
+
14
+ import ml_dtypes as mldt
15
+ import numpy as np
16
+
17
+ from ._axclrt_capi import axclrt_cffi, axclrt_lib
18
+ from ._axclrt_types import VNPUType, ModelType
19
+ from ._base_session import Session, SessionOptions
20
+ from ._node import NodeArg
21
+
22
+ __all__: ["AXCLRTSession"]
23
+
24
+ _is_axclrt_initialized = False
25
+ _is_axclrt_engine_initialized = False
26
+
27
+
28
+ def _transform_dtype(dtype):
29
+ if dtype == axclrt_cffi.cast("axclrtEngineDataType", axclrt_lib.AXCL_DATA_TYPE_UINT8):
30
+ return np.dtype(np.uint8)
31
+ elif dtype == axclrt_cffi.cast("axclrtEngineDataType", axclrt_lib.AXCL_DATA_TYPE_INT8):
32
+ return np.dtype(np.int8)
33
+ elif dtype == axclrt_cffi.cast("axclrtEngineDataType", axclrt_lib.AXCL_DATA_TYPE_UINT16):
34
+ return np.dtype(np.uint16)
35
+ elif dtype == axclrt_cffi.cast("axclrtEngineDataType", axclrt_lib.AXCL_DATA_TYPE_INT16):
36
+ return np.dtype(np.int16)
37
+ elif dtype == axclrt_cffi.cast("axclrtEngineDataType", axclrt_lib.AXCL_DATA_TYPE_UINT32):
38
+ return np.dtype(np.uint32)
39
+ elif dtype == axclrt_cffi.cast("axclrtEngineDataType", axclrt_lib.AXCL_DATA_TYPE_INT32):
40
+ return np.dtype(np.int32)
41
+ elif dtype == axclrt_cffi.cast("axclrtEngineDataType", axclrt_lib.AXCL_DATA_TYPE_FP32):
42
+ return np.dtype(np.float32)
43
+ elif dtype == axclrt_cffi.cast("axclrtEngineDataType", axclrt_lib.AXCL_DATA_TYPE_BF16):
44
+ return np.dtype(mldt.bfloat16)
45
+ else:
46
+ raise ValueError(f"Unsupported data type '{dtype}'.")
47
+
48
+ def _initialize_axclrt():
49
+ global _is_axclrt_initialized
50
+ ret = axclrt_lib.axclInit([])
51
+ if ret != 0:
52
+ raise RuntimeError(f"Failed to initialize axcl runtime. {ret}.")
53
+ _is_axclrt_initialized = True
54
+
55
+
56
+ def _finalize_axclrt():
57
+ global _is_axclrt_initialized, _is_axclrt_engine_initialized
58
+ if _is_axclrt_engine_initialized:
59
+ axclrt_lib.axclrtEngineFinalize()
60
+ _is_axclrt_engine_initialized = False
61
+ if _is_axclrt_initialized:
62
+ axclrt_lib.axclFinalize()
63
+ _is_axclrt_initialized = False
64
+
65
+
66
+ _initialize_axclrt()
67
+ atexit.register(_finalize_axclrt)
68
+
69
+
70
+ def _get_vnpu_type() -> VNPUType:
71
+ vnpu_type = axclrt_cffi.new("axclrtEngineVNpuKind *")
72
+ ret = axclrt_lib.axclrtEngineGetVNpuKind(vnpu_type)
73
+ if ret != 0:
74
+ raise RuntimeError("Failed to get VNPU attribute.")
75
+ return VNPUType(vnpu_type[0])
76
+
77
+
78
+ def _get_version():
79
+ major, minor, patch = axclrt_cffi.new('int32_t *'), axclrt_cffi.new('int32_t *'), axclrt_cffi.new(
80
+ 'int32_t *')
81
+ axclrt_lib.axclrtGetVersion(major, minor, patch)
82
+ return f'{major[0]}.{minor[0]}.{patch[0]}'
83
+
84
+
85
+ class AXCLRTSession(Session):
86
+ def __init__(
87
+ self,
88
+ path_or_bytes: str | bytes | os.PathLike,
89
+ sess_options: SessionOptions | None = None,
90
+ provider_options: dict[Any, Any] | None = None,
91
+ **kwargs,
92
+ ) -> None:
93
+ super().__init__()
94
+
95
+ self._device_index = 0
96
+
97
+ if provider_options is not None and "device_id" in provider_options[0]:
98
+ self._device_index = provider_options[0].get("device_id", 0)
99
+
100
+ lst = axclrt_cffi.new("axclrtDeviceList *")
101
+ ret = axclrt_lib.axclrtGetDeviceList(lst)
102
+ if ret != 0 or lst.num == 0:
103
+ raise RuntimeError(f"Get AXCL device failed 0x{ret:08x}, find total {lst.num} device.")
104
+
105
+ if self._device_index >= lst.num:
106
+ raise RuntimeError(f"Device index {self._device_index} is out of range, total {lst.num} device.")
107
+
108
+ self._device_id = lst.devices[self._device_index]
109
+ ret = axclrt_lib.axclrtSetDevice(self._device_id)
110
+ if ret != 0 or lst.num == 0:
111
+ raise RuntimeError(f"Set AXCL device failed 0x{ret:08x}.")
112
+
113
+ global _is_axclrt_engine_initialized
114
+ vnpu_type = axclrt_cffi.cast(
115
+ "axclrtEngineVNpuKind", VNPUType.DISABLED.value
116
+ )
117
+ # try to initialize NPU as disabled
118
+ ret = axclrt_lib.axclrtEngineInit(vnpu_type)
119
+ # if failed, try to get vnpu type
120
+ if 0 != ret:
121
+ vnpu = axclrt_cffi.new("axclrtEngineVNpuKind *")
122
+ ret = axclrt_lib.axclrtEngineGetVNpuKind(vnpu)
123
+ # if failed, that means the NPU is not available
124
+ if ret != 0:
125
+ raise RuntimeError(f"axclrtEngineInit as {vnpu.value} failed 0x{ret:08x}.")
126
+ # if success, that means the NPU is already initialized as vnpu.value
127
+ # so the initialization is failed.
128
+ # this means the other users maybe uninitialized the NPU suddenly
129
+ # and the app would be terminated unexpectedly at that moment.
130
+ # but we can't do anything to fix this issue, just print a warning message.
131
+ # it because the api looks like onnxruntime, so there no window avoid this.
132
+ # such as the life.
133
+ else:
134
+ print(f"[WARNING] Failed to initialize NPU as {vnpu_type}, NPU is already initialized as {vnpu.value}.")
135
+ # initialize NPU successfully, mark the flag to ensure the engine will be finalized
136
+ else:
137
+ _is_axclrt_engine_initialized = True
138
+
139
+ self.soc_name = axclrt_cffi.string(axclrt_lib.axclrtGetSocName()).decode()
140
+ print(f"[INFO] SOC Name: {self.soc_name}")
141
+
142
+ # model handle, context, info, io
143
+ self._model_id = axclrt_cffi.new("uint64_t *")
144
+ self._context_id = axclrt_cffi.new("uint64_t *")
145
+
146
+ # get vnpu type
147
+ self._vnpu_type = _get_vnpu_type()
148
+ print(f"[INFO] VNPU type: {self._vnpu_type}")
149
+
150
+ # load model
151
+ ret = self._load(path_or_bytes)
152
+ if 0 != ret:
153
+ raise RuntimeError("Failed to load model.")
154
+ print(f"[INFO] Compiler version: {self._get_model_tool_version()}")
155
+
156
+ # get model info
157
+ self._info = self._get_info()
158
+ self._shape_count = self._get_shape_count()
159
+ self._inputs = self._get_inputs()
160
+ self._outputs = self._get_outputs()
161
+
162
+ # prepare io
163
+ self._io = self._prepare_io()
164
+
165
+ def __del__(self):
166
+ self._unload()
167
+
168
+ def _load(self, path_or_bytes):
169
+ # model buffer, almost copied from onnx runtime
170
+ if isinstance(path_or_bytes, (str, os.PathLike)):
171
+ _model_path = axclrt_cffi.new("char[]", path_or_bytes.encode('utf-8'))
172
+ ret = axclrt_lib.axclrtEngineLoadFromFile(_model_path, self._model_id)
173
+ if ret != 0:
174
+ raise RuntimeError("axclrtEngineLoadFromFile failed.")
175
+ elif isinstance(path_or_bytes, bytes):
176
+ _model_buffer = axclrt_cffi.new("char[]", path_or_bytes)
177
+ _model_buffer_size = len(path_or_bytes)
178
+
179
+ dev_mem_ptr = axclrt_cffi.new('void **', axclrt_cffi.NULL)
180
+ ret = axclrt_lib.axclrtMalloc(dev_mem_ptr, _model_buffer_size, axclrt_lib.AXCL_MEM_MALLOC_NORMAL_ONLY)
181
+ if ret != 0:
182
+ raise RuntimeError("axclrtMalloc failed.")
183
+
184
+ ret = axclrt_lib.axclrtMemcpy(dev_mem_ptr[0], _model_buffer, _model_buffer_size, axclrt_lib.AXCL_MEMCPY_HOST_TO_DEVICE)
185
+ if ret != 0:
186
+ axclrt_lib.axclrtFree(dev_mem_ptr[0])
187
+ raise RuntimeError("axclrtMemcpy failed.")
188
+
189
+ ret = axclrt_lib.axclrtEngineLoadFromMem(dev_mem_ptr[0], _model_buffer_size, self._model_id)
190
+ axclrt_lib.axclrtFree(dev_mem_ptr[0])
191
+ if ret != 0:
192
+ raise RuntimeError("axclrtEngineLoadFromMem failed.")
193
+ else:
194
+ raise TypeError(f"Unable to load model from type '{type(path_or_bytes)}'")
195
+
196
+ ret = axclrt_lib.axclrtEngineCreateContext(self._model_id[0], self._context_id)
197
+ if ret != 0:
198
+ raise RuntimeError("axclrtEngineCreateContext failed")
199
+ return ret
200
+
201
+ def _unload(self):
202
+ if self._io is not None:
203
+ dev_size = axclrt_cffi.new("uint64_t *")
204
+ dev_prt = axclrt_cffi.new("void **")
205
+ for i in range(axclrt_lib.axclrtEngineGetNumInputs(self._info[0])):
206
+ axclrt_lib.axclrtEngineGetInputBufferByIndex(self._io, i, dev_prt, dev_size)
207
+ axclrt_lib.axclrtFree(dev_prt[0])
208
+ for i in range(axclrt_lib.axclrtEngineGetNumOutputs(self._info[0])):
209
+ axclrt_lib.axclrtEngineGetOutputBufferByIndex(self._io, i, dev_prt, dev_size)
210
+ axclrt_lib.axclrtFree(dev_prt[0])
211
+ axclrt_lib.axclrtEngineDestroyIO(self._io)
212
+ self._io = None
213
+ if self._model_id[0] is not None:
214
+ axclrt_lib.axclrtEngineUnload(self._model_id[0])
215
+ self._model_id[0] = 0
216
+
217
+ def _get_model_tool_version(self):
218
+ model_tool_version = axclrt_lib.axclrtEngineGetModelCompilerVersion(self._model_id[0])
219
+ return axclrt_cffi.string(model_tool_version).decode()
220
+
221
+ def _get_info(self):
222
+ io_info = axclrt_cffi.new("axclrtEngineIOInfo *")
223
+ ret = axclrt_lib.axclrtEngineGetIOInfo(self._model_id[0], io_info)
224
+ if ret != 0:
225
+ raise RuntimeError("axclrtEngineGetIOInfo failed.")
226
+ return io_info
227
+
228
+ def _get_shape_count(self):
229
+ count = axclrt_cffi.new("int32_t *")
230
+ ret = axclrt_lib.axclrtEngineGetShapeGroupsCount(self._info[0], count)
231
+ if ret != 0:
232
+ axclrt_lib.axclrtEngineUnload(self._model_id[0])
233
+ raise RuntimeError("axclrtEngineGetShapeGroupsCount failed.")
234
+ return count[0]
235
+
236
+ def _get_inputs(self):
237
+ inputs = []
238
+ for group in range(self._shape_count):
239
+ one_group_io = []
240
+ for index in range(axclrt_lib.axclrtEngineGetNumInputs(self._info[0])):
241
+ cffi_name = axclrt_lib.axclrtEngineGetInputNameByIndex(self._info[0], index)
242
+ name = axclrt_cffi.string(cffi_name).decode("utf-8")
243
+
244
+ cffi_dtype = axclrt_cffi.new("axclrtEngineDataType *")
245
+ ret = axclrt_lib.axclrtEngineGetInputDataType(self._info[0], index, cffi_dtype)
246
+ if ret != 0:
247
+ raise RuntimeError("axclrtEngineGetInputDataType failed.")
248
+ dtype = _transform_dtype(cffi_dtype[0])
249
+
250
+ cffi_dims = axclrt_cffi.new("axclrtEngineIODims *")
251
+ ret = axclrt_lib.axclrtEngineGetInputDims(self._info[0], group, index, cffi_dims)
252
+ if ret != 0:
253
+ raise RuntimeError("axclrtEngineGetInputDims failed.")
254
+ shape = [cffi_dims.dims[i] for i in range(cffi_dims.dimCount)]
255
+
256
+ meta = NodeArg(name, dtype, shape)
257
+ one_group_io.append(meta)
258
+ inputs.append(one_group_io)
259
+ return inputs
260
+
261
+ def _get_outputs(self):
262
+ outputs = []
263
+ for group in range(self._shape_count):
264
+ one_group_io = []
265
+ for index in range(axclrt_lib.axclrtEngineGetNumOutputs(self._info[0])):
266
+ name = axclrt_lib.axclrtEngineGetOutputNameByIndex(self._info[0], index)
267
+
268
+ cffi_dtype = axclrt_cffi.new("axclrtEngineDataType *")
269
+ ret = axclrt_lib.axclrtEngineGetOutputDataType(self._info[0], index, cffi_dtype)
270
+ if ret != 0:
271
+ raise RuntimeError("axclrtEngineGetOutputDataType failed.")
272
+ dtype = _transform_dtype(cffi_dtype[0])
273
+
274
+ cffi_dims = axclrt_cffi.new("axclrtEngineIODims *")
275
+ ret = axclrt_lib.axclrtEngineGetOutputDims(self._info[0], group, index, cffi_dims)
276
+ if ret != 0:
277
+ raise RuntimeError("axclrtEngineGetOutputDims failed.")
278
+ shape = [cffi_dims.dims[i] for i in range(cffi_dims.dimCount)]
279
+
280
+ meta = NodeArg(name, dtype, shape)
281
+ one_group_io.append(meta)
282
+ outputs.append(one_group_io)
283
+ return outputs
284
+
285
+ def _prepare_io(self):
286
+ _io = axclrt_cffi.new("axclrtEngineIO *")
287
+ ret = axclrt_lib.axclrtEngineCreateIO(self._info[0], _io)
288
+ if ret != 0:
289
+ raise RuntimeError(f"axclrtEngineCreateIO failed 0x{ret:08x}.")
290
+ for i in range(axclrt_lib.axclrtEngineGetNumInputs(self._info[0])):
291
+ max_size = 0
292
+ for group in range(self._shape_count):
293
+ size = axclrt_lib.axclrtEngineGetInputSizeByIndex(self._info[0], group, i)
294
+ max_size = max(max_size, size)
295
+ dev_ptr = axclrt_cffi.new("void **")
296
+ ret = axclrt_lib.axclrtMalloc(dev_ptr, max_size, axclrt_lib.AXCL_MEM_MALLOC_NORMAL_ONLY)
297
+ if 0 != ret or dev_ptr[0] == axclrt_cffi.NULL:
298
+ raise RuntimeError(f"axclrtMalloc failed 0x{ret:08x} for input {i}.")
299
+ ret = axclrt_lib.axclrtEngineSetInputBufferByIndex(_io[0], i, dev_ptr[0], max_size)
300
+ if 0 != ret:
301
+ raise RuntimeError(f"axclrtEngineSetInputBufferByIndex failed 0x{ret:08x} for input {i}.")
302
+ for i in range(axclrt_lib.axclrtEngineGetNumOutputs(self._info[0])):
303
+ max_size = 0
304
+ for group in range(self._shape_count):
305
+ size = axclrt_lib.axclrtEngineGetOutputSizeByIndex(self._info[0], group, i)
306
+ max_size = max(max_size, size)
307
+ dev_ptr = axclrt_cffi.new("void **")
308
+ ret = axclrt_lib.axclrtMalloc(dev_ptr, max_size, axclrt_lib.AXCL_MEM_MALLOC_NORMAL_ONLY)
309
+ if 0 != ret or dev_ptr[0] == axclrt_cffi.NULL:
310
+ raise RuntimeError(f"axclrtMalloc failed 0x{ret:08x} for output {i}.")
311
+ ret = axclrt_lib.axclrtEngineSetOutputBufferByIndex(_io[0], i, dev_ptr[0], max_size)
312
+ if 0 != ret:
313
+ raise RuntimeError(f"axclrtEngineSetOutputBufferByIndex failed 0x{ret:08x} for output {i}.")
314
+ return _io[0]
315
+
316
+ def run(
317
+ self,
318
+ output_names: list[str],
319
+ input_feed: dict[str, np.ndarray],
320
+ run_options=None
321
+ ):
322
+ self._validate_input(input_feed)
323
+ self._validate_output(output_names)
324
+
325
+ if None is output_names:
326
+ output_names = [o.name for o in self.get_outputs()]
327
+
328
+ # fill model io
329
+ dev_prt = axclrt_cffi.new("void **")
330
+ dev_size = axclrt_cffi.new("uint64_t *")
331
+ for key, npy in input_feed.items():
332
+ for i, one in enumerate(self.get_inputs()):
333
+ if one.name == key:
334
+ assert (
335
+ list(one.shape) == list(npy.shape) and one.dtype == npy.dtype
336
+ ), f"model inputs({key}) expect shape {one.shape} and dtype {one.dtype}, howerver gets input with shape {npy.shape} and dtype {npy.dtype}"
337
+
338
+ if not (
339
+ not npy.flags.c_contiguous
340
+ and npy.flags.f_contiguous
341
+ and npy.flags.contiguous
342
+ ):
343
+ npy = np.ascontiguousarray(npy)
344
+ npy_ptr = axclrt_cffi.cast("void *", npy.ctypes.data)
345
+ ret = axclrt_lib.axclrtEngineGetInputBufferByIndex(self._io, i, dev_prt, dev_size)
346
+ if 0 != ret:
347
+ raise RuntimeError(f"axclrtEngineGetInputBufferByIndex failed for input {i}.")
348
+ ret = axclrt_lib.axclrtMemcpy(dev_prt[0], npy_ptr, npy.nbytes, axclrt_lib.AXCL_MEMCPY_HOST_TO_DEVICE)
349
+ if 0 != ret:
350
+ raise RuntimeError(f"axclrtMemcpy failed for input {i}.")
351
+
352
+ # execute model
353
+ ret = axclrt_lib.axclrtEngineExecute(self._model_id[0], self._context_id[0], 0, self._io)
354
+
355
+ # get output
356
+ outputs = []
357
+ if 0 == ret:
358
+ for i in range(len(self.get_outputs())):
359
+ ret = axclrt_lib.axclrtEngineGetOutputBufferByIndex(self._io, i, dev_prt, dev_size)
360
+ if 0 != ret:
361
+ raise RuntimeError(f"axclrtEngineGetOutputBufferByIndex failed for output {i}.")
362
+ npy = np.zeros(self.get_outputs()[i].shape, dtype=self.get_outputs()[i].dtype)
363
+ npy_ptr = axclrt_cffi.cast("void *", npy.ctypes.data)
364
+ ret = axclrt_lib.axclrtMemcpy(npy_ptr, dev_prt[0], npy.nbytes, axclrt_lib.AXCL_MEMCPY_DEVICE_TO_HOST)
365
+ if 0 != ret:
366
+ raise RuntimeError(f"axclrtMemcpy failed for output {i}.")
367
+ name = self.get_outputs()[i].name
368
+ if name in output_names:
369
+ outputs.append(npy)
370
+ return outputs
371
+ else:
372
+ raise RuntimeError(f"axclrtEngineExecute failed 0x{ret:08x}")
python/axengine/_axclrt_capi.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019-2024 Axera Semiconductor Co., Ltd. All Rights Reserved.
2
+ #
3
+ # This source file is the property of Axera Semiconductor Co., Ltd. and
4
+ # may not be copied or distributed in any isomorphic form without the prior
5
+ # written consent of Axera Semiconductor Co., Ltd.
6
+ #
7
+
8
+ import ctypes.util
9
+
10
+ from cffi import FFI
11
+
12
+ __all__: ["axclrt_cffi", "axclrt_lib"]
13
+
14
+ axclrt_cffi = FFI()
15
+
16
+ # axcl_base.h
17
+ axclrt_cffi.cdef(
18
+ """
19
+ #define AXCL_MAX_DEVICE_COUNT 256
20
+ typedef int32_t axclError;
21
+ typedef void *axclrtContext;
22
+ """
23
+ )
24
+
25
+ # axcl_rt_type.h
26
+ axclrt_cffi.cdef(
27
+ """
28
+ typedef struct axclrtDeviceList {
29
+ uint32_t num;
30
+ int32_t devices[AXCL_MAX_DEVICE_COUNT];
31
+ } axclrtDeviceList;
32
+
33
+ typedef enum axclrtMemMallocPolicy {
34
+ AXCL_MEM_MALLOC_HUGE_FIRST,
35
+ AXCL_MEM_MALLOC_HUGE_ONLY,
36
+ AXCL_MEM_MALLOC_NORMAL_ONLY
37
+ } axclrtMemMallocPolicy;
38
+
39
+ typedef enum axclrtMemcpyKind {
40
+ AXCL_MEMCPY_HOST_TO_HOST,
41
+ AXCL_MEMCPY_HOST_TO_DEVICE, //!< host vir -> device phy
42
+ AXCL_MEMCPY_DEVICE_TO_HOST, //!< host vir <- device phy
43
+ AXCL_MEMCPY_DEVICE_TO_DEVICE,
44
+ AXCL_MEMCPY_HOST_PHY_TO_DEVICE, //!< host phy -> device phy
45
+ AXCL_MEMCPY_DEVICE_TO_HOST_PHY, //!< host phy <- device phy
46
+ } axclrtMemcpyKind;
47
+ """
48
+ )
49
+
50
+ # axcl_rt_engine_type.h
51
+ axclrt_cffi.cdef(
52
+ """
53
+ #define AXCLRT_ENGINE_MAX_DIM_CNT 32
54
+ typedef void* axclrtEngineIOInfo;
55
+ typedef void* axclrtEngineIO;
56
+
57
+ typedef enum axclrtEngineVNpuKind {
58
+ AXCL_VNPU_DISABLE = 0,
59
+ AXCL_VNPU_ENABLE = 1,
60
+ AXCL_VNPU_BIG_LITTLE = 2,
61
+ AXCL_VNPU_LITTLE_BIG = 3,
62
+ } axclrtEngineVNpuKind;
63
+
64
+ typedef enum axclrtEngineDataType {
65
+ AXCL_DATA_TYPE_NONE = 0,
66
+ AXCL_DATA_TYPE_INT4 = 1,
67
+ AXCL_DATA_TYPE_UINT4 = 2,
68
+ AXCL_DATA_TYPE_INT8 = 3,
69
+ AXCL_DATA_TYPE_UINT8 = 4,
70
+ AXCL_DATA_TYPE_INT16 = 5,
71
+ AXCL_DATA_TYPE_UINT16 = 6,
72
+ AXCL_DATA_TYPE_INT32 = 7,
73
+ AXCL_DATA_TYPE_UINT32 = 8,
74
+ AXCL_DATA_TYPE_INT64 = 9,
75
+ AXCL_DATA_TYPE_UINT64 = 10,
76
+ AXCL_DATA_TYPE_FP4 = 11,
77
+ AXCL_DATA_TYPE_FP8 = 12,
78
+ AXCL_DATA_TYPE_FP16 = 13,
79
+ AXCL_DATA_TYPE_BF16 = 14,
80
+ AXCL_DATA_TYPE_FP32 = 15,
81
+ AXCL_DATA_TYPE_FP64 = 16,
82
+ } axclrtEngineDataType;
83
+
84
+ typedef enum axclrtEngineDataLayout {
85
+ AXCL_DATA_LAYOUT_NONE = 0,
86
+ AXCL_DATA_LAYOUT_NHWC = 0,
87
+ AXCL_DATA_LAYOUT_NCHW = 1,
88
+ } axclrtEngineDataLayout;
89
+
90
+ typedef struct axclrtEngineIODims {
91
+ int32_t dimCount;
92
+ int32_t dims[AXCLRT_ENGINE_MAX_DIM_CNT];
93
+ } axclrtEngineIODims;
94
+ """
95
+ )
96
+
97
+ # axcl.h
98
+ axclrt_cffi.cdef(
99
+ """
100
+ axclError axclInit(const char *config);
101
+ axclError axclFinalize();
102
+ """
103
+ )
104
+
105
+ # axcl_rt.h
106
+ axclrt_cffi.cdef(
107
+ """
108
+ axclError axclrtGetVersion(int32_t *major, int32_t *minor, int32_t *patch);
109
+ const char *axclrtGetSocName();
110
+ """
111
+ )
112
+
113
+ # axcl_rt_device.h
114
+ axclrt_cffi.cdef(
115
+ """
116
+ axclError axclrtGetDeviceList(axclrtDeviceList *deviceList);
117
+ axclError axclrtSetDevice(int32_t deviceId);
118
+ axclError axclrtResetDevice(int32_t deviceId);
119
+ """
120
+ )
121
+
122
+ # axcl_rt_context.h
123
+ axclrt_cffi.cdef(
124
+ """
125
+ axclError axclrtCreateContext(axclrtContext *context, int32_t deviceId);
126
+ axclError axclrtDestroyContext(axclrtContext context);
127
+ axclError axclrtSetCurrentContext(axclrtContext context);
128
+ axclError axclrtGetCurrentContext(axclrtContext *context);
129
+ axclError axclrtGetDefaultContext(axclrtContext *context, int32_t deviceId);
130
+ """
131
+ )
132
+
133
+ # axcl_rt_engine.h
134
+ axclrt_cffi.cdef(
135
+ """
136
+ axclError axclrtEngineInit(axclrtEngineVNpuKind npuKind);
137
+ axclError axclrtEngineGetVNpuKind(axclrtEngineVNpuKind *npuKind);
138
+ axclError axclrtEngineFinalize();
139
+
140
+ axclError axclrtEngineLoadFromFile(const char *modelPath, uint64_t *modelId);
141
+ axclError axclrtEngineLoadFromMem(const void *model, uint64_t modelSize, uint64_t *modelId);
142
+ const char* axclrtEngineGetModelCompilerVersion(uint64_t modelId);
143
+ axclError axclrtEngineUnload(uint64_t modelId);
144
+
145
+ axclError axclrtEngineGetIOInfo(uint64_t modelId, axclrtEngineIOInfo *ioInfo);
146
+ axclError axclrtEngineGetShapeGroupsCount(axclrtEngineIOInfo ioInfo, int32_t *count);
147
+
148
+ uint32_t axclrtEngineGetNumInputs(axclrtEngineIOInfo ioInfo);
149
+ uint32_t axclrtEngineGetNumOutputs(axclrtEngineIOInfo ioInfo);
150
+
151
+ uint64_t axclrtEngineGetInputSizeByIndex(axclrtEngineIOInfo ioInfo, uint32_t group, uint32_t index);
152
+ uint64_t axclrtEngineGetOutputSizeByIndex(axclrtEngineIOInfo ioInfo, uint32_t group, uint32_t index);
153
+
154
+ axclError axclrtEngineGetInputDims(axclrtEngineIOInfo ioInfo, uint32_t group, uint32_t index, axclrtEngineIODims *dims);
155
+ axclError axclrtEngineGetOutputDims(axclrtEngineIOInfo ioInfo, uint32_t group, uint32_t index, axclrtEngineIODims *dims);
156
+
157
+ const char *axclrtEngineGetInputNameByIndex(axclrtEngineIOInfo ioInfo, uint32_t index);
158
+ const char *axclrtEngineGetOutputNameByIndex(axclrtEngineIOInfo ioInfo, uint32_t index);
159
+
160
+ int32_t axclrtEngineGetInputDataType(axclrtEngineIOInfo ioInfo, uint32_t index, axclrtEngineDataType *type);
161
+ int32_t axclrtEngineGetOutputDataType(axclrtEngineIOInfo ioInfo, uint32_t index, axclrtEngineDataType *type);
162
+
163
+ int32_t axclrtEngineGetInputDataLayout(axclrtEngineIOInfo ioInfo, uint32_t index, axclrtEngineDataLayout *layout);
164
+ int32_t axclrtEngineGetOutputDataLayout(axclrtEngineIOInfo ioInfo, uint32_t index, axclrtEngineDataLayout *layout);
165
+
166
+ axclError axclrtEngineCreateIO(axclrtEngineIOInfo ioInfo, axclrtEngineIO *io);
167
+ axclError axclrtEngineDestroyIO(axclrtEngineIO io);
168
+
169
+ axclError axclrtEngineSetInputBufferByIndex(axclrtEngineIO io, uint32_t index, const void *dataBuffer, uint64_t size);
170
+ axclError axclrtEngineSetOutputBufferByIndex(axclrtEngineIO io, uint32_t index, const void *dataBuffer, uint64_t size);
171
+ axclError axclrtEngineGetInputBufferByIndex(axclrtEngineIO io, uint32_t index, void **dataBuffer, uint64_t *size);
172
+ axclError axclrtEngineGetOutputBufferByIndex(axclrtEngineIO io, uint32_t index, void **dataBuffer, uint64_t *size);
173
+
174
+ axclError axclrtEngineCreateContext(uint64_t modelId, uint64_t *contextId);
175
+
176
+ axclError axclrtEngineExecute(uint64_t modelId, uint64_t contextId, uint32_t group, axclrtEngineIO io);
177
+ """
178
+ )
179
+
180
+ # axcl_rt_memory.h
181
+ axclrt_cffi.cdef(
182
+ """
183
+ axclError axclrtMalloc(void **devPtr, size_t size, axclrtMemMallocPolicy policy);
184
+ axclError axclrtMallocCached(void **devPtr, size_t size, axclrtMemMallocPolicy policy);
185
+ axclError axclrtMemcpy(void *dstPtr, const void *srcPtr, size_t count, axclrtMemcpyKind kind);
186
+ axclError axclrtFree(void *devPtr);
187
+ axclError axclrtMemFlush(void *devPtr, size_t size);
188
+ """
189
+ )
190
+
191
+ rt_name = "axcl_rt"
192
+ rt_path = ctypes.util.find_library(rt_name)
193
+ assert (
194
+ rt_path is not None
195
+ ), f"Failed to find library {rt_name}. Please ensure it is installed and in the library path."
196
+
197
+ axclrt_lib = axclrt_cffi.dlopen(rt_path)
198
+ assert axclrt_lib is not None, f"Failed to load library {rt_path}. Please ensure it is installed and in the library path."
python/axengine/_axclrt_types.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019-2024 Axera Semiconductor Co., Ltd. All Rights Reserved.
2
+ #
3
+ # This source file is the property of Axera Semiconductor Co., Ltd. and
4
+ # may not be copied or distributed in any isomorphic form without the prior
5
+ # written consent of Axera Semiconductor Co., Ltd.
6
+ #
7
+
8
+ from enum import Enum
9
+
10
+
11
+ class VNPUType(Enum):
12
+ DISABLED = 0
13
+ ENABLED = 1
14
+ BIG_LITTLE = 2
15
+ LITTLE_BIG = 3
16
+
17
+
18
+ class ModelType(Enum):
19
+ SINGLE = 0
20
+ DUAL = 1
21
+ TRIPLE = 2
python/axengine/_axe.py ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019-2024 Axera Semiconductor Co., Ltd. All Rights Reserved.
2
+ #
3
+ # This source file is the property of Axera Semiconductor Co., Ltd. and
4
+ # may not be copied or distributed in any isomorphic form without the prior
5
+ # written consent of Axera Semiconductor Co., Ltd.
6
+ #
7
+
8
+ import atexit
9
+ import os
10
+ from typing import Any, Sequence
11
+
12
+ import ml_dtypes as mldt
13
+ import numpy as np
14
+
15
+ from ._axe_capi import sys_lib, engine_cffi, engine_lib
16
+ from ._axe_types import VNPUType, ModelType, ChipType
17
+ from ._base_session import Session, SessionOptions
18
+ from ._node import NodeArg
19
+
20
+ __all__: ["AXEngineSession"]
21
+
22
+ _is_sys_initialized = False
23
+ _is_engine_initialized = False
24
+
25
+
26
+ def _transform_dtype(dtype):
27
+ if dtype == engine_cffi.cast("AX_ENGINE_DATA_TYPE_T", engine_lib.AX_ENGINE_DT_UINT8):
28
+ return np.dtype(np.uint8)
29
+ elif dtype == engine_cffi.cast("AX_ENGINE_DATA_TYPE_T", engine_lib.AX_ENGINE_DT_SINT8):
30
+ return np.dtype(np.int8)
31
+ elif dtype == engine_cffi.cast("AX_ENGINE_DATA_TYPE_T", engine_lib.AX_ENGINE_DT_UINT16):
32
+ return np.dtype(np.uint16)
33
+ elif dtype == engine_cffi.cast("AX_ENGINE_DATA_TYPE_T", engine_lib.AX_ENGINE_DT_SINT16):
34
+ return np.dtype(np.int16)
35
+ elif dtype == engine_cffi.cast("AX_ENGINE_DATA_TYPE_T", engine_lib.AX_ENGINE_DT_UINT32):
36
+ return np.dtype(np.uint32)
37
+ elif dtype == engine_cffi.cast("AX_ENGINE_DATA_TYPE_T", engine_lib.AX_ENGINE_DT_SINT32):
38
+ return np.dtype(np.int32)
39
+ elif dtype == engine_cffi.cast("AX_ENGINE_DATA_TYPE_T", engine_lib.AX_ENGINE_DT_FLOAT32):
40
+ return np.dtype(np.float32)
41
+ elif dtype == engine_cffi.cast("AX_ENGINE_DATA_TYPE_T", engine_lib.AX_ENGINE_DT_BFLOAT16):
42
+ return np.dtype(mldt.bfloat16)
43
+ else:
44
+ raise ValueError(f"Unsupported data type '{dtype}'.")
45
+
46
+
47
+ def _check_cffi_func_exists(lib, func_name):
48
+ try:
49
+ getattr(lib, func_name)
50
+ return True
51
+ except AttributeError:
52
+ return False
53
+
54
+
55
+ def _get_chip_type():
56
+ if not _check_cffi_func_exists(engine_lib, "AX_ENGINE_SetAffinity"):
57
+ return ChipType.M57H
58
+ elif not _check_cffi_func_exists(engine_lib, "AX_ENGINE_GetTotalOps"):
59
+ return ChipType.MC50
60
+ else:
61
+ return ChipType.MC20E
62
+
63
+
64
+ def _get_version():
65
+ engine_version = engine_lib.AX_ENGINE_GetVersion()
66
+ return engine_cffi.string(engine_version).decode("utf-8")
67
+
68
+
69
+ def _get_vnpu_type() -> VNPUType:
70
+ vnpu_type = engine_cffi.new("AX_ENGINE_NPU_ATTR_T *")
71
+ ret = engine_lib.AX_ENGINE_GetVNPUAttr(vnpu_type)
72
+ if 0 != ret:
73
+ raise RuntimeError("Failed to get VNPU attribute.")
74
+ return VNPUType(vnpu_type.eHardMode)
75
+
76
+
77
+ def _initialize_engine():
78
+ global _is_sys_initialized, _is_engine_initialized
79
+
80
+ ret = sys_lib.AX_SYS_Init()
81
+ if ret != 0:
82
+ raise RuntimeError("Failed to initialize ax sys.")
83
+ _is_sys_initialized = True
84
+
85
+ # disabled mode by default
86
+ vnpu_type = engine_cffi.new("AX_ENGINE_NPU_ATTR_T *")
87
+ ret = engine_lib.AX_ENGINE_GetVNPUAttr(vnpu_type)
88
+ if 0 != ret:
89
+ # this means the NPU was not initialized
90
+ vnpu_type.eHardMode = engine_cffi.cast(
91
+ "AX_ENGINE_NPU_MODE_T", VNPUType.DISABLED.value
92
+ )
93
+ ret = engine_lib.AX_ENGINE_Init(vnpu_type)
94
+ if ret != 0:
95
+ raise RuntimeError("Failed to initialize ax sys engine.")
96
+ _is_engine_initialized = True
97
+
98
+ print(f"[INFO] Chip type: {_get_chip_type()}")
99
+ print(f"[INFO] VNPU type: {_get_vnpu_type()}")
100
+ print(f"[INFO] Engine version: {_get_version()}")
101
+
102
+
103
+ def _finalize_engine():
104
+ global _is_sys_initialized, _is_engine_initialized
105
+
106
+ if _is_engine_initialized:
107
+ engine_lib.AX_ENGINE_Deinit()
108
+ if _is_sys_initialized:
109
+ sys_lib.AX_SYS_Deinit()
110
+
111
+
112
+ _initialize_engine()
113
+ atexit.register(_finalize_engine)
114
+
115
+
116
+ class AXEngineSession(Session):
117
+ def __init__(
118
+ self,
119
+ path_or_bytes: str | bytes | os.PathLike,
120
+ sess_options: SessionOptions | None = None,
121
+ provider_options: dict[Any, Any] | None = None,
122
+ **kwargs,
123
+ ) -> None:
124
+ super().__init__()
125
+
126
+ self._chip_type = _get_chip_type()
127
+ self._vnpu_type = _get_vnpu_type()
128
+
129
+ # handle, context, info, io
130
+ self._handle = engine_cffi.new("uint64_t **")
131
+ self._context = engine_cffi.new("uint64_t **")
132
+ self._io = engine_cffi.new("AX_ENGINE_IO_T *")
133
+
134
+ # model buffer, almost copied from onnx runtime
135
+ if isinstance(path_or_bytes, (str, os.PathLike)):
136
+ self._model_name = os.path.splitext(os.path.basename(path_or_bytes))[0]
137
+ with open(path_or_bytes, "rb") as f:
138
+ data = f.read()
139
+ self._model_buffer = engine_cffi.new("char[]", data)
140
+ self._model_buffer_size = len(data)
141
+ elif isinstance(path_or_bytes, bytes):
142
+ self._model_buffer = engine_cffi.new("char[]", path_or_bytes)
143
+ self._model_buffer_size = len(path_or_bytes)
144
+ else:
145
+ raise TypeError(f"Unable to load model from type '{type(path_or_bytes)}'")
146
+
147
+ # get model type
148
+ self._model_type = self._get_model_type()
149
+ if self._chip_type is ChipType.MC20E:
150
+ if self._model_type is ModelType.FULL:
151
+ print(f"[INFO] Model type: {self._model_type.value} (full core)")
152
+ if self._model_type is ModelType.HALF:
153
+ print(f"[INFO] Model type: {self._model_type.value} (half core)")
154
+ if self._chip_type is ChipType.MC50:
155
+ if self._model_type is ModelType.SINGLE:
156
+ print(f"[INFO] Model type: {self._model_type.value} (single core)")
157
+ if self._model_type is ModelType.DUAL:
158
+ print(f"[INFO] Model type: {self._model_type.value} (dual core)")
159
+ if self._model_type is ModelType.TRIPLE:
160
+ print(f"[INFO] Model type: {self._model_type.value} (triple core)")
161
+ if self._chip_type is ChipType.M57H:
162
+ print(f"[INFO] Model type: {self._model_type.value} (single core)")
163
+
164
+ # check model type
165
+ if self._chip_type is ChipType.MC50:
166
+ # all types (single or dual or triple) of model are allowed in vnpu mode disabled
167
+ # only single core model is allowed in vnpu mode enabled
168
+ # only triple core model is NOT allowed in vnpu mode big-little or little-big
169
+ if self._vnpu_type is VNPUType.ENABLED:
170
+ if self._model_type is not ModelType.SINGLE:
171
+ raise ValueError(
172
+ f"Model type '{self._model_type}' is not allowed when vnpu is inited as {self._vnpu_type}."
173
+ )
174
+ if (
175
+ self._vnpu_type is VNPUType.BIG_LITTLE
176
+ or self._vnpu_type is VNPUType.LITTLE_BIG
177
+ ):
178
+ if self._model_type is ModelType.TRIPLE:
179
+ raise ValueError(
180
+ f"Model type '{self._model_type}' is not allowed when vnpu is inited as {self._vnpu_type}."
181
+ )
182
+ if self._chip_type is ChipType.MC20E:
183
+ # all types of full or half core model are allowed in vnpu mode disabled
184
+ # only half core model is allowed in vnpu mode enabled
185
+ if self._vnpu_type is VNPUType.ENABLED:
186
+ if self._model_type is ModelType.FULL:
187
+ raise ValueError(
188
+ f"Model type '{self._model_type}' is not allowed when vnpu is inited as {self._vnpu_type}."
189
+ )
190
+ # if self._chip_type is ChipType.M57H:
191
+ # there only one type of model will be compiled, so no need to check
192
+
193
+ # load model
194
+ ret = self._load()
195
+ if 0 != ret:
196
+ raise RuntimeError("Failed to load model.")
197
+ print(f"[INFO] Compiler version: {self._get_model_tool_version()}")
198
+
199
+ # get shape group count
200
+ try:
201
+ self._shape_count = self._get_shape_count()
202
+ except AttributeError as e:
203
+ print(f"[WARNING] {e}")
204
+ self._shape_count = 1
205
+
206
+ # get model shape
207
+ self._info = self._get_info()
208
+ self._inputs = self._get_inputs()
209
+ self._outputs = self._get_outputs()
210
+
211
+ # fill model io
212
+ self._align = 128
213
+ self._cmm_token = engine_cffi.new("AX_S8[]", b"PyEngine")
214
+ self._io[0].nInputSize = len(self.get_inputs())
215
+ self._io[0].nOutputSize = len(self.get_outputs())
216
+ self._io[0].pInputs = engine_cffi.new(
217
+ "AX_ENGINE_IO_BUFFER_T[{}]".format(self._io[0].nInputSize)
218
+ )
219
+ self._io[0].pOutputs = engine_cffi.new(
220
+ "AX_ENGINE_IO_BUFFER_T[{}]".format(self._io[0].nOutputSize)
221
+ )
222
+ for i in range(len(self.get_inputs())):
223
+ max_buf = 0
224
+ for j in range(self._shape_count):
225
+ max_buf = max(max_buf, self._info[j][0].pInputs[i].nSize)
226
+ self._io[0].pInputs[i].nSize = max_buf
227
+ phy = engine_cffi.new("AX_U64*")
228
+ vir = engine_cffi.new("AX_VOID**")
229
+ ret = sys_lib.AX_SYS_MemAllocCached(
230
+ phy, vir, self._io[0].pInputs[i].nSize, self._align, self._cmm_token
231
+ )
232
+ if 0 != ret:
233
+ raise RuntimeError("Failed to allocate memory for input.")
234
+ self._io[0].pInputs[i].phyAddr = phy[0]
235
+ self._io[0].pInputs[i].pVirAddr = vir[0]
236
+ for i in range(len(self.get_outputs())):
237
+ max_buf = 0
238
+ for j in range(self._shape_count):
239
+ max_buf = max(max_buf, self._info[j][0].pOutputs[i].nSize)
240
+ self._io[0].pOutputs[i].nSize = max_buf
241
+ phy = engine_cffi.new("AX_U64*")
242
+ vir = engine_cffi.new("AX_VOID**")
243
+ ret = sys_lib.AX_SYS_MemAllocCached(
244
+ phy, vir, self._io[0].pOutputs[i].nSize, self._align, self._cmm_token
245
+ )
246
+ if 0 != ret:
247
+ raise RuntimeError("Failed to allocate memory for output.")
248
+ self._io[0].pOutputs[i].phyAddr = phy[0]
249
+ self._io[0].pOutputs[i].pVirAddr = vir[0]
250
+
251
+ def __del__(self):
252
+ self._unload()
253
+
254
+ def _get_model_type(self) -> ModelType:
255
+ model_type = engine_cffi.new("AX_ENGINE_MODEL_TYPE_T *")
256
+ ret = engine_lib.AX_ENGINE_GetModelType(
257
+ self._model_buffer, self._model_buffer_size, model_type
258
+ )
259
+ if 0 != ret:
260
+ raise RuntimeError("Failed to get model type.")
261
+ return ModelType(model_type[0])
262
+
263
+ def _get_model_tool_version(self):
264
+ model_tool_version = engine_lib.AX_ENGINE_GetModelToolsVersion(
265
+ self._handle[0]
266
+ )
267
+ return engine_cffi.string(model_tool_version).decode("utf-8")
268
+
269
+ def _load(self):
270
+ extra = engine_cffi.new("AX_ENGINE_HANDLE_EXTRA_T *")
271
+ extra_name = engine_cffi.new("char[]", self._model_name.encode("utf-8"))
272
+ extra.pName = extra_name
273
+
274
+ # for onnx runtime do not support one model multiple context running in multi-thread as far as I know, so
275
+ # the engine handle and context will create only once
276
+ ret = engine_lib.AX_ENGINE_CreateHandleV2(
277
+ self._handle, self._model_buffer, self._model_buffer_size, extra
278
+ )
279
+ if 0 == ret:
280
+ ret = engine_lib.AX_ENGINE_CreateContextV2(
281
+ self._handle[0], self._context
282
+ )
283
+ return ret
284
+
285
+ def _get_info(self):
286
+ total_info = []
287
+ if 1 == self._shape_count:
288
+ info = engine_cffi.new("AX_ENGINE_IO_INFO_T **")
289
+ ret = engine_lib.AX_ENGINE_GetIOInfo(self._handle[0], info)
290
+ if 0 != ret:
291
+ raise RuntimeError("Failed to get model shape.")
292
+ total_info.append(info)
293
+ else:
294
+ for i in range(self._shape_count):
295
+ info = engine_cffi.new("AX_ENGINE_IO_INFO_T **")
296
+ ret = engine_lib.AX_ENGINE_GetGroupIOInfo(
297
+ self._handle[0], i, info
298
+ )
299
+ if 0 != ret:
300
+ raise RuntimeError(f"Failed to get model the {i}th shape.")
301
+ total_info.append(info)
302
+ return total_info
303
+
304
+ def _get_shape_count(self):
305
+ count = engine_cffi.new("AX_U32 *")
306
+ ret = engine_lib.AX_ENGINE_GetGroupIOInfoCount(self._handle[0], count)
307
+ if 0 != ret:
308
+ raise RuntimeError("Failed to get model shape group.")
309
+ return count[0]
310
+
311
+ def _unload(self):
312
+ if self._handle[0] is not None:
313
+ engine_lib.AX_ENGINE_DestroyHandle(self._handle[0])
314
+ self._handle[0] = engine_cffi.NULL
315
+
316
+ def _get_io(self, io_type: str):
317
+ io_info = []
318
+ for group in range(self._shape_count):
319
+ one_group_io = []
320
+ for index in range(getattr(self._info[group][0], f'n{io_type}Size')):
321
+ current_io = getattr(self._info[group][0], f'p{io_type}s')[index]
322
+ name = engine_cffi.string(current_io.pName).decode("utf-8")
323
+ shape = [current_io.pShape[i] for i in range(current_io.nShapeSize)]
324
+ dtype = _transform_dtype(current_io.eDataType)
325
+ meta = NodeArg(name, dtype, shape)
326
+ one_group_io.append(meta)
327
+ io_info.append(one_group_io)
328
+ return io_info
329
+
330
+ def _get_inputs(self):
331
+ return self._get_io('Input')
332
+
333
+ def _get_outputs(self):
334
+ return self._get_io('Output')
335
+
336
+ def run(
337
+ self,
338
+ output_names: list[str],
339
+ input_feed: dict[str, np.ndarray],
340
+ run_options=None
341
+ ):
342
+ self._validate_input(input_feed)
343
+ self._validate_output(output_names)
344
+
345
+ if None is output_names:
346
+ output_names = [o.name for o in self.get_outputs()]
347
+
348
+ # fill model io
349
+ for key, npy in input_feed.items():
350
+ for i, one in enumerate(self.get_inputs()):
351
+ if one.name == key:
352
+ assert (
353
+ list(one.shape) == list(npy.shape) and one.dtype == npy.dtype
354
+ ), f"model inputs({key}) expect shape {one.shape} and dtype {one.dtype}, however gets input with shape {npy.shape} and dtype {npy.dtype}"
355
+
356
+ if not (
357
+ not npy.flags.c_contiguous
358
+ and npy.flags.f_contiguous
359
+ and npy.flags.contiguous
360
+ ):
361
+ npy = np.ascontiguousarray(npy)
362
+ npy_ptr = engine_cffi.cast("void *", npy.ctypes.data)
363
+
364
+ engine_cffi.memmove(
365
+ self._io[0].pInputs[i].pVirAddr, npy_ptr, npy.nbytes
366
+ )
367
+ sys_lib.AX_SYS_MflushCache(
368
+ self._io[0].pInputs[i].phyAddr,
369
+ self._io[0].pInputs[i].pVirAddr,
370
+ self._io[0].pInputs[i].nSize,
371
+ )
372
+ break
373
+
374
+ # execute model
375
+ ret = engine_lib.AX_ENGINE_RunSyncV2(
376
+ self._handle[0], self._context[0], self._io
377
+ )
378
+
379
+ # flush output
380
+ outputs = []
381
+ if 0 == ret:
382
+ for i in range(len(self.get_outputs())):
383
+ sys_lib.AX_SYS_MinvalidateCache(
384
+ self._io[0].pOutputs[i].phyAddr,
385
+ self._io[0].pOutputs[i].pVirAddr,
386
+ self._io[0].pOutputs[i].nSize,
387
+ )
388
+ npy = np.frombuffer(
389
+ engine_cffi.buffer(
390
+ self._io[0].pOutputs[i].pVirAddr, self._io[0].pOutputs[i].nSize
391
+ ),
392
+ dtype=self.get_outputs()[i].dtype,
393
+ ).reshape(self.get_outputs()[i].shape)
394
+ name = self.get_outputs()[i].name
395
+ if name in output_names:
396
+ outputs.append(npy)
397
+ return outputs
398
+ else:
399
+ raise RuntimeError("Failed to run model.")
python/axengine/_axe_capi.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019-2024 Axera Semiconductor Co., Ltd. All Rights Reserved.
2
+ #
3
+ # This source file is the property of Axera Semiconductor Co., Ltd. and
4
+ # may not be copied or distributed in any isomorphic form without the prior
5
+ # written consent of Axera Semiconductor Co., Ltd.
6
+ #
7
+
8
+ import ctypes.util
9
+ import platform
10
+
11
+ from cffi import FFI
12
+
13
+ __all__: ["sys_lib", "sys_cffi", "engine_lib", "engine_cffi"]
14
+
15
+ sys_cffi = FFI()
16
+
17
+ # ax_base_type.h
18
+ sys_cffi.cdef(
19
+ """
20
+ typedef int AX_S32;
21
+ typedef unsigned int AX_U32;
22
+ typedef unsigned long long int AX_U64;
23
+ typedef signed char AX_S8;
24
+ typedef void AX_VOID;
25
+ """
26
+ )
27
+
28
+ # ax_sys_api.h
29
+ sys_cffi.cdef(
30
+ """
31
+ AX_S32 AX_SYS_Init(AX_VOID);
32
+ AX_S32 AX_SYS_Deinit(AX_VOID);
33
+ AX_S32 AX_SYS_MemAllocCached(AX_U64 *phyaddr, AX_VOID **pviraddr, AX_U32 size, AX_U32 align, const AX_S8 *token);
34
+ AX_S32 AX_SYS_MemFree(AX_U64 phyaddr, AX_VOID *pviraddr);
35
+ AX_S32 AX_SYS_MflushCache(AX_U64 phyaddr, AX_VOID *pviraddr, AX_U32 size);
36
+ AX_S32 AX_SYS_MinvalidateCache(AX_U64 phyaddr, AX_VOID *pviraddr, AX_U32 size);
37
+ """
38
+ )
39
+
40
+ sys_name = "ax_sys"
41
+ sys_path = ctypes.util.find_library(sys_name)
42
+ assert (
43
+ sys_path is not None
44
+ ), f"Failed to find library {sys_name}. Please ensure it is installed and in the library path."
45
+
46
+ sys_lib = sys_cffi.dlopen(sys_path)
47
+ assert sys_lib is not None, f"Failed to load library {sys_path}. Please ensure it is installed and in the library path."
48
+
49
+ engine_cffi = FFI()
50
+
51
+ # ax_base_type.h
52
+ engine_cffi.cdef(
53
+ """
54
+ typedef unsigned long long int AX_U64;
55
+ typedef unsigned int AX_U32;
56
+ typedef unsigned char AX_U8;
57
+ typedef int AX_S32;
58
+ typedef signed char AX_S8;
59
+ typedef char AX_CHAR;
60
+ typedef void AX_VOID;
61
+
62
+ typedef enum {
63
+ AX_FALSE = 0,
64
+ AX_TRUE = 1,
65
+ } AX_BOOL;
66
+ """
67
+ )
68
+
69
+ # ax_engine_type.h, base type
70
+ engine_cffi.cdef(
71
+ """
72
+ typedef AX_U32 AX_ENGINE_NPU_SET_T;
73
+ """
74
+ )
75
+
76
+ # ax_engine_type.h, enum
77
+ engine_cffi.cdef(
78
+ """
79
+ typedef enum _AX_ENGINE_TENSOR_LAYOUT_E
80
+ {
81
+ AX_ENGINE_TENSOR_LAYOUT_UNKNOWN = 0,
82
+ AX_ENGINE_TENSOR_LAYOUT_NHWC = 1,
83
+ AX_ENGINE_TENSOR_LAYOUT_NCHW = 2,
84
+ } AX_ENGINE_TENSOR_LAYOUT_T;
85
+
86
+ typedef enum
87
+ {
88
+ AX_ENGINE_MT_PHYSICAL = 0,
89
+ AX_ENGINE_MT_VIRTUAL = 1,
90
+ AX_ENGINE_MT_OCM = 2,
91
+ } AX_ENGINE_MEMORY_TYPE_T;
92
+
93
+ typedef enum
94
+ {
95
+ AX_ENGINE_DT_UNKNOWN = 0,
96
+ AX_ENGINE_DT_UINT8 = 1,
97
+ AX_ENGINE_DT_UINT16 = 2,
98
+ AX_ENGINE_DT_FLOAT32 = 3,
99
+ AX_ENGINE_DT_SINT16 = 4,
100
+ AX_ENGINE_DT_SINT8 = 5,
101
+ AX_ENGINE_DT_SINT32 = 6,
102
+ AX_ENGINE_DT_UINT32 = 7,
103
+ AX_ENGINE_DT_FLOAT64 = 8,
104
+ AX_ENGINE_DT_BFLOAT16 = 9,
105
+ AX_ENGINE_DT_UINT10_PACKED = 100,
106
+ AX_ENGINE_DT_UINT12_PACKED = 101,
107
+ AX_ENGINE_DT_UINT14_PACKED = 102,
108
+ AX_ENGINE_DT_UINT16_PACKED = 103,
109
+ } AX_ENGINE_DATA_TYPE_T;
110
+
111
+ typedef enum
112
+ {
113
+ AX_ENGINE_CS_FEATUREMAP = 0,
114
+ AX_ENGINE_CS_RAW8 = 12,
115
+ AX_ENGINE_CS_RAW10 = 1,
116
+ AX_ENGINE_CS_RAW12 = 2,
117
+ AX_ENGINE_CS_RAW14 = 11,
118
+ AX_ENGINE_CS_RAW16 = 3,
119
+ AX_ENGINE_CS_NV12 = 4,
120
+ AX_ENGINE_CS_NV21 = 5,
121
+ AX_ENGINE_CS_RGB = 6,
122
+ AX_ENGINE_CS_BGR = 7,
123
+ AX_ENGINE_CS_RGBA = 8,
124
+ AX_ENGINE_CS_GRAY = 9,
125
+ AX_ENGINE_CS_YUV444 = 10,
126
+ } AX_ENGINE_COLOR_SPACE_T;
127
+ """
128
+ )
129
+
130
+ # ax_engine_type.h, architecturally agnostic struct
131
+ engine_cffi.cdef(
132
+ """
133
+ typedef enum {
134
+ AX_ENGINE_VIRTUAL_NPU_DISABLE = 0,
135
+ } AX_ENGINE_NPU_MODE_T;
136
+
137
+ typedef enum {
138
+ AX_ENGINE_MODEL_TYPE0 = 0,
139
+ } AX_ENGINE_MODEL_TYPE_T;
140
+
141
+ typedef struct {
142
+ AX_ENGINE_NPU_MODE_T eHardMode;
143
+ AX_U32 reserve[8];
144
+ } AX_ENGINE_NPU_ATTR_T;
145
+
146
+ typedef struct _AX_ENGINE_IO_META_EX_T
147
+ {
148
+ AX_ENGINE_COLOR_SPACE_T eColorSpace;
149
+ AX_U64 u64Reserved[18];
150
+ } AX_ENGINE_IO_META_EX_T;
151
+
152
+ typedef struct {
153
+ AX_ENGINE_NPU_SET_T nNpuSet;
154
+ AX_S8* pName;
155
+ AX_U32 reserve[8];
156
+ } AX_ENGINE_HANDLE_EXTRA_T;
157
+
158
+ typedef struct _AX_ENGINE_CMM_INFO_T
159
+ {
160
+ AX_U32 nCMMSize;
161
+ } AX_ENGINE_CMM_INFO_T;
162
+
163
+ typedef struct _AX_ENGINE_IO_SETTING_T
164
+ {
165
+ AX_U32 nWbtIndex;
166
+ AX_U64 u64Reserved[7];
167
+ }AX_ENGINE_IO_SETTING_T;
168
+ """
169
+ )
170
+
171
+ # check architecture, 32bit or 64bit
172
+ arch = platform.architecture()[0]
173
+
174
+ # ax_engine_type.h, struct
175
+ if arch == "64bit":
176
+ engine_cffi.cdef(
177
+ """
178
+ typedef struct _AX_ENGINE_IO_META_T
179
+ {
180
+ AX_CHAR* pName;
181
+ AX_S32* pShape;
182
+ AX_U8 nShapeSize;
183
+ AX_ENGINE_TENSOR_LAYOUT_T eLayout;
184
+ AX_ENGINE_MEMORY_TYPE_T eMemoryType;
185
+ AX_ENGINE_DATA_TYPE_T eDataType;
186
+ AX_ENGINE_IO_META_EX_T* pExtraMeta;
187
+ AX_U32 nSize;
188
+ AX_U32 nQuantizationValue;
189
+ AX_S32* pStride;
190
+ AX_U64 u64Reserved[9];
191
+ } AX_ENGINE_IO_META_T;
192
+
193
+ typedef struct _AX_ENGINE_IO_INFO_T
194
+ {
195
+ AX_ENGINE_IO_META_T* pInputs;
196
+ AX_U32 nInputSize;
197
+ AX_ENGINE_IO_META_T* pOutputs;
198
+ AX_U32 nOutputSize;
199
+ AX_U32 nMaxBatchSize;
200
+ AX_BOOL bDynamicBatchSize;
201
+ AX_U64 u64Reserved[11];
202
+ } AX_ENGINE_IO_INFO_T;
203
+
204
+ typedef struct _AX_ENGINE_IO_BUFFER_T
205
+ {
206
+ AX_U64 phyAddr;
207
+ AX_VOID* pVirAddr;
208
+ AX_U32 nSize;
209
+ AX_S32* pStride;
210
+ AX_U8 nStrideSize;
211
+ AX_U64 u64Reserved[11];
212
+ } AX_ENGINE_IO_BUFFER_T;
213
+
214
+ typedef struct _AX_ENGINE_IO_T
215
+ {
216
+ AX_ENGINE_IO_BUFFER_T* pInputs;
217
+ AX_U32 nInputSize;
218
+ AX_ENGINE_IO_BUFFER_T* pOutputs;
219
+ AX_U32 nOutputSize;
220
+ AX_U32 nBatchSize;
221
+ AX_ENGINE_IO_SETTING_T* pIoSetting;
222
+ AX_U64 u64Reserved[10];
223
+ } AX_ENGINE_IO_T;
224
+ """
225
+ )
226
+ else:
227
+ engine_cffi.cdef(
228
+ """
229
+ typedef struct _AX_ENGINE_IO_META_T
230
+ {
231
+ AX_CHAR* pName;
232
+ AX_S32* pShape;
233
+ AX_U8 nShapeSize;
234
+ AX_ENGINE_TENSOR_LAYOUT_T eLayout;
235
+ AX_ENGINE_MEMORY_TYPE_T eMemoryType;
236
+ AX_ENGINE_DATA_TYPE_T eDataType;
237
+ AX_ENGINE_IO_META_EX_T* pExtraMeta;
238
+ AX_U32 nSize;
239
+ AX_U32 nQuantizationValue;
240
+ AX_S32* pStride;
241
+ AX_U64 u64Reserved[11];
242
+ } AX_ENGINE_IO_META_T;
243
+
244
+ typedef struct _AX_ENGINE_IO_INFO_T
245
+ {
246
+ AX_ENGINE_IO_META_T* pInputs;
247
+ AX_U32 nInputSize;
248
+ AX_ENGINE_IO_META_T* pOutputs;
249
+ AX_U32 nOutputSize;
250
+ AX_U32 nMaxBatchSize;
251
+ AX_BOOL bDynamicBatchSize;
252
+ AX_U64 u64Reserved[13];
253
+ } AX_ENGINE_IO_INFO_T;
254
+
255
+ typedef struct _AX_ENGINE_IO_BUFFER_T
256
+ {
257
+ AX_U64 phyAddr;
258
+ AX_VOID* pVirAddr;
259
+ AX_U32 nSize;
260
+ AX_S32* pStride;
261
+ AX_U8 nStrideSize;
262
+ AX_U64 u64Reserved[13];
263
+ } AX_ENGINE_IO_BUFFER_T;
264
+
265
+ typedef struct _AX_ENGINE_IO_T
266
+ {
267
+ AX_ENGINE_IO_BUFFER_T* pInputs;
268
+ AX_U32 nInputSize;
269
+ AX_ENGINE_IO_BUFFER_T* pOutputs;
270
+ AX_U32 nOutputSize;
271
+ AX_U32 nBatchSize;
272
+ AX_ENGINE_IO_SETTING_T* pIoSetting;
273
+ AX_U64 u64Reserved[12];
274
+ } AX_ENGINE_IO_T;
275
+ """
276
+ )
277
+
278
+ # ax_engine_api.h
279
+ engine_cffi.cdef(
280
+ """
281
+ const AX_CHAR* AX_ENGINE_GetVersion(AX_VOID);
282
+
283
+ AX_VOID AX_ENGINE_NPUReset(AX_VOID);
284
+ AX_S32 AX_ENGINE_Init(AX_ENGINE_NPU_ATTR_T* pNpuAttr);
285
+ AX_S32 AX_ENGINE_GetVNPUAttr(AX_ENGINE_NPU_ATTR_T* pNpuAttr);
286
+ AX_S32 AX_ENGINE_Deinit(AX_VOID);
287
+
288
+ AX_S32 AX_ENGINE_GetModelType(const AX_VOID* pData, AX_U32 nDataSize, AX_ENGINE_MODEL_TYPE_T* pModelType);
289
+
290
+ AX_S32 AX_ENGINE_CreateHandleV2(uint64_t** pHandle, const AX_VOID* pData, AX_U32 nDataSize, AX_ENGINE_HANDLE_EXTRA_T* pExtraParam);
291
+ AX_S32 AX_ENGINE_DestroyHandle(uint64_t* nHandle);
292
+
293
+ AX_S32 AX_ENGINE_GetIOInfo(uint64_t* nHandle, AX_ENGINE_IO_INFO_T** pIO);
294
+ AX_S32 AX_ENGINE_GetGroupIOInfoCount(uint64_t* nHandle, AX_U32* pCount);
295
+ AX_S32 AX_ENGINE_GetGroupIOInfo(uint64_t* nHandle, AX_U32 nIndex, AX_ENGINE_IO_INFO_T** pIO);
296
+
297
+ AX_S32 AX_ENGINE_GetHandleModelType(uint64_t* nHandle, AX_ENGINE_MODEL_TYPE_T* pModelType);
298
+
299
+ AX_S32 AX_ENGINE_CreateContextV2(uint64_t* nHandle, uint64_t** pContext);
300
+
301
+ AX_S32 AX_ENGINE_RunSyncV2(uint64_t* handle, uint64_t* context, AX_ENGINE_IO_T* pIO);
302
+ AX_S32 AX_ENGINE_RunGroupIOSync(uint64_t* handle, uint64_t* context, AX_U32 nIndex, AX_ENGINE_IO_T* pIO);
303
+
304
+ AX_S32 AX_ENGINE_SetAffinity(uint64_t* nHandle, AX_ENGINE_NPU_SET_T nNpuSet);
305
+ AX_S32 AX_ENGINE_GetAffinity(uint64_t* nHandle, AX_ENGINE_NPU_SET_T* pNpuSet);
306
+
307
+ AX_S32 AX_ENGINE_GetCMMUsage(uint64_t* nHandle, AX_ENGINE_CMM_INFO_T* pCMMInfo);
308
+
309
+ const AX_CHAR* AX_ENGINE_GetModelToolsVersion(uint64_t* nHandle);
310
+
311
+ // internal use api, remember no question
312
+ AX_S32 AX_ENGINE_GetTotalOps();
313
+ """
314
+ )
315
+
316
+ engine_name = "ax_engine"
317
+ engine_path = ctypes.util.find_library(engine_name)
318
+ assert (
319
+ engine_path is not None
320
+ ), f"Failed to find library {engine_name}. Please ensure it is installed and in the library path."
321
+
322
+ engine_lib = engine_cffi.dlopen(engine_path)
323
+ assert engine_lib is not None, f"Failed to load library {engine_path}. Please ensure it is installed and in the library path."
python/axengine/_axe_types.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019-2024 Axera Semiconductor Co., Ltd. All Rights Reserved.
2
+ #
3
+ # This source file is the property of Axera Semiconductor Co., Ltd. and
4
+ # may not be copied or distributed in any isomorphic form without the prior
5
+ # written consent of Axera Semiconductor Co., Ltd.
6
+ #
7
+
8
+ from enum import Enum
9
+
10
+
11
+ class VNPUType(Enum):
12
+ DISABLED = 0
13
+ ENABLED = 1
14
+ BIG_LITTLE = 2
15
+ LITTLE_BIG = 3
16
+
17
+
18
+ class ModelType(Enum):
19
+ HALF = 0 # for MC20E, which means chip is AX630C(x), or AX620Q(x)
20
+ FULL = 1 # for MC20E
21
+ SINGLE = 0 # for MC50, which means chip is AX650A or AX650N, and M57H
22
+ DUAL = 1 # for MC50
23
+ TRIPLE = 2 # for MC50
24
+
25
+
26
+ class ChipType(Enum):
27
+ MC20E = 0
28
+ MC50 = 1
29
+ M57H = 2
python/axengine/_base_session.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019-2024 Axera Semiconductor Co., Ltd. All Rights Reserved.
2
+ #
3
+ # This source file is the property of Axera Semiconductor Co., Ltd. and
4
+ # may not be copied or distributed in any isomorphic form without the prior
5
+ # written consent of Axera Semiconductor Co., Ltd.
6
+ #
7
+
8
+ from abc import ABC, abstractmethod
9
+
10
+ import numpy as np
11
+
12
+ from ._node import NodeArg
13
+
14
+
15
+ class SessionOptions:
16
+ pass
17
+
18
+
19
+ class Session(ABC):
20
+ def __init__(self) -> None:
21
+ self._shape_count = 0
22
+ self._inputs = []
23
+ self._outputs = []
24
+
25
+ def _validate_input(self, feed_input_names: dict[str, np.ndarray]):
26
+ missing_input_names = []
27
+ for i in self.get_inputs():
28
+ if i.name not in feed_input_names:
29
+ missing_input_names.append(i.name)
30
+ if missing_input_names:
31
+ raise ValueError(
32
+ f"Required inputs ({missing_input_names}) are missing from input feed ({feed_input_names}).")
33
+
34
+ def _validate_output(self, output_names: list[str]):
35
+ if output_names is not None:
36
+ for name in output_names:
37
+ if name not in [o.name for o in self.get_outputs()]:
38
+ raise ValueError(f"Output name '{name}' is not in model outputs name list.")
39
+
40
+ def get_inputs(self, shape_group: int = 0) -> list[NodeArg]:
41
+ if shape_group > self._shape_count:
42
+ raise ValueError(f"Shape group '{shape_group}' is out of range, total {self._shape_count}.")
43
+ selected_info = self._inputs[shape_group]
44
+ return selected_info
45
+
46
+ def get_outputs(self, shape_group: int = 0) -> list[NodeArg]:
47
+ if shape_group > self._shape_count:
48
+ raise ValueError(f"Shape group '{shape_group}' is out of range, total {self._shape_count}.")
49
+ selected_info = self._outputs[shape_group]
50
+ return selected_info
51
+
52
+ @abstractmethod
53
+ def run(
54
+ self,
55
+ output_names: list[str] | None,
56
+ input_feed: dict[str, np.ndarray],
57
+ run_options=None
58
+ ) -> list[np.ndarray]:
59
+ pass
python/axengine/_node.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019-2024 Axera Semiconductor Co., Ltd. All Rights Reserved.
2
+ #
3
+ # This source file is the property of Axera Semiconductor Co., Ltd. and
4
+ # may not be copied or distributed in any isomorphic form without the prior
5
+ # written consent of Axera Semiconductor Co., Ltd.
6
+ #
7
+
8
+
9
+ class NodeArg(object):
10
+ def __init__(self, name, dtype, shape):
11
+ self.name = name
12
+ self.dtype = dtype
13
+ self.shape = shape
python/axengine/_providers.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019-2024 Axera Semiconductor Co., Ltd. All Rights Reserved.
2
+ #
3
+ # This source file is the property of Axera Semiconductor Co., Ltd. and
4
+ # may not be copied or distributed in any isomorphic form without the prior
5
+ # written consent of Axera Semiconductor Co., Ltd.
6
+ #
7
+
8
+ import ctypes.util as cutil
9
+
10
+ providers = []
11
+ axengine_provider_name = 'AxEngineExecutionProvider'
12
+ axclrt_provider_name = 'AXCLRTExecutionProvider'
13
+
14
+ _axengine_lib_name = 'ax_engine'
15
+ _axclrt_lib_name = 'axcl_rt'
16
+
17
+ # check if axcl_rt is installed, so if available, it's the default provider
18
+ if cutil.find_library(_axclrt_lib_name) is not None:
19
+ providers.append(axclrt_provider_name)
20
+
21
+ # check if ax_engine is installed
22
+ if cutil.find_library(_axengine_lib_name) is not None:
23
+ providers.append(axengine_provider_name)
24
+
25
+
26
+ def get_all_providers():
27
+ return [axengine_provider_name, axclrt_provider_name]
28
+
29
+
30
+ def get_available_providers():
31
+ return providers
python/axengine/_session.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019-2024 Axera Semiconductor Co., Ltd. All Rights Reserved.
2
+ #
3
+ # This source file is the property of Axera Semiconductor Co., Ltd. and
4
+ # may not be copied or distributed in any isomorphic form without the prior
5
+ # written consent of Axera Semiconductor Co., Ltd.
6
+ #
7
+
8
+ import os
9
+ from typing import Any, Sequence
10
+
11
+ import numpy as np
12
+
13
+ from ._base_session import SessionOptions
14
+ from ._node import NodeArg
15
+ from ._providers import axclrt_provider_name, axengine_provider_name
16
+ from ._providers import get_available_providers
17
+
18
+
19
+ class InferenceSession:
20
+ def __init__(
21
+ self,
22
+ path_or_bytes: str | bytes | os.PathLike,
23
+ sess_options: SessionOptions | None = None,
24
+ providers: Sequence[str | tuple[str, dict[Any, Any]]] | None = None,
25
+ provider_options: Sequence[dict[Any, Any]] | None = None, **kwargs,
26
+ ) -> None:
27
+ self._sess = None
28
+ self._sess_options = sess_options
29
+ self._provider = None
30
+ self._provider_options = None
31
+ self._available_providers = get_available_providers()
32
+
33
+ # the providers should be available at least one, checked in __init__.py
34
+ if providers is None:
35
+ # using first available provider as default
36
+ _provider_name = self._available_providers[0]
37
+ self._provider = _provider_name
38
+ else:
39
+ # if only one provider is specified
40
+ if isinstance(providers, str):
41
+ if providers not in self._available_providers:
42
+ raise ValueError(f"Selected provider: '{providers}' is not available.")
43
+ self._provider = providers
44
+ # if multiple providers are specified, using the first one as default
45
+ elif isinstance(providers, list):
46
+ _unavailable_provider = []
47
+ for p in providers:
48
+ assert isinstance(p, str) or isinstance(p, tuple), \
49
+ f"Invalid provider type: {type(p)}. Must be str or tuple."
50
+ if isinstance(p, str):
51
+ if p not in self._available_providers:
52
+ _unavailable_provider.append(p)
53
+ elif self._provider is None:
54
+ self._provider = p
55
+ if isinstance(p, tuple):
56
+ assert len(p) == 2, f"Invalid provider type: {p}. Must be tuple with 2 elements."
57
+ assert isinstance(p[0], str), f"Invalid provider type: {type(p[0])}. Must be str."
58
+ assert isinstance(p[1], dict), f"Invalid provider type: {type(p[1])}. Must be dict."
59
+ if p[0] not in self._available_providers:
60
+ _unavailable_provider.append(p[0])
61
+ elif self._provider is None:
62
+ self._provider = p[0]
63
+ # FIXME: check provider options
64
+ self._provider_options = p[1]
65
+ if _unavailable_provider:
66
+ if self._provider is None:
67
+ raise ValueError(f"Selected provider(s): {_unavailable_provider} is(are) not available.")
68
+ else:
69
+ print(f"[WARNING] Selected provider(s): {_unavailable_provider} is(are) not available.")
70
+
71
+ # FIXME: can we remove this check?
72
+ if self._provider is None:
73
+ raise ValueError(f"No available provider found in {providers}.")
74
+ print(f"[INFO] Using provider: {self._provider}")
75
+
76
+ if self._provider == axclrt_provider_name:
77
+ from ._axclrt import AXCLRTSession
78
+ self._sess = AXCLRTSession(path_or_bytes, sess_options, provider_options, **kwargs)
79
+ if self._provider == axengine_provider_name:
80
+ from ._axe import AXEngineSession
81
+ self._sess = AXEngineSession(path_or_bytes, sess_options, provider_options, **kwargs)
82
+ if self._sess is None:
83
+ raise RuntimeError(f"Create session failed with provider: {self._provider}")
84
+
85
+ # add to support 'with' statement
86
+ def __enter__(self):
87
+ return self
88
+
89
+ def __exit__(self, exc_type, exc_value, traceback):
90
+ # not suppress exceptions
91
+ return False
92
+
93
+ def get_session_options(self):
94
+ """
95
+ Return the session options. See :class:`axengine.SessionOptions`.
96
+ """
97
+ return self._sess_options
98
+
99
+ def get_providers(self):
100
+ """
101
+ Return list of registered execution providers.
102
+ """
103
+ return self._provider
104
+
105
+ def get_inputs(self, shape_group: int = 0) -> list[NodeArg]:
106
+ return self._sess.get_inputs(shape_group)
107
+
108
+ def get_outputs(self, shape_group: int = 0) -> list[NodeArg]:
109
+ return self._sess.get_outputs(shape_group)
110
+
111
+ def run(
112
+ self,
113
+ output_names: list[str] | None,
114
+ input_feed: dict[str, np.ndarray],
115
+ run_options=None
116
+ ) -> list[np.ndarray]:
117
+ return self._sess.run(output_names, input_feed, run_options)
python/examples/demo01.jpg ADDED
python/examples/demo02.jpg ADDED
python/examples/demo03.jpg ADDED
python/examples/demo04.jpg ADDED
python/examples/demo05.jpg ADDED
python/examples/demo06.jpg ADDED
python/examples/demo07.jpg ADDED
python/examples/demo08.jpg ADDED
python/examples/demo09.jpg ADDED
python/examples/demo10.jpg ADDED
python/examples/demo11.jpg ADDED
python/examples/demo12.jpg ADDED
python/examples/demo13.jpg ADDED
python/examples/demo14.jpg ADDED
python/examples/demo15.jpg ADDED
python/examples/demo16.jpg ADDED
python/examples/demo17.jpg ADDED
python/examples/demo18.jpg ADDED
python/examples/demo19.jpg ADDED

Git LFS Details

  • SHA256: 7cdb09c34eb0b4d2ac5f6070aec47c8f983a0b1b2c9ee1fc30decafb64f1bd98
  • Pointer size: 132 Bytes
  • Size of remote file: 1 MB
python/examples/demo20.jpg ADDED
python/infer.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import cv2
3
+ import numpy as np
4
+ from axengine import InferenceSession
5
+
6
+ def parse_args() -> argparse.Namespace:
7
+ parser = argparse.ArgumentParser()
8
+ parser.add_argument(
9
+ "--img",
10
+ type=str,
11
+ required=True,
12
+ help="Path to input image.",
13
+ )
14
+ parser.add_argument(
15
+ "--model",
16
+ type=str,
17
+ required=True,
18
+ help="Path to ONNX model.",
19
+ )
20
+
21
+ return parser.parse_args()
22
+
23
+
24
+ def infer(img: str, model: str, viz: bool = False):
25
+ img_raw = cv2.imread(img)
26
+ image = cv2.cvtColor(img_raw, cv2.COLOR_BGR2RGB)
27
+ orig_h, orig_w = image.shape[:2]
28
+ image = cv2.resize(image, (518,518) )
29
+ image = image[None]
30
+
31
+ session = InferenceSession(path_or_bytes=model, providers= ['AxEngineExecutionProvider', 'AXCLRTExecutionProvider'])
32
+
33
+ depth = session.run(output_names=["output"], input_feed={"input":image})[0]
34
+
35
+ depth = cv2.resize(depth[0, 0], (orig_w, orig_h))
36
+ depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
37
+ depth = depth.astype(np.uint8)
38
+
39
+ depth_color = cv2.applyColorMap(depth, cv2.COLORMAP_INFERNO)
40
+
41
+ combined_result = cv2.hconcat([img_raw, depth_color])
42
+
43
+ cv2.imwrite("output-ax.png", combined_result)
44
+
45
+ return depth
46
+
47
+
48
+ if __name__ == "__main__":
49
+ args = parse_args()
50
+ infer(**vars(args))
python/infer_onnx.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import cv2
3
+ import numpy as np
4
+ import onnxruntime as ort
5
+
6
+ def parse_args() -> argparse.Namespace:
7
+ parser = argparse.ArgumentParser()
8
+ parser.add_argument(
9
+ "--img",
10
+ type=str,
11
+ required=True,
12
+ help="Path to input image.",
13
+ )
14
+ parser.add_argument(
15
+ "--model",
16
+ type=str,
17
+ required=True,
18
+ help="Path to ONNX model.",
19
+ )
20
+
21
+ return parser.parse_args()
22
+
23
+
24
+ def infer(img: str, model: str, viz: bool = False):
25
+ img_raw = cv2.imread(img)
26
+ image = cv2.cvtColor(img_raw, cv2.COLOR_BGR2RGB)
27
+ orig_h, orig_w = image.shape[:2]
28
+ image = cv2.resize(image, (518,518) )
29
+ mean = np.array([123.675, 116.28, 103.53],dtype=np.float32).reshape(1,1,3)
30
+ std = np.array([58.395, 57.12, 57.375],dtype=np.float32).reshape(1,1,3)
31
+
32
+ image = (image-mean)/std
33
+ image = image.transpose(2,0,1)
34
+ image = image[None]
35
+
36
+ session = ort.InferenceSession(
37
+ model, providers=["CUDAExecutionProvider", "CPUExecutionProvider"]
38
+ )
39
+ depth = session.run(None, {"input": image})[0]
40
+
41
+ depth = cv2.resize(depth[0, 0], (orig_w, orig_h))
42
+ depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
43
+ depth = depth.astype(np.uint8)
44
+
45
+ depth_color = cv2.applyColorMap(depth, cv2.COLORMAP_INFERNO)
46
+
47
+ combined_result = cv2.hconcat([img_raw, depth_color])
48
+
49
+ cv2.imwrite("output-onnx.png", combined_result)
50
+
51
+ return depth
52
+
53
+
54
+ if __name__ == "__main__":
55
+ args = parse_args()
56
+ infer(**vars(args))
python/output.png ADDED

Git LFS Details

  • SHA256: e91f86e1a9584ccbd4f484e7e4b74673b633f35eab50a2ec05b28b174d9392c1
  • Pointer size: 132 Bytes
  • Size of remote file: 4.33 MB
python/requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ onnx
2
+ onnxruntime
3
+ opencv-python