Krypton / app.py
sandz7's picture
commented accelerate temporarily
7523035
raw
history blame
1.56 kB
import torch
import gradio as gr
from transformers import TextIteratorStreamer, AutoProcessor, LlavaForConditionalGeneration
from PIL import Image
import requests
import threading
import spaces
# import accelerate
DESCRIPTION = '''
<div>
<h1 style="text-align: center;">Krypton πŸ•‹</h1>
<p>This uses an Open Source model from <a href="https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers"><b>xtuner/llava-llama-3-8b-v1_1-transformers</b></a></p>
</div>
'''
model_id = "xtuner/llava-llama-3-8b-v1_1-transformers"
model = LlavaForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.float16,
low_cpu_mem_usage=True
).to('cuda')
processor = AutoProcessor.from_pretrained(model_id)
@spaces.GPU(duration=120)
def krypton(input_image):
pil_image = Image.fromarray(input_image.astype('uint8'), 'RGB')
# image = Image.open(requests.get(url, stream=True).raw)
prompt = ("<|start_header_id|>user<|end_header_id|>\n\n<image>\nWhat are these?<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n")
inputs = processor(prompt, pil_image, return_tensors='pt').to('cuda', torch.float16)
outputs = model.generate(**inputs, max_new_tokens=200, do_sample=False)
output_text = processor.decode(outputs[0][:2], skip_special_tokens=True)
return output_text
with gr.Blocks(fill_height=True) as demo:
gr.Markdown(DESCRIPTION)
gr.Interface(
fn=krypton,
inputs="image",
outputs="text",
fill_height=True
)
if __name__ == "__main__":
demo.launch()