handertrails / handler.py
hughtayloe's picture
Update handler.py
5701230 verified
raw
history blame
1.58 kB
from typing import Dict, List, Any
from PIL import Image
import requests
import torch
from transformers import AutoProcessor, LlavaForConditionalGeneration
class EndpointHandler():
def __init__(self, path=""):
model_id = ""
model = LlavaForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
).to(0)
processor = AutoProcessor.from_pretrained(model_id)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
parameters = data.pop("inputs",data)
inputs = data.pop("inputs", data)
if parameters is not None:
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
prompt = "USER: <image>\nWhat are these?\nASSISTANT:"
output = model.generate(**inputs, max_new_tokens=200, do_sample=False)
return output
prompt = "USER: <image>\nWhat are these?\nASSISTANT:"
image_file = "http://images.cocodataset.org/val2017/000000039769.jpg"
model = LlavaForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
).to(0)
processor = AutoProcessor.from_pretrained(model_id)
raw_image = Image.open(requests.get(image_file, stream=True).raw)
inputs = processor(prompt, raw_image, return_tensors='pt').to(0, torch.float16)
output = model.generate(**inputs, max_new_tokens=200, do_sample=False)
print(processor.decode(output[0][2:], skip_special_tokens=True))