VQASynth / app.py
smellslikeml
Add application file
22fc8c6
raw
history blame
3.12 kB
"""SpaceLlama3.1 demo gradio app."""
"""SpaceLlama3.1 demo gradio app."""
import datetime
import logging
import os
import gradio as gr
import requests
import torch
import PIL.Image
from prismatic import load
INTRO_TEXT = """SpaceLlama3.1 demo\n\n
| [Model](https://huggingface.co/remyxai/SpaceLlama3.1)
| [GitHub](https://github.com/remyxai/VQASynth/tree/main)
| [Demo](https://huggingface.co/spaces/remyxai/SpaceLlama3.1)
| [Discord](https://discord.gg/DAy3P5wYJk)
\n\n
**This is an experimental research model.** Make sure to add appropriate guardrails when using the model for applications.
"""
def compute(image, prompt, model_location):
"""Runs model inference."""
if image is None:
raise gr.Error("Image required")
logging.info('prompt="%s"', prompt)
# Open the image file
if isinstance(image, str):
image = PIL.Image.open(image).convert("RGB")
# Set device and load the model
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
vlm = load(model_location)
vlm.to(device, dtype=torch.bfloat16)
# Prepare prompt
prompt_builder = vlm.get_prompt_builder()
prompt_builder.add_turn(role="human", message=prompt)
prompt_text = prompt_builder.get_prompt()
# Generate the text based on image and prompt
generated_text = vlm.generate(
image,
prompt_text,
do_sample=True,
temperature=0.1,
max_new_tokens=512,
min_length=1,
)
output = generated_text.split("</s>")[0]
logging.info('output="%s"', output)
return output
def reset():
"""Resets the input fields."""
return "", None
def create_app():
"""Creates demo UI."""
with gr.Blocks() as demo:
# Main UI structure
gr.Markdown(INTRO_TEXT)
with gr.Row():
image = gr.Image(value=None, label="Image", type="filepath", visible=True) # input
with gr.Column():
prompt = gr.Textbox(value="", label="Prompt", visible=True)
model_info = gr.Markdown(label="Model Info")
run = gr.Button("Run", variant="primary")
clear = gr.Button("Clear")
highlighted_text = gr.HighlightedText(value="", label="Output", visible=True)
# Model location
model_location = "remyxai/SpaceLlama3.1" # Update as needed
# Button event handlers
run.click(
compute,
[image, prompt, model_location],
highlighted_text,
)
clear.click(reset, None, [prompt, image])
# Status
status = gr.Markdown(f"Startup: {datetime.datetime.now()}")
gpu_kind = gr.Markdown(f"GPU=?")
demo.load(
lambda: [f"Model `{model_location}` loaded."],
None,
model_info,
)
return demo
if __name__ == "__main__":
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
for k, v in os.environ.items():
logging.info('environ["%s"] = %r', k, v)
create_app().queue().launch()