Spaces:
Running
on
Zero
Running
on
Zero
"""SpaceLlama3.1 demo gradio app.""" | |
"""SpaceLlama3.1 demo gradio app.""" | |
import datetime | |
import logging | |
import os | |
import gradio as gr | |
import requests | |
import torch | |
import PIL.Image | |
from prismatic import load | |
INTRO_TEXT = """SpaceLlama3.1 demo\n\n | |
| [Model](https://huggingface.co/remyxai/SpaceLlama3.1) | |
| [GitHub](https://github.com/remyxai/VQASynth/tree/main) | |
| [Demo](https://huggingface.co/spaces/remyxai/SpaceLlama3.1) | |
| [Discord](https://discord.gg/DAy3P5wYJk) | |
\n\n | |
**This is an experimental research model.** Make sure to add appropriate guardrails when using the model for applications. | |
""" | |
def compute(image, prompt, model_location): | |
"""Runs model inference.""" | |
if image is None: | |
raise gr.Error("Image required") | |
logging.info('prompt="%s"', prompt) | |
# Open the image file | |
if isinstance(image, str): | |
image = PIL.Image.open(image).convert("RGB") | |
# Set device and load the model | |
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
vlm = load(model_location) | |
vlm.to(device, dtype=torch.bfloat16) | |
# Prepare prompt | |
prompt_builder = vlm.get_prompt_builder() | |
prompt_builder.add_turn(role="human", message=prompt) | |
prompt_text = prompt_builder.get_prompt() | |
# Generate the text based on image and prompt | |
generated_text = vlm.generate( | |
image, | |
prompt_text, | |
do_sample=True, | |
temperature=0.1, | |
max_new_tokens=512, | |
min_length=1, | |
) | |
output = generated_text.split("</s>")[0] | |
logging.info('output="%s"', output) | |
return output | |
def reset(): | |
"""Resets the input fields.""" | |
return "", None | |
def create_app(): | |
"""Creates demo UI.""" | |
with gr.Blocks() as demo: | |
# Main UI structure | |
gr.Markdown(INTRO_TEXT) | |
with gr.Row(): | |
image = gr.Image(value=None, label="Image", type="filepath", visible=True) # input | |
with gr.Column(): | |
prompt = gr.Textbox(value="", label="Prompt", visible=True) | |
model_info = gr.Markdown(label="Model Info") | |
run = gr.Button("Run", variant="primary") | |
clear = gr.Button("Clear") | |
highlighted_text = gr.HighlightedText(value="", label="Output", visible=True) | |
# Model location | |
model_location = "remyxai/SpaceLlama3.1" # Update as needed | |
# Button event handlers | |
run.click( | |
compute, | |
[image, prompt, model_location], | |
highlighted_text, | |
) | |
clear.click(reset, None, [prompt, image]) | |
# Status | |
status = gr.Markdown(f"Startup: {datetime.datetime.now()}") | |
gpu_kind = gr.Markdown(f"GPU=?") | |
demo.load( | |
lambda: [f"Model `{model_location}` loaded."], | |
None, | |
model_info, | |
) | |
return demo | |
if __name__ == "__main__": | |
logging.basicConfig( | |
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" | |
) | |
for k, v in os.environ.items(): | |
logging.info('environ["%s"] = %r', k, v) | |
create_app().queue().launch() | |