File size: 3,555 Bytes
19fe404
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import os
import time
from datetime import datetime
from typing import List, Union
from pathlib import Path

import sglang as sgl
from PIL import Image

from utils.logger import logger

TMP_DIR = "./tmp"


def get_timestamp():
    timestamp_ns = int(time.time_ns())
    milliseconds = timestamp_ns // 1000000
    formatted_time = datetime.fromtimestamp(milliseconds / 1000).strftime("%Y-%m-%d_%H-%M-%S-%f")[:-3]

    return formatted_time


class LLaVASRT:
    def __init__(self, device: str = "cuda:0", quantized: bool = True):
        self.runtime = sgl.Runtime(model_path="liuhaotian/llava-v1.6-vicuna-7b", tokenizer_path="llava-hf/llava-1.5-7b-hf")
        sgl.set_default_backend(self.runtime)
        logger.info(
            f"Start the SGLang runtime for llava-v1.6-vicuna-7b with chat template: {self.runtime.endpoint.chat_template.name}. "
            "Input parameter device and quantized do not take effect."
        )
        if not os.path.exists(TMP_DIR):
            os.makedirs(TMP_DIR, exist_ok=True)

    @sgl.function
    def image_qa(s, prompt: str, image: str):
        s += sgl.user(sgl.image(image) + prompt)
        s += sgl.assistant(sgl.gen("answer"))

    def __call__(self, prompt: Union[str, List[str]], image: Union[str, Image.Image, List[str]]):
        pil_input_flag = False
        if isinstance(prompt, str) and (isinstance(image, str) or isinstance(image, Image.Image)):
            if isinstance(image, Image.Image):
                pil_input_flag = True
                image_path = os.path.join(TMP_DIR, get_timestamp() + ".jpg")
                image.save(image_path)
            state = self.image_qa.run(prompt=prompt, image=image, max_new_tokens=256)
            # Post-process.
            if pil_input_flag:
                os.remove(image)

            return state["answer"], state
        elif isinstance(prompt, list) and isinstance(image, list):
            assert len(prompt) == len(image)
            if isinstance(image[0], Image.Image):
                pil_input_flag = True
                image_path = [os.path.join(TMP_DIR, get_timestamp() + f"-{i}" + ".jpg") for i in range(len(image))]
                for i in range(len(image)):
                    image[i].save(image_path[i])
                image = image_path
            batch_query = [{"prompt": p, "image": img} for p, img in zip(prompt, image)]
            state = self.image_qa.run_batch(batch_query, max_new_tokens=256)
            # Post-process.
            if pil_input_flag:
                for i in range(len(image)):
                    os.remove(image[i])

            return [s["answer"] for s in state], state
        else:
            raise ValueError("Input prompt and image must be both strings or list of strings with the same length.")
    
    def __del__(self):
        self.runtime.shutdown()


if __name__ == "__main__":
    image_folder = "demo/"
    wildcard_list = ["*.jpg", "*.png"]
    image_list = []
    for wildcard in wildcard_list:
        image_list.extend([str(image_path) for image_path in Path(image_folder).glob(wildcard)])
    # SGLang need the exclusive GPU and cannot re-initialize CUDA in forked subprocess.
    llava_srt = LLaVASRT()
    # Batch inference.
    llava_srt_prompt = ["Please describe this image in detail."] * len(image_list)
    response, _ = llava_srt(llava_srt_prompt, image_list)
    print(response)
    llava_srt_prompt = "Please describe this image in detail."
    for image in image_list:
        response, _ = llava_srt(llava_srt_prompt, image)
        print(image, response)