File size: 15,602 Bytes
4d5d386
 
 
 
 
 
 
 
 
 
d29dbb1
 
4d5d386
 
 
 
d29dbb1
 
 
4d5d386
 
 
5e70c4c
4d5d386
d29dbb1
 
 
 
 
8282334
d29dbb1
 
 
097f8b4
 
 
 
 
 
 
ee22f2b
097f8b4
 
 
 
 
d29dbb1
 
 
4d5d386
 
 
 
 
 
 
 
 
d29dbb1
4d5d386
 
d29dbb1
 
 
4d5d386
5e70c4c
4d5d386
 
 
 
 
 
 
 
 
 
 
 
 
d29dbb1
4d5d386
 
 
 
d29dbb1
4d5d386
 
d29dbb1
4d5d386
 
 
 
 
 
d29dbb1
4d5d386
 
 
d29dbb1
4d5d386
 
 
 
8282334
 
4d5d386
 
 
 
 
 
 
 
 
8282334
 
 
 
 
 
 
 
 
 
 
 
 
 
4d5d386
 
 
 
d29dbb1
4d5d386
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d29dbb1
 
 
 
 
 
ee22f2b
d29dbb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a4a3999
ee22f2b
 
d29dbb1
 
 
 
 
 
 
 
 
 
 
 
 
ee22f2b
 
 
f64dbb3
eb74bc9
ee22f2b
 
d29dbb1
 
 
 
 
 
 
 
 
 
 
 
8282334
d29dbb1
 
8282334
 
d29dbb1
 
a4a3999
d29dbb1
a4a3999
d29dbb1
 
ee22f2b
c0d8acd
 
 
 
 
 
 
 
ee22f2b
 
 
 
 
 
1d066cb
 
 
ee22f2b
 
 
 
 
 
 
 
 
c0d8acd
 
 
 
1d066cb
d29dbb1
a4a3999
d29dbb1
4d5d386
 
5e70c4c
 
4d5d386
 
 
 
 
 
 
 
a4a3999
 
4d5d386
 
a4a3999
097f8b4
4d5d386
 
 
 
5e70c4c
 
 
 
4d5d386
 
 
 
 
 
5e70c4c
4d5d386
 
 
 
5e70c4c
 
 
 
 
 
 
4d5d386
b96054d
 
097f8b4
 
1d066cb
 
b96054d
 
 
097f8b4
b96054d
 
 
 
4d5d386
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
import gradio as gr
import os
import requests
import shutil
import uuid
from ftplib import FTP
from spandrel import ImageModelDescriptor, ModelLoader
import torch
import subprocess

# 定义 downloaded_files 变量
downloaded_files = {}

# 新增日志开关
log_to_terminal = True

# 新增全局任务计数器
task_counter = 0

# 新增日志函数
def print_log(task_id, filename, stage, status):
    if log_to_terminal:
        print(f"任务{task_id}: {filename}, [{status}] {stage}")

# 修改 start_process 函数,处理新增输入
def start_process(input1, input2, shape0_str, shape1_str, input_suffix=".pth"):
    global task_counter
    task_counter += 1
    task_id = task_counter
    print_log(task_id, input2,  input1, "input1")
    log = "转换过程非常慢,请耐心等待。显示文件列表不代表转换完成。如果未发生错误,转换结束会显示”任务完成“\n"
    yield [], log
    if input2 == None or input2.strip() == "":
        split_input = os.path.splitext(os.path.basename(input1))
        if len(split_input) > 1:
            suffix = split_input[1].split('?')[0].lower()
            if  suffix not in [".pth" , ".safetensors" , ".ckpt"]:
                print_log(task_id, input2,  "不支持此文件的格式 suffix="+suffix, "错误")
                log += f"不支持此文件的格式\n"
                return [] , log
            input2 = split_input[0]
        print_log(task_id, input2, "检查文件名", "开始")
        log += f"检查文件名…\n"
        yield [], log
        if input2 == None or input2.strip() == "":
            input2 = str(task_id)
        log += f"未提供文件名,使用{input2}\n"
        print_log(task_id, input2,  f"未提供文件名,使用{input2}", "修正")
        yield [], log

    try:
        # 判断 input1 是地址还是文件,增加对 ftp 和 webdav 协议的支持
        supported_protocols = ('http://', 'https://', 'ftp://', 'webdav://')
        if isinstance(input1, str) and input1.startswith(supported_protocols):
            url = input1
            if url in downloaded_files and os.path.exists(downloaded_files[url]):
                file_path = downloaded_files[url]
                print_log(task_id, input2, "检查下载状态", "跳过下载")
                log += f"跳过下载,文件已存在: {file_path}\n"
                yield [], log
            else:
                print_log(task_id, input2, "下载文件", "开始")
                log += f"开始下载文件…\n"
                yield [], log
                # 生成唯一文件名
                file_name = str(task_id) + input_suffix
                file_path = os.path.join(os.getcwd(), file_name)
                if url.startswith('ftp://'):
                    try:
                        # 解析 ftp 地址
                        parts = url.replace('ftp://', '').split('/')
                        host = parts[0]
                        remote_file_path = '/'.join(parts[1:])
                        ftp = FTP(host)
                        ftp.login()
                        with open(file_path, 'wb') as f:
                            ftp.retrbinary('RETR ' + remote_file_path, f.write)
                        ftp.quit()
                        downloaded_files[url] = file_path
                        print_log(task_id, input2, "下载文件", "成功")
                        log += f"文件下载成功: {file_path}\n"
                        yield [], log
                    except Exception as e:
                        print_log(task_id, input2, "下载文件", f"失败 (FTP): {str(e)}")
                        log += f"FTP 文件下载失败: {str(e)}\n"
                        yield [], log
                        return
                else:
                    if url.startswith(('http://', 'https://')):
                        response = requests.get(url)
                        if response.status_code == 200:
                            with open(file_path, 'wb') as f:
                                f.write(response.content)
                            downloaded_files[url] = file_path
                            print_log(task_id, input2, "下载文件", "成功")
                            log += f"文件下载成功: {file_path}\n"
                            yield [], log
                        else:
                            print_log(task_id, input2, f"下载文件(HTTP): {response.status_code}", "失败")
                            log += f"文件下载失败,状态码: {response.status_code}\n"
                            yield [], log
                            return
        elif input1 is not None:
            print("check file" , input1, os.path.exists(input1))
            file_path = input1
            log += f"使用上传的文件: {file_path}\n"
            print_log(task_id, input2, "使用上传文件", "开始")
            yield [], log
        else:
            log += "未提供有效文件或地址\n"
            print_log(task_id, input2, "检查文件输入", "失败 (无有效输入)")
            yield [], log
            return

        # 检查文件大小
        try:
            file_size = os.path.getsize(file_path) / 1024 /1024  # 转换为 KB
            if file_size > 100 :
                log += f"文件太大,建议 100MB 以内,当前文件大小为 {file_size } MB。\n"
                print_log(task_id, input2, "文件太大("+ file_size +"MB)", "失败")
                yield [], log
                return
        except Exception as e:
            log += f"获取文件大小失败: {str(e)}\n"
            print_log(task_id, input2, "检查文件大小", f"失败: {str(e)}")
            yield [], log
            return

        # 生成新文件夹用于暂存结果
        output_folder = os.path.join(os.getcwd(), str(uuid.uuid4()))
        os.makedirs(output_folder, exist_ok=True)
        print_log(task_id, input2, "创建临时文件夹", "完成")
        log += f"创建临时文件夹: {output_folder}\n生成张量\n"
        yield [], log
        # 解析输入的字符串为数组
        try:
            # 尝试解析 shape0_str
            shape0 = [int(x) for x in shape0_str.split(',')] if shape0_str else [0, 0, 0, 0]
            # 检查 shape0 是否为 4 个元素,如果不是则设置为全 0
            if len(shape0) != 4:
                shape0 = [0, 0, 0, 0]
            # 尝试解析 shape1_str
            shape1 = [int(x) for x in shape1_str.split(',')] if shape1_str else [0, 0, 0, 0]
            # 检查 shape1 是否为 4 个元素,如果不是则设置为全 0
            if len(shape1) != 4:
                shape1 = [0, 0, 0, 0]
        except ValueError:
            # 如果解析过程中出现 ValueError,将 shape0 和 shape1 设置为全 0
            shape0 = [0, 0, 0, 0]
            shape1 = [0, 0, 0, 0]
            log += "输入的 shape 字符串格式不正确,请使用逗号分隔的整数。\n"
            yield [], log
            return

        # 以下是 process_file 函数的代码
        # 使用 torch.rand 生成 input_shape
        print_log(task_id, input2, "生成输入张量", "开始")
        log += "生成张量…\n"
        yield [], log
        pt_path = output_folder + "/" + input2 + ".pt"
        # onnx_path = output_folder + "/" + input2 + ".onnx"
        input_tensor0 = torch.rand(shape0) if any(shape0) else None
        input_tensor1 = torch.rand(shape1) if any(shape1) else None
        if input_tensor0 is not None and input_tensor1 is not None:
            example_input = (input_tensor0, input_tensor1)
            # 修改此处,去除 shape 字符串中的空格
            command = f"pnnx {pt_path} inputshape={str(shape0).replace(' ', '')} inputshape2={str(shape1).replace(' ', '')}"
        elif input_tensor0 is not None:
            example_input = input_tensor0
            command = f"pnnx {pt_path} inputshape={str(shape0).replace(' ', '')}"
        else:
            example_input = input_tensor1
            command = f"pnnx {pt_path}"
        print_log(task_id, input2, "生成输入张量", "完成")


        # 确保 output_folder 存在
        if not os.path.exists(output_folder):
            os.makedirs(output_folder)

        print_log(task_id, input2, "加载模型", "开始")
        log += "加载模型…\n"
        yield [], log
        # load a model from disk
        model = ModelLoader().load_from_file(file_path)

        # make sure it's an image to image model
        assert isinstance(model, ImageModelDescriptor)

        print_log(task_id, input2, "获得模型对象", "开始")
        log += "获得模型对象…\n"
        yield [], log
        # send it to the GPU and put it in inference mode
        # model.cuda().eval()
        model.eval()
        torch_model = model.model
        print_log(task_id, input2, "获得模型对象", "完成")
        yield [], log


        width_ratio = 4
        if os.path.exists(pt_path):
            print_log(task_id, input2, "转换为TorchScript模型", "跳过")
            log += "跳过转换为TorchScript模型\n"
            yield [], log
        else:
            print_log(task_id, input2, "转换为TorchScript模型", "开始")
            log+= "转换为TorchScript模型…\n"
            yield [], log
            # 使用 torch.jit.trace 进行模型转换
            traced_torch_model = torch.jit.trace(torch_model, example_input)
            traced_torch_model.save(output_folder + "/" + input2 + ".pt")
            print_log(task_id, input2, "转换为TorchScript模型", "完成")

            # 获取输出
            example_output = traced_torch_model(example_input)
            width_ratio = example_output.shape[2] / example_input.shape[2]
            print_log(task_id, input2, "获得缩放倍率="+ str(width_ratio)+", 输出shape="+str(list(example_output.shape)), "完成")
            log+= ("获得缩放倍率="+str(width_ratio)+", 输出shape="+str(list(example_output.shape))+"\n")
            yield [], log

        print_log(task_id, input2, "执行命令" + command, "开始")
        log += "执行命令…\n"
        yield [], log

        try:
            # 使用 subprocess.Popen 执行命令
            process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
            while True:
                output = process.stdout.readline()
                if output == '' and process.poll() is not None:
                    break
                if output:
                    # 
                    if log_to_terminal:
                        print(output.strip())
                        log += output.strip() + '\n'
                        yield [], log
            returncode = process.poll()
            if returncode != 0:
                log += f"执行命令失败,返回码: {returncode},命令: {command} \n"
            else:
                log += f"执行命令成功: {command} \n"
        except Exception as e:
            log += f"执行命令: {command} 失败,错误信息: {str(e)}\n"
            
        # 查找 output_folder 目录下以 .ncnn.bin 和 .ncnn.param 结尾的文件
        bin_files = [f for f in os.listdir(output_folder) if f.endswith('.ncnn.bin')]
        param_files = [f for f in os.listdir(output_folder) if f.endswith('.ncnn.param')]
        

        if bin_files and param_files:
            param_file = os.path.join(output_folder, param_files[0])
            bin_file = os.path.join(output_folder, bin_files[0])
            import zipfile
            # 压缩包名称
            zip_file_name = os.path.join(output_folder, f"models-{input2}.zip")
            # 压缩包内文件夹名称
            zip_folder_name = f"models-{input2}"
            # 重命名后的文件名
            scale = int(width_ratio)
            new_bin_name = f"x{scale}.bin"
            new_param_name = f"x{scale}.param"
            # 创建压缩包
            with zipfile.ZipFile(zip_file_name, 'w', zipfile.ZIP_DEFLATED) as zipf:
                # 写入重命名后的.bin文件
                zipf.write(bin_file, os.path.join(zip_folder_name, new_bin_name))
                # 写入重命名后的.param文件
                zipf.write(param_file, os.path.join(zip_folder_name, new_param_name))
            log += f"已创建压缩包: {zip_file_name}\n"
            print_log(task_id, input2, "创建压缩包"+zip_file_name, "完成")
            yield [], log
        else:
            log += f"未找到 ncnn 文件\n"
            print_log(task_id, input2, "查找 ncnn 文件", "失败")
            yield [], log

        output_files = [os.path.join(output_folder, f) for f in os.listdir(output_folder) if os.path.isfile(os.path.join(output_folder, f))]
        log += f"任务完成\n"
        print_log(task_id, input2, "执行命令", "完成")
        yield output_files, log
    except Exception as e:
        log += f"发生错误: {e}\n"
        print_log(task_id, input2, e , f"失败")
        yield [], log

# 创建 Gradio 界面
with gr.Blocks() as demo:
    gr.Markdown("文件处理界面")
    with gr.Row():
        # 左侧列,包含输入组件和按钮
        with gr.Column():
            # 添加文本提示
            gr.Markdown("请输入的url,或者上传一个文件。限制文件为小于100M的*.pth模型")
            with gr.Row():
                input1 = gr.Textbox(label="粘贴地址")
                # 新增文件上传组件
                input1_file = gr.File(label="上传文件", file_types=[".pth", ".safetensors", ".ckpt"])
            input2 = gr.Textbox(label="自定义文件名")
            # 修改为字符串输入控件
            shape0_str = gr.Textbox(label="shape0 (逗号分隔的整数)", value="1,3,128,128")
            shape1_str = gr.Textbox(label="shape1 (逗号分隔的整数)", value="0,0,0,0")
            with gr.Row():
                start_button = gr.Button("开始")
                # 添加取消按钮
                cancel_button = gr.Button("取消")
        # 右侧列,包含输出组件和日志文本框
        with gr.Column():
            output = gr.File(label="输出文件", file_count="multiple")
            log_textbox = gr.Textbox(label="日志", lines=10, interactive=False)

    # 绑定事件,修改输入参数
    process = start_button.click(
        fn=start_process,
        inputs=[input1_file if input1_file.value else input1, input2, shape0_str, shape1_str],
        outputs=[output, log_textbox]
    )
    # 为取消按钮添加点击事件绑定,使用 cancels 属性取消 start_process 任务
    cancel_button.click(
        fn=None,
        inputs=None,
        outputs=None,
        cancels=[process]
    )

    # 添加范例
    examples = [
        ["https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth", "", "1,3,128,128", "0,0,0,0"],
        ["https://github.com/Phhofm/models/releases/download/4xNomos8kSC/4xNomos8kSC.pth", "", "1,3,128,128", "0,0,0,0"],
        ["https://github.com/Phhofm/models/releases/download/1xDeJPG/1xDeJPG_SRFormer_light.pth", "", "1,3,128,128", "0,0,0,0"],
        ["https://objectstorage.us-phoenix-1.oraclecloud.com/n/ax6ygfvpvzka/b/open-modeldb-files/o/4x-WTP-ColorDS.pth", "", "1,3,128,128", "0,0,0,0"],
    ]
    gr.Examples(
        examples=examples,
        inputs=[input1, input2, shape0_str, shape1_str],
        outputs=[output, log_textbox],
        fn=start_process
    )

demo.launch()