Spaces:
Sleeping
Sleeping
import subprocess | |
import sys | |
import os | |
import gradio as gr | |
def install_package(package): | |
"""安装Python包""" | |
try: | |
subprocess.check_call([sys.executable, "-m", "pip", "install", package]) | |
return True | |
except subprocess.CalledProcessError: | |
return False | |
def check_and_install_dependencies(): | |
"""检查并安装依赖""" | |
print("🔍 检查依赖库...") | |
# 检查transformers | |
try: | |
import transformers | |
print("✅ transformers 已安装") | |
return True | |
except ImportError: | |
print("❌ transformers 未安装,尝试安装...") | |
# 尝试安装transformers | |
packages_to_install = [ | |
"transformers==4.35.2", | |
"accelerate==0.24.1", | |
"bitsandbytes==0.41.3" | |
] | |
for package in packages_to_install: | |
print(f"📦 安装 {package}...") | |
if install_package(package): | |
print(f"✅ {package} 安装成功") | |
else: | |
print(f"❌ {package} 安装失败") | |
# 再次检查 | |
try: | |
import transformers | |
print("✅ transformers 现已可用") | |
return True | |
except ImportError: | |
print("❌ transformers 安装后仍不可用") | |
return False | |
# 检查并安装依赖 | |
dependencies_ok = check_and_install_dependencies() | |
if dependencies_ok: | |
# 如果依赖OK,导入所需库 | |
try: | |
import torch | |
from transformers import AutoTokenizer, AutoModel, AutoProcessor, Blip2ForConditionalGeneration | |
from PIL import Image | |
print("✅ 所有库导入成功") | |
# 在这里放置你的完整应用代码 | |
# HF Spaces 环境检测 | |
IS_SPACES = os.environ.get("SPACE_ID") is not None | |
print(f"Running on HF Spaces: {IS_SPACES}") | |
# 设备配置 | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Using device: {device}") | |
# 全局变量 | |
tokenizer = None | |
model = None | |
processor = None | |
blip_model = None | |
def load_models(): | |
"""加载模型""" | |
global tokenizer, model, processor, blip_model | |
try: | |
print("🔄 正在加载模型...") | |
# 加载图像理解模型 | |
vision_model = "Salesforce/blip2-opt-2.7b" | |
print(f"📷 加载图像模型: {vision_model}") | |
processor = AutoProcessor.from_pretrained(vision_model) | |
blip_model = Blip2ForConditionalGeneration.from_pretrained( | |
vision_model, | |
torch_dtype=torch.float16 if device == "cuda" else torch.float32, | |
device_map="auto" if device == "cuda" else None, | |
load_in_8bit=device == "cuda" | |
) | |
if device == "cpu": | |
blip_model = blip_model.to("cpu") | |
print("✅ 图像模型加载完成") | |
# 加载对话模型 | |
model_name = "THUDM/chatglm2-6b-int4" | |
print(f"💬 加载对话模型: {model_name}") | |
tokenizer = AutoTokenizer.from_pretrained( | |
model_name, | |
trust_remote_code=True | |
) | |
model = AutoModel.from_pretrained( | |
model_name, | |
trust_remote_code=True, | |
torch_dtype=torch.float16 if device == "cuda" else torch.float32, | |
low_cpu_mem_usage=True | |
) | |
if device == "cuda": | |
model = model.half().cuda() | |
model.eval() | |
print("✅ 对话模型加载完成") | |
return True | |
except Exception as e: | |
print(f"❌ 模型加载失败: {str(e)}") | |
return False | |
def describe_image(image): | |
"""生成图像描述""" | |
if blip_model is None or processor is None: | |
return "图像模型未加载" | |
try: | |
if not isinstance(image, Image.Image): | |
image = Image.fromarray(image) | |
if image.size[0] > 512 or image.size[1] > 512: | |
image.thumbnail((512, 512), Image.Resampling.LANCZOS) | |
inputs = processor(image, return_tensors="pt") | |
if device == "cuda": | |
inputs = {k: v.to(device) for k, v in inputs.items()} | |
with torch.no_grad(): | |
generated_ids = blip_model.generate( | |
**inputs, | |
max_new_tokens=30, | |
num_beams=2, | |
do_sample=False | |
) | |
caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() | |
return caption | |
except Exception as e: | |
print(f"图像描述错误: {str(e)}") | |
return f"图像描述生成失败" | |
def on_image_upload(image): | |
"""处理图像上传""" | |
if image is None: | |
return [], [] | |
try: | |
print("🖼️ 处理上传的图像...") | |
history = [] | |
chat_history = [] | |
caption = describe_image(image) | |
print(f"图像描述: {caption}") | |
prompt = f"这是一幅艺术作品,描述为: {caption}。请用中文对这件艺术作品进行介绍和分析。" | |
if model is not None and tokenizer is not None: | |
try: | |
with torch.no_grad(): | |
response, history = model.chat(tokenizer, prompt, history=history) | |
chat_history.append([image, response]) | |
print("✅ 初始分析完成") | |
except Exception as e: | |
print(f"对话生成错误: {str(e)}") | |
chat_history.append([image, f"很抱歉,分析过程中出现了错误。请重新尝试。"]) | |
else: | |
chat_history.append([image, "对话模型未正确加载,请刷新页面重试。"]) | |
return chat_history, history | |
except Exception as e: | |
print(f"图像处理错误: {str(e)}") | |
return [[None, "图像处理失败,请重新上传。"]], [] | |
def on_user_message(user_message, chat_history, history): | |
"""处理用户消息""" | |
if not user_message or not user_message.strip(): | |
yield chat_history or [], history or [] | |
return | |
if model is None or tokenizer is None: | |
chat_history = chat_history or [] | |
chat_history.append([user_message, "对话模型未加载,请刷新页面。"]) | |
yield chat_history, history or [] | |
return | |
try: | |
chat_history = chat_history or [] | |
history = history or [] | |
chat_history.append([user_message, ""]) | |
for output, new_history in model.stream_chat( | |
tokenizer, | |
user_message, | |
history, | |
max_length=2048, | |
temperature=0.7, | |
top_p=0.8 | |
): | |
chat_history[-1][1] = output | |
yield chat_history, new_history | |
except Exception as e: | |
print(f"对话错误: {str(e)}") | |
if chat_history: | |
chat_history[-1][1] = "回复生成失败,请重试。" | |
yield chat_history, history or [] | |
def clear_chat(): | |
"""清空对话""" | |
return [], [] | |
# 创建界面 | |
with gr.Blocks(title="AI艺术品讲解智能体") as demo: | |
gr.HTML(""" | |
<div style="text-align: center; margin-bottom: 20px;"> | |
<h1>🎨 AI 艺术品讲解智能体</h1> | |
<p>上传艺术品图像,获得专业的艺术分析和解读</p> | |
</div> | |
""") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
image_input = gr.Image( | |
label="📤 上传艺术品图像", | |
type="pil", | |
height=350 | |
) | |
clear_btn = gr.Button("🗑️ 清空对话", variant="secondary") | |
with gr.Column(scale=2): | |
chatbot = gr.Chatbot( | |
label="🤖 AI 分析师", | |
height=500, | |
show_label=True | |
) | |
user_input = gr.Textbox( | |
label="💬 继续提问", | |
placeholder="例如:这幅作品使用了什么绘画技法?创作背景如何?", | |
lines=2 | |
) | |
# 状态管理 | |
state = gr.State([]) | |
# 事件绑定 | |
image_input.upload( | |
fn=on_image_upload, | |
inputs=image_input, | |
outputs=[chatbot, state], | |
show_progress=True | |
) | |
user_input.submit( | |
fn=on_user_message, | |
inputs=[user_input, chatbot, state], | |
outputs=[chatbot, state], | |
show_progress=True | |
) | |
user_input.submit(lambda: "", inputs=[], outputs=[user_input]) | |
clear_btn.click(fn=clear_chat, inputs=[], outputs=[chatbot, state]) | |
# 启动应用 | |
print("🚀 启动应用...") | |
if load_models(): | |
print("✅ 启动成功") | |
demo.queue(max_size=20).launch() | |
else: | |
print("❌ 模型加载失败,启动简化版本") | |
with gr.Blocks() as simple_demo: | |
gr.HTML("<h2>模型加载失败</h2><p>请等待依赖安装完成后重试</p>") | |
simple_demo.launch() | |
except Exception as e: | |
print(f"❌ 导入失败: {str(e)}") | |
# 创建错误页面 | |
with gr.Blocks() as error_demo: | |
gr.HTML(f""" | |
<div style="text-align: center; padding: 50px;"> | |
<h2>❌ 库导入失败</h2> | |
<p>错误: {str(e)}</p> | |
<p>正在尝试自动修复...</p> | |
</div> | |
""") | |
error_demo.launch() | |
else: | |
# 依赖安装失败,显示错误页面 | |
with gr.Blocks() as error_demo: | |
gr.HTML(""" | |
<div style="text-align: center; padding: 50px;"> | |
<h2>❌ 依赖安装失败</h2> | |
<p>transformers 库无法安装</p> | |
<p>请尝试以下解决方案:</p> | |
<ol style="text-align: left; display: inline-block;"> | |
<li>检查 requirements.txt 文件是否存在</li> | |
<li>在 Settings 中执行 Factory reboot</li> | |
<li>等待 HF Spaces 重新构建环境</li> | |
</ol> | |
</div> | |
""") | |
error_demo.launch() | |