File size: 3,854 Bytes
ea37c27
ca317b2
 
ea37c27
ca317b2
ea37c27
8c54553
ed5a7bf
ca30e4f
7dc477a
ca30e4f
 
 
ee668ff
a7191f1
 
 
 
87752ed
a7191f1
 
 
 
ca317b2
a7191f1
ca317b2
a7191f1
ca317b2
 
a7191f1
 
7dc477a
 
ee668ff
386e329
ea37c27
 
1736895
ea37c27
 
 
 
 
1736895
ea37c27
1736895
ea37c27
1736895
ea37c27
 
1736895
ea37c27
1736895
 
ea37c27
 
1736895
 
 
ea37c27
 
 
 
cec0b15
 
 
 
 
 
1d234d5
63e77a7
cec0b15
 
 
 
 
 
 
 
 
 
 
 
ea37c27
 
 
cec0b15
 
 
 
 
 
 
 
 
 
ea37c27
 
 
 
7dc477a
ea37c27
7dc477a
ea37c27
7dc477a
 
 
1736895
 
7dc477a
 
 
 
ee668ff
 
 
7dc477a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
from threading import Thread
from llava_llama3.serve.cli import chat_llava
from llava_llama3.model.builder import load_pretrained_model
import gradio as gr
import torch
from PIL import Image
import argparse
import spaces
import os
import time

root_path = os.path.dirname(os.path.abspath(__file__))
print(root_path)

parser = argparse.ArgumentParser()
parser.add_argument("--model-path", type=str, default="TheFinAI/FinLLaVA")
parser.add_argument("--device", type=str, default="cuda:0")
parser.add_argument("--conv-mode", type=str, default="llama_3")
parser.add_argument("--temperature", type=float, default=0.7)
parser.add_argument("--max-new-tokens", type=int, default=512)
parser.add_argument("--load-8bit", action="store_true")
parser.add_argument("--load-4bit", action="store_true")
args = parser.parse_args()

# load model
tokenizer, llava_model, image_processor, context_len = load_pretrained_model(
    args.model_path, 
    None, 
    'llava_llama3', 
    args.load_8bit, 
    args.load_4bit, 
    device=args.device
)

@spaces.GPU
def bot_streaming(message, history):
    print(message)
    image_path = None
    
    # Check if there's an image in the current message
    if message["files"]:
        # message["files"][-1] could be a dictionary or a string
        if isinstance(message["files"][-1], dict):
            image_path = message["files"][-1]["path"]
        else:
            image_path = message["files"][-1]
    else:
        # If no image in the current message, look in the history for the last image path
        for hist in history:
            if isinstance(hist[0], tuple):
                image_path = hist[0][0]
    
    # Error handling if no image path is found
    if image_path is None:
        raise gr.Error("You need to upload an image for LLaVA to work.")
    
    # If the image_path is a string, no need to load it into a PIL image
    # Just use the path directly in the next steps
    print(f"\033[91m{image_path}, {type(image_path)}\033[0m")
    
    # Generate the prompt for the model
    prompt = message['text']
    
    # Use a streamer to generate the output in a streaming fashion
    streamer = []

    # Define a function to call chat_llava in a separate thread
    def generate_output():
        output = chat_llava(
            args=args,
            image_file=image_path,
            text=prompt,
            tokenizer=tokenizer,
            model=llava_model,
            image_processor=image_processor,
            context_len=context_len
        )
        for new_text in output:
            streamer.append(new_text)
    
    # Start the generation in a separate thread
    thread = Thread(target=generate_output)
    thread.start()
    
    # Stream the output
    buffer = ""
    while thread.is_alive() or streamer:
        while streamer:
            new_text = streamer.pop(0)
            buffer += new_text
            yield buffer
        time.sleep(0.1)
    
    # Ensure any remaining text is yielded after the thread completes
    while streamer:
        new_text = streamer.pop(0)
        buffer += new_text
        yield buffer


chatbot = gr.Chatbot(scale=1)
chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False)
with gr.Blocks(fill_height=True) as demo:
    gr.ChatInterface(
        fn=bot_streaming,
        title="FinLLaVA",
        examples=[{"text": "What is on the flower?", "files": ["./bee.jpg"]},
                  {"text": "How to make this pastry?", "files": ["./baklava.png"]},
                  {"text":"What is this?","files":["http://images.cocodataset.org/val2017/000000039769.jpg"]}],
        stop_btn="Stop Generation",
        multimodal=True,
        textbox=chat_input,
        chatbot=chatbot,
    )

demo.queue(api_open=False)
demo.launch(show_api=False, share=False)