tumuyan2 commited on
Commit
9b29e0b
·
1 Parent(s): a4a3999
Files changed (2) hide show
  1. app.py +9 -5
  2. requirements.txt +1 -0
app.py CHANGED
@@ -196,11 +196,15 @@ def start_process(input1, input2, shape0_str, shape1_str, input_suffix=".pth"):
196
  print_log(task_id, input2, "转换为ONNX模型", "开始")
197
  log += "转换为ONNX模型…\n"
198
  yield [], log
199
- # 使用 torch.onnx.export 进行模型转换
200
- # 将列表转换为元组
201
- shape_tuple = tuple(shape0)
202
- torch.onnx.export(torch_model, torch.rand(shape_tuple), onnx_path, verbose=True, input_names=["data"], output_names=["output"])
203
-
 
 
 
 
204
 
205
  if os.path.exists(pt_path):
206
  print_log(task_id, input2, "转换为TorchScript模型", "跳过")
 
196
  print_log(task_id, input2, "转换为ONNX模型", "开始")
197
  log += "转换为ONNX模型…\n"
198
  yield [], log
199
+ try:
200
+ # 使用 torch.onnx.export 进行模型转换
201
+ # 将列表转换为元组
202
+ shape_tuple = tuple(shape0)
203
+ torch.onnx.export(torch_model, torch.rand(shape_tuple), onnx_path, verbose=True, input_names=["data"], output_names=["output"])
204
+ except Exception as e:
205
+ print_log(task_id, input2, "转换为ONNX模型"+e, f"失败")
206
+ log += f"转换为ONNX模型失败: {e}\n"
207
+ yield [], log
208
 
209
  if os.path.exists(pt_path):
210
  print_log(task_id, input2, "转换为TorchScript模型", "跳过")
requirements.txt CHANGED
@@ -1,3 +1,4 @@
1
  spandrel
2
  torch
3
  pnnx
 
 
1
  spandrel
2
  torch
3
  pnnx
4
+ onnx