model-convert / app.py
tumuyan2's picture
update App
d29dbb1
raw
history blame
11.1 kB
import gradio as gr
import os
import requests
import shutil
import uuid
from ftplib import FTP
from spandrel import ImageModelDescriptor, ModelLoader
import torch
import subprocess
# 定义 downloaded_files 变量
downloaded_files = {}
# 新增日志开关
log_to_terminal = True
# 新增全局任务计数器
task_counter = 0
# 新增日志函数
def print_log(task_id, filename, stage, status):
if log_to_terminal:
print(f"任务 ID: {task_id}, 文件名: {filename}, 状态: [{status}], 阶段: {stage}")
# 修改 start_process 函数,处理新增输入
def start_process(input1, input2, shape0_str, shape1_str, input_suffix=".pth"):
global task_counter
task_counter += 1
task_id = task_counter
log = "转换过程非常慢,请耐心等待。显示文件列表不代表转换完成。如果未发生错误,转换结束会显示”任务完成“\n"
yield [], log
if input2 == None or input2.strip() == "":
if isinstance(input1, str):
input2 = os.path.splitext(os.path.basename(input1))[0]
else:
input2 = os.path.splitext(os.path.basename(input1.name))[0]
if input2 == "":
input2 = str(task_id)
log += f"未提供文件名,使用{input2}\n"
print_log(task_id, input2, f"未提供文件名,使用{input2}", "修正")
yield [], log
input2 = "output"
try:
# 判断 input1 是地址还是文件,增加对 ftp 和 webdav 协议的支持
supported_protocols = ('http://', 'https://', 'ftp://', 'webdav://')
if isinstance(input1, str) and input1.startswith(supported_protocols):
url = input1
if url in downloaded_files and os.path.exists(downloaded_files[url]):
file_path = downloaded_files[url]
print_log(task_id, input2, "检查下载状态", "跳过下载")
log += f"跳过下载,文件已存在: {file_path}\n"
yield [], log
else:
print_log(task_id, input2, "下载文件", "开始")
log += f"开始下载文件…\n"
yield [], log
# 生成唯一文件名
file_name = str(uuid.uuid4()) + input_suffix
file_path = os.path.join(os.getcwd(), file_name)
if url.startswith('ftp://'):
try:
# 解析 ftp 地址
parts = url.replace('ftp://', '').split('/')
host = parts[0]
remote_file_path = '/'.join(parts[1:])
ftp = FTP(host)
ftp.login()
with open(file_path, 'wb') as f:
ftp.retrbinary('RETR ' + remote_file_path, f.write)
ftp.quit()
downloaded_files[url] = file_path
print_log(task_id, input2, "下载文件", "成功")
log += f"文件下载成功: {file_path}\n"
yield [], log
except Exception as e:
print_log(task_id, input2, "下载文件", f"失败 (FTP): {str(e)}")
log += f"FTP 文件下载失败: {str(e)}\n"
yield [], log
return
else:
if url.startswith(('http://', 'https://')):
response = requests.get(url)
if response.status_code == 200:
with open(file_path, 'wb') as f:
f.write(response.content)
downloaded_files[url] = file_path
print_log(task_id, input2, "下载文件", "成功")
log += f"文件下载成功: {file_path}\n"
yield [], log
else:
print_log(task_id, input2, f"下载文件(HTTP): {response.status_code}", "失败")
log += f"文件下载失败,状态码: {response.status_code}\n"
yield [], log
return
elif input1 is not None:
file_path = input1.name
log += f"使用上传的文件: {file_path}\n"
print_log(task_id, input2, "使用上传文件", "开始")
yield [], log
else:
log += "未提供有效文件或地址\n"
print_log(task_id, input2, "检查文件输入", "失败 (无有效输入)")
yield [], log
return
# 生成新文件夹用于暂存结果
output_folder = os.path.join(os.getcwd(), str(uuid.uuid4()))
os.makedirs(output_folder, exist_ok=True)
print_log(task_id, input2, "创建临时文件夹", "完成")
log += f"创建临时文件夹: {output_folder}\n生成张量\n"
yield [], log
# 解析输入的字符串为数组
try:
# 尝试解析 shape0_str
shape0 = [int(x) for x in shape0_str.split(',')] if shape0_str else [0, 0, 0, 0]
# 检查 shape0 是否为 4 个元素,如果不是则设置为全 0
if len(shape0) != 4:
shape0 = [0, 0, 0, 0]
# 尝试解析 shape1_str
shape1 = [int(x) for x in shape1_str.split(',')] if shape1_str else [0, 0, 0, 0]
# 检查 shape1 是否为 4 个元素,如果不是则设置为全 0
if len(shape1) != 4:
shape1 = [0, 0, 0, 0]
except ValueError:
# 如果解析过程中出现 ValueError,将 shape0 和 shape1 设置为全 0
shape0 = [0, 0, 0, 0]
shape1 = [0, 0, 0, 0]
log += "输入的 shape 字符串格式不正确,请使用逗号分隔的整数。\n"
yield [], log
return
# 以下是 process_file 函数的代码
# 使用 torch.rand 生成 input_shape
print_log(task_id, input2, "生成输入张量", "开始")
log += "生成张量…\n"
yield [], log
pt_path = output_folder + "/" + input2 + ".pt"
input_tensor0 = torch.rand(shape0) if any(shape0) else None
input_tensor1 = torch.rand(shape1) if any(shape1) else None
if input_tensor0 is not None and input_tensor1 is not None:
example_input = (input_tensor0, input_tensor1)
# 修改此处,去除 shape 字符串中的空格
command = f"pnnx {pt_path} inputshape={str(shape0).replace(' ', '')} inputshape2={str(shape1).replace(' ', '')}"
elif input_tensor0 is not None:
example_input = input_tensor0
command = f"pnnx {pt_path} inputshape={str(shape0).replace(' ', '')}"
else:
example_input = input_tensor1
command = f"pnnx {pt_path}"
print_log(task_id, input2, "生成输入张量", "完成")
# 确保 output_folder 存在
if not os.path.exists(output_folder):
os.makedirs(output_folder)
print_log(task_id, input2, "加载模型", "开始")
log += "加载模型…\n"
yield [], log
# load a model from disk
model = ModelLoader().load_from_file(file_path)
# make sure it's an image to image model
assert isinstance(model, ImageModelDescriptor)
print_log(task_id, input2, "获得模型对象", "开始")
log += "获得模型对象…\n"
yield [], log
# send it to the GPU and put it in inference mode
# model.cuda().eval()
model.eval()
torch_model = model.model
print_log(task_id, input2, "获得模型对象", "完成")
yield [], log
if os.path.exists(pt_path):
print_log(task_id, input2, "转换为TorchScript模型", "跳过")
log += "跳过转换为TorchScript模型\n"
yield [], log
else:
print_log(task_id, input2, "转换为TorchScript模型", "开始")
log+= "转换为TorchScript模型…\n"
yield [], log
# 使用 torch.jit.trace 进行模型转换
traced_torch_model = torch.jit.trace(torch_model, example_input)
traced_torch_model.save(output_folder + "/" + input2 + ".pt")
print_log(task_id, input2, "转换为TorchScript模型", "完成")
print_log(task_id, input2, "执行命令" + command, "开始")
log += "执行命令…\n"
yield [], log
try:
# 使用 subprocess.Popen 执行命令
process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
while True:
output = process.stdout.readline()
if output == '' and process.poll() is not None:
break
if output:
log += output.strip() + '\n'
if log_to_terminal:
print(output.strip())
returncode = process.poll()
if returncode != 0:
log += f"执行命令: {command} 失败,返回码: {returncode}\n"
else:
log += f"执行命令: {command} 成功\n"
except Exception as e:
log += f"执行命令: {command} 失败,错误信息: {str(e)}\n"
output_files = [os.path.join(output_folder, f) for f in os.listdir(output_folder) if os.path.isfile(os.path.join(output_folder, f))]
log += f"任务完成,输出文件: {output_files}\n"
print_log(task_id, input2, "执行命令", "完成")
yield output_files, log
except Exception as e:
log += f"发生错误: {str(e)}\n"
print_log(task_id, input2,str(e) , f"失败")
yield [], log
# 创建 Gradio 界面
with gr.Blocks() as demo:
gr.Markdown("文件处理界面")
with gr.Row():
# 左侧列,包含输入组件和按钮
with gr.Column():
with gr.Row():
input1 = gr.Textbox(label="粘贴地址")
input1_file = gr.File(label="上传文件")
input2 = gr.Textbox(label="自定义文件名")
# 修改为字符串输入控件
shape0_str = gr.Textbox(label="shape0 (逗号分隔的整数)", value="1,3,128,128")
shape1_str = gr.Textbox(label="shape1 (逗号分隔的整数)", value="0,0,0,0")
start_button = gr.Button("开始")
# 右侧列,包含输出组件和日志文本框
with gr.Column():
output = gr.File(label="输出文件", file_count="multiple")
log_textbox = gr.Textbox(label="日志", lines=10, interactive=False)
# 绑定事件,修改输入参数
start_button.click(
fn=start_process,
inputs=[input1_file if input1_file.value else input1, input2, shape0_str, shape1_str],
outputs=[output, log_textbox]
)
demo.launch()