Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,854 Bytes
ea37c27 ca317b2 ea37c27 ca317b2 ea37c27 8c54553 ed5a7bf ca30e4f 7dc477a ca30e4f ee668ff a7191f1 87752ed a7191f1 ca317b2 a7191f1 ca317b2 a7191f1 ca317b2 a7191f1 7dc477a ee668ff 386e329 ea37c27 1736895 ea37c27 1736895 ea37c27 1736895 ea37c27 1736895 ea37c27 1736895 ea37c27 1736895 ea37c27 1736895 ea37c27 cec0b15 1d234d5 63e77a7 cec0b15 ea37c27 cec0b15 ea37c27 7dc477a ea37c27 7dc477a ea37c27 7dc477a 1736895 7dc477a ee668ff 7dc477a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
from threading import Thread
from llava_llama3.serve.cli import chat_llava
from llava_llama3.model.builder import load_pretrained_model
import gradio as gr
import torch
from PIL import Image
import argparse
import spaces
import os
import time
root_path = os.path.dirname(os.path.abspath(__file__))
print(root_path)
parser = argparse.ArgumentParser()
parser.add_argument("--model-path", type=str, default="TheFinAI/FinLLaVA")
parser.add_argument("--device", type=str, default="cuda:0")
parser.add_argument("--conv-mode", type=str, default="llama_3")
parser.add_argument("--temperature", type=float, default=0.7)
parser.add_argument("--max-new-tokens", type=int, default=512)
parser.add_argument("--load-8bit", action="store_true")
parser.add_argument("--load-4bit", action="store_true")
args = parser.parse_args()
# load model
tokenizer, llava_model, image_processor, context_len = load_pretrained_model(
args.model_path,
None,
'llava_llama3',
args.load_8bit,
args.load_4bit,
device=args.device
)
@spaces.GPU
def bot_streaming(message, history):
print(message)
image_path = None
# Check if there's an image in the current message
if message["files"]:
# message["files"][-1] could be a dictionary or a string
if isinstance(message["files"][-1], dict):
image_path = message["files"][-1]["path"]
else:
image_path = message["files"][-1]
else:
# If no image in the current message, look in the history for the last image path
for hist in history:
if isinstance(hist[0], tuple):
image_path = hist[0][0]
# Error handling if no image path is found
if image_path is None:
raise gr.Error("You need to upload an image for LLaVA to work.")
# If the image_path is a string, no need to load it into a PIL image
# Just use the path directly in the next steps
print(f"\033[91m{image_path}, {type(image_path)}\033[0m")
# Generate the prompt for the model
prompt = message['text']
# Use a streamer to generate the output in a streaming fashion
streamer = []
# Define a function to call chat_llava in a separate thread
def generate_output():
output = chat_llava(
args=args,
image_file=image_path,
text=prompt,
tokenizer=tokenizer,
model=llava_model,
image_processor=image_processor,
context_len=context_len
)
for new_text in output:
streamer.append(new_text)
# Start the generation in a separate thread
thread = Thread(target=generate_output)
thread.start()
# Stream the output
buffer = ""
while thread.is_alive() or streamer:
while streamer:
new_text = streamer.pop(0)
buffer += new_text
yield buffer
time.sleep(0.1)
# Ensure any remaining text is yielded after the thread completes
while streamer:
new_text = streamer.pop(0)
buffer += new_text
yield buffer
chatbot = gr.Chatbot(scale=1)
chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False)
with gr.Blocks(fill_height=True) as demo:
gr.ChatInterface(
fn=bot_streaming,
title="FinLLaVA",
examples=[{"text": "What is on the flower?", "files": ["./bee.jpg"]},
{"text": "How to make this pastry?", "files": ["./baklava.png"]},
{"text":"What is this?","files":["http://images.cocodataset.org/val2017/000000039769.jpg"]}],
stop_btn="Stop Generation",
multimodal=True,
textbox=chat_input,
chatbot=chatbot,
)
demo.queue(api_open=False)
demo.launch(show_api=False, share=False) |