chenglu commited on
Commit
e54d7aa
·
1 Parent(s): f5ccf14

transformers

Browse files
Files changed (1) hide show
  1. app.py +242 -51
app.py CHANGED
@@ -1,55 +1,246 @@
1
- import torch, gradio as gr
2
- from transformers import (
3
- AutoTokenizer, AutoModel,
4
- AutoProcessor, Blip2ForConditionalGeneration
5
- )
6
-
7
- # --------模型加载--------
8
- chat_model_name = "THUDM/chatglm2-6b-int4"
9
- vision_model_name = "Salesforce/blip2-opt-2.7b"
10
-
11
- tokenizer = AutoTokenizer.from_pretrained(chat_model_name, trust_remote_code=True)
12
- chat_model = AutoModel.from_pretrained(chat_model_name, trust_remote_code=True).eval()
13
  if torch.cuda.is_available():
14
- chat_model = chat_model.half().cuda()
15
-
16
- processor = AutoProcessor.from_pretrained(vision_model_name)
17
- vision_model = Blip2ForConditionalGeneration.from_pretrained(
18
- vision_model_name,
19
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
20
- device_map="auto" if torch.cuda.is_available() else None,
21
- )
22
- if not torch.cuda.is_available():
23
- vision_model = vision_model.to("cpu")
24
-
25
- # --------工具函数--------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  def describe_image(image):
27
- inputs = processor(image, return_tensors="pt").to(vision_model.device)
28
- ids = vision_model.generate(**inputs, max_new_tokens=50)
29
- return processor.batch_decode(ids, skip_special_tokens=True)[0].strip()
30
-
31
- def on_image(img):
32
- caption = describe_image(img)
33
- sys_prompt = f"这是一幅艺术作品图像: {caption}。请为普通观众做简介。"
34
- answer, hist = chat_model.chat(tokenizer, sys_prompt, history=[])
35
- return [[img, answer]], hist
36
-
37
- def on_chat(msg, chat_hist, hist):
38
- chat_hist = chat_hist or []
39
- chat_hist.append([msg, ""])
40
- for out, h in chat_model.stream_chat(tokenizer, msg, history=hist):
41
- chat_hist[-1][1] = out
42
- yield chat_hist, h
43
-
44
- # --------Gradio 界面--------
45
- with gr.Blocks() as demo:
46
- gr.Markdown("# AI 艺术品讲解智能体")
47
- image = gr.Image(type="pil", label="上传艺术品")
48
- chatbot = gr.Chatbot()
49
- txt = gr.Textbox(label="提问")
50
- state = gr.State()
51
- image.upload(on_image, image, [chatbot, state])
52
- txt.submit(on_chat, [txt, chatbot, state], [chatbot, state]).then(lambda: "", None, txt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
 
54
  if __name__ == "__main__":
55
- demo.queue(concurrency_count=2).launch(share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModel, AutoProcessor, Blip2ForConditionalGeneration
3
+ import gradio as gr
4
+ import gc
5
+ from PIL import Image
6
+
7
+ # 检查设备和内存
8
+ device = "cuda" if torch.cuda.is_available() else "cpu"
9
+ print(f"Using device: {device}")
 
 
 
10
  if torch.cuda.is_available():
11
+ print(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
12
+
13
+ # 全局变量存储模型
14
+ tokenizer = None
15
+ model = None
16
+ processor = None
17
+ blip_model = None
18
+
19
+ def load_models():
20
+ """延迟加载模型以优化内存使用"""
21
+ global tokenizer, model, processor, blip_model
22
+
23
+ try:
24
+ # 加载对话模型 (ChatGLM2-6B int4量化版本)
25
+ model_name = "THUDM/chatglm2-6b-int4"
26
+ print(f"正在加载对话模型: {model_name}")
27
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
28
+ model = AutoModel.from_pretrained(
29
+ model_name,
30
+ trust_remote_code=True,
31
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32
32
+ )
33
+
34
+ if device == "cuda":
35
+ model = model.half().cuda()
36
+ model.eval()
37
+ print("✅ 对话模型加载完成")
38
+
39
+ # 加载图像理解模型 (BLIP-2)
40
+ vision_model = "Salesforce/blip2-opt-2.7b"
41
+ print(f"正在加载图像理解模型: {vision_model}")
42
+ processor = AutoProcessor.from_pretrained(vision_model)
43
+ blip_model = Blip2ForConditionalGeneration.from_pretrained(
44
+ vision_model,
45
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
46
+ device_map="auto" if device == "cuda" else None,
47
+ load_in_8bit=True if device == "cuda" else False # 使用8bit量化节省内存
48
+ )
49
+
50
+ if device == "cpu":
51
+ blip_model = blip_model.to("cpu")
52
+
53
+ print("✅ 图像理解模型加载完成")
54
+ return True
55
+
56
+ except Exception as e:
57
+ print(f"❌ 模型加载失败: {str(e)}")
58
+ return False
59
+
60
  def describe_image(image):
61
+ """使用BLIP-2生成图像描述"""
62
+ if blip_model is None or processor is None:
63
+ return "模型未正确加载"
64
+
65
+ try:
66
+ # 确保图像格式正确
67
+ if not isinstance(image, Image.Image):
68
+ image = Image.fromarray(image)
69
+
70
+ # 预处理图像
71
+ inputs = processor(image, return_tensors="pt")
72
+
73
+ # 移动到正确的设备
74
+ if device == "cuda":
75
+ inputs = {k: v.to(device) for k, v in inputs.items()}
76
+
77
+ # 生成描述
78
+ with torch.no_grad():
79
+ generated_ids = blip_model.generate(
80
+ **inputs,
81
+ max_new_tokens=50,
82
+ num_beams=3,
83
+ temperature=0.7,
84
+ do_sample=True
85
+ )
86
+
87
+ caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
88
+ return caption
89
+
90
+ except Exception as e:
91
+ print(f"图像描述生成错误: {str(e)}")
92
+ return f"图像描述生成失败: {str(e)}"
93
+
94
+ def on_image_upload(image):
95
+ """处理图像上传事件"""
96
+ if image is None:
97
+ return [], []
98
+
99
+ try:
100
+ # 重置对话历史
101
+ history = []
102
+ chat_history = []
103
+
104
+ # 生成图像描述
105
+ caption = describe_image(image)
106
+ print(f"图像描述: {caption}")
107
+
108
+ # 构建提示词
109
+ prompt = f"这是一幅艺术作品图像,其内容是: {caption}。请对此艺术作品进行详细的介绍和分析,包括艺术风格、创作技法、可能的历史背景等方面。"
110
+
111
+ # 生成初始分析
112
+ if model is not None and tokenizer is not None:
113
+ try:
114
+ with torch.no_grad():
115
+ response, history = model.chat(tokenizer, prompt, history=history)
116
+ chat_history.append([image, response])
117
+ except Exception as e:
118
+ print(f"对话生成错误: {str(e)}")
119
+ chat_history.append([image, f"分析生成失败: {str(e)}"])
120
+ else:
121
+ chat_history.append([image, "对话模型未正确加载"])
122
+
123
+ return chat_history, history
124
+
125
+ except Exception as e:
126
+ print(f"图像上传处理错误: {str(e)}")
127
+ return [[None, f"处理失败: {str(e)}"]], []
128
+
129
+ def on_user_message(user_message, chat_history, history):
130
+ """处理用户消息"""
131
+ if not user_message.strip():
132
+ yield chat_history, history
133
+ return
134
+
135
+ if model is None or tokenizer is None:
136
+ chat_history = chat_history or []
137
+ chat_history.append([user_message, "对话模型未正确加载"])
138
+ yield chat_history, history
139
+ return
140
+
141
+ try:
142
+ chat_history = chat_history or []
143
+ chat_history.append([user_message, ""])
144
+
145
+ # 使用流式响应
146
+ for output, new_history in model.stream_chat(tokenizer, user_message, history):
147
+ chat_history[-1][1] = output
148
+ yield chat_history, new_history
149
+
150
+ except Exception as e:
151
+ print(f"用户消息处理错误: {str(e)}")
152
+ chat_history[-1][1] = f"回复生成失败: {str(e)}"
153
+ yield chat_history, history
154
+
155
+ def clear_chat():
156
+ """清空对话"""
157
+ return [], []
158
+
159
+ # 构建Gradio界面
160
+ def create_interface():
161
+ with gr.Blocks(title="AI艺术品讲解智能体", theme=gr.themes.Soft()) as demo:
162
+ gr.Markdown("# 🎨 AI 艺术品讲解智能体")
163
+ gr.Markdown("上传一张艺术品图像,让 AI 为您描述这件艺术作品,并回答有关它的问题。")
164
+
165
+ with gr.Row():
166
+ with gr.Column(scale=1):
167
+ image_input = gr.Image(
168
+ label="上传艺术品图像",
169
+ type="pil",
170
+ height=300
171
+ )
172
+ clear_btn = gr.Button("🗑️ 清空对话", variant="secondary")
173
+
174
+ with gr.Column(scale=2):
175
+ chatbot = gr.Chatbot(
176
+ label="对话区域",
177
+ height=500,
178
+ show_label=True
179
+ )
180
+
181
+ user_input = gr.Textbox(
182
+ label="询问问题",
183
+ placeholder="请输入关于这幅作品的提问...",
184
+ lines=2
185
+ )
186
+
187
+ # 状态管理
188
+ state = gr.State([]) # 存储模型对话历史
189
+
190
+ # 事件绑定
191
+ image_input.upload(
192
+ fn=on_image_upload,
193
+ inputs=image_input,
194
+ outputs=[chatbot, state]
195
+ )
196
+
197
+ user_input.submit(
198
+ fn=on_user_message,
199
+ inputs=[user_input, chatbot, state],
200
+ outputs=[chatbot, state]
201
+ )
202
+
203
+ user_input.submit(
204
+ lambda: "",
205
+ inputs=[],
206
+ outputs=[user_input]
207
+ )
208
+
209
+ clear_btn.click(
210
+ fn=clear_chat,
211
+ inputs=[],
212
+ outputs=[chatbot, state]
213
+ )
214
+
215
+ # 添加使用说明
216
+ gr.Markdown("""
217
+ ### 使用说明:
218
+ 1. 点击上传区域选择一张艺术品图像
219
+ 2. AI 会自动分析图像并生成初始介绍
220
+ 3. 在下方输入框中提问关于艺术品的问题
221
+ 4. 支持多轮对话,可以深入讨论艺术品的各个方面
222
+
223
+ ### 注意事项:
224
+ - 支持常见图片格式(JPG, PNG, WebP等)
225
+ - 建议上传清晰的艺术品图像以获得更好的分析效果
226
+ - 首次加载模型可能需要一些时间,请耐心等待
227
+ """)
228
+
229
+ return demo
230
 
231
+ # 主程序
232
  if __name__ == "__main__":
233
+ print("🚀 启动 AI 艺术品讲解智能体...")
234
+
235
+ # 加载模型
236
+ if load_models():
237
+ print("✅ 所有模型加载完成,启动界��...")
238
+ demo = create_interface()
239
+ demo.queue(max_size=10).launch(
240
+ share=True,
241
+ server_name="0.0.0.0",
242
+ server_port=7860,
243
+ show_error=True
244
+ )
245
+ else:
246
+ print("❌ 模型加载失败,请检查环境配置")