File size: 3,052 Bytes
dcef8cb
2c94591
3153182
b75d3d8
 
dcef8cb
758f4b1
dcef8cb
758f4b1
b75d3d8
3153182
 
b75d3d8
3153182
 
 
b75d3d8
 
 
 
 
 
758f4b1
b75d3d8
3153182
b75d3d8
 
dcef8cb
b75d3d8
 
a675f47
b75d3d8
a675f47
b75d3d8
a675f47
 
dcef8cb
b75d3d8
 
a675f47
 
b75d3d8
a675f47
 
758f4b1
b75d3d8
 
 
 
 
 
 
 
 
 
 
 
 
 
a675f47
758f4b1
b75d3d8
a675f47
 
b75d3d8
a675f47
 
 
b75d3d8
a675f47
 
 
758f4b1
b75d3d8
a675f47
b75d3d8
3153182
a675f47
 
dcef8cb
2c94591
 
b75d3d8
 
 
2c94591
b75d3d8
 
2c94591
 
 
b75d3d8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import torch
import gradio as gr
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
from deepseek_vl.models import VLChatProcessor, MultiModalityCausalLM
from deepseek_vl.utils.io import load_pil_images

# 模型路徑
model_path = "deepseek-ai/deepseek-vl-7b-chat"

# ==== BitsAndBytes 4-bit 量化設定 ====
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,  # 強制 float16
    bnb_4bit_use_double_quant=True
)

# 載入 processor 和 tokenizer
vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
tokenizer = vl_chat_processor.tokenizer

# 載入模型 (4-bit 量化 + float16)
vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(
    model_path,
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True
).eval()

# ==== 單張圖片推理函式 ====
def chat_with_image(image, user_message):
    try:
        # 建立對話格式
        conversation = [
            {"role": "User", "content": "<image_placeholder>" + user_message, "images": [image]},
            {"role": "Assistant", "content": ""}
        ]

        # 輸入處理
        pil_images = load_pil_images(conversation)
        prepare_inputs = vl_chat_processor(
            conversations=conversation,
            images=pil_images,
            force_batchify=True
        ).to(vl_gpt.device)

        # 🚨 正確 dtype 處理
        # 只將需要的 tensor 轉 float16,input_ids 必須是 long
        new_inputs = {}
        for k, v in prepare_inputs.items():
            if torch.is_tensor(v):
                if k in ["input_ids", "labels"]:
                    new_inputs[k] = v.to(torch.long)
                else:
                    new_inputs[k] = v.to(torch.float16)
            else:
                new_inputs[k] = v
        prepare_inputs = new_inputs

        # 取得輸入 embeddings
        inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)

        # 生成回答
        outputs = vl_gpt.language_model.generate(
            inputs_embeds=inputs_embeds,
            attention_mask=prepare_inputs["attention_mask"],
            pad_token_id=tokenizer.eos_token_id,
            bos_token_id=tokenizer.bos_token_id,
            eos_token_id=tokenizer.eos_token_id,
            max_new_tokens=128,  # 降低生成長度以減少記憶體
            do_sample=False,
            use_cache=True
        )

        # 解碼
        answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
        return answer

    except Exception as e:
        return f"Error: {str(e)}"

# ==== Gradio Web UI ====
demo = gr.Interface(
    fn=chat_with_image,
    inputs=[gr.Image(type="pil", label="Upload Image"),
            gr.Textbox(lines=2, placeholder="Ask about the image...")],
    outputs="text",
    title="DeepSeek-VL-7B-Chat Demo (4-bit, float16)",
    description="上傳圖片並輸入問題,模型會生成與圖片相關的回答"
)

if __name__ == "__main__":
    demo.launch()