robot0820 commited on
Commit
21f59de
·
verified ·
1 Parent(s): d150731

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -3
app.py CHANGED
@@ -28,19 +28,23 @@ vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(
28
  # ==== 單張圖片推理函式 ====
29
  def chat_with_image(image, user_message):
30
  try:
 
31
  conversation = [
32
  {"role": "User", "content": "<image_placeholder>" + user_message, "images": [image]},
33
  {"role": "Assistant", "content": ""}
34
  ]
35
 
36
- # 直接傳入 PIL.Image,不再使用 load_pil_images
37
  prepare_inputs = vl_chat_processor(
38
  conversations=conversation,
39
  images=[image],
40
  force_batchify=True
41
  ).to(vl_gpt.device)
42
 
43
- # 正確 dtype 處理
 
 
 
44
  new_inputs = {}
45
  for k, v in prepare_inputs.items():
46
  if torch.is_tensor(v):
@@ -62,11 +66,12 @@ def chat_with_image(image, user_message):
62
  pad_token_id=tokenizer.eos_token_id,
63
  bos_token_id=tokenizer.bos_token_id,
64
  eos_token_id=tokenizer.eos_token_id,
65
- max_new_tokens=128,
66
  do_sample=False,
67
  use_cache=True
68
  )
69
 
 
70
  answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
71
  return answer
72
 
 
28
  # ==== 單張圖片推理函式 ====
29
  def chat_with_image(image, user_message):
30
  try:
31
+ # 建立對話
32
  conversation = [
33
  {"role": "User", "content": "<image_placeholder>" + user_message, "images": [image]},
34
  {"role": "Assistant", "content": ""}
35
  ]
36
 
37
+ # 直接傳入 PIL.Image,不使用 load_pil_images
38
  prepare_inputs = vl_chat_processor(
39
  conversations=conversation,
40
  images=[image],
41
  force_batchify=True
42
  ).to(vl_gpt.device)
43
 
44
+ # 🚨 BatchedVLChatProcessorOutput 轉 dict
45
+ prepare_inputs = {k: getattr(prepare_inputs, k) for k in prepare_inputs.__dataclass_fields__.keys()}
46
+
47
+ # 正確 dtype:input_ids/labels 保持 long,其他 tensor 轉 float16
48
  new_inputs = {}
49
  for k, v in prepare_inputs.items():
50
  if torch.is_tensor(v):
 
66
  pad_token_id=tokenizer.eos_token_id,
67
  bos_token_id=tokenizer.bos_token_id,
68
  eos_token_id=tokenizer.eos_token_id,
69
+ max_new_tokens=128, # 減少記憶體
70
  do_sample=False,
71
  use_cache=True
72
  )
73
 
74
+ # 解碼
75
  answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
76
  return answer
77