tumuyan2 commited on
Commit
ee22f2b
·
1 Parent(s): 097f8b4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -5
app.py CHANGED
@@ -38,15 +38,12 @@ def start_process(input1, input2, shape0_str, shape1_str, input_suffix=".pth"):
38
  print_log(task_id, input2, "不支持此文件的格式 suffix="+suffix, "错误")
39
  log += f"不支持此文件的格式\n"
40
  return [] , log
 
41
  print_log(task_id, input2, "检查文件名", "开始")
42
  log += f"检查文件名…\n"
43
  yield [], log
44
  if input2 == None or input2.strip() == "":
45
  input2 = str(task_id)
46
-
47
- input1 = split_input[0]
48
- if input2 == "":
49
- input2 = str(task_id)
50
  log += f"未提供文件名,使用{input2}\n"
51
  print_log(task_id, input2, f"未提供文件名,使用{input2}", "修正")
52
  yield [], log
@@ -161,7 +158,7 @@ def start_process(input1, input2, shape0_str, shape1_str, input_suffix=".pth"):
161
  log += "生成张量…\n"
162
  yield [], log
163
  pt_path = output_folder + "/" + input2 + ".pt"
164
- onnx_path = output_folder + "/" + input2 + ".onnx"
165
  input_tensor0 = torch.rand(shape0) if any(shape0) else None
166
  input_tensor1 = torch.rand(shape1) if any(shape1) else None
167
  if input_tensor0 is not None and input_tensor1 is not None:
@@ -200,6 +197,8 @@ def start_process(input1, input2, shape0_str, shape1_str, input_suffix=".pth"):
200
  print_log(task_id, input2, "获得模型对象", "完成")
201
  yield [], log
202
 
 
 
203
  if os.path.exists(pt_path):
204
  print_log(task_id, input2, "转换为TorchScript模型", "跳过")
205
  log += "跳过转换为TorchScript模型\n"
@@ -213,6 +212,13 @@ def start_process(input1, input2, shape0_str, shape1_str, input_suffix=".pth"):
213
  traced_torch_model.save(output_folder + "/" + input2 + ".pt")
214
  print_log(task_id, input2, "转换为TorchScript模型", "完成")
215
 
 
 
 
 
 
 
 
216
  print_log(task_id, input2, "执行命令" + command, "开始")
217
  log += "执行命令…\n"
218
  yield [], log
@@ -237,6 +243,28 @@ def start_process(input1, input2, shape0_str, shape1_str, input_suffix=".pth"):
237
  log += f"执行命令成功: {command} \n"
238
  except Exception as e:
239
  log += f"执行命令: {command} 失败,错误信息: {str(e)}\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
 
241
  output_files = [os.path.join(output_folder, f) for f in os.listdir(output_folder) if os.path.isfile(os.path.join(output_folder, f))]
242
  log += f"任务完成\n"
 
38
  print_log(task_id, input2, "不支持此文件的格式 suffix="+suffix, "错误")
39
  log += f"不支持此文件的格式\n"
40
  return [] , log
41
+ input2 = split_input[0]
42
  print_log(task_id, input2, "检查文件名", "开始")
43
  log += f"检查文件名…\n"
44
  yield [], log
45
  if input2 == None or input2.strip() == "":
46
  input2 = str(task_id)
 
 
 
 
47
  log += f"未提供文件名,使用{input2}\n"
48
  print_log(task_id, input2, f"未提供文件名,使用{input2}", "修正")
49
  yield [], log
 
158
  log += "生成张量…\n"
159
  yield [], log
160
  pt_path = output_folder + "/" + input2 + ".pt"
161
+ # onnx_path = output_folder + "/" + input2 + ".onnx"
162
  input_tensor0 = torch.rand(shape0) if any(shape0) else None
163
  input_tensor1 = torch.rand(shape1) if any(shape1) else None
164
  if input_tensor0 is not None and input_tensor1 is not None:
 
197
  print_log(task_id, input2, "获得模型对象", "完成")
198
  yield [], log
199
 
200
+
201
+ width_ratio = 4
202
  if os.path.exists(pt_path):
203
  print_log(task_id, input2, "转换为TorchScript模型", "跳过")
204
  log += "跳过转换为TorchScript模型\n"
 
212
  traced_torch_model.save(output_folder + "/" + input2 + ".pt")
213
  print_log(task_id, input2, "转换为TorchScript模型", "完成")
214
 
215
+ # 获取输出
216
+ example_output = traced_torch_model(example_input)
217
+ width_ratio = example_output.shape[2] / example_input.shape[2]
218
+ print_log(task_id, input2, "获得缩放倍率="+width_ratio+", 输出shape="+example_output.shape, "完成")
219
+ log+= "获得缩放倍率="+width_ratio+", 输出shape="+example_output.shape+"\n"
220
+ yield [], log
221
+
222
  print_log(task_id, input2, "执行命令" + command, "开始")
223
  log += "执行命令…\n"
224
  yield [], log
 
243
  log += f"执行命令成功: {command} \n"
244
  except Exception as e:
245
  log += f"执行命令: {command} 失败,错误信息: {str(e)}\n"
246
+
247
+ # 检查文件是否存在
248
+ bin_file = os.path.join(output_folder, input2 + ".ncnn.bin")
249
+ param_file = os.path.join(output_folder, input2 + ".ncnn.param")
250
+ if os.path.exists(bin_file) and os.path.exists(param_file):
251
+ import zipfile
252
+ # 压缩包名称
253
+ zip_file_name = os.path.join(output_folder, f"models-{input2}.zip")
254
+ # 压缩包内文件夹名称
255
+ zip_folder_name = f"models-{input2}"
256
+ # 重命名后的文件名
257
+ new_bin_name = f"x{width_ratio}.bin"
258
+ new_param_name = f"x{width_ratio}.param"
259
+ # 创建压缩包
260
+ with zipfile.ZipFile(zip_file_name, 'w', zipfile.ZIP_DEFLATED) as zipf:
261
+ # 写入重命名后的.bin文件
262
+ zipf.write(bin_file, os.path.join(zip_folder_name, new_bin_name))
263
+ # 写入重命名后的.param文件
264
+ zipf.write(param_file, os.path.join(zip_folder_name, new_param_name))
265
+ log += f"已创建压缩包: {zip_file_name}\n"
266
+ print_log(task_id, input2, "创建压缩包"+zip_file_name, "完成")
267
+ yield [], log
268
 
269
  output_files = [os.path.join(output_folder, f) for f in os.listdir(output_folder) if os.path.isfile(os.path.join(output_folder, f))]
270
  log += f"任务完成\n"