File size: 11,793 Bytes
b2989a2
1713e65
b2989a2
 
1713e65
b2989a2
 
e54d7aa
b2989a2
1713e65
b2989a2
1713e65
 
b2989a2
 
 
e54d7aa
b2989a2
e54d7aa
b2989a2
 
 
 
 
1713e65
b2989a2
 
 
 
 
 
1713e65
b2989a2
 
 
 
 
 
1713e65
b2989a2
 
 
 
 
 
 
 
1713e65
b2989a2
 
1713e65
b2989a2
 
1713e65
b2989a2
 
 
 
1713e65
b2989a2
 
 
 
e54d7aa
b2989a2
 
 
e54d7aa
b2989a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1713e65
 
b2989a2
 
1713e65
b2989a2
 
 
 
 
1713e65
b2989a2
 
 
1713e65
 
b2989a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83358ff
b2989a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1713e65
 
b2989a2
1713e65
b2989a2
 
 
1713e65
 
b2989a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a874f3c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
import subprocess
import sys
import os
import gradio as gr

def install_package(package):
    """安装Python包"""
    try:
        subprocess.check_call([sys.executable, "-m", "pip", "install", package])
        return True
    except subprocess.CalledProcessError:
        return False

def check_and_install_dependencies():
    """检查并安装依赖"""
    print("🔍 检查依赖库...")
    
    # 检查transformers
    try:
        import transformers
        print("✅ transformers 已安装")
        return True
    except ImportError:
        print("❌ transformers 未安装,尝试安装...")
        
        # 尝试安装transformers
        packages_to_install = [
            "transformers==4.35.2",
            "accelerate==0.24.1",
            "bitsandbytes==0.41.3"
        ]
        
        for package in packages_to_install:
            print(f"📦 安装 {package}...")
            if install_package(package):
                print(f"✅ {package} 安装成功")
            else:
                print(f"❌ {package} 安装失败")
        
        # 再次检查
        try:
            import transformers
            print("✅ transformers 现已可用")
            return True
        except ImportError:
            print("❌ transformers 安装后仍不可用")
            return False

# 检查并安装依赖
dependencies_ok = check_and_install_dependencies()

if dependencies_ok:
    # 如果依赖OK,导入所需库
    try:
        import torch
        from transformers import AutoTokenizer, AutoModel, AutoProcessor, Blip2ForConditionalGeneration
        from PIL import Image
        print("✅ 所有库导入成功")
        
        # 在这里放置你的完整应用代码
        # HF Spaces 环境检测
        IS_SPACES = os.environ.get("SPACE_ID") is not None
        print(f"Running on HF Spaces: {IS_SPACES}")

        # 设备配置
        device = "cuda" if torch.cuda.is_available() else "cpu"
        print(f"Using device: {device}")

        # 全局变量
        tokenizer = None
        model = None
        processor = None
        blip_model = None

        def load_models():
            """加载模型"""
            global tokenizer, model, processor, blip_model
            
            try:
                print("🔄 正在加载模型...")
                
                # 加载图像理解模型
                vision_model = "Salesforce/blip2-opt-2.7b"
                print(f"📷 加载图像模型: {vision_model}")
                processor = AutoProcessor.from_pretrained(vision_model)
                
                blip_model = Blip2ForConditionalGeneration.from_pretrained(
                    vision_model,
                    torch_dtype=torch.float16 if device == "cuda" else torch.float32,
                    device_map="auto" if device == "cuda" else None,
                    load_in_8bit=device == "cuda"
                )
                
                if device == "cpu":
                    blip_model = blip_model.to("cpu")
                    
                print("✅ 图像模型加载完成")
                
                # 加载对话模型
                model_name = "THUDM/chatglm2-6b-int4"
                print(f"💬 加载对话模型: {model_name}")
                
                tokenizer = AutoTokenizer.from_pretrained(
                    model_name, 
                    trust_remote_code=True
                )
                
                model = AutoModel.from_pretrained(
                    model_name,
                    trust_remote_code=True,
                    torch_dtype=torch.float16 if device == "cuda" else torch.float32,
                    low_cpu_mem_usage=True
                )
                
                if device == "cuda":
                    model = model.half().cuda()
                model.eval()
                
                print("✅ 对话模型加载完成")
                return True
                
            except Exception as e:
                print(f"❌ 模型加载失败: {str(e)}")
                return False

        def describe_image(image):
            """生成图像描述"""
            if blip_model is None or processor is None:
                return "图像模型未加载"
            
            try:
                if not isinstance(image, Image.Image):
                    image = Image.fromarray(image)
                
                if image.size[0] > 512 or image.size[1] > 512:
                    image.thumbnail((512, 512), Image.Resampling.LANCZOS)
                
                inputs = processor(image, return_tensors="pt")
                
                if device == "cuda":
                    inputs = {k: v.to(device) for k, v in inputs.items()}
                
                with torch.no_grad():
                    generated_ids = blip_model.generate(
                        **inputs,
                        max_new_tokens=30,
                        num_beams=2,
                        do_sample=False
                    )
                
                caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
                return caption
                
            except Exception as e:
                print(f"图像描述错误: {str(e)}")
                return f"图像描述生成失败"

        def on_image_upload(image):
            """处理图像上传"""
            if image is None:
                return [], []
            
            try:
                print("🖼️ 处理上传的图像...")
                history = []
                chat_history = []
                
                caption = describe_image(image)
                print(f"图像描述: {caption}")
                
                prompt = f"这是一幅艺术作品,描述为: {caption}。请用中文对这件艺术作品进行介绍和分析。"
                
                if model is not None and tokenizer is not None:
                    try:
                        with torch.no_grad():
                            response, history = model.chat(tokenizer, prompt, history=history)
                        chat_history.append([image, response])
                        print("✅ 初始分析完成")
                    except Exception as e:
                        print(f"对话生成错误: {str(e)}")
                        chat_history.append([image, f"很抱歉,分析过程中出现了错误。请重新尝试。"])
                else:
                    chat_history.append([image, "对话模型未正确加载,请刷新页面重试。"])
                
                return chat_history, history
                
            except Exception as e:
                print(f"图像处理错误: {str(e)}")
                return [[None, "图像处理失败,请重新上传。"]], []

        def on_user_message(user_message, chat_history, history):
            """处理用户消息"""
            if not user_message or not user_message.strip():
                yield chat_history or [], history or []
                return
            
            if model is None or tokenizer is None:
                chat_history = chat_history or []
                chat_history.append([user_message, "对话模型未加载,请刷新页面。"])
                yield chat_history, history or []
                return
            
            try:
                chat_history = chat_history or []
                history = history or []
                chat_history.append([user_message, ""])
                
                for output, new_history in model.stream_chat(
                    tokenizer, 
                    user_message, 
                    history,
                    max_length=2048,
                    temperature=0.7,
                    top_p=0.8
                ):
                    chat_history[-1][1] = output
                    yield chat_history, new_history
                    
            except Exception as e:
                print(f"对话错误: {str(e)}")
                if chat_history:
                    chat_history[-1][1] = "回复生成失败,请重试。"
                yield chat_history, history or []

        def clear_chat():
            """清空对话"""
            return [], []

        # 创建界面
        with gr.Blocks(title="AI艺术品讲解智能体") as demo:
            gr.HTML("""
            <div style="text-align: center; margin-bottom: 20px;">
                <h1>🎨 AI 艺术品讲解智能体</h1>
                <p>上传艺术品图像,获得专业的艺术分析和解读</p>
            </div>
            """)
            
            with gr.Row():
                with gr.Column(scale=1):
                    image_input = gr.Image(
                        label="📤 上传艺术品图像",
                        type="pil",
                        height=350
                    )
                    clear_btn = gr.Button("🗑️ 清空对话", variant="secondary")
                    
                with gr.Column(scale=2):
                    chatbot = gr.Chatbot(
                        label="🤖 AI 分析师",
                        height=500,
                        show_label=True
                    )
                    
            user_input = gr.Textbox(
                label="💬 继续提问",
                placeholder="例如:这幅作品使用了什么绘画技法?创作背景如何?",
                lines=2
            )
            
            # 状态管理
            state = gr.State([])
            
            # 事件绑定
            image_input.upload(
                fn=on_image_upload,
                inputs=image_input,
                outputs=[chatbot, state],
                show_progress=True
            )
            
            user_input.submit(
                fn=on_user_message,
                inputs=[user_input, chatbot, state],
                outputs=[chatbot, state],
                show_progress=True
            )
            
            user_input.submit(lambda: "", inputs=[], outputs=[user_input])
            clear_btn.click(fn=clear_chat, inputs=[], outputs=[chatbot, state])

        # 启动应用
        print("🚀 启动应用...")
        if load_models():
            print("✅ 启动成功")
            demo.queue(max_size=20).launch()
        else:
            print("❌ 模型加载失败,启动简化版本")
            with gr.Blocks() as simple_demo:
                gr.HTML("<h2>模型加载失败</h2><p>请等待依赖安装完成后重试</p>")
            simple_demo.launch()
            
    except Exception as e:
        print(f"❌ 导入失败: {str(e)}")
        # 创建错误页面
        with gr.Blocks() as error_demo:
            gr.HTML(f"""
            <div style="text-align: center; padding: 50px;">
                <h2>❌ 库导入失败</h2>
                <p>错误: {str(e)}</p>
                <p>正在尝试自动修复...</p>
            </div>
            """)
        error_demo.launch()

else:
    # 依赖安装失败,显示错误页面
    with gr.Blocks() as error_demo:
        gr.HTML("""
        <div style="text-align: center; padding: 50px;">
            <h2>❌ 依赖安装失败</h2>
            <p>transformers 库无法安装</p>
            <p>请尝试以下解决方案:</p>
            <ol style="text-align: left; display: inline-block;">
                <li>检查 requirements.txt 文件是否存在</li>
                <li>在 Settings 中执行 Factory reboot</li>
                <li>等待 HF Spaces 重新构建环境</li>
            </ol>
        </div>
        """)
    error_demo.launch()