File size: 4,519 Bytes
168da77 06de88f bfd4b05 5ae7f9c d364219 168da77 d364219 5659ce7 69eca47 5659ce7 daa8caf 5659ce7 d364219 5659ce7 daa8caf 5ae7f9c daa8caf 5ae7f9c daa8caf 153de5a daa8caf 5ae7f9c d364219 debb687 482857a d364219 5659ce7 d364219 5659ce7 debb687 daa8caf 5659ce7 d364219 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import torch
import gradio as gr
from transformers import TextIteratorStreamer, AutoProcessor, LlavaForConditionalGeneration
from PIL import Image
import threading
import spaces
import accelerate
import time
DESCRIPTION = '''
<div>
<h1 style="text-align: center;">Krypton π</h1>
<p>This uses an Open Source model from <a href="https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers"><b>xtuner/llava-llama-3-8b-v1_1-transformers</b></a></p>
</div>
'''
model_id = "xtuner/llava-llama-3-8b-v1_1-transformers"
model = LlavaForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.float16,
low_cpu_mem_usage=True
).to('cuda')
processor = AutoProcessor.from_pretrained(model_id)
model.generation_config.eos_token_id = 128009
@spaces.GPU(duration=120)
def krypton(input,
history,
max_new_tokens,
temperature,
num_beams,
do_sample: bool=True):
"""
Recieves inputs (prompts with images if they were added),
the image is formated for pil and prompt is formated for the model,
to place it's output to the user, these prompts and images are passed in
the processor and generation of the model, than the output is decoded from the processor,
onto the UI.
"""
if input["files"]:
if type(input["files"][-1]) == dict:
image = input["files"][-1]["path"]
else:
image = input["files"][-1]
else:
# If no images were passed now, look at the past images to keep up as reference still to the prompts
# kept inside in tuples, the last one
for hist in history:
if type(hist[0]) == tuple:
image = hist[0][0]
try:
if image is None:
gr.Error("You need to upload an image please for krypton to work.")
except NameError:
# Image is not defined at all
gr.Error("Uplaod an image for Krypton to work")
prompt = ("<|start_header_id|>user<|end_header_id|>\n\n<image>\n{input['text']}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n")
image = Image.open(image)
inputs = processor(prompt, image, return_tensors='pt').to(0, torch.float16)
# Streamer
streamer = TextIteratorStreamer(processor, **{"skip_special_tokens": False, "skip_prompt": True})
if temperature == 0.0:
do_sample = False
# Generation kwargs
generation_kwargs = dict(
inputs=inputs,
streamer=streamer,
max_new_tokens=max_new_tokens,
temperature=temperature,
num_beams=num_beams,
do_sample=do_sample
)
thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
buffer = ""
time.sleep(0.5)
for new_text in streamer:
# find <|eot_id|> and remove it from the new_text
if "<|eot_id|>" in new_text:
new_text = new_text.split("<|eot_id|>")[0]
buffer += new_text
# generated_text_without_prompt = buffer[len(text_prompt):]
generated_text_without_prompt = buffer
# print(generated_text_without_prompt)
time.sleep(0.06)
# print(f"new_text: {generated_text_without_prompt}")
yield generated_text_without_prompt
chatbot=gr.Chatbot(height=600, label="Krypt AI")
chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter your question or upload an image.", show_label=False)
with gr.Blocks(fill_height=True) as demo:
gr.Markdown(DESCRIPTION)
gr.ChatInterface(
fn=krypton,
chatbot=chatbot,
fill_height=True,
additional_inputs_accordion=gr.Accordion(label="βοΈ Parameters", open=False, render=False),
additional_inputs=[
gr.Slider(minimum=20,
maximum=80,
step=1,
value=50,
label="Max New Tokens",
render=False),
gr.Slider(minimum=0.0,
maximum=1.0,
step=0.1,
value=0.7,
label="Temperature",
render=False),
gr.Slider(minimum=1,
maximum=12,
step=1,
value=5,
label="Number of Beams",
render=False),
],
multimodal=True,
textbox=chat_input,
)
if __name__ == "__main__":
demo.launch()
|