VLM_Test / app.py
robot0820's picture
Update app.py
d150731 verified
raw
history blame
2.82 kB
import torch
import gradio as gr
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
from deepseek_vl.models import VLChatProcessor, MultiModalityCausalLM
# 模型路徑
model_path = "deepseek-ai/deepseek-vl-7b-chat"
# ==== BitsAndBytes 4-bit 量化設定 ====
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16, # 強制 float16
bnb_4bit_use_double_quant=True
)
# 載入 processor 和 tokenizer
vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
tokenizer = vl_chat_processor.tokenizer
# 載入模型 (4-bit 量化 + float16)
vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(
model_path,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True
).eval()
# ==== 單張圖片推理函式 ====
def chat_with_image(image, user_message):
try:
conversation = [
{"role": "User", "content": "<image_placeholder>" + user_message, "images": [image]},
{"role": "Assistant", "content": ""}
]
# 直接傳入 PIL.Image,不再使用 load_pil_images
prepare_inputs = vl_chat_processor(
conversations=conversation,
images=[image],
force_batchify=True
).to(vl_gpt.device)
# 正確 dtype 處理
new_inputs = {}
for k, v in prepare_inputs.items():
if torch.is_tensor(v):
if k in ["input_ids", "labels"]:
new_inputs[k] = v.to(torch.long)
else:
new_inputs[k] = v.to(torch.float16)
else:
new_inputs[k] = v
prepare_inputs = new_inputs
# 取得 embeddings
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
# 生成回答
outputs = vl_gpt.language_model.generate(
inputs_embeds=inputs_embeds,
attention_mask=prepare_inputs["attention_mask"],
pad_token_id=tokenizer.eos_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
max_new_tokens=128,
do_sample=False,
use_cache=True
)
answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
return answer
except Exception as e:
return f"Error: {str(e)}"
# ==== Gradio Web UI ====
demo = gr.Interface(
fn=chat_with_image,
inputs=[gr.Image(type="pil", label="Upload Image"),
gr.Textbox(lines=2, placeholder="Ask about the image...")],
outputs="text",
title="DeepSeek-VL-7B-Chat Demo (4-bit, float16)",
description="上傳圖片並輸入問題,模型會生成與圖片相關的回答"
)
if __name__ == "__main__":
demo.launch()