robot0820 commited on
Commit
513a480
·
verified ·
1 Parent(s): 7c695c3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -27
app.py CHANGED
@@ -1,15 +1,16 @@
 
1
  import torch
2
  import gradio as gr
3
  from transformers import AutoModelForCausalLM, BitsAndBytesConfig
4
  from deepseek_vl.models import VLChatProcessor, MultiModalityCausalLM
5
 
6
- # 模型路徑
7
  model_path = "deepseek-ai/deepseek-vl-7b-chat"
8
 
9
- # ==== BitsAndBytes 4-bit 量化設定 ====
10
  bnb_config = BitsAndBytesConfig(
11
  load_in_4bit=True,
12
- bnb_4bit_compute_dtype=torch.float16, # 強制 float16
13
  bnb_4bit_use_double_quant=True
14
  )
15
 
@@ -17,7 +18,7 @@ bnb_config = BitsAndBytesConfig(
17
  vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
18
  tokenizer = vl_chat_processor.tokenizer
19
 
20
- # 載入模型 (4-bit 量化 + float16)
21
  vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(
22
  model_path,
23
  quantization_config=bnb_config,
@@ -25,30 +26,35 @@ vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(
25
  trust_remote_code=True
26
  ).eval()
27
 
28
- # ==== 單張圖片推理函式 ====
 
 
 
29
  def chat_with_image(image, user_message):
 
30
  try:
31
- # 建立對話
32
- conversation = [
33
- {"role": "User", "content": "<image_placeholder>" + user_message, "images": [image]},
34
- {"role": "Assistant", "content": ""}
35
- ]
 
 
 
36
 
37
- # 直接傳入 PIL.Image,不使用 load_pil_images
38
  prepare_inputs = vl_chat_processor(
39
  conversations=conversation,
40
- images=[image],
41
  force_batchify=True
42
  ).to(vl_gpt.device)
43
 
44
- # 🚨 BatchedVLChatProcessorOutput 轉 dict
45
  prepare_inputs = {k: getattr(prepare_inputs, k) for k in prepare_inputs.__dataclass_fields__.keys()}
46
-
47
- # 正確 dtype:input_ids/labels 保持 long,其他 tensor 轉 float16
48
  new_inputs = {}
49
  for k, v in prepare_inputs.items():
50
  if torch.is_tensor(v):
51
- if k in ["input_ids", "labels","attention_mask"]:
52
  new_inputs[k] = v.to(torch.long)
53
  else:
54
  new_inputs[k] = v.to(torch.float16)
@@ -66,27 +72,40 @@ def chat_with_image(image, user_message):
66
  pad_token_id=tokenizer.eos_token_id,
67
  bos_token_id=tokenizer.bos_token_id,
68
  eos_token_id=tokenizer.eos_token_id,
69
- max_new_tokens=128, # 減少記憶體
70
  do_sample=False,
71
  use_cache=True
72
  )
73
 
74
  # 解碼
75
  answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
76
- return answer
 
 
 
77
 
78
  except Exception as e:
79
- return f"Error: {str(e)}"
 
 
 
 
 
80
 
81
  # ==== Gradio Web UI ====
82
- demo = gr.Interface(
83
- fn=chat_with_image,
84
- inputs=[gr.Image(type="pil", label="Upload Image"),
85
- gr.Textbox(lines=2, placeholder="Ask about the image...")],
86
- outputs="text",
87
- title="DeepSeek-VL-7B-Chat Demo (4-bit, float16)",
88
- description="上傳圖片並輸入問題,模型會生成與圖片相關的回答"
89
- )
 
 
 
 
 
90
 
91
  if __name__ == "__main__":
92
  demo.launch()
 
1
+ # app.py
2
  import torch
3
  import gradio as gr
4
  from transformers import AutoModelForCausalLM, BitsAndBytesConfig
5
  from deepseek_vl.models import VLChatProcessor, MultiModalityCausalLM
6
 
7
+ # ==== 模型設定 ====
8
  model_path = "deepseek-ai/deepseek-vl-7b-chat"
9
 
10
+ # BitsAndBytes 4-bit 量化設定
11
  bnb_config = BitsAndBytesConfig(
12
  load_in_4bit=True,
13
+ bnb_4bit_compute_dtype=torch.float16,
14
  bnb_4bit_use_double_quant=True
15
  )
16
 
 
18
  vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
19
  tokenizer = vl_chat_processor.tokenizer
20
 
21
+ # 載入模型
22
  vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(
23
  model_path,
24
  quantization_config=bnb_config,
 
26
  trust_remote_code=True
27
  ).eval()
28
 
29
+ # ==== 對話歷史 ====
30
+ chat_history = []
31
+
32
+ # ==== 文字+圖片推理函式 ====
33
  def chat_with_image(image, user_message):
34
+ global chat_history
35
  try:
36
+ # 建立對話內容
37
+ conversation = chat_history.copy()
38
+ conversation.append({
39
+ "role": "User",
40
+ "content": "<image_placeholder>" + user_message,
41
+ "images": [image] if image else []
42
+ })
43
+ conversation.append({"role": "Assistant", "content": ""})
44
 
45
+ # 準備輸入
46
  prepare_inputs = vl_chat_processor(
47
  conversations=conversation,
48
+ images=[image] if image else [],
49
  force_batchify=True
50
  ).to(vl_gpt.device)
51
 
52
+ # 轉成 dict,並正確處理 dtype
53
  prepare_inputs = {k: getattr(prepare_inputs, k) for k in prepare_inputs.__dataclass_fields__.keys()}
 
 
54
  new_inputs = {}
55
  for k, v in prepare_inputs.items():
56
  if torch.is_tensor(v):
57
+ if k in ["input_ids", "labels"]:
58
  new_inputs[k] = v.to(torch.long)
59
  else:
60
  new_inputs[k] = v.to(torch.float16)
 
72
  pad_token_id=tokenizer.eos_token_id,
73
  bos_token_id=tokenizer.bos_token_id,
74
  eos_token_id=tokenizer.eos_token_id,
75
+ max_new_tokens=128,
76
  do_sample=False,
77
  use_cache=True
78
  )
79
 
80
  # 解碼
81
  answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
82
+
83
+ # 更新歷史
84
+ chat_history.append((user_message, answer))
85
+ return answer, chat_history
86
 
87
  except Exception as e:
88
+ return f"Error: {str(e)}", chat_history
89
+
90
+ def reset_chat():
91
+ global chat_history
92
+ chat_history = []
93
+ return "", []
94
 
95
  # ==== Gradio Web UI ====
96
+ with gr.Blocks() as demo:
97
+ gr.Markdown("# DeepSeek-VL-7B-Chat Demo (4-bit, float16)")
98
+ with gr.Row():
99
+ image_input = gr.Image(type="pil", label="Upload Image")
100
+ text_input = gr.Textbox(lines=2, placeholder="Ask about the image...")
101
+ with gr.Row():
102
+ submit_btn = gr.Button("Submit")
103
+ reset_btn = gr.Button("Reset Chat")
104
+ output_text = gr.Textbox(label="Answer")
105
+ chat_display = gr.Chatbot(label="Chat History")
106
+
107
+ submit_btn.click(chat_with_image, inputs=[image_input, text_input], outputs=[output_text, chat_display])
108
+ reset_btn.click(reset_chat, inputs=[], outputs=[output_text, chat_display])
109
 
110
  if __name__ == "__main__":
111
  demo.launch()