from PIL import Image
from typing import Dict, Any
import torch
import base64
from io import BytesIO
from transformers import BlipForConditionalGeneration, BlipProcessor

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class EndpointHandler():
    def __init__(self, path=""):
        self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
        self.model = BlipForConditionalGeneration.from_pretrained(
            "Salesforce/blip-image-captioning-large"
        ).to(device)
        self.model.eval()
        self.max_length = 16
        self.num_beams = 4
        
    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        try:
            image_data = data.get("inputs", None)
            
            # Convert base64 encoded image string to bytes
            image_bytes = base64.b64decode(image_data)
            
            # Convert bytes to a BytesIO object
            image_buffer = BytesIO(image_bytes)

            # Process the image with the processor
            processed_inputs = self.processor(image_buffer, return_tensors="pt").to(device)

            # Generate the caption
            gen_kwargs = {"max_length": self.max_length, "num_beams": self.num_beams}
            output_ids = self.model.generate(**processed_inputs, **gen_kwargs)

            caption = self.processor.batch_decode(output_ids, skip_special_tokens=True)[0].strip()

            return {"caption": caption}
        except Exception as e:
            # Log the error for better tracking
            print(f"Error during processing: {str(e)}")
            return {"caption": "", "error": str(e)}




# from PIL import Image
# from typing import Dict, Any
# import torch
# import base64
# from io import BytesIO
# from transformers import BlipForConditionalGeneration, BlipProcessor

# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# class EndpointHandler():
#     def __init__(self, path=""):
#         self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
#         self.model = BlipForConditionalGeneration.from_pretrained(
#             "Salesforce/blip-image-captioning-large"
#         ).to(device)
#         self.model.eval()
#         self.max_length = 16
#         self.num_beams = 4
        
#     def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
#         try:
#             image_data = data.get("inputs", None)
            
#             # Convert base64 encoded image string to bytes
#             image_bytes = base64.b64decode(image_data)

#             # Create a BytesIO object from the bytes data
#             image_buffer = BytesIO(image_bytes)
            
#             # Open the image from the buffer
#             raw_image = Image.open(image_buffer)

#             # Ensure the image is in RGB mode (if necessary)
#             if raw_image.mode != "RGB":
#                 raw_image = raw_image.convert(mode="RGB")

#             # Extract pixel values and move them to the device
#             pixel_values = self.processor(raw_image, return_tensors="pt").pixel_values.to(device)

#             # Generate the caption
#             gen_kwargs = {"max_length": self.max_length, "num_beams": self.num_beams}
#             output_ids = self.model.generate(pixel_values, **gen_kwargs)

#             caption = self.processor.batch_decode(output_ids[0], skip_special_tokens=True).strip()

#             return {"caption": caption}
#         except Exception as e:
#             # Log the error for better tracking
#             print(f"Error during processing: {str(e)}")
#             return {"caption": "", "error": str(e)}





# from PIL import Image
# from typing import Dict, Any
# import torch
# import base64
# from io import BytesIO
# from transformers import BlipForConditionalGeneration, BlipProcessor

# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# class EndpointHandler():
#     def __init__(self, path=""):
#         self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
#         self.model = BlipForConditionalGeneration.from_pretrained(
#             "Salesforce/blip-image-captioning-large"
#         ).to(device)
#         self.model.eval()
#         self.max_length = 16
#         self.num_beams = 4
        
#     def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
#         try:
#             image_bytes = data.get("inputs", None)
            
#             # Convert base64 encoded image string to a PIL Image
#             raw_image = Image.open(BytesIO(image_bytes))

#             # Ensure the image is in RGB mode (if necessary)
#             if raw_image.mode != "RGB":
#                 raw_image = raw_image.convert(mode="RGB")

#             # Extract pixel values and move them to the device
#             pixel_values = self.processor(raw_image, return_tensors="pt").pixel_values.to(device)

#             # Generate the caption
#             gen_kwargs = {"max_length": self.max_length, "num_beams": self.num_beams}
#             output_ids = self.model.generate(pixel_values, **gen_kwargs)

#             caption = self.processor.batch_decode(output_ids[0], skip_special_tokens=True).strip()

#             return {"caption": caption}
#         except Exception as e:
#             # Log the error for better tracking
#             print(f"Error during processing: {str(e)}")
#             return {"caption": "", "error": str(e)}