Rausda6 commited on
Commit
ff3e8f4
·
verified ·
1 Parent(s): 16dfb59

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -87
app.py CHANGED
@@ -1,47 +1,30 @@
1
  import os
2
 
3
- # Disable Xet/CAS backend (it’s what’s throwing the error)
4
  os.environ["HF_HUB_ENABLE_XET"] = "0"
5
-
6
- # Use the robust Rust downloader for big files
7
  os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
8
-
9
- # Optional but helpful: resume and avoid symlinks on some filesystems
10
  os.environ["HF_HUB_ENABLE_RESUME"] = "1"
11
  os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
12
 
13
-
14
  import gradio as gr
15
  import torch
16
  from transformers import AutoProcessor, LlavaForConditionalGeneration
17
  from PIL import Image
18
 
19
-
20
- # Hugging Face model identifier. See the model card for more details:
21
- # https://huggingface.co/StarCycle/llava-dinov2-internlm2-7b-v1
22
  MODEL_ID = "xtuner/llava-phi-3-mini-hf"
23
 
24
- # Determine the computation device. If a CUDA‑enabled GPU is
25
- # available we will use it and cast the weights to half precision to
26
- # reduce memory consumption. Otherwise we fall back to CPU.
27
  if torch.cuda.is_available():
28
- DEVICE = torch.device("cuda")
29
  TORCH_DTYPE = torch.float16
30
  else:
31
- DEVICE = torch.device("cpu")
32
  TORCH_DTYPE = torch.float32
33
 
34
 
35
  def load_model():
36
- """Load the LLaVA model and its processor.
37
-
38
- The model is loaded with ``trust_remote_code=True`` to allow the
39
- repository’s custom projector and adapter classes to be registered
40
- correctly. We specify ``device_map='auto'`` so that the
41
- ``accelerate`` library will distribute the model across the
42
- available hardware (GPU/CPU) automatically. The ``torch_dtype``
43
- argument ensures that the model weights are loaded in half
44
- precision on a GPU and in full precision on a CPU.
45
  """
46
  model = LlavaForConditionalGeneration.from_pretrained(
47
  MODEL_ID,
@@ -51,93 +34,99 @@ def load_model():
51
  low_cpu_mem_usage=True,
52
  )
53
  processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  return model, processor
55
 
56
 
57
- # Load the model and processor at import time. Loading is expensive so
58
- # we only do it once. If the model fails to load (for example
59
- # because of missing dependencies) the exception will be raised here.
60
  MODEL, PROCESSOR = load_model()
61
 
62
 
63
  def answer_question(image: Image.Image, question: str) -> str:
64
- """Generate an answer for the given question about the uploaded image.
65
-
66
- Parameters
67
- ----------
68
- image: PIL.Image.Image
69
- The user‑provided image. Gradio supplies images as PIL
70
- objects, which the LLaVA processor accepts directly.
71
- question: str
72
- The user’s question about the image.
73
-
74
- Returns
75
- -------
76
- str
77
- The answer generated by the model. If either the image or
78
- question is missing, an explanatory message is returned.
79
  """
80
- # Basic validation: ensure both inputs are provided.
 
81
  if image is None:
82
  return "Please upload an image."
83
  if not question or not question.strip():
84
  return "Please enter a question about the image."
85
 
86
- # Build the chat prompt. The LLaVA model uses the ``<image>``
87
- # placeholder to indicate where the image will be inserted.
88
- prompt = f"USER: <image>\n{question.strip()} ASSISTANT:"
89
-
90
- # Tokenize the inputs. The processor will process both the image
91
- # and the text and return PyTorch tensors. We move these to the
92
- # same device as the model to avoid device mismatch errors.
93
- inputs = PROCESSOR(
94
- images=image,
95
- text=prompt,
96
- return_tensors="pt",
97
- )
98
- inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
- # Generate the answer. We limit the number of new tokens to 256 to
101
- # avoid excessive memory usage. Feel free to adjust this value
102
- # depending on your hardware constraints and desired response length.
103
- with torch.no_grad():
104
  generated_ids = MODEL.generate(
105
  **inputs,
106
  max_new_tokens=256,
107
  do_sample=False,
108
  )
109
 
110
- # Decode the generated ids back into text. The output will include
111
- # the entire conversation (e.g., ``USER: ... ASSISTANT: ...``).
112
- output = PROCESSOR.batch_decode(
113
  generated_ids,
114
  skip_special_tokens=True,
115
  clean_up_tokenization_spaces=True,
116
  )[0]
117
 
118
- # Extract the assistant's response by splitting on the
119
- # ``ASSISTANT:`` delimiter.
120
- if "ASSISTANT:" in output:
121
- answer = output.split("ASSISTANT:")[-1].strip()
122
- else:
123
- # Fallback if the delimiter is not present.
124
- answer = output.strip()
125
-
126
- return answer
127
 
128
 
129
  def build_interface() -> gr.Interface:
130
- """Construct the Gradio Interface object for the app."""
131
  description = (
132
  "Upload an image and ask a question about it.\n\n"
133
- "This demo uses the multimodal model "
134
- "StarCycle/llava‑dinov2‑internlm2‑7b‑v1 to perform visual "
135
- "question answering. The model combines the Dinov2 vision encoder with "
136
- "the InternLM2‑Chat‑7B language model via a lightweight projector and "
137
- "LoRA adapters. Note: inference requires a GPU with sufficient "
138
- "memory; on a CPU the generation will be extremely slow."
139
  )
140
- iface = gr.Interface(
141
  fn=answer_question,
142
  inputs=[
143
  gr.Image(type="pil", label="Image"),
@@ -148,22 +137,16 @@ def build_interface() -> gr.Interface:
148
  ),
149
  ],
150
  outputs=gr.Textbox(label="Answer"),
151
- title="Visual Question Answering with LLaVA Dinov2 InternLM2 7B",
152
  description=description,
153
  flagging_mode="never",
154
  )
155
- return iface
156
 
157
 
158
  def main() -> None:
159
- """Launch the Gradio app."""
160
  iface = build_interface()
161
- # When running on Hugging Face Spaces the app will automatically set
162
- # the appropriate host and port. For local development you can
163
- # uncomment the ``server_name`` argument to make the app reachable
164
- # from other machines on your network.
165
  iface.launch()
166
 
167
 
168
  if __name__ == "__main__":
169
- main()
 
1
  import os
2
 
3
+ # ---- Hub download settings (apply before any HF imports) ----
4
  os.environ["HF_HUB_ENABLE_XET"] = "0"
5
+ os.environ["HF_HUB_DISABLE_XET"] = "1"
 
6
  os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
 
 
7
  os.environ["HF_HUB_ENABLE_RESUME"] = "1"
8
  os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
9
 
 
10
  import gradio as gr
11
  import torch
12
  from transformers import AutoProcessor, LlavaForConditionalGeneration
13
  from PIL import Image
14
 
15
+ # Use the compact HF-format LLaVA model
 
 
16
  MODEL_ID = "xtuner/llava-phi-3-mini-hf"
17
 
18
+ # Device + dtype
 
 
19
  if torch.cuda.is_available():
 
20
  TORCH_DTYPE = torch.float16
21
  else:
 
22
  TORCH_DTYPE = torch.float32
23
 
24
 
25
  def load_model():
26
+ """
27
+ Load the LLaVA model and its processor.
 
 
 
 
 
 
 
28
  """
29
  model = LlavaForConditionalGeneration.from_pretrained(
30
  MODEL_ID,
 
34
  low_cpu_mem_usage=True,
35
  )
36
  processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
37
+
38
+ # ---- Robustness: ensure processor carries vision attrs expected by LLaVA ----
39
+ vcfg = getattr(model.config, "vision_config", None)
40
+
41
+ if not hasattr(processor, "patch_size") or processor.patch_size is None:
42
+ # CLIP-L/336 typically uses patch_size=14; default to 14 if missing
43
+ processor.patch_size = getattr(vcfg, "patch_size", 14)
44
+
45
+ if (
46
+ not hasattr(processor, "vision_feature_select_strategy")
47
+ or processor.vision_feature_select_strategy is None
48
+ ):
49
+ processor.vision_feature_select_strategy = getattr(
50
+ model.config, "vision_feature_select_strategy", "default"
51
+ )
52
+
53
+ if (
54
+ not hasattr(processor, "num_additional_image_tokens")
55
+ or processor.num_additional_image_tokens is None
56
+ ):
57
+ # CLIP ViT uses a single CLS token
58
+ processor.num_additional_image_tokens = 1
59
+
60
  return model, processor
61
 
62
 
63
+ # Load once at import
 
 
64
  MODEL, PROCESSOR = load_model()
65
 
66
 
67
  def answer_question(image: Image.Image, question: str) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  """
69
+ Generate an answer about the uploaded image.
70
+ """
71
  if image is None:
72
  return "Please upload an image."
73
  if not question or not question.strip():
74
  return "Please enter a question about the image."
75
 
76
+ try:
77
+ # ---- Preferred: chat-template path (handles image + text cleanly) ----
78
+ conversation = [{
79
+ "role": "user",
80
+ "content": [
81
+ {"type": "image"},
82
+ {"type": "text", "text": question.strip()},
83
+ ],
84
+ }]
85
+
86
+ inputs = PROCESSOR.apply_chat_template(
87
+ conversation,
88
+ add_generation_prompt=True,
89
+ tokenize=True,
90
+ return_dict=True,
91
+ return_tensors="pt",
92
+ images=[image],
93
+ )
94
+ except Exception:
95
+ # ---- Fallback: legacy prompt with <image> placeholder ----
96
+ prompt = f"USER: <image>\n{question.strip()} ASSISTANT:"
97
+ inputs = PROCESSOR(
98
+ images=image,
99
+ text=prompt,
100
+ return_tensors="pt",
101
+ )
102
 
103
+ # Move all tensors to the model's device
104
+ inputs = {k: (v.to(MODEL.device) if hasattr(v, "to") else v) for k, v in inputs.items()}
105
+
106
+ with torch.inference_mode():
107
  generated_ids = MODEL.generate(
108
  **inputs,
109
  max_new_tokens=256,
110
  do_sample=False,
111
  )
112
 
113
+ text = PROCESSOR.batch_decode(
 
 
114
  generated_ids,
115
  skip_special_tokens=True,
116
  clean_up_tokenization_spaces=True,
117
  )[0]
118
 
119
+ return text.strip()
 
 
 
 
 
 
 
 
120
 
121
 
122
  def build_interface() -> gr.Interface:
 
123
  description = (
124
  "Upload an image and ask a question about it.\n\n"
125
+ "This demo uses **xtuner/llava-phi-3-mini-hf** (LLaVA in HF format) "
126
+ "to perform visual question answering. Note: a GPU is recommended; "
127
+ "CPU inference will be slow."
 
 
 
128
  )
129
+ return gr.Interface(
130
  fn=answer_question,
131
  inputs=[
132
  gr.Image(type="pil", label="Image"),
 
137
  ),
138
  ],
139
  outputs=gr.Textbox(label="Answer"),
140
+ title="Visual Question Answering (LLaVA Phi-3 Mini)",
141
  description=description,
142
  flagging_mode="never",
143
  )
 
144
 
145
 
146
  def main() -> None:
 
147
  iface = build_interface()
 
 
 
 
148
  iface.launch()
149
 
150
 
151
  if __name__ == "__main__":
152
+ main()