TheArchitect416's picture
Create handler.py
fd04d5f verified
import torch
import segmentation_models_pytorch as smp
from torchvision import transforms
from PIL import Image
import io
import json
import base64
import numpy as np
# Define the number of output classes (update if needed)
NUM_CLASSES = 4
# Define preprocessing transforms (should match what was used during training)
preprocess = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=(0.485, 0.456, 0.406), # ImageNet means
std=(0.229, 0.224, 0.225))
])
# Define class-color mapping for segmentation mask visualization
COLOR_MAPPING = {
0: [0, 0, 0], # Background
1: [255, 0, 124], # Oil
2: [255, 204, 51], # Others
3: [51, 221, 255] # Water
}
def colorize_mask(mask):
"""Convert a 2D segmentation mask into an RGB image."""
h, w = mask.shape
color_mask = np.zeros((h, w, 3), dtype=np.uint8)
for cls, color in COLOR_MAPPING.items():
color_mask[mask == cls] = color
return color_mask
class OilSpillSegmentationHandler:
def __init__(self):
"""Load the model and set it to evaluation mode."""
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = smp.Unet(
encoder_name="resnet34", # Ensure this matches your training
encoder_weights=None, # Weights are loaded from state_dict
in_channels=3,
classes=NUM_CLASSES
)
self.model.load_state_dict(torch.load("model.pth", map_location=self.device))
self.model.to(self.device)
self.model.eval()
def preprocess(self, image_bytes):
"""Preprocess input image (convert to tensor)."""
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
image_tensor = preprocess(image).unsqueeze(0).to(self.device)
return image_tensor, image
def inference(self, image_tensor):
"""Run inference and return the segmentation mask."""
with torch.no_grad():
output = self.model(image_tensor)
pred_mask = torch.argmax(output, dim=1).squeeze(0).cpu().numpy().astype(np.uint8)
return pred_mask
def postprocess(self, pred_mask):
"""Convert segmentation mask to colorized image."""
colorized_mask = colorize_mask(pred_mask)
return Image.fromarray(colorized_mask)
def handle_request(self, request_body):
"""Handle API request: preprocess, infer, postprocess."""
try:
data = json.loads(request_body)
image_bytes = base64.b64decode(data["image"])
image_tensor, original_image = self.preprocess(image_bytes)
pred_mask = self.inference(image_tensor)
output_image = self.postprocess(pred_mask)
# Convert output image to base64
buffered = io.BytesIO()
output_image.save(buffered, format="PNG")
output_b64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
return json.dumps({"output_image": output_b64})
except Exception as e:
return json.dumps({"error": str(e)})
# Instantiate the handler
handler = OilSpillSegmentationHandler()