LLaVADinov2 / app.py
Rausda6's picture
Update app.py
62be6ed verified
raw
history blame
5.92 kB
import os
# Disable Xet/CAS backend (it’s what’s throwing the error)
os.environ["HF_HUB_ENABLE_XET"] = "0"
# Use the robust Rust downloader for big files
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
# Optional but helpful: resume and avoid symlinks on some filesystems
os.environ["HF_HUB_ENABLE_RESUME"] = "1"
os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
import gradio as gr
import torch
from transformers import AutoProcessor, LlavaForConditionalGeneration
# Hugging Face model identifier. See the model card for more details:
# https://huggingface.co/StarCycle/llava-dinov2-internlm2-7b-v1
MODEL_ID = "xtuner/llava-phi-3-mini-hf"
# Determine the computation device. If a CUDA‑enabled GPU is
# available we will use it and cast the weights to half precision to
# reduce memory consumption. Otherwise we fall back to CPU.
if torch.cuda.is_available():
DEVICE = torch.device("cuda")
TORCH_DTYPE = torch.float16
else:
DEVICE = torch.device("cpu")
TORCH_DTYPE = torch.float32
def load_model():
"""Load the LLaVA model and its processor.
The model is loaded with ``trust_remote_code=True`` to allow the
repository’s custom projector and adapter classes to be registered
correctly. We specify ``device_map='auto'`` so that the
``accelerate`` library will distribute the model across the
available hardware (GPU/CPU) automatically. The ``torch_dtype``
argument ensures that the model weights are loaded in half
precision on a GPU and in full precision on a CPU.
"""
model = LlavaForConditionalGeneration.from_pretrained(
MODEL_ID,
torch_dtype=TORCH_DTYPE,
device_map="auto",
trust_remote_code=True,
low_cpu_mem_usage=True,
)
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
return model, processor
# Load the model and processor at import time. Loading is expensive so
# we only do it once. If the model fails to load (for example
# because of missing dependencies) the exception will be raised here.
MODEL, PROCESSOR = load_model()
def answer_question(image: "PIL.Image.Image", question: str) -> str:
"""Generate an answer for the given question about the uploaded image.
Parameters
----------
image: PIL.Image.Image
The user‑provided image. Gradio supplies images as PIL
objects, which the LLaVA processor accepts directly.
question: str
The user’s question about the image.
Returns
-------
str
The answer generated by the model. If either the image or
question is missing, an explanatory message is returned.
"""
# Basic validation: ensure both inputs are provided.
if image is None:
return "Please upload an image."
if not question or not question.strip():
return "Please enter a question about the image."
# Build the chat prompt. The LLaVA model uses the ``<image>``
# placeholder to indicate where the image will be inserted.
prompt = f"USER: <image>\n{question.strip()} ASSISTANT:"
# Tokenize the inputs. The processor will process both the image
# and the text and return PyTorch tensors. We move these to the
# same device as the model to avoid device mismatch errors.
inputs = PROCESSOR(
images=image,
text=prompt,
return_tensors="pt",
)
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
# Generate the answer. We limit the number of new tokens to 256 to
# avoid excessive memory usage. Feel free to adjust this value
# depending on your hardware constraints and desired response length.
with torch.no_grad():
generated_ids = MODEL.generate(
**inputs,
max_new_tokens=256,
do_sample=False,
)
# Decode the generated ids back into text. The output will include
# the entire conversation (e.g., ``USER: ... ASSISTANT: ...``).
output = PROCESSOR.batch_decode(
generated_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
)[0]
# Extract the assistant's response by splitting on the
# ``ASSISTANT:`` delimiter.
if "ASSISTANT:" in output:
answer = output.split("ASSISTANT:")[-1].strip()
else:
# Fallback if the delimiter is not present.
answer = output.strip()
return answer
def build_interface() -> gr.Interface:
"""Construct the Gradio Interface object for the app."""
description = (
"Upload an image and ask a question about it.\n\n"
"This demo uses the multimodal model "
"StarCycle/llava‑dinov2‑internlm2‑7b‑v1 to perform visual "
"question answering. The model combines the Dinov2 vision encoder with "
"the InternLM2‑Chat‑7B language model via a lightweight projector and "
"LoRA adapters. Note: inference requires a GPU with sufficient "
"memory; on a CPU the generation will be extremely slow."
)
iface = gr.Interface(
fn=answer_question,
inputs=[
gr.Image(type="pil", label="Image"),
gr.Textbox(
label="Question",
placeholder="Describe or ask something about the image",
lines=1,
),
],
outputs=gr.Textbox(label="Answer"),
title="Visual Question Answering with LLaVA Dinov2 InternLM2 7B",
description=description,
allow_flagging="never",
)
return iface
def main() -> None:
"""Launch the Gradio app."""
iface = build_interface()
# When running on Hugging Face Spaces the app will automatically set
# the appropriate host and port. For local development you can
# uncomment the ``server_name`` argument to make the app reachable
# from other machines on your network.
iface.launch()
if __name__ == "__main__":
main()