new-blip / handler.py
pdich2085's picture
Update handler.py
1721131
raw
history blame
5.5 kB
from PIL import Image
from typing import Dict, Any
import torch
import base64
from io import BytesIO
from transformers import BlipForConditionalGeneration, BlipProcessor
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class EndpointHandler():
def __init__(self, path=""):
self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
self.model = BlipForConditionalGeneration.from_pretrained(
"Salesforce/blip-image-captioning-large"
).to(device)
self.model.eval()
self.max_length = 16
self.num_beams = 4
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
try:
image_data = data.get("inputs", None)
# Convert base64 encoded image string to bytes
image_bytes = base64.b64decode(image_data)
# Convert bytes to a BytesIO object
image_buffer = BytesIO(image_bytes)
# Process the image with the processor
processed_inputs = self.processor(image_buffer, return_tensors="pt").to(device)
# Generate the caption
gen_kwargs = {"max_length": self.max_length, "num_beams": self.num_beams}
output_ids = self.model.generate(**processed_inputs, **gen_kwargs)
caption = self.processor.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
return {"caption": caption}
except Exception as e:
# Log the error for better tracking
print(f"Error during processing: {str(e)}")
return {"caption": "", "error": str(e)}
# from PIL import Image
# from typing import Dict, Any
# import torch
# import base64
# from io import BytesIO
# from transformers import BlipForConditionalGeneration, BlipProcessor
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# class EndpointHandler():
# def __init__(self, path=""):
# self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
# self.model = BlipForConditionalGeneration.from_pretrained(
# "Salesforce/blip-image-captioning-large"
# ).to(device)
# self.model.eval()
# self.max_length = 16
# self.num_beams = 4
# def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
# try:
# image_data = data.get("inputs", None)
# # Convert base64 encoded image string to bytes
# image_bytes = base64.b64decode(image_data)
# # Create a BytesIO object from the bytes data
# image_buffer = BytesIO(image_bytes)
# # Open the image from the buffer
# raw_image = Image.open(image_buffer)
# # Ensure the image is in RGB mode (if necessary)
# if raw_image.mode != "RGB":
# raw_image = raw_image.convert(mode="RGB")
# # Extract pixel values and move them to the device
# pixel_values = self.processor(raw_image, return_tensors="pt").pixel_values.to(device)
# # Generate the caption
# gen_kwargs = {"max_length": self.max_length, "num_beams": self.num_beams}
# output_ids = self.model.generate(pixel_values, **gen_kwargs)
# caption = self.processor.batch_decode(output_ids[0], skip_special_tokens=True).strip()
# return {"caption": caption}
# except Exception as e:
# # Log the error for better tracking
# print(f"Error during processing: {str(e)}")
# return {"caption": "", "error": str(e)}
# from PIL import Image
# from typing import Dict, Any
# import torch
# import base64
# from io import BytesIO
# from transformers import BlipForConditionalGeneration, BlipProcessor
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# class EndpointHandler():
# def __init__(self, path=""):
# self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
# self.model = BlipForConditionalGeneration.from_pretrained(
# "Salesforce/blip-image-captioning-large"
# ).to(device)
# self.model.eval()
# self.max_length = 16
# self.num_beams = 4
# def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
# try:
# image_bytes = data.get("inputs", None)
# # Convert base64 encoded image string to a PIL Image
# raw_image = Image.open(BytesIO(image_bytes))
# # Ensure the image is in RGB mode (if necessary)
# if raw_image.mode != "RGB":
# raw_image = raw_image.convert(mode="RGB")
# # Extract pixel values and move them to the device
# pixel_values = self.processor(raw_image, return_tensors="pt").pixel_values.to(device)
# # Generate the caption
# gen_kwargs = {"max_length": self.max_length, "num_beams": self.num_beams}
# output_ids = self.model.generate(pixel_values, **gen_kwargs)
# caption = self.processor.batch_decode(output_ids[0], skip_special_tokens=True).strip()
# return {"caption": caption}
# except Exception as e:
# # Log the error for better tracking
# print(f"Error during processing: {str(e)}")
# return {"caption": "", "error": str(e)}