robot0820 commited on
Commit
a675f47
ยท
verified ยท
1 Parent(s): 2c94591

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -34
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import torch
2
  import gradio as gr
3
- from transformers import AutoModelForCausalLM, BitsAndBytesConfig
4
  from deepseek_vl.models import VLChatProcessor
5
  from deepseek_vl.utils.io import load_pil_images
6
 
@@ -11,50 +11,47 @@ model_path = "deepseek-ai/deepseek-vl-7b-chat"
11
  vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
12
  tokenizer = vl_chat_processor.tokenizer
13
 
14
- # ==== ้‡ๅŒ–ๆจกๅž‹่จญๅฎš (4-bit) ====
15
- bnb_config = BitsAndBytesConfig(
16
- load_in_4bit=True,
17
- bnb_4bit_compute_dtype=torch.float16,
18
- bnb_4bit_use_double_quant=True
19
- )
20
-
21
  vl_gpt: AutoModelForCausalLM = AutoModelForCausalLM.from_pretrained(
22
  model_path,
23
  trust_remote_code=True,
24
- device_map="auto",
25
- quantization_config=bnb_config
26
  )
27
  vl_gpt.eval()
28
 
29
  # ==== ๅ–ฎๅผตๅœ–็‰‡่™•็† + ๆธ›ๅฐ‘ max_new_tokens ====
30
  def generate_answer(image, text):
31
- conversation = [
32
- {"role": "User", "content": "<image_placeholder>" + text, "images": [image]},
33
- {"role": "Assistant", "content": ""}
34
- ]
 
35
 
36
- pil_images = load_pil_images(conversation)
37
- prepare_inputs = vl_chat_processor(
38
- conversations=conversation,
39
- images=pil_images,
40
- force_batchify=True
41
- ).to(vl_gpt.device)
42
 
43
- inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
44
 
45
- outputs = vl_gpt.language_model.generate(
46
- inputs_embeds=inputs_embeds,
47
- attention_mask=prepare_inputs.attention_mask,
48
- pad_token_id=tokenizer.eos_token_id,
49
- bos_token_id=tokenizer.bos_token_id,
50
- eos_token_id=tokenizer.eos_token_id,
51
- max_new_tokens=128, # ้™ไฝŽ็”Ÿๆˆ้•ทๅบฆ
52
- do_sample=False,
53
- use_cache=True
54
- )
55
 
56
- answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
57
- return f"{prepare_inputs['sft_format'][0]} {answer}"
 
 
58
 
59
  # ==== Gradio Web UI ====
60
  demo = gr.Interface(
@@ -62,7 +59,7 @@ demo = gr.Interface(
62
  inputs=[gr.Image(type="pil", label="Upload Image"), gr.Textbox(label="Question")],
63
  outputs="text",
64
  title="DeepSeek-VL-7B Chat Demo",
65
- description="ไธŠๅ‚ณๅœ–็‰‡ไธฆ่ผธๅ…ฅๅ•้กŒ๏ผŒๆจกๅž‹ๆœƒ็”Ÿๆˆ่ˆ‡ๅœ–็‰‡็›ธ้—œ็š„ๅ›ž็ญ”๏ผˆ4-bit ้‡ๅŒ–๏ผŒไฝŽ่จ˜ๆ†ถ้ซ”ๆจกๅผ๏ผ‰"
66
  )
67
 
68
  if __name__ == "__main__":
 
1
  import torch
2
  import gradio as gr
3
+ from transformers import AutoModelForCausalLM
4
  from deepseek_vl.models import VLChatProcessor
5
  from deepseek_vl.utils.io import load_pil_images
6
 
 
11
  vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
12
  tokenizer = vl_chat_processor.tokenizer
13
 
14
+ # ==== ่ผ‰ๅ…ฅๆจกๅž‹ (CPU/GPU ๅ…ผๅฎน๏ผŒไธไฝฟ็”จ้‡ๅŒ–) ====
 
 
 
 
 
 
15
  vl_gpt: AutoModelForCausalLM = AutoModelForCausalLM.from_pretrained(
16
  model_path,
17
  trust_remote_code=True,
18
+ device_map="auto", # ่‡ชๅ‹•ๅˆ†้… GPU / CPU
19
+ torch_dtype=torch.bfloat16 # ไฝŽ็ฒพๅบฆๆธ›ๅฐ‘ VRAM
20
  )
21
  vl_gpt.eval()
22
 
23
  # ==== ๅ–ฎๅผตๅœ–็‰‡่™•็† + ๆธ›ๅฐ‘ max_new_tokens ====
24
  def generate_answer(image, text):
25
+ try:
26
+ conversation = [
27
+ {"role": "User", "content": "<image_placeholder>" + text, "images": [image]},
28
+ {"role": "Assistant", "content": ""}
29
+ ]
30
 
31
+ pil_images = load_pil_images(conversation)
32
+ prepare_inputs = vl_chat_processor(
33
+ conversations=conversation,
34
+ images=pil_images,
35
+ force_batchify=True
36
+ ).to(vl_gpt.device)
37
 
38
+ inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
39
 
40
+ outputs = vl_gpt.language_model.generate(
41
+ inputs_embeds=inputs_embeds,
42
+ attention_mask=prepare_inputs.attention_mask,
43
+ pad_token_id=tokenizer.eos_token_id,
44
+ bos_token_id=tokenizer.bos_token_id,
45
+ eos_token_id=tokenizer.eos_token_id,
46
+ max_new_tokens=128, # ้™ไฝŽ็”Ÿๆˆ้•ทๅบฆ
47
+ do_sample=False,
48
+ use_cache=True
49
+ )
50
 
51
+ answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
52
+ return f"{prepare_inputs['sft_format'][0]} {answer}"
53
+ except Exception as e:
54
+ return f"Error: {str(e)}"
55
 
56
  # ==== Gradio Web UI ====
57
  demo = gr.Interface(
 
59
  inputs=[gr.Image(type="pil", label="Upload Image"), gr.Textbox(label="Question")],
60
  outputs="text",
61
  title="DeepSeek-VL-7B Chat Demo",
62
+ description="ไธŠๅ‚ณๅœ–็‰‡ไธฆ่ผธๅ…ฅๅ•้กŒ๏ผŒๆจกๅž‹ๆœƒ็”Ÿๆˆ่ˆ‡ๅœ–็‰‡็›ธ้—œ็š„ๅ›ž็ญ”๏ผˆCPU/GPU ๅ…ผๅฎน๏ผŒไฝŽ่จ˜ๆ†ถ้ซ”ๆจกๅผ๏ผ‰"
63
  )
64
 
65
  if __name__ == "__main__":