File size: 5,924 Bytes
62be6ed
 
 
 
 
 
 
 
 
 
 
 
 
621ed57
 
 
 
 
 
 
be08ea0
621ed57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import os

# Disable Xet/CAS backend (it’s what’s throwing the error)
os.environ["HF_HUB_ENABLE_XET"] = "0"

# Use the robust Rust downloader for big files
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

# Optional but helpful: resume and avoid symlinks on some filesystems
os.environ["HF_HUB_ENABLE_RESUME"] = "1"
os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"


import gradio as gr
import torch
from transformers import AutoProcessor, LlavaForConditionalGeneration


# Hugging Face model identifier.  See the model card for more details:
# https://huggingface.co/StarCycle/llava-dinov2-internlm2-7b-v1
MODEL_ID = "xtuner/llava-phi-3-mini-hf"

# Determine the computation device.  If a CUDA‑enabled GPU is
# available we will use it and cast the weights to half precision to
# reduce memory consumption.  Otherwise we fall back to CPU.
if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
    TORCH_DTYPE = torch.float16
else:
    DEVICE = torch.device("cpu")
    TORCH_DTYPE = torch.float32


def load_model():
    """Load the LLaVA model and its processor.

    The model is loaded with ``trust_remote_code=True`` to allow the
    repository’s custom projector and adapter classes to be registered
    correctly.  We specify ``device_map='auto'`` so that the
    ``accelerate`` library will distribute the model across the
    available hardware (GPU/CPU) automatically.  The ``torch_dtype``
    argument ensures that the model weights are loaded in half
    precision on a GPU and in full precision on a CPU.
    """
    model = LlavaForConditionalGeneration.from_pretrained(
        MODEL_ID,
        torch_dtype=TORCH_DTYPE,
        device_map="auto",
        trust_remote_code=True,
        low_cpu_mem_usage=True,
    )
    processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
    return model, processor


# Load the model and processor at import time.  Loading is expensive so
# we only do it once.  If the model fails to load (for example
# because of missing dependencies) the exception will be raised here.
MODEL, PROCESSOR = load_model()


def answer_question(image: "PIL.Image.Image", question: str) -> str:
    """Generate an answer for the given question about the uploaded image.

    Parameters
    ----------
    image: PIL.Image.Image
        The user‑provided image.  Gradio supplies images as PIL
        objects, which the LLaVA processor accepts directly.
    question: str
        The user’s question about the image.

    Returns
    -------
    str
        The answer generated by the model.  If either the image or
        question is missing, an explanatory message is returned.
    """
    # Basic validation: ensure both inputs are provided.
    if image is None:
        return "Please upload an image."
    if not question or not question.strip():
        return "Please enter a question about the image."

    # Build the chat prompt.  The LLaVA model uses the ``<image>``
    # placeholder to indicate where the image will be inserted.
    prompt = f"USER: <image>\n{question.strip()} ASSISTANT:"

    # Tokenize the inputs.  The processor will process both the image
    # and the text and return PyTorch tensors.  We move these to the
    # same device as the model to avoid device mismatch errors.
    inputs = PROCESSOR(
        images=image,
        text=prompt,
        return_tensors="pt",
    )
    inputs = {k: v.to(DEVICE) for k, v in inputs.items()}

    # Generate the answer.  We limit the number of new tokens to 256 to
    # avoid excessive memory usage.  Feel free to adjust this value
    # depending on your hardware constraints and desired response length.
    with torch.no_grad():
        generated_ids = MODEL.generate(
            **inputs,
            max_new_tokens=256,
            do_sample=False,
        )

    # Decode the generated ids back into text.  The output will include
    # the entire conversation (e.g., ``USER: ... ASSISTANT: ...``).
    output = PROCESSOR.batch_decode(
        generated_ids,
        skip_special_tokens=True,
        clean_up_tokenization_spaces=True,
    )[0]

    # Extract the assistant's response by splitting on the
    # ``ASSISTANT:`` delimiter.
    if "ASSISTANT:" in output:
        answer = output.split("ASSISTANT:")[-1].strip()
    else:
        # Fallback if the delimiter is not present.
        answer = output.strip()

    return answer


def build_interface() -> gr.Interface:
    """Construct the Gradio Interface object for the app."""
    description = (
        "Upload an image and ask a question about it.\n\n"
        "This demo uses the multimodal model "
        "StarCycle/llava‑dinov2‑internlm2‑7b‑v1 to perform visual "
        "question answering.  The model combines the Dinov2 vision encoder with "
        "the InternLM2‑Chat‑7B language model via a lightweight projector and "
        "LoRA adapters.  Note: inference requires a GPU with sufficient "
        "memory; on a CPU the generation will be extremely slow."
    )
    iface = gr.Interface(
        fn=answer_question,
        inputs=[
            gr.Image(type="pil", label="Image"),
            gr.Textbox(
                label="Question",
                placeholder="Describe or ask something about the image",
                lines=1,
            ),
        ],
        outputs=gr.Textbox(label="Answer"),
        title="Visual Question Answering with LLaVA Dinov2 InternLM2 7B",
        description=description,
        allow_flagging="never",
    )
    return iface


def main() -> None:
    """Launch the Gradio app."""
    iface = build_interface()
    # When running on Hugging Face Spaces the app will automatically set
    # the appropriate host and port.  For local development you can
    # uncomment the ``server_name`` argument to make the app reachable
    # from other machines on your network.
    iface.launch()


if __name__ == "__main__":
    main()