import gradio as gr import os import requests import shutil import uuid from ftplib import FTP from spandrel import ImageModelDescriptor, ModelLoader import torch import subprocess # 定义 downloaded_files 变量 downloaded_files = {} # 新增日志开关 log_to_terminal = True # 新增全局任务计数器 task_counter = 0 # 新增日志函数 def print_log(task_id, filename, stage, status): if log_to_terminal: print(f"任务{task_id}: {filename}, [{status}] {stage}") # 修改 start_process 函数,处理新增输入 def start_process(input1, input2, shape0_str, shape1_str, input_suffix=".pth"): global task_counter task_counter += 1 task_id = task_counter print_log(task_id, input2, input1, "input1") log = "转换过程非常慢,请耐心等待。显示文件列表不代表转换完成。如果未发生错误,转换结束会显示”任务完成“\n" yield [], log if input2 == None or input2.strip() == "": split_input = os.path.splitext(os.path.basename(input1)) if len(split_input) > 1: suffix = split_input[1].split('?')[0].lower() if suffix not in [".pth" , ".safetensors" , ".ckpt"]: print_log(task_id, input2, "不支持此文件的格式 suffix="+suffix, "错误") log += f"不支持此文件的格式\n" return [] , log input2 = split_input[0] print_log(task_id, input2, "检查文件名", "开始") log += f"检查文件名…\n" yield [], log if input2 == None or input2.strip() == "": input2 = str(task_id) log += f"未提供文件名,使用{input2}\n" print_log(task_id, input2, f"未提供文件名,使用{input2}", "修正") yield [], log try: # 判断 input1 是地址还是文件,增加对 ftp 和 webdav 协议的支持 supported_protocols = ('http://', 'https://', 'ftp://', 'webdav://') if isinstance(input1, str) and input1.startswith(supported_protocols): url = input1 if url in downloaded_files and os.path.exists(downloaded_files[url]): file_path = downloaded_files[url] print_log(task_id, input2, "检查下载状态", "跳过下载") log += f"跳过下载,文件已存在: {file_path}\n" yield [], log else: print_log(task_id, input2, "下载文件", "开始") log += f"开始下载文件…\n" yield [], log # 生成唯一文件名 file_name = str(task_id) + input_suffix file_path = os.path.join(os.getcwd(), file_name) if url.startswith('ftp://'): try: # 解析 ftp 地址 parts = url.replace('ftp://', '').split('/') host = parts[0] remote_file_path = '/'.join(parts[1:]) ftp = FTP(host) ftp.login() with open(file_path, 'wb') as f: ftp.retrbinary('RETR ' + remote_file_path, f.write) ftp.quit() downloaded_files[url] = file_path print_log(task_id, input2, "下载文件", "成功") log += f"文件下载成功: {file_path}\n" yield [], log except Exception as e: print_log(task_id, input2, "下载文件", f"失败 (FTP): {str(e)}") log += f"FTP 文件下载失败: {str(e)}\n" yield [], log return else: if url.startswith(('http://', 'https://')): response = requests.get(url) if response.status_code == 200: with open(file_path, 'wb') as f: f.write(response.content) downloaded_files[url] = file_path print_log(task_id, input2, "下载文件", "成功") log += f"文件下载成功: {file_path}\n" yield [], log else: print_log(task_id, input2, f"下载文件(HTTP): {response.status_code}", "失败") log += f"文件下载失败,状态码: {response.status_code}\n" yield [], log return elif input1 is not None: print("check file" , input1, os.path.exists(input1)) file_path = input1 log += f"使用上传的文件: {file_path}\n" print_log(task_id, input2, "使用上传文件", "开始") yield [], log else: log += "未提供有效文件或地址\n" print_log(task_id, input2, "检查文件输入", "失败 (无有效输入)") yield [], log return # 检查文件大小 try: file_size = os.path.getsize(file_path) / 1024 /1024 # 转换为 KB if file_size > 100 : log += f"文件太大,建议 100MB 以内,当前文件大小为 {file_size } MB。\n" print_log(task_id, input2, "文件太大("+ file_size +"MB)", "失败") yield [], log return except Exception as e: log += f"获取文件大小失败: {str(e)}\n" print_log(task_id, input2, "检查文件大小", f"失败: {str(e)}") yield [], log return # 生成新文件夹用于暂存结果 output_folder = os.path.join(os.getcwd(), str(uuid.uuid4())) os.makedirs(output_folder, exist_ok=True) print_log(task_id, input2, "创建临时文件夹", "完成") log += f"创建临时文件夹: {output_folder}\n生成张量\n" yield [], log # 解析输入的字符串为数组 try: # 尝试解析 shape0_str shape0 = [int(x) for x in shape0_str.split(',')] if shape0_str else [0, 0, 0, 0] # 检查 shape0 是否为 4 个元素,如果不是则设置为全 0 if len(shape0) != 4: shape0 = [0, 0, 0, 0] # 尝试解析 shape1_str shape1 = [int(x) for x in shape1_str.split(',')] if shape1_str else [0, 0, 0, 0] # 检查 shape1 是否为 4 个元素,如果不是则设置为全 0 if len(shape1) != 4: shape1 = [0, 0, 0, 0] except ValueError: # 如果解析过程中出现 ValueError,将 shape0 和 shape1 设置为全 0 shape0 = [0, 0, 0, 0] shape1 = [0, 0, 0, 0] log += "输入的 shape 字符串格式不正确,请使用逗号分隔的整数。\n" yield [], log return # 以下是 process_file 函数的代码 # 使用 torch.rand 生成 input_shape print_log(task_id, input2, "生成输入张量", "开始") log += "生成张量…\n" yield [], log pt_path = output_folder + "/" + input2 + ".pt" # onnx_path = output_folder + "/" + input2 + ".onnx" input_tensor0 = torch.rand(shape0) if any(shape0) else None input_tensor1 = torch.rand(shape1) if any(shape1) else None if input_tensor0 is not None and input_tensor1 is not None: example_input = (input_tensor0, input_tensor1) # 修改此处,去除 shape 字符串中的空格 command = f"pnnx {pt_path} inputshape={str(shape0).replace(' ', '')} inputshape2={str(shape1).replace(' ', '')}" elif input_tensor0 is not None: example_input = input_tensor0 command = f"pnnx {pt_path} inputshape={str(shape0).replace(' ', '')}" else: example_input = input_tensor1 command = f"pnnx {pt_path}" print_log(task_id, input2, "生成输入张量", "完成") # 确保 output_folder 存在 if not os.path.exists(output_folder): os.makedirs(output_folder) print_log(task_id, input2, "加载模型", "开始") log += "加载模型…\n" yield [], log # load a model from disk model = ModelLoader().load_from_file(file_path) # make sure it's an image to image model assert isinstance(model, ImageModelDescriptor) print_log(task_id, input2, "获得模型对象", "开始") log += "获得模型对象…\n" yield [], log # send it to the GPU and put it in inference mode # model.cuda().eval() model.eval() torch_model = model.model print_log(task_id, input2, "获得模型对象", "完成") yield [], log width_ratio = 4 if os.path.exists(pt_path): print_log(task_id, input2, "转换为TorchScript模型", "跳过") log += "跳过转换为TorchScript模型\n" yield [], log else: print_log(task_id, input2, "转换为TorchScript模型", "开始") log+= "转换为TorchScript模型…\n" yield [], log # 使用 torch.jit.trace 进行模型转换 traced_torch_model = torch.jit.trace(torch_model, example_input) traced_torch_model.save(output_folder + "/" + input2 + ".pt") print_log(task_id, input2, "转换为TorchScript模型", "完成") # 获取输出 example_output = traced_torch_model(example_input) width_ratio = example_output.shape[2] / example_input.shape[2] print_log(task_id, input2, "获得缩放倍率="+ str(width_ratio)+", 输出shape="+str(list(example_output.shape)), "完成") log+= ("获得缩放倍率="+str(width_ratio)+", 输出shape="+str(list(example_output.shape))+"\n") yield [], log print_log(task_id, input2, "执行命令" + command, "开始") log += "执行命令…\n" yield [], log try: # 使用 subprocess.Popen 执行命令 process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True) while True: output = process.stdout.readline() if output == '' and process.poll() is not None: break if output: # if log_to_terminal: print(output.strip()) log += output.strip() + '\n' yield [], log returncode = process.poll() if returncode != 0: log += f"执行命令失败,返回码: {returncode},命令: {command} \n" else: log += f"执行命令成功: {command} \n" except Exception as e: log += f"执行命令: {command} 失败,错误信息: {str(e)}\n" # 查找 output_folder 目录下以 .ncnn.bin 和 .ncnn.param 结尾的文件 bin_files = [f for f in os.listdir(output_folder) if f.endswith('.ncnn.bin')] param_files = [f for f in os.listdir(output_folder) if f.endswith('.ncnn.param')] if bin_files and param_files: param_file = os.path.join(output_folder, param_files[0]) bin_file = os.path.join(output_folder, bin_files[0]) import zipfile # 压缩包名称 zip_file_name = os.path.join(output_folder, f"models-{input2}.zip") # 压缩包内文件夹名称 zip_folder_name = f"models-{input2}" # 重命名后的文件名 scale = int(width_ratio) new_bin_name = f"x{scale}.bin" new_param_name = f"x{scale}.param" # 创建压缩包 with zipfile.ZipFile(zip_file_name, 'w', zipfile.ZIP_DEFLATED) as zipf: # 写入重命名后的.bin文件 zipf.write(bin_file, os.path.join(zip_folder_name, new_bin_name)) # 写入重命名后的.param文件 zipf.write(param_file, os.path.join(zip_folder_name, new_param_name)) log += f"已创建压缩包: {zip_file_name}\n" print_log(task_id, input2, "创建压缩包"+zip_file_name, "完成") yield [], log else: log += f"未找到 ncnn 文件\n" print_log(task_id, input2, "查找 ncnn 文件", "失败") yield [], log output_files = [os.path.join(output_folder, f) for f in os.listdir(output_folder) if os.path.isfile(os.path.join(output_folder, f))] log += f"任务完成\n" print_log(task_id, input2, "执行命令", "完成") yield output_files, log except Exception as e: log += f"发生错误: {e}\n" print_log(task_id, input2, e , f"失败") yield [], log # 创建 Gradio 界面 with gr.Blocks() as demo: gr.Markdown("文件处理界面") with gr.Row(): # 左侧列,包含输入组件和按钮 with gr.Column(): # 添加文本提示 gr.Markdown("请输入的url,或者上传一个文件。限制文件为小于100M的*.pth模型") with gr.Row(): input1 = gr.Textbox(label="粘贴地址") # 新增文件上传组件 input1_file = gr.File(label="上传文件", file_types=[".pth", ".safetensors", ".ckpt"]) input2 = gr.Textbox(label="自定义文件名") # 修改为字符串输入控件 shape0_str = gr.Textbox(label="shape0 (逗号分隔的整数)", value="1,3,128,128") shape1_str = gr.Textbox(label="shape1 (逗号分隔的整数)", value="0,0,0,0") with gr.Row(): start_button = gr.Button("开始") # 添加取消按钮 cancel_button = gr.Button("取消") # 右侧列,包含输出组件和日志文本框 with gr.Column(): output = gr.File(label="输出文件", file_count="multiple") log_textbox = gr.Textbox(label="日志", lines=10, interactive=False) # 绑定事件,修改输入参数 process = start_button.click( fn=start_process, inputs=[input1_file if input1_file.value else input1, input2, shape0_str, shape1_str], outputs=[output, log_textbox] ) # 为取消按钮添加点击事件绑定,使用 cancels 属性取消 start_process 任务 cancel_button.click( fn=None, inputs=None, outputs=None, cancels=[process] ) # 添加范例 examples = [ ["https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth", "", "1,3,128,128", "0,0,0,0"], ["https://github.com/Phhofm/models/releases/download/4xNomos8kSC/4xNomos8kSC.pth", "", "1,3,128,128", "0,0,0,0"], ["https://github.com/Phhofm/models/releases/download/1xDeJPG/1xDeJPG_SRFormer_light.pth", "", "1,3,128,128", "0,0,0,0"], ["https://objectstorage.us-phoenix-1.oraclecloud.com/n/ax6ygfvpvzka/b/open-modeldb-files/o/4x-WTP-ColorDS.pth", "", "1,3,128,128", "0,0,0,0"], ] gr.Examples( examples=examples, inputs=[input1, input2, shape0_str, shape1_str], outputs=[output, log_textbox], fn=start_process ) demo.launch()