tumuyan2 commited on
Commit
d29dbb1
·
1 Parent(s): c7a87ae

update App

Browse files
Files changed (2) hide show
  1. .gitignore +5 -0
  2. app.py +115 -84
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ *.pt
2
+ *.pth
3
+ *.bin
4
+ *.onnx
5
+ *.param
app.py CHANGED
@@ -8,89 +8,40 @@ from spandrel import ImageModelDescriptor, ModelLoader
8
  import torch
9
  import subprocess
10
 
 
 
11
 
12
  # 新增日志开关
13
  log_to_terminal = True
14
 
 
 
 
15
  # 新增日志函数
16
  def print_log(task_id, filename, stage, status):
17
  if log_to_terminal:
18
  print(f"任务 ID: {task_id}, 文件名: {filename}, 状态: [{status}], 阶段: {stage}")
19
 
20
- # 修改 process_file 函数,接收 shape 数组
21
- def process_file(task_id, file_path, model_name, output_folder, shape0, shape1):
22
- # 使用 torch.rand 生成 input_shape
23
- log = ""
24
- print_log(task_id, model_name, "生成输入张量", "开始")
25
- pt_path = output_folder + "/" + model_name + ".pt"
26
- input_tensor0 = torch.rand(shape0) if any(shape0) else None
27
- input_tensor1 = torch.rand(shape1) if any(shape1) else None
28
- if input_tensor0 is not None and input_tensor1 is not None:
29
- example_input = (input_tensor0, input_tensor1)
30
- # 修改此处,去除 shape 字符串中的空格
31
- command = f"pnnx {pt_path} inputshape={str(shape0).replace(' ', '')} inputshape2={str(shape1).replace(' ', '')}"
32
- elif input_tensor0 is not None:
33
- example_input = input_tensor0
34
- command = f"pnnx {pt_path} inputshape={str(shape0).replace(' ', '')}"
35
- else:
36
- example_input = input_tensor1
37
- command = f"pnnx {pt_path}"
38
- print_log(task_id, model_name, "生成输入张量", "完成")
39
-
40
- # 确保 output_folder 存在
41
- if not os.path.exists(output_folder):
42
- os.makedirs(output_folder)
43
-
44
- # load a model from disk
45
- model = ModelLoader().load_from_file(file_path)
46
-
47
- # make sure it's an image to image model
48
- assert isinstance(model, ImageModelDescriptor)
49
-
50
- print_log(task_id, model_name, "加载模型", "完成")
51
- # send it to the GPU and put it in inference mode
52
- # model.cuda().eval()
53
- model.eval()
54
- torch_model = model.model
55
- print_log(task_id, model_name, "获得模型对象", "完成")
56
- if os.path.exists(pt_path):
57
- print_log(task_id, model_name, "转换为TorchScript模型", "跳过")
58
- else:
59
- print_log(task_id, model_name, "转换为TorchScript模型", "开始")
60
- # 使用 torch.jit.trace 进行模型转换
61
- traced_torch_model = torch.jit.trace(torch_model, example_input)
62
- traced_torch_model.save(output_folder + "/" + model_name + ".pt")
63
- print_log(task_id, model_name, "转换为TorchScript模型", "完成")
64
-
65
- print_log(task_id, model_name, "运行命令"+command, "开始")
66
-
67
- try:
68
- # 使用 subprocess.Popen 执行命令
69
- process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
70
- while True:
71
- output = process.stdout.readline()
72
- if output == '' and process.poll() is not None:
73
- break
74
- if output:
75
- log += output.strip() + '\n'
76
- if log_to_terminal:
77
- print(output.strip())
78
- returncode = process.poll()
79
- if returncode != 0:
80
- log += f"执行命令: {command} 失败,返回码: {returncode}\n"
81
  else:
82
- log += f"执行命令: {command} 成功\n"
83
- except Exception as e:
84
- log += f"执行命令: {command} 失败,错误信息: {str(e)}\n"
85
 
86
- return [os.path.join(output_folder, f) for f in os.listdir(output_folder) if os.path.isfile(os.path.join(output_folder, f))], log
 
 
 
 
 
87
 
88
- # 修改为字典类型
89
- downloaded_files = {}
90
- # 修改 start_process 函数,处理新增输入
91
- def start_process(input1, input2, shape0_str, shape1_str, input_suffix=".pth"):
92
- task_id = str(uuid.uuid4())
93
- log = ""
94
  try:
95
  # 判断 input1 是地址还是文件,增加对 ftp 和 webdav 协议的支持
96
  supported_protocols = ('http://', 'https://', 'ftp://', 'webdav://')
@@ -98,10 +49,13 @@ def start_process(input1, input2, shape0_str, shape1_str, input_suffix=".pth"):
98
  url = input1
99
  if url in downloaded_files and os.path.exists(downloaded_files[url]):
100
  file_path = downloaded_files[url]
101
- log += f"跳过下载,文件已存在: {file_path}\n"
102
  print_log(task_id, input2, "检查下载状态", "跳过下载")
 
103
  yield [], log
104
  else:
 
 
 
105
  # 生成唯一文件名
106
  file_name = str(uuid.uuid4()) + input_suffix
107
  file_path = os.path.join(os.getcwd(), file_name)
@@ -117,27 +71,29 @@ def start_process(input1, input2, shape0_str, shape1_str, input_suffix=".pth"):
117
  ftp.retrbinary('RETR ' + remote_file_path, f.write)
118
  ftp.quit()
119
  downloaded_files[url] = file_path
 
120
  log += f"文件下载成功: {file_path}\n"
121
  yield [], log
122
  except Exception as e:
123
- log += f"FTP 文件下载失败: {str(e)}\n"
124
  print_log(task_id, input2, "下载文件", f"失败 (FTP): {str(e)}")
 
125
  yield [], log
126
  return
127
- else :
128
  if url.startswith(('http://', 'https://')):
129
  response = requests.get(url)
130
  if response.status_code == 200:
131
  with open(file_path, 'wb') as f:
132
  f.write(response.content)
133
  downloaded_files[url] = file_path
 
134
  log += f"文件下载成功: {file_path}\n"
135
  yield [], log
136
  else:
 
137
  log += f"文件下载失败,状态码: {response.status_code}\n"
138
  yield [], log
139
  return
140
-
141
  elif input1 is not None:
142
  file_path = input1.name
143
  log += f"使用上传的文件: {file_path}\n"
@@ -152,10 +108,9 @@ def start_process(input1, input2, shape0_str, shape1_str, input_suffix=".pth"):
152
  # 生成新文件夹用于暂存结果
153
  output_folder = os.path.join(os.getcwd(), str(uuid.uuid4()))
154
  os.makedirs(output_folder, exist_ok=True)
155
- log += f"创建临时文件夹: {output_folder}\n"
156
  print_log(task_id, input2, "创建临时文件夹", "完成")
 
157
  yield [], log
158
-
159
  # 解析输入的字符串为数组
160
  try:
161
  # 尝试解析 shape0_str
@@ -176,18 +131,94 @@ def start_process(input1, input2, shape0_str, shape1_str, input_suffix=".pth"):
176
  yield [], log
177
  return
178
 
179
- # 调用处理函数,传递 shape 数组
180
- output_files, process_log = process_file(task_id, file_path, input2, output_folder, shape0, shape1)
181
- log += process_log
182
- log += f"处理完成,输出文件: {output_files}\n"
183
- print_log(task_id, input2, "调用处理函数", "完成")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
  yield output_files, log
185
  except Exception as e:
186
  log += f"发生错误: {str(e)}\n"
187
- print_log(task_id, input2, "整体处理", f"失败: {str(e)}")
188
  yield [], log
189
 
190
-
191
  # 创建 Gradio 界面
192
  with gr.Blocks() as demo:
193
  gr.Markdown("文件处理界面")
 
8
  import torch
9
  import subprocess
10
 
11
+ # 定义 downloaded_files 变量
12
+ downloaded_files = {}
13
 
14
  # 新增日志开关
15
  log_to_terminal = True
16
 
17
+ # 新增全局任务计数器
18
+ task_counter = 0
19
+
20
  # 新增日志函数
21
  def print_log(task_id, filename, stage, status):
22
  if log_to_terminal:
23
  print(f"任务 ID: {task_id}, 文件名: {filename}, 状态: [{status}], 阶段: {stage}")
24
 
25
+ # 修改 start_process 函数,处理新增输入
26
+ def start_process(input1, input2, shape0_str, shape1_str, input_suffix=".pth"):
27
+ global task_counter
28
+ task_counter += 1
29
+ task_id = task_counter
30
+ log = "转换过程非常慢,请耐心等待。显示文件列表不代表转换完成。如果未发生错误,转换结束会显示”任务完成“\n"
31
+ yield [], log
32
+ if input2 == None or input2.strip() == "":
33
+ if isinstance(input1, str):
34
+ input2 = os.path.splitext(os.path.basename(input1))[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  else:
36
+ input2 = os.path.splitext(os.path.basename(input1.name))[0]
 
 
37
 
38
+ if input2 == "":
39
+ input2 = str(task_id)
40
+ log += f"未提供文件名,使用{input2}\n"
41
+ print_log(task_id, input2, f"未提供文件名,使用{input2}", "修正")
42
+ yield [], log
43
+ input2 = "output"
44
 
 
 
 
 
 
 
45
  try:
46
  # 判断 input1 是地址还是文件,增加对 ftp 和 webdav 协议的支持
47
  supported_protocols = ('http://', 'https://', 'ftp://', 'webdav://')
 
49
  url = input1
50
  if url in downloaded_files and os.path.exists(downloaded_files[url]):
51
  file_path = downloaded_files[url]
 
52
  print_log(task_id, input2, "检查下载状态", "跳过下载")
53
+ log += f"跳过下载,文件已存在: {file_path}\n"
54
  yield [], log
55
  else:
56
+ print_log(task_id, input2, "下载文件", "开始")
57
+ log += f"开始下载文件…\n"
58
+ yield [], log
59
  # 生成唯一文件名
60
  file_name = str(uuid.uuid4()) + input_suffix
61
  file_path = os.path.join(os.getcwd(), file_name)
 
71
  ftp.retrbinary('RETR ' + remote_file_path, f.write)
72
  ftp.quit()
73
  downloaded_files[url] = file_path
74
+ print_log(task_id, input2, "下载文件", "成功")
75
  log += f"文件下载成功: {file_path}\n"
76
  yield [], log
77
  except Exception as e:
 
78
  print_log(task_id, input2, "下载文件", f"失败 (FTP): {str(e)}")
79
+ log += f"FTP 文件下载失败: {str(e)}\n"
80
  yield [], log
81
  return
82
+ else:
83
  if url.startswith(('http://', 'https://')):
84
  response = requests.get(url)
85
  if response.status_code == 200:
86
  with open(file_path, 'wb') as f:
87
  f.write(response.content)
88
  downloaded_files[url] = file_path
89
+ print_log(task_id, input2, "下载文件", "成功")
90
  log += f"文件下载成功: {file_path}\n"
91
  yield [], log
92
  else:
93
+ print_log(task_id, input2, f"下载文件(HTTP): {response.status_code}", "失败")
94
  log += f"文件下载失败,状态码: {response.status_code}\n"
95
  yield [], log
96
  return
 
97
  elif input1 is not None:
98
  file_path = input1.name
99
  log += f"使用上传的文件: {file_path}\n"
 
108
  # 生成新文件夹用于暂存结果
109
  output_folder = os.path.join(os.getcwd(), str(uuid.uuid4()))
110
  os.makedirs(output_folder, exist_ok=True)
 
111
  print_log(task_id, input2, "创建临时文件夹", "完成")
112
+ log += f"创建临时文件夹: {output_folder}\n生成张量\n"
113
  yield [], log
 
114
  # 解析输入的字符串为数组
115
  try:
116
  # 尝试解析 shape0_str
 
131
  yield [], log
132
  return
133
 
134
+ # 以下是 process_file 函数的代码
135
+ # 使用 torch.rand 生成 input_shape
136
+ print_log(task_id, input2, "生成输入张量", "开始")
137
+ log += "生成张量…\n"
138
+ yield [], log
139
+ pt_path = output_folder + "/" + input2 + ".pt"
140
+ input_tensor0 = torch.rand(shape0) if any(shape0) else None
141
+ input_tensor1 = torch.rand(shape1) if any(shape1) else None
142
+ if input_tensor0 is not None and input_tensor1 is not None:
143
+ example_input = (input_tensor0, input_tensor1)
144
+ # 修改此处,去除 shape 字符串中的空格
145
+ command = f"pnnx {pt_path} inputshape={str(shape0).replace(' ', '')} inputshape2={str(shape1).replace(' ', '')}"
146
+ elif input_tensor0 is not None:
147
+ example_input = input_tensor0
148
+ command = f"pnnx {pt_path} inputshape={str(shape0).replace(' ', '')}"
149
+ else:
150
+ example_input = input_tensor1
151
+ command = f"pnnx {pt_path}"
152
+ print_log(task_id, input2, "生成输入张量", "完成")
153
+
154
+
155
+ # 确保 output_folder 存在
156
+ if not os.path.exists(output_folder):
157
+ os.makedirs(output_folder)
158
+
159
+ print_log(task_id, input2, "加载模型", "开始")
160
+ log += "加载模型…\n"
161
+ yield [], log
162
+ # load a model from disk
163
+ model = ModelLoader().load_from_file(file_path)
164
+
165
+ # make sure it's an image to image model
166
+ assert isinstance(model, ImageModelDescriptor)
167
+
168
+ print_log(task_id, input2, "获得���型对象", "开始")
169
+ log += "获得模型对象…\n"
170
+ yield [], log
171
+ # send it to the GPU and put it in inference mode
172
+ # model.cuda().eval()
173
+ model.eval()
174
+ torch_model = model.model
175
+ print_log(task_id, input2, "获得模型对象", "完成")
176
+ yield [], log
177
+ if os.path.exists(pt_path):
178
+ print_log(task_id, input2, "转换为TorchScript模型", "跳过")
179
+ log += "跳过转换为TorchScript模型\n"
180
+ yield [], log
181
+ else:
182
+ print_log(task_id, input2, "转换为TorchScript模型", "开始")
183
+ log+= "转换为TorchScript模型…\n"
184
+ yield [], log
185
+ # 使用 torch.jit.trace 进行模型转换
186
+ traced_torch_model = torch.jit.trace(torch_model, example_input)
187
+ traced_torch_model.save(output_folder + "/" + input2 + ".pt")
188
+ print_log(task_id, input2, "转换为TorchScript模型", "完成")
189
+
190
+ print_log(task_id, input2, "执行命令" + command, "开始")
191
+ log += "执行命令…\n"
192
+ yield [], log
193
+
194
+ try:
195
+ # 使用 subprocess.Popen 执行命令
196
+ process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
197
+ while True:
198
+ output = process.stdout.readline()
199
+ if output == '' and process.poll() is not None:
200
+ break
201
+ if output:
202
+ log += output.strip() + '\n'
203
+ if log_to_terminal:
204
+ print(output.strip())
205
+ returncode = process.poll()
206
+ if returncode != 0:
207
+ log += f"执行命令: {command} 失败,返回码: {returncode}\n"
208
+ else:
209
+ log += f"执行命令: {command} 成功\n"
210
+ except Exception as e:
211
+ log += f"执行命令: {command} 失败,错误信息: {str(e)}\n"
212
+
213
+ output_files = [os.path.join(output_folder, f) for f in os.listdir(output_folder) if os.path.isfile(os.path.join(output_folder, f))]
214
+ log += f"任务完成,输出文件: {output_files}\n"
215
+ print_log(task_id, input2, "执行命令", "完成")
216
  yield output_files, log
217
  except Exception as e:
218
  log += f"发生错误: {str(e)}\n"
219
+ print_log(task_id, input2,str(e) , f"失败")
220
  yield [], log
221
 
 
222
  # 创建 Gradio 界面
223
  with gr.Blocks() as demo:
224
  gr.Markdown("文件处理界面")