Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -38,15 +38,12 @@ def start_process(input1, input2, shape0_str, shape1_str, input_suffix=".pth"):
|
|
38 |
print_log(task_id, input2, "不支持此文件的格式 suffix="+suffix, "错误")
|
39 |
log += f"不支持此文件的格式\n"
|
40 |
return [] , log
|
|
|
41 |
print_log(task_id, input2, "检查文件名", "开始")
|
42 |
log += f"检查文件名…\n"
|
43 |
yield [], log
|
44 |
if input2 == None or input2.strip() == "":
|
45 |
input2 = str(task_id)
|
46 |
-
|
47 |
-
input1 = split_input[0]
|
48 |
-
if input2 == "":
|
49 |
-
input2 = str(task_id)
|
50 |
log += f"未提供文件名,使用{input2}\n"
|
51 |
print_log(task_id, input2, f"未提供文件名,使用{input2}", "修正")
|
52 |
yield [], log
|
@@ -161,7 +158,7 @@ def start_process(input1, input2, shape0_str, shape1_str, input_suffix=".pth"):
|
|
161 |
log += "生成张量…\n"
|
162 |
yield [], log
|
163 |
pt_path = output_folder + "/" + input2 + ".pt"
|
164 |
-
onnx_path = output_folder + "/" + input2 + ".onnx"
|
165 |
input_tensor0 = torch.rand(shape0) if any(shape0) else None
|
166 |
input_tensor1 = torch.rand(shape1) if any(shape1) else None
|
167 |
if input_tensor0 is not None and input_tensor1 is not None:
|
@@ -200,6 +197,8 @@ def start_process(input1, input2, shape0_str, shape1_str, input_suffix=".pth"):
|
|
200 |
print_log(task_id, input2, "获得模型对象", "完成")
|
201 |
yield [], log
|
202 |
|
|
|
|
|
203 |
if os.path.exists(pt_path):
|
204 |
print_log(task_id, input2, "转换为TorchScript模型", "跳过")
|
205 |
log += "跳过转换为TorchScript模型\n"
|
@@ -213,6 +212,13 @@ def start_process(input1, input2, shape0_str, shape1_str, input_suffix=".pth"):
|
|
213 |
traced_torch_model.save(output_folder + "/" + input2 + ".pt")
|
214 |
print_log(task_id, input2, "转换为TorchScript模型", "完成")
|
215 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
216 |
print_log(task_id, input2, "执行命令" + command, "开始")
|
217 |
log += "执行命令…\n"
|
218 |
yield [], log
|
@@ -237,6 +243,28 @@ def start_process(input1, input2, shape0_str, shape1_str, input_suffix=".pth"):
|
|
237 |
log += f"执行命令成功: {command} \n"
|
238 |
except Exception as e:
|
239 |
log += f"执行命令: {command} 失败,错误信息: {str(e)}\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
240 |
|
241 |
output_files = [os.path.join(output_folder, f) for f in os.listdir(output_folder) if os.path.isfile(os.path.join(output_folder, f))]
|
242 |
log += f"任务完成\n"
|
|
|
38 |
print_log(task_id, input2, "不支持此文件的格式 suffix="+suffix, "错误")
|
39 |
log += f"不支持此文件的格式\n"
|
40 |
return [] , log
|
41 |
+
input2 = split_input[0]
|
42 |
print_log(task_id, input2, "检查文件名", "开始")
|
43 |
log += f"检查文件名…\n"
|
44 |
yield [], log
|
45 |
if input2 == None or input2.strip() == "":
|
46 |
input2 = str(task_id)
|
|
|
|
|
|
|
|
|
47 |
log += f"未提供文件名,使用{input2}\n"
|
48 |
print_log(task_id, input2, f"未提供文件名,使用{input2}", "修正")
|
49 |
yield [], log
|
|
|
158 |
log += "生成张量…\n"
|
159 |
yield [], log
|
160 |
pt_path = output_folder + "/" + input2 + ".pt"
|
161 |
+
# onnx_path = output_folder + "/" + input2 + ".onnx"
|
162 |
input_tensor0 = torch.rand(shape0) if any(shape0) else None
|
163 |
input_tensor1 = torch.rand(shape1) if any(shape1) else None
|
164 |
if input_tensor0 is not None and input_tensor1 is not None:
|
|
|
197 |
print_log(task_id, input2, "获得模型对象", "完成")
|
198 |
yield [], log
|
199 |
|
200 |
+
|
201 |
+
width_ratio = 4
|
202 |
if os.path.exists(pt_path):
|
203 |
print_log(task_id, input2, "转换为TorchScript模型", "跳过")
|
204 |
log += "跳过转换为TorchScript模型\n"
|
|
|
212 |
traced_torch_model.save(output_folder + "/" + input2 + ".pt")
|
213 |
print_log(task_id, input2, "转换为TorchScript模型", "完成")
|
214 |
|
215 |
+
# 获取输出
|
216 |
+
example_output = traced_torch_model(example_input)
|
217 |
+
width_ratio = example_output.shape[2] / example_input.shape[2]
|
218 |
+
print_log(task_id, input2, "获得缩放倍率="+width_ratio+", 输出shape="+example_output.shape, "完成")
|
219 |
+
log+= "获得缩放倍率="+width_ratio+", 输出shape="+example_output.shape+"\n"
|
220 |
+
yield [], log
|
221 |
+
|
222 |
print_log(task_id, input2, "执行命令" + command, "开始")
|
223 |
log += "执行命令…\n"
|
224 |
yield [], log
|
|
|
243 |
log += f"执行命令成功: {command} \n"
|
244 |
except Exception as e:
|
245 |
log += f"执行命令: {command} 失败,错误信息: {str(e)}\n"
|
246 |
+
|
247 |
+
# 检查文件是否存在
|
248 |
+
bin_file = os.path.join(output_folder, input2 + ".ncnn.bin")
|
249 |
+
param_file = os.path.join(output_folder, input2 + ".ncnn.param")
|
250 |
+
if os.path.exists(bin_file) and os.path.exists(param_file):
|
251 |
+
import zipfile
|
252 |
+
# 压缩包名称
|
253 |
+
zip_file_name = os.path.join(output_folder, f"models-{input2}.zip")
|
254 |
+
# 压缩包内文件夹名称
|
255 |
+
zip_folder_name = f"models-{input2}"
|
256 |
+
# 重命名后的文件名
|
257 |
+
new_bin_name = f"x{width_ratio}.bin"
|
258 |
+
new_param_name = f"x{width_ratio}.param"
|
259 |
+
# 创建压缩包
|
260 |
+
with zipfile.ZipFile(zip_file_name, 'w', zipfile.ZIP_DEFLATED) as zipf:
|
261 |
+
# 写入重命名后的.bin文件
|
262 |
+
zipf.write(bin_file, os.path.join(zip_folder_name, new_bin_name))
|
263 |
+
# 写入重命名后的.param文件
|
264 |
+
zipf.write(param_file, os.path.join(zip_folder_name, new_param_name))
|
265 |
+
log += f"已创建压缩包: {zip_file_name}\n"
|
266 |
+
print_log(task_id, input2, "创建压缩包"+zip_file_name, "完成")
|
267 |
+
yield [], log
|
268 |
|
269 |
output_files = [os.path.join(output_folder, f) for f in os.listdir(output_folder) if os.path.isfile(os.path.join(output_folder, f))]
|
270 |
log += f"任务完成\n"
|