robot0820 commited on
Commit
b75d3d8
·
verified ·
1 Parent(s): 375f86a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -32
app.py CHANGED
@@ -1,79 +1,92 @@
1
  import torch
2
  import gradio as gr
3
  from transformers import AutoModelForCausalLM, BitsAndBytesConfig
4
- from deepseek_vl.models import VLChatProcessor
 
5
 
6
  # 模型路徑
7
  model_path = "deepseek-ai/deepseek-vl-7b-chat"
8
 
9
- # 載入 processor tokenizer
10
- vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
11
- tokenizer = vl_chat_processor.tokenizer
12
-
13
- # ==== 量化模型設定 (4-bit) ====
14
  bnb_config = BitsAndBytesConfig(
15
  load_in_4bit=True,
16
- bnb_4bit_compute_dtype=torch.float16,
17
  bnb_4bit_use_double_quant=True
18
  )
19
 
20
- vl_gpt: AutoModelForCausalLM = AutoModelForCausalLM.from_pretrained(
 
 
 
 
 
21
  model_path,
22
- trust_remote_code=True,
23
  device_map="auto",
24
- quantization_config=bnb_config
25
- )
26
- vl_gpt.eval()
27
 
28
- # ==== 單張圖片處理 + 減少 max_new_tokens ====
29
- def generate_answer(image, text):
30
  try:
31
- # 將圖片與文字組合成對話格式
32
  conversation = [
33
- {"role": "User", "content": "<image_placeholder>" + text, "images": [image]},
34
  {"role": "Assistant", "content": ""}
35
  ]
36
 
37
- # 🚨 這裡直接傳入 [image],因為已經是 PIL.Image,不需要 load_pil_images
 
38
  prepare_inputs = vl_chat_processor(
39
  conversations=conversation,
40
- images=[image],
41
  force_batchify=True
42
  ).to(vl_gpt.device)
43
 
44
- prepare_inputs = dict(prepare_inputs)
45
- prepare_inputs = {k: (v.to(torch.float16) if torch.is_tensor(v) else v)
46
- for k, v in prepare_inputs.items()}
47
-
48
- # 轉換成 embeddings
 
 
 
 
 
 
 
 
 
49
  inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
50
 
51
- # 產生回答
52
  outputs = vl_gpt.language_model.generate(
53
  inputs_embeds=inputs_embeds,
54
- attention_mask=prepare_inputs.attention_mask,
55
  pad_token_id=tokenizer.eos_token_id,
56
  bos_token_id=tokenizer.bos_token_id,
57
  eos_token_id=tokenizer.eos_token_id,
58
- max_new_tokens=128, # 降低生成長度以減少記憶體
59
  do_sample=False,
60
  use_cache=True
61
  )
62
 
 
63
  answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
64
- return f"{prepare_inputs['sft_format'][0]} {answer}"
65
 
66
  except Exception as e:
67
  return f"Error: {str(e)}"
68
 
69
  # ==== Gradio Web UI ====
70
  demo = gr.Interface(
71
- fn=generate_answer,
72
- inputs=[gr.Image(type="pil", label="Upload Image"), gr.Textbox(label="Question")],
 
73
  outputs="text",
74
- title="DeepSeek-VL-7B Chat Demo",
75
- description="上傳圖片並輸入問題,模型會生成與圖片相關的回答(4-bit 量化,低記憶體模式)"
76
  )
77
 
78
  if __name__ == "__main__":
79
- demo.launch()
 
1
  import torch
2
  import gradio as gr
3
  from transformers import AutoModelForCausalLM, BitsAndBytesConfig
4
+ from deepseek_vl.models import VLChatProcessor, MultiModalityCausalLM
5
+ from deepseek_vl.utils.io import load_pil_images
6
 
7
  # 模型路徑
8
  model_path = "deepseek-ai/deepseek-vl-7b-chat"
9
 
10
+ # ==== BitsAndBytes 4-bit 量化設定 ====
 
 
 
 
11
  bnb_config = BitsAndBytesConfig(
12
  load_in_4bit=True,
13
+ bnb_4bit_compute_dtype=torch.float16, # 強制 float16
14
  bnb_4bit_use_double_quant=True
15
  )
16
 
17
+ # 載入 processor 和 tokenizer
18
+ vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
19
+ tokenizer = vl_chat_processor.tokenizer
20
+
21
+ # 載入模型 (4-bit 量化 + float16)
22
+ vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(
23
  model_path,
24
+ quantization_config=bnb_config,
25
  device_map="auto",
26
+ trust_remote_code=True
27
+ ).eval()
 
28
 
29
+ # ==== 單張圖片推理函式 ====
30
+ def chat_with_image(image, user_message):
31
  try:
32
+ # 建立對話格式
33
  conversation = [
34
+ {"role": "User", "content": "<image_placeholder>" + user_message, "images": [image]},
35
  {"role": "Assistant", "content": ""}
36
  ]
37
 
38
+ # 輸入處理
39
+ pil_images = load_pil_images(conversation)
40
  prepare_inputs = vl_chat_processor(
41
  conversations=conversation,
42
+ images=pil_images,
43
  force_batchify=True
44
  ).to(vl_gpt.device)
45
 
46
+ # 🚨 正確 dtype 處理
47
+ # 只將需要的 tensor float16,input_ids 必須是 long
48
+ new_inputs = {}
49
+ for k, v in prepare_inputs.items():
50
+ if torch.is_tensor(v):
51
+ if k in ["input_ids", "labels"]:
52
+ new_inputs[k] = v.to(torch.long)
53
+ else:
54
+ new_inputs[k] = v.to(torch.float16)
55
+ else:
56
+ new_inputs[k] = v
57
+ prepare_inputs = new_inputs
58
+
59
+ # 取得輸入 embeddings
60
  inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
61
 
62
+ # 生成回答
63
  outputs = vl_gpt.language_model.generate(
64
  inputs_embeds=inputs_embeds,
65
+ attention_mask=prepare_inputs["attention_mask"],
66
  pad_token_id=tokenizer.eos_token_id,
67
  bos_token_id=tokenizer.bos_token_id,
68
  eos_token_id=tokenizer.eos_token_id,
69
+ max_new_tokens=128, # 降低生成長度以減少記憶體
70
  do_sample=False,
71
  use_cache=True
72
  )
73
 
74
+ # 解碼
75
  answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
76
+ return answer
77
 
78
  except Exception as e:
79
  return f"Error: {str(e)}"
80
 
81
  # ==== Gradio Web UI ====
82
  demo = gr.Interface(
83
+ fn=chat_with_image,
84
+ inputs=[gr.Image(type="pil", label="Upload Image"),
85
+ gr.Textbox(lines=2, placeholder="Ask about the image...")],
86
  outputs="text",
87
+ title="DeepSeek-VL-7B-Chat Demo (4-bit, float16)",
88
+ description="上傳圖片並輸入問題,模型會生成與圖片相關的回答"
89
  )
90
 
91
  if __name__ == "__main__":
92
+ demo.launch()