Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,661 Bytes
4ffdbdc 770775f 6957169 e1361b1 6957169 ab82892 e1361b1 6957169 e1361b1 2eb2d02 6957169 2acb8d8 15b745f 6957169 2053e3b 6957169 2acb8d8 e1361b1 2acb8d8 e1361b1 2acb8d8 6957169 e1361b1 6957169 770775f 6957169 0c18e2a 0776050 6957169 fbee232 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
# A100 Zero GPU
import spaces
import time
import torch
import gradio as gr
from PIL import Image
from utils.utils import *
from threading import Thread
import torch.nn.functional as F
from accelerate import Accelerator
from meteor.load_mmamba import load_mmamba
from meteor.load_meteor import load_meteor
from transformers import TextIteratorStreamer
from torchvision.transforms.functional import pil_to_tensor
# flash attention
import subprocess
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
# accel
accel = Accelerator()
# loading meteor model
mmamba = load_mmamba('BK-Lee/Meteor-Mamba')
meteor, tok_meteor = load_meteor('BK-Lee/Meteor-MLM', bits=4)
# freeze model
freeze_model(mmamba)
freeze_model(meteor)
# previous length
previous_length = 0
def threading_function(inputs, image_token_number, streamer, device):
# Meteor Mamba
mmamba_inputs = mmamba.eval_process(inputs=inputs, tokenizer=tok_meteor, device=device, img_token_number=image_token_number)
if 'image' in mmamba_inputs.keys():
clip_features = meteor.clip_features(mmamba_inputs['image'])
mmamba_inputs.update({"image_features": clip_features})
mmamba_outputs = mmamba(**mmamba_inputs)
# Meteor
meteor_inputs = meteor.eval_process(inputs=inputs, data='demo', tokenizer=tok_meteor, device=device, img_token_number=image_token_number)
if 'image' in mmamba_inputs.keys():
meteor_inputs.update({"image_features": clip_features})
meteor_inputs.update({"tor_features": mmamba_outputs.tor_features})
generation_kwargs = meteor_inputs
generation_kwargs.update({'streamer': streamer})
generation_kwargs.update({'do_sample': True})
generation_kwargs.update({'max_new_tokens': 128})
generation_kwargs.update({'top_p': 0.95})
generation_kwargs.update({'temperature': 0.9})
generation_kwargs.update({'use_cache': True})
return meteor.generate(**generation_kwargs)
@spaces.GPU
def bot_streaming(message, history):
# param
for param in mmamba.parameters():
param.data = param.to(accel.device)
for param in meteor.parameters():
param.data = param.to(accel.device)
# prompt type -> input prompt
image_token_number = int((490/14)**2)
if len(message['files']) != 0:
# Image Load
image = F.interpolate(pil_to_tensor(Image.open(message['files'][0]).convert("RGB")).unsqueeze(0), size=(490, 490), mode='bicubic').squeeze(0)
inputs = [{'image': image, 'question': message['text']}]
else:
inputs = [{'question': message['text']}]
# [4] Meteor Generation
with torch.inference_mode():
# kwargs
streamer = TextIteratorStreamer(tok_meteor, skip_special_tokens=True)
# Threading generation
thread = Thread(target=threading_function, kwargs=dict(inputs=inputs, image_token_number=image_token_number, streamer=streamer, device=accel.device))
thread.start()
# generated text
generated_text = ""
for new_text in streamer:
generated_text += new_text
generated_text
# Text decoding
response = generated_text.split('assistant\n')[-1].split('[U')[0].strip()
buffer = ""
for character in response:
buffer += character
time.sleep(0.02)
yield buffer
demo = gr.ChatInterface(fn=bot_streaming, title="☄️ Meteor",
description="Meteor is efficient 7B size Large Language and Vision Model built on the help of traversal of rationale",
stop_btn="Stop Generation", multimodal=True)
demo.launch() |