handertrails / handler.py
hughtayloe's picture
Update handler.py
3b62fae verified
raw
history blame
1.67 kB
from typing import Dict, Any
from PIL import Image
import requests
import torch
import numpy as np
from transformers import AutoProcessor, LlavaForConditionalGeneration, BitsAndBytesConfig
class EndpointHandler():
def __init__(self, path=""):
model_id = path
self.model = LlavaForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
load_in_4bit=True
)
self.processor = AutoProcessor.from_pretrained(model_id)
def __call__(self, data: Dict[list, Any]):
parameters = data.pop("inputs", data)
givenprompt = data.pop("prompt", data)
outputs = []
print(parameters)
prompt = f"USER: <image>\n{givenprompt}?\nASSISTANT:"
for link in parameters:
try:
# Fetch image from URL
response = requests.get(link, stream=True)
response.raise_for_status() # Raise an exception for 4xx or 5xx status codes
raw_image = Image.open(response.raw)
# Process image and generate output
inputs = self.processor(prompt, raw_image, return_tensors='pt').to(0, torch.float16)
output = self.model.generate(**inputs, max_new_tokens=200, do_sample=False)
readable = self.processor.decode(output[0][2:], skip_special_tokens=True)
outputs.append(readable)
except Exception as e:
# Handle any exceptions and log the error
outputs.append(f"Error processing image from {link}: {str(e)}")
return outputs