Spaces:
Running
Running
update
Browse files- app.py +9 -5
- requirements.txt +1 -0
app.py
CHANGED
@@ -196,11 +196,15 @@ def start_process(input1, input2, shape0_str, shape1_str, input_suffix=".pth"):
|
|
196 |
print_log(task_id, input2, "转换为ONNX模型", "开始")
|
197 |
log += "转换为ONNX模型…\n"
|
198 |
yield [], log
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
|
|
|
|
|
|
|
|
204 |
|
205 |
if os.path.exists(pt_path):
|
206 |
print_log(task_id, input2, "转换为TorchScript模型", "跳过")
|
|
|
196 |
print_log(task_id, input2, "转换为ONNX模型", "开始")
|
197 |
log += "转换为ONNX模型…\n"
|
198 |
yield [], log
|
199 |
+
try:
|
200 |
+
# 使用 torch.onnx.export 进行模型转换
|
201 |
+
# 将列表转换为元组
|
202 |
+
shape_tuple = tuple(shape0)
|
203 |
+
torch.onnx.export(torch_model, torch.rand(shape_tuple), onnx_path, verbose=True, input_names=["data"], output_names=["output"])
|
204 |
+
except Exception as e:
|
205 |
+
print_log(task_id, input2, "转换为ONNX模型"+e, f"失败")
|
206 |
+
log += f"转换为ONNX模型失败: {e}\n"
|
207 |
+
yield [], log
|
208 |
|
209 |
if os.path.exists(pt_path):
|
210 |
print_log(task_id, input2, "转换为TorchScript模型", "跳过")
|
requirements.txt
CHANGED
@@ -1,3 +1,4 @@
|
|
1 |
spandrel
|
2 |
torch
|
3 |
pnnx
|
|
|
|
1 |
spandrel
|
2 |
torch
|
3 |
pnnx
|
4 |
+
onnx
|