Spaces:
Running
Running
import gradio as gr | |
import os | |
import requests | |
import shutil | |
import uuid | |
from ftplib import FTP | |
from spandrel import ImageModelDescriptor, ModelLoader | |
import torch | |
import subprocess | |
# 定义 downloaded_files 变量 | |
downloaded_files = {} | |
# 新增日志开关 | |
log_to_terminal = True | |
# 新增全局任务计数器 | |
task_counter = 0 | |
# 新增日志函数 | |
def print_log(task_id, filename, stage, status): | |
if log_to_terminal: | |
print(f"任务 ID: {task_id}, 文件名: {filename}, 状态: [{status}], 阶段: {stage}") | |
# 修改 start_process 函数,处理新增输入 | |
def start_process(input1, input2, shape0_str, shape1_str, input_suffix=".pth"): | |
global task_counter | |
task_counter += 1 | |
task_id = task_counter | |
log = "转换过程非常慢,请耐心等待。显示文件列表不代表转换完成。如果未发生错误,转换结束会显示”任务完成“\n" | |
yield [], log | |
if input2 == None or input2.strip() == "": | |
if isinstance(input1, str): | |
input2 = os.path.splitext(os.path.basename(input1))[0] | |
else: | |
input2 = os.path.splitext(os.path.basename(input1.name))[0] | |
if input2 == "": | |
input2 = str(task_id) | |
log += f"未提供文件名,使用{input2}\n" | |
print_log(task_id, input2, f"未提供文件名,使用{input2}", "修正") | |
yield [], log | |
input2 = "output" | |
try: | |
# 判断 input1 是地址还是文件,增加对 ftp 和 webdav 协议的支持 | |
supported_protocols = ('http://', 'https://', 'ftp://', 'webdav://') | |
if isinstance(input1, str) and input1.startswith(supported_protocols): | |
url = input1 | |
if url in downloaded_files and os.path.exists(downloaded_files[url]): | |
file_path = downloaded_files[url] | |
print_log(task_id, input2, "检查下载状态", "跳过下载") | |
log += f"跳过下载,文件已存在: {file_path}\n" | |
yield [], log | |
else: | |
print_log(task_id, input2, "下载文件", "开始") | |
log += f"开始下载文件…\n" | |
yield [], log | |
# 生成唯一文件名 | |
file_name = str(uuid.uuid4()) + input_suffix | |
file_path = os.path.join(os.getcwd(), file_name) | |
if url.startswith('ftp://'): | |
try: | |
# 解析 ftp 地址 | |
parts = url.replace('ftp://', '').split('/') | |
host = parts[0] | |
remote_file_path = '/'.join(parts[1:]) | |
ftp = FTP(host) | |
ftp.login() | |
with open(file_path, 'wb') as f: | |
ftp.retrbinary('RETR ' + remote_file_path, f.write) | |
ftp.quit() | |
downloaded_files[url] = file_path | |
print_log(task_id, input2, "下载文件", "成功") | |
log += f"文件下载成功: {file_path}\n" | |
yield [], log | |
except Exception as e: | |
print_log(task_id, input2, "下载文件", f"失败 (FTP): {str(e)}") | |
log += f"FTP 文件下载失败: {str(e)}\n" | |
yield [], log | |
return | |
else: | |
if url.startswith(('http://', 'https://')): | |
response = requests.get(url) | |
if response.status_code == 200: | |
with open(file_path, 'wb') as f: | |
f.write(response.content) | |
downloaded_files[url] = file_path | |
print_log(task_id, input2, "下载文件", "成功") | |
log += f"文件下载成功: {file_path}\n" | |
yield [], log | |
else: | |
print_log(task_id, input2, f"下载文件(HTTP): {response.status_code}", "失败") | |
log += f"文件下载失败,状态码: {response.status_code}\n" | |
yield [], log | |
return | |
elif input1 is not None: | |
file_path = input1.name | |
log += f"使用上传的文件: {file_path}\n" | |
print_log(task_id, input2, "使用上传文件", "开始") | |
yield [], log | |
else: | |
log += "未提供有效文件或地址\n" | |
print_log(task_id, input2, "检查文件输入", "失败 (无有效输入)") | |
yield [], log | |
return | |
# 生成新文件夹用于暂存结果 | |
output_folder = os.path.join(os.getcwd(), str(uuid.uuid4())) | |
os.makedirs(output_folder, exist_ok=True) | |
print_log(task_id, input2, "创建临时文件夹", "完成") | |
log += f"创建临时文件夹: {output_folder}\n生成张量\n" | |
yield [], log | |
# 解析输入的字符串为数组 | |
try: | |
# 尝试解析 shape0_str | |
shape0 = [int(x) for x in shape0_str.split(',')] if shape0_str else [0, 0, 0, 0] | |
# 检查 shape0 是否为 4 个元素,如果不是则设置为全 0 | |
if len(shape0) != 4: | |
shape0 = [0, 0, 0, 0] | |
# 尝试解析 shape1_str | |
shape1 = [int(x) for x in shape1_str.split(',')] if shape1_str else [0, 0, 0, 0] | |
# 检查 shape1 是否为 4 个元素,如果不是则设置为全 0 | |
if len(shape1) != 4: | |
shape1 = [0, 0, 0, 0] | |
except ValueError: | |
# 如果解析过程中出现 ValueError,将 shape0 和 shape1 设置为全 0 | |
shape0 = [0, 0, 0, 0] | |
shape1 = [0, 0, 0, 0] | |
log += "输入的 shape 字符串格式不正确,请使用逗号分隔的整数。\n" | |
yield [], log | |
return | |
# 以下是 process_file 函数的代码 | |
# 使用 torch.rand 生成 input_shape | |
print_log(task_id, input2, "生成输入张量", "开始") | |
log += "生成张量…\n" | |
yield [], log | |
pt_path = output_folder + "/" + input2 + ".pt" | |
input_tensor0 = torch.rand(shape0) if any(shape0) else None | |
input_tensor1 = torch.rand(shape1) if any(shape1) else None | |
if input_tensor0 is not None and input_tensor1 is not None: | |
example_input = (input_tensor0, input_tensor1) | |
# 修改此处,去除 shape 字符串中的空格 | |
command = f"pnnx {pt_path} inputshape={str(shape0).replace(' ', '')} inputshape2={str(shape1).replace(' ', '')}" | |
elif input_tensor0 is not None: | |
example_input = input_tensor0 | |
command = f"pnnx {pt_path} inputshape={str(shape0).replace(' ', '')}" | |
else: | |
example_input = input_tensor1 | |
command = f"pnnx {pt_path}" | |
print_log(task_id, input2, "生成输入张量", "完成") | |
# 确保 output_folder 存在 | |
if not os.path.exists(output_folder): | |
os.makedirs(output_folder) | |
print_log(task_id, input2, "加载模型", "开始") | |
log += "加载模型…\n" | |
yield [], log | |
# load a model from disk | |
model = ModelLoader().load_from_file(file_path) | |
# make sure it's an image to image model | |
assert isinstance(model, ImageModelDescriptor) | |
print_log(task_id, input2, "获得模型对象", "开始") | |
log += "获得模型对象…\n" | |
yield [], log | |
# send it to the GPU and put it in inference mode | |
# model.cuda().eval() | |
model.eval() | |
torch_model = model.model | |
print_log(task_id, input2, "获得模型对象", "完成") | |
yield [], log | |
if os.path.exists(pt_path): | |
print_log(task_id, input2, "转换为TorchScript模型", "跳过") | |
log += "跳过转换为TorchScript模型\n" | |
yield [], log | |
else: | |
print_log(task_id, input2, "转换为TorchScript模型", "开始") | |
log+= "转换为TorchScript模型…\n" | |
yield [], log | |
# 使用 torch.jit.trace 进行模型转换 | |
traced_torch_model = torch.jit.trace(torch_model, example_input) | |
traced_torch_model.save(output_folder + "/" + input2 + ".pt") | |
print_log(task_id, input2, "转换为TorchScript模型", "完成") | |
print_log(task_id, input2, "执行命令" + command, "开始") | |
log += "执行命令…\n" | |
yield [], log | |
try: | |
# 使用 subprocess.Popen 执行命令 | |
process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True) | |
while True: | |
output = process.stdout.readline() | |
if output == '' and process.poll() is not None: | |
break | |
if output: | |
log += output.strip() + '\n' | |
if log_to_terminal: | |
print(output.strip()) | |
returncode = process.poll() | |
if returncode != 0: | |
log += f"执行命令: {command} 失败,返回码: {returncode}\n" | |
else: | |
log += f"执行命令: {command} 成功\n" | |
except Exception as e: | |
log += f"执行命令: {command} 失败,错误信息: {str(e)}\n" | |
output_files = [os.path.join(output_folder, f) for f in os.listdir(output_folder) if os.path.isfile(os.path.join(output_folder, f))] | |
log += f"任务完成,输出文件: {output_files}\n" | |
print_log(task_id, input2, "执行命令", "完成") | |
yield output_files, log | |
except Exception as e: | |
log += f"发生错误: {str(e)}\n" | |
print_log(task_id, input2,str(e) , f"失败") | |
yield [], log | |
# 创建 Gradio 界面 | |
with gr.Blocks() as demo: | |
gr.Markdown("文件处理界面") | |
with gr.Row(): | |
# 左侧列,包含输入组件和按钮 | |
with gr.Column(): | |
with gr.Row(): | |
input1 = gr.Textbox(label="粘贴地址") | |
input1_file = gr.File(label="上传文件") | |
input2 = gr.Textbox(label="自定义文件名") | |
# 修改为字符串输入控件 | |
shape0_str = gr.Textbox(label="shape0 (逗号分隔的整数)", value="1,3,128,128") | |
shape1_str = gr.Textbox(label="shape1 (逗号分隔的整数)", value="0,0,0,0") | |
start_button = gr.Button("开始") | |
# 右侧列,包含输出组件和日志文本框 | |
with gr.Column(): | |
output = gr.File(label="输出文件", file_count="multiple") | |
log_textbox = gr.Textbox(label="日志", lines=10, interactive=False) | |
# 绑定事件,修改输入参数 | |
start_button.click( | |
fn=start_process, | |
inputs=[input1_file if input1_file.value else input1, input2, shape0_str, shape1_str], | |
outputs=[output, log_textbox] | |
) | |
demo.launch() |