gradio-test / app.py
chenglu's picture
1
a874f3c
raw
history blame
11.8 kB
import subprocess
import sys
import os
import gradio as gr
def install_package(package):
"""安装Python包"""
try:
subprocess.check_call([sys.executable, "-m", "pip", "install", package])
return True
except subprocess.CalledProcessError:
return False
def check_and_install_dependencies():
"""检查并安装依赖"""
print("🔍 检查依赖库...")
# 检查transformers
try:
import transformers
print("✅ transformers 已安装")
return True
except ImportError:
print("❌ transformers 未安装,尝试安装...")
# 尝试安装transformers
packages_to_install = [
"transformers==4.35.2",
"accelerate==0.24.1",
"bitsandbytes==0.41.3"
]
for package in packages_to_install:
print(f"📦 安装 {package}...")
if install_package(package):
print(f"✅ {package} 安装成功")
else:
print(f"❌ {package} 安装失败")
# 再次检查
try:
import transformers
print("✅ transformers 现已可用")
return True
except ImportError:
print("❌ transformers 安装后仍不可用")
return False
# 检查并安装依赖
dependencies_ok = check_and_install_dependencies()
if dependencies_ok:
# 如果依赖OK,导入所需库
try:
import torch
from transformers import AutoTokenizer, AutoModel, AutoProcessor, Blip2ForConditionalGeneration
from PIL import Image
print("✅ 所有库导入成功")
# 在这里放置你的完整应用代码
# HF Spaces 环境检测
IS_SPACES = os.environ.get("SPACE_ID") is not None
print(f"Running on HF Spaces: {IS_SPACES}")
# 设备配置
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# 全局变量
tokenizer = None
model = None
processor = None
blip_model = None
def load_models():
"""加载模型"""
global tokenizer, model, processor, blip_model
try:
print("🔄 正在加载模型...")
# 加载图像理解模型
vision_model = "Salesforce/blip2-opt-2.7b"
print(f"📷 加载图像模型: {vision_model}")
processor = AutoProcessor.from_pretrained(vision_model)
blip_model = Blip2ForConditionalGeneration.from_pretrained(
vision_model,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
device_map="auto" if device == "cuda" else None,
load_in_8bit=device == "cuda"
)
if device == "cpu":
blip_model = blip_model.to("cpu")
print("✅ 图像模型加载完成")
# 加载对话模型
model_name = "THUDM/chatglm2-6b-int4"
print(f"💬 加载对话模型: {model_name}")
tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=True
)
model = AutoModel.from_pretrained(
model_name,
trust_remote_code=True,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
low_cpu_mem_usage=True
)
if device == "cuda":
model = model.half().cuda()
model.eval()
print("✅ 对话模型加载完成")
return True
except Exception as e:
print(f"❌ 模型加载失败: {str(e)}")
return False
def describe_image(image):
"""生成图像描述"""
if blip_model is None or processor is None:
return "图像模型未加载"
try:
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
if image.size[0] > 512 or image.size[1] > 512:
image.thumbnail((512, 512), Image.Resampling.LANCZOS)
inputs = processor(image, return_tensors="pt")
if device == "cuda":
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
generated_ids = blip_model.generate(
**inputs,
max_new_tokens=30,
num_beams=2,
do_sample=False
)
caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
return caption
except Exception as e:
print(f"图像描述错误: {str(e)}")
return f"图像描述生成失败"
def on_image_upload(image):
"""处理图像上传"""
if image is None:
return [], []
try:
print("🖼️ 处理上传的图像...")
history = []
chat_history = []
caption = describe_image(image)
print(f"图像描述: {caption}")
prompt = f"这是一幅艺术作品,描述为: {caption}。请用中文对这件艺术作品进行介绍和分析。"
if model is not None and tokenizer is not None:
try:
with torch.no_grad():
response, history = model.chat(tokenizer, prompt, history=history)
chat_history.append([image, response])
print("✅ 初始分析完成")
except Exception as e:
print(f"对话生成错误: {str(e)}")
chat_history.append([image, f"很抱歉,分析过程中出现了错误。请重新尝试。"])
else:
chat_history.append([image, "对话模型未正确加载,请刷新页面重试。"])
return chat_history, history
except Exception as e:
print(f"图像处理错误: {str(e)}")
return [[None, "图像处理失败,请重新上传。"]], []
def on_user_message(user_message, chat_history, history):
"""处理用户消息"""
if not user_message or not user_message.strip():
yield chat_history or [], history or []
return
if model is None or tokenizer is None:
chat_history = chat_history or []
chat_history.append([user_message, "对话模型未加载,请刷新页面。"])
yield chat_history, history or []
return
try:
chat_history = chat_history or []
history = history or []
chat_history.append([user_message, ""])
for output, new_history in model.stream_chat(
tokenizer,
user_message,
history,
max_length=2048,
temperature=0.7,
top_p=0.8
):
chat_history[-1][1] = output
yield chat_history, new_history
except Exception as e:
print(f"对话错误: {str(e)}")
if chat_history:
chat_history[-1][1] = "回复生成失败,请重试。"
yield chat_history, history or []
def clear_chat():
"""清空对话"""
return [], []
# 创建界面
with gr.Blocks(title="AI艺术品讲解智能体") as demo:
gr.HTML("""
<div style="text-align: center; margin-bottom: 20px;">
<h1>🎨 AI 艺术品讲解智能体</h1>
<p>上传艺术品图像,获得专业的艺术分析和解读</p>
</div>
""")
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(
label="📤 上传艺术品图像",
type="pil",
height=350
)
clear_btn = gr.Button("🗑️ 清空对话", variant="secondary")
with gr.Column(scale=2):
chatbot = gr.Chatbot(
label="🤖 AI 分析师",
height=500,
show_label=True
)
user_input = gr.Textbox(
label="💬 继续提问",
placeholder="例如:这幅作品使用了什么绘画技法?创作背景如何?",
lines=2
)
# 状态管理
state = gr.State([])
# 事件绑定
image_input.upload(
fn=on_image_upload,
inputs=image_input,
outputs=[chatbot, state],
show_progress=True
)
user_input.submit(
fn=on_user_message,
inputs=[user_input, chatbot, state],
outputs=[chatbot, state],
show_progress=True
)
user_input.submit(lambda: "", inputs=[], outputs=[user_input])
clear_btn.click(fn=clear_chat, inputs=[], outputs=[chatbot, state])
# 启动应用
print("🚀 启动应用...")
if load_models():
print("✅ 启动成功")
demo.queue(max_size=20).launch()
else:
print("❌ 模型加载失败,启动简化版本")
with gr.Blocks() as simple_demo:
gr.HTML("<h2>模型加载失败</h2><p>请等待依赖安装完成后重试</p>")
simple_demo.launch()
except Exception as e:
print(f"❌ 导入失败: {str(e)}")
# 创建错误页面
with gr.Blocks() as error_demo:
gr.HTML(f"""
<div style="text-align: center; padding: 50px;">
<h2>❌ 库导入失败</h2>
<p>错误: {str(e)}</p>
<p>正在尝试自动修复...</p>
</div>
""")
error_demo.launch()
else:
# 依赖安装失败,显示错误页面
with gr.Blocks() as error_demo:
gr.HTML("""
<div style="text-align: center; padding: 50px;">
<h2>❌ 依赖安装失败</h2>
<p>transformers 库无法安装</p>
<p>请尝试以下解决方案:</p>
<ol style="text-align: left; display: inline-block;">
<li>检查 requirements.txt 文件是否存在</li>
<li>在 Settings 中执行 Factory reboot</li>
<li>等待 HF Spaces 重新构建环境</li>
</ol>
</div>
""")
error_demo.launch()