|
import torch |
|
import segmentation_models_pytorch as smp |
|
from torchvision import transforms |
|
from PIL import Image |
|
import io |
|
import json |
|
import base64 |
|
import numpy as np |
|
|
|
|
|
NUM_CLASSES = 4 |
|
|
|
|
|
preprocess = transforms.Compose([ |
|
transforms.Resize((256, 256)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=(0.485, 0.456, 0.406), |
|
std=(0.229, 0.224, 0.225)) |
|
]) |
|
|
|
|
|
COLOR_MAPPING = { |
|
0: [0, 0, 0], |
|
1: [255, 0, 124], |
|
2: [255, 204, 51], |
|
3: [51, 221, 255] |
|
} |
|
|
|
def colorize_mask(mask): |
|
"""Convert a 2D segmentation mask into an RGB image.""" |
|
h, w = mask.shape |
|
color_mask = np.zeros((h, w, 3), dtype=np.uint8) |
|
for cls, color in COLOR_MAPPING.items(): |
|
color_mask[mask == cls] = color |
|
return color_mask |
|
|
|
class OilSpillSegmentationHandler: |
|
def __init__(self): |
|
"""Load the model and set it to evaluation mode.""" |
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
self.model = smp.Unet( |
|
encoder_name="resnet34", |
|
encoder_weights=None, |
|
in_channels=3, |
|
classes=NUM_CLASSES |
|
) |
|
self.model.load_state_dict(torch.load("model.pth", map_location=self.device)) |
|
self.model.to(self.device) |
|
self.model.eval() |
|
|
|
def preprocess(self, image_bytes): |
|
"""Preprocess input image (convert to tensor).""" |
|
image = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
|
image_tensor = preprocess(image).unsqueeze(0).to(self.device) |
|
return image_tensor, image |
|
|
|
def inference(self, image_tensor): |
|
"""Run inference and return the segmentation mask.""" |
|
with torch.no_grad(): |
|
output = self.model(image_tensor) |
|
pred_mask = torch.argmax(output, dim=1).squeeze(0).cpu().numpy().astype(np.uint8) |
|
return pred_mask |
|
|
|
def postprocess(self, pred_mask): |
|
"""Convert segmentation mask to colorized image.""" |
|
colorized_mask = colorize_mask(pred_mask) |
|
return Image.fromarray(colorized_mask) |
|
|
|
def handle_request(self, request_body): |
|
"""Handle API request: preprocess, infer, postprocess.""" |
|
try: |
|
data = json.loads(request_body) |
|
image_bytes = base64.b64decode(data["image"]) |
|
image_tensor, original_image = self.preprocess(image_bytes) |
|
pred_mask = self.inference(image_tensor) |
|
output_image = self.postprocess(pred_mask) |
|
|
|
|
|
buffered = io.BytesIO() |
|
output_image.save(buffered, format="PNG") |
|
output_b64 = base64.b64encode(buffered.getvalue()).decode("utf-8") |
|
|
|
return json.dumps({"output_image": output_b64}) |
|
except Exception as e: |
|
return json.dumps({"error": str(e)}) |
|
|
|
|
|
handler = OilSpillSegmentationHandler() |
|
|