chenglu commited on
Commit
f5ccf14
·
1 Parent(s): d99a437

transformers

Browse files
Files changed (1) hide show
  1. app.py +40 -62
app.py CHANGED
@@ -1,77 +1,55 @@
1
- import torch
2
- from transformers import AutoTokenizer, AutoModel, AutoProcessor, Blip2ForConditionalGeneration
3
- import gradio as gr
 
 
 
 
 
 
4
 
5
- # Load the Chinese conversational model (ChatGLM 6B, int4 quantized version)
6
- model_name = "THUDM/chatglm2-6b-int4"
7
- print(f"Loading conversation model: {model_name}")
8
- tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
9
- model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
10
- # If GPU is available, use half precision on GPU for faster inference
11
  if torch.cuda.is_available():
12
- model = model.half().cuda()
13
- model.eval()
14
 
15
- # Load the image captioning model (BLIP-2 with OPT 2.7B LLM)
16
- vision_model = "Salesforce/blip2-opt-2.7b"
17
- print(f"Loading image captioning model: {vision_model}")
18
- processor = AutoProcessor.from_pretrained(vision_model)
19
- blip_model = Blip2ForConditionalGeneration.from_pretrained(
20
- vision_model,
21
- torch_dtype=(torch.float16 if torch.cuda.is_available() else torch.float32),
22
- device_map=("auto" if torch.cuda.is_available() else None)
23
  )
24
- # Ensure BLIP model on CPU if no GPU
25
  if not torch.cuda.is_available():
26
- blip_model = blip_model.to("cpu")
27
 
28
- # Function: generate a descriptive caption for the image using BLIP-2
29
  def describe_image(image):
30
- inputs = processor(image, return_tensors="pt").to(blip_model.device)
31
- generated_ids = blip_model.generate(**inputs, max_new_tokens=50)
32
- caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
33
- return caption
34
 
35
- # Event handler: when a new image is uploaded
36
- def on_image_upload(image):
37
- # Reset histories for a new conversation
38
- history = [] # model's conversation history
39
- chat_history = [] # chat display history for Gradio
40
- # Describe the uploaded artwork image
41
- caption = describe_image(image)
42
- # Build the prompt for the conversational model (include the image description)
43
- prompt = f"这是一幅艺术作品图像,其内容是: {caption}。请对此艺术作品进行简要的介绍和分析。"
44
- # Generate the initial analysis using the conversation model
45
- response, history = model.chat(tokenizer, prompt, history=history)
46
- # Add the image (user side) and the model's response (assistant side) to chat history
47
- chat_history.append([image, response])
48
- return chat_history, history
49
 
50
- # Event handler: when the user sends a new text message (question)
51
- def on_user_message(user_message, chat_history, history):
52
- chat_history = chat_history or []
53
- # Append the user's question and an empty response placeholder
54
- chat_history.append([user_message, ""])
55
- # Use streaming response from the model
56
- for output, new_history in model.stream_chat(tokenizer, user_message, history):
57
- # Update the assistant's response in the chat history
58
- chat_history[-1][1] = output
59
- # Yield the updated chat history and model history for streaming in UI
60
- yield chat_history, new_history
61
 
62
- # Build Gradio interface
63
  with gr.Blocks() as demo:
64
  gr.Markdown("# AI 艺术品讲解智能体")
65
- gr.Markdown("上传一张艺术品图像,让 AI 为您描述这件艺术作品,并回答有关它的问题。")
66
- image_input = gr.Image(label="上传艺术品图像", type="pil")
67
  chatbot = gr.Chatbot()
68
- user_input = gr.Textbox(label="询问问题", placeholder="请输入关于这幅作品的提问...")
69
- state = gr.State() # state to store model history
70
- # Connect events
71
- image_input.upload(fn=on_image_upload, inputs=image_input, outputs=[chatbot, state])
72
- user_input.submit(fn=on_user_message, inputs=[user_input, chatbot, state], outputs=[chatbot, state])
73
- user_input.submit(lambda: "", inputs=[], outputs=[user_input]) # clear input field
74
 
75
- # Launch the app (if running locally; not required in HF Spaces)
76
  if __name__ == "__main__":
77
- demo.queue().launch(share=True)
 
1
+ import torch, gradio as gr
2
+ from transformers import (
3
+ AutoTokenizer, AutoModel,
4
+ AutoProcessor, Blip2ForConditionalGeneration
5
+ )
6
+
7
+ # --------模型加载--------
8
+ chat_model_name = "THUDM/chatglm2-6b-int4"
9
+ vision_model_name = "Salesforce/blip2-opt-2.7b"
10
 
11
+ tokenizer = AutoTokenizer.from_pretrained(chat_model_name, trust_remote_code=True)
12
+ chat_model = AutoModel.from_pretrained(chat_model_name, trust_remote_code=True).eval()
 
 
 
 
13
  if torch.cuda.is_available():
14
+ chat_model = chat_model.half().cuda()
 
15
 
16
+ processor = AutoProcessor.from_pretrained(vision_model_name)
17
+ vision_model = Blip2ForConditionalGeneration.from_pretrained(
18
+ vision_model_name,
19
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
20
+ device_map="auto" if torch.cuda.is_available() else None,
 
 
 
21
  )
 
22
  if not torch.cuda.is_available():
23
+ vision_model = vision_model.to("cpu")
24
 
25
+ # --------工具函数--------
26
  def describe_image(image):
27
+ inputs = processor(image, return_tensors="pt").to(vision_model.device)
28
+ ids = vision_model.generate(**inputs, max_new_tokens=50)
29
+ return processor.batch_decode(ids, skip_special_tokens=True)[0].strip()
 
30
 
31
+ def on_image(img):
32
+ caption = describe_image(img)
33
+ sys_prompt = f"这是一幅艺术作品图像: {caption}。请为普通观众做简介。"
34
+ answer, hist = chat_model.chat(tokenizer, sys_prompt, history=[])
35
+ return [[img, answer]], hist
 
 
 
 
 
 
 
 
 
36
 
37
+ def on_chat(msg, chat_hist, hist):
38
+ chat_hist = chat_hist or []
39
+ chat_hist.append([msg, ""])
40
+ for out, h in chat_model.stream_chat(tokenizer, msg, history=hist):
41
+ chat_hist[-1][1] = out
42
+ yield chat_hist, h
 
 
 
 
 
43
 
44
+ # --------Gradio 界面--------
45
  with gr.Blocks() as demo:
46
  gr.Markdown("# AI 艺术品讲解智能体")
47
+ image = gr.Image(type="pil", label="上传艺术品")
 
48
  chatbot = gr.Chatbot()
49
+ txt = gr.Textbox(label="提问")
50
+ state = gr.State()
51
+ image.upload(on_image, image, [chatbot, state])
52
+ txt.submit(on_chat, [txt, chatbot, state], [chatbot, state]).then(lambda: "", None, txt)
 
 
53
 
 
54
  if __name__ == "__main__":
55
+ demo.queue(concurrency_count=2).launch(share=True)