Spaces:
Running
Running
File size: 3,555 Bytes
19fe404 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
import os
import time
from datetime import datetime
from typing import List, Union
from pathlib import Path
import sglang as sgl
from PIL import Image
from utils.logger import logger
TMP_DIR = "./tmp"
def get_timestamp():
timestamp_ns = int(time.time_ns())
milliseconds = timestamp_ns // 1000000
formatted_time = datetime.fromtimestamp(milliseconds / 1000).strftime("%Y-%m-%d_%H-%M-%S-%f")[:-3]
return formatted_time
class LLaVASRT:
def __init__(self, device: str = "cuda:0", quantized: bool = True):
self.runtime = sgl.Runtime(model_path="liuhaotian/llava-v1.6-vicuna-7b", tokenizer_path="llava-hf/llava-1.5-7b-hf")
sgl.set_default_backend(self.runtime)
logger.info(
f"Start the SGLang runtime for llava-v1.6-vicuna-7b with chat template: {self.runtime.endpoint.chat_template.name}. "
"Input parameter device and quantized do not take effect."
)
if not os.path.exists(TMP_DIR):
os.makedirs(TMP_DIR, exist_ok=True)
@sgl.function
def image_qa(s, prompt: str, image: str):
s += sgl.user(sgl.image(image) + prompt)
s += sgl.assistant(sgl.gen("answer"))
def __call__(self, prompt: Union[str, List[str]], image: Union[str, Image.Image, List[str]]):
pil_input_flag = False
if isinstance(prompt, str) and (isinstance(image, str) or isinstance(image, Image.Image)):
if isinstance(image, Image.Image):
pil_input_flag = True
image_path = os.path.join(TMP_DIR, get_timestamp() + ".jpg")
image.save(image_path)
state = self.image_qa.run(prompt=prompt, image=image, max_new_tokens=256)
# Post-process.
if pil_input_flag:
os.remove(image)
return state["answer"], state
elif isinstance(prompt, list) and isinstance(image, list):
assert len(prompt) == len(image)
if isinstance(image[0], Image.Image):
pil_input_flag = True
image_path = [os.path.join(TMP_DIR, get_timestamp() + f"-{i}" + ".jpg") for i in range(len(image))]
for i in range(len(image)):
image[i].save(image_path[i])
image = image_path
batch_query = [{"prompt": p, "image": img} for p, img in zip(prompt, image)]
state = self.image_qa.run_batch(batch_query, max_new_tokens=256)
# Post-process.
if pil_input_flag:
for i in range(len(image)):
os.remove(image[i])
return [s["answer"] for s in state], state
else:
raise ValueError("Input prompt and image must be both strings or list of strings with the same length.")
def __del__(self):
self.runtime.shutdown()
if __name__ == "__main__":
image_folder = "demo/"
wildcard_list = ["*.jpg", "*.png"]
image_list = []
for wildcard in wildcard_list:
image_list.extend([str(image_path) for image_path in Path(image_folder).glob(wildcard)])
# SGLang need the exclusive GPU and cannot re-initialize CUDA in forked subprocess.
llava_srt = LLaVASRT()
# Batch inference.
llava_srt_prompt = ["Please describe this image in detail."] * len(image_list)
response, _ = llava_srt(llava_srt_prompt, image_list)
print(response)
llava_srt_prompt = "Please describe this image in detail."
for image in image_list:
response, _ = llava_srt(llava_srt_prompt, image)
print(image, response) |