File size: 3,052 Bytes
dcef8cb 2c94591 3153182 b75d3d8 dcef8cb 758f4b1 dcef8cb 758f4b1 b75d3d8 3153182 b75d3d8 3153182 b75d3d8 758f4b1 b75d3d8 3153182 b75d3d8 dcef8cb b75d3d8 a675f47 b75d3d8 a675f47 b75d3d8 a675f47 dcef8cb b75d3d8 a675f47 b75d3d8 a675f47 758f4b1 b75d3d8 a675f47 758f4b1 b75d3d8 a675f47 b75d3d8 a675f47 b75d3d8 a675f47 758f4b1 b75d3d8 a675f47 b75d3d8 3153182 a675f47 dcef8cb 2c94591 b75d3d8 2c94591 b75d3d8 2c94591 b75d3d8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
import torch
import gradio as gr
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
from deepseek_vl.models import VLChatProcessor, MultiModalityCausalLM
from deepseek_vl.utils.io import load_pil_images
# 模型路徑
model_path = "deepseek-ai/deepseek-vl-7b-chat"
# ==== BitsAndBytes 4-bit 量化設定 ====
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16, # 強制 float16
bnb_4bit_use_double_quant=True
)
# 載入 processor 和 tokenizer
vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
tokenizer = vl_chat_processor.tokenizer
# 載入模型 (4-bit 量化 + float16)
vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(
model_path,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True
).eval()
# ==== 單張圖片推理函式 ====
def chat_with_image(image, user_message):
try:
# 建立對話格式
conversation = [
{"role": "User", "content": "<image_placeholder>" + user_message, "images": [image]},
{"role": "Assistant", "content": ""}
]
# 輸入處理
pil_images = load_pil_images(conversation)
prepare_inputs = vl_chat_processor(
conversations=conversation,
images=pil_images,
force_batchify=True
).to(vl_gpt.device)
# 🚨 正確 dtype 處理
# 只將需要的 tensor 轉 float16,input_ids 必須是 long
new_inputs = {}
for k, v in prepare_inputs.items():
if torch.is_tensor(v):
if k in ["input_ids", "labels"]:
new_inputs[k] = v.to(torch.long)
else:
new_inputs[k] = v.to(torch.float16)
else:
new_inputs[k] = v
prepare_inputs = new_inputs
# 取得輸入 embeddings
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
# 生成回答
outputs = vl_gpt.language_model.generate(
inputs_embeds=inputs_embeds,
attention_mask=prepare_inputs["attention_mask"],
pad_token_id=tokenizer.eos_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
max_new_tokens=128, # 降低生成長度以減少記憶體
do_sample=False,
use_cache=True
)
# 解碼
answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
return answer
except Exception as e:
return f"Error: {str(e)}"
# ==== Gradio Web UI ====
demo = gr.Interface(
fn=chat_with_image,
inputs=[gr.Image(type="pil", label="Upload Image"),
gr.Textbox(lines=2, placeholder="Ask about the image...")],
outputs="text",
title="DeepSeek-VL-7B-Chat Demo (4-bit, float16)",
description="上傳圖片並輸入問題,模型會生成與圖片相關的回答"
)
if __name__ == "__main__":
demo.launch()
|