Spaces:
Runtime error
Runtime error
File size: 5,924 Bytes
62be6ed 621ed57 be08ea0 621ed57 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
import os
# Disable Xet/CAS backend (it’s what’s throwing the error)
os.environ["HF_HUB_ENABLE_XET"] = "0"
# Use the robust Rust downloader for big files
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
# Optional but helpful: resume and avoid symlinks on some filesystems
os.environ["HF_HUB_ENABLE_RESUME"] = "1"
os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
import gradio as gr
import torch
from transformers import AutoProcessor, LlavaForConditionalGeneration
# Hugging Face model identifier. See the model card for more details:
# https://huggingface.co/StarCycle/llava-dinov2-internlm2-7b-v1
MODEL_ID = "xtuner/llava-phi-3-mini-hf"
# Determine the computation device. If a CUDA‑enabled GPU is
# available we will use it and cast the weights to half precision to
# reduce memory consumption. Otherwise we fall back to CPU.
if torch.cuda.is_available():
DEVICE = torch.device("cuda")
TORCH_DTYPE = torch.float16
else:
DEVICE = torch.device("cpu")
TORCH_DTYPE = torch.float32
def load_model():
"""Load the LLaVA model and its processor.
The model is loaded with ``trust_remote_code=True`` to allow the
repository’s custom projector and adapter classes to be registered
correctly. We specify ``device_map='auto'`` so that the
``accelerate`` library will distribute the model across the
available hardware (GPU/CPU) automatically. The ``torch_dtype``
argument ensures that the model weights are loaded in half
precision on a GPU and in full precision on a CPU.
"""
model = LlavaForConditionalGeneration.from_pretrained(
MODEL_ID,
torch_dtype=TORCH_DTYPE,
device_map="auto",
trust_remote_code=True,
low_cpu_mem_usage=True,
)
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
return model, processor
# Load the model and processor at import time. Loading is expensive so
# we only do it once. If the model fails to load (for example
# because of missing dependencies) the exception will be raised here.
MODEL, PROCESSOR = load_model()
def answer_question(image: "PIL.Image.Image", question: str) -> str:
"""Generate an answer for the given question about the uploaded image.
Parameters
----------
image: PIL.Image.Image
The user‑provided image. Gradio supplies images as PIL
objects, which the LLaVA processor accepts directly.
question: str
The user’s question about the image.
Returns
-------
str
The answer generated by the model. If either the image or
question is missing, an explanatory message is returned.
"""
# Basic validation: ensure both inputs are provided.
if image is None:
return "Please upload an image."
if not question or not question.strip():
return "Please enter a question about the image."
# Build the chat prompt. The LLaVA model uses the ``<image>``
# placeholder to indicate where the image will be inserted.
prompt = f"USER: <image>\n{question.strip()} ASSISTANT:"
# Tokenize the inputs. The processor will process both the image
# and the text and return PyTorch tensors. We move these to the
# same device as the model to avoid device mismatch errors.
inputs = PROCESSOR(
images=image,
text=prompt,
return_tensors="pt",
)
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
# Generate the answer. We limit the number of new tokens to 256 to
# avoid excessive memory usage. Feel free to adjust this value
# depending on your hardware constraints and desired response length.
with torch.no_grad():
generated_ids = MODEL.generate(
**inputs,
max_new_tokens=256,
do_sample=False,
)
# Decode the generated ids back into text. The output will include
# the entire conversation (e.g., ``USER: ... ASSISTANT: ...``).
output = PROCESSOR.batch_decode(
generated_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
)[0]
# Extract the assistant's response by splitting on the
# ``ASSISTANT:`` delimiter.
if "ASSISTANT:" in output:
answer = output.split("ASSISTANT:")[-1].strip()
else:
# Fallback if the delimiter is not present.
answer = output.strip()
return answer
def build_interface() -> gr.Interface:
"""Construct the Gradio Interface object for the app."""
description = (
"Upload an image and ask a question about it.\n\n"
"This demo uses the multimodal model "
"StarCycle/llava‑dinov2‑internlm2‑7b‑v1 to perform visual "
"question answering. The model combines the Dinov2 vision encoder with "
"the InternLM2‑Chat‑7B language model via a lightweight projector and "
"LoRA adapters. Note: inference requires a GPU with sufficient "
"memory; on a CPU the generation will be extremely slow."
)
iface = gr.Interface(
fn=answer_question,
inputs=[
gr.Image(type="pil", label="Image"),
gr.Textbox(
label="Question",
placeholder="Describe or ask something about the image",
lines=1,
),
],
outputs=gr.Textbox(label="Answer"),
title="Visual Question Answering with LLaVA Dinov2 InternLM2 7B",
description=description,
allow_flagging="never",
)
return iface
def main() -> None:
"""Launch the Gradio app."""
iface = build_interface()
# When running on Hugging Face Spaces the app will automatically set
# the appropriate host and port. For local development you can
# uncomment the ``server_name`` argument to make the app reachable
# from other machines on your network.
iface.launch()
if __name__ == "__main__":
main() |