|  | from io import BytesIO | 
					
						
						|  | from time import perf_counter | 
					
						
						|  | from typing import Any, List, Tuple, Union | 
					
						
						|  |  | 
					
						
						|  | import numpy as np | 
					
						
						|  | from PIL import Image, ImageDraw, ImageFont | 
					
						
						|  |  | 
					
						
						|  | from inference.core.entities.requests.inference import ClassificationInferenceRequest | 
					
						
						|  | from inference.core.entities.responses.inference import ( | 
					
						
						|  | ClassificationInferenceResponse, | 
					
						
						|  | InferenceResponse, | 
					
						
						|  | InferenceResponseImage, | 
					
						
						|  | MultiLabelClassificationInferenceResponse, | 
					
						
						|  | ) | 
					
						
						|  | from inference.core.models.roboflow import OnnxRoboflowInferenceModel | 
					
						
						|  | from inference.core.models.types import PreprocessReturnMetadata | 
					
						
						|  | from inference.core.models.utils.validate import ( | 
					
						
						|  | get_num_classes_from_model_prediction_shape, | 
					
						
						|  | ) | 
					
						
						|  | from inference.core.utils.image_utils import load_image_rgb | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class ClassificationBaseOnnxRoboflowInferenceModel(OnnxRoboflowInferenceModel): | 
					
						
						|  | """Base class for ONNX models for Roboflow classification inference. | 
					
						
						|  |  | 
					
						
						|  | Attributes: | 
					
						
						|  | multiclass (bool): Whether the classification is multi-class or not. | 
					
						
						|  |  | 
					
						
						|  | Methods: | 
					
						
						|  | get_infer_bucket_file_list() -> list: Get the list of required files for inference. | 
					
						
						|  | softmax(x): Compute softmax values for a given set of scores. | 
					
						
						|  | infer(request: ClassificationInferenceRequest) -> Union[List[Union[ClassificationInferenceResponse, MultiLabelClassificationInferenceResponse]], Union[ClassificationInferenceResponse, MultiLabelClassificationInferenceResponse]]: Perform inference on a given request and return the response. | 
					
						
						|  | draw_predictions(inference_request, inference_response): Draw prediction visuals on an image. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | task_type = "classification" | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, *args, **kwargs): | 
					
						
						|  | """Initialize the model, setting whether it is multiclass or not.""" | 
					
						
						|  | super().__init__(*args, **kwargs) | 
					
						
						|  | self.multiclass = self.environment.get("MULTICLASS", False) | 
					
						
						|  |  | 
					
						
						|  | def draw_predictions(self, inference_request, inference_response): | 
					
						
						|  | """Draw prediction visuals on an image. | 
					
						
						|  |  | 
					
						
						|  | This method overlays the predictions on the input image, including drawing rectangles and text to visualize the predicted classes. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | inference_request: The request object containing the image and parameters. | 
					
						
						|  | inference_response: The response object containing the predictions and other details. | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | bytes: The bytes of the visualized image in JPEG format. | 
					
						
						|  | """ | 
					
						
						|  | image = load_image_rgb(inference_request.image) | 
					
						
						|  | image = Image.fromarray(image) | 
					
						
						|  | draw = ImageDraw.Draw(image) | 
					
						
						|  | font = ImageFont.load_default() | 
					
						
						|  | if isinstance(inference_response.predictions, list): | 
					
						
						|  | prediction = inference_response.predictions[0] | 
					
						
						|  | color = self.colors.get(prediction.class_name, "#4892EA") | 
					
						
						|  | draw.rectangle( | 
					
						
						|  | [0, 0, image.size[1], image.size[0]], | 
					
						
						|  | outline=color, | 
					
						
						|  | width=inference_request.visualization_stroke_width, | 
					
						
						|  | ) | 
					
						
						|  | text = f"{prediction.class_id} - {prediction.class_name} {prediction.confidence:.2f}" | 
					
						
						|  | text_size = font.getbbox(text) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | button_size = (text_size[2] + 20, text_size[3] + 20) | 
					
						
						|  | button_img = Image.new("RGBA", button_size, color) | 
					
						
						|  |  | 
					
						
						|  | button_draw = ImageDraw.Draw(button_img) | 
					
						
						|  | button_draw.text((10, 10), text, font=font, fill=(255, 255, 255, 255)) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | image.paste(button_img, (0, 0)) | 
					
						
						|  | else: | 
					
						
						|  | if len(inference_response.predictions) > 0: | 
					
						
						|  | box_color = "#4892EA" | 
					
						
						|  | draw.rectangle( | 
					
						
						|  | [0, 0, image.size[1], image.size[0]], | 
					
						
						|  | outline=box_color, | 
					
						
						|  | width=inference_request.visualization_stroke_width, | 
					
						
						|  | ) | 
					
						
						|  | row = 0 | 
					
						
						|  | predictions = [ | 
					
						
						|  | (cls_name, pred) | 
					
						
						|  | for cls_name, pred in inference_response.predictions.items() | 
					
						
						|  | ] | 
					
						
						|  | predictions = sorted( | 
					
						
						|  | predictions, key=lambda x: x[1].confidence, reverse=True | 
					
						
						|  | ) | 
					
						
						|  | for i, (cls_name, pred) in enumerate(predictions): | 
					
						
						|  | color = self.colors.get(cls_name, "#4892EA") | 
					
						
						|  | text = f"{cls_name} {pred.confidence:.2f}" | 
					
						
						|  | text_size = font.getbbox(text) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | button_size = (text_size[2] + 20, text_size[3] + 20) | 
					
						
						|  | button_img = Image.new("RGBA", button_size, color) | 
					
						
						|  |  | 
					
						
						|  | button_draw = ImageDraw.Draw(button_img) | 
					
						
						|  | button_draw.text((10, 10), text, font=font, fill=(255, 255, 255, 255)) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | image.paste(button_img, (0, row)) | 
					
						
						|  | row += button_size[1] | 
					
						
						|  |  | 
					
						
						|  | buffered = BytesIO() | 
					
						
						|  | image = image.convert("RGB") | 
					
						
						|  | image.save(buffered, format="JPEG") | 
					
						
						|  | return buffered.getvalue() | 
					
						
						|  |  | 
					
						
						|  | def get_infer_bucket_file_list(self) -> list: | 
					
						
						|  | """Get the list of required files for inference. | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | list: A list of required files for inference, e.g., ["environment.json"]. | 
					
						
						|  | """ | 
					
						
						|  | return ["environment.json"] | 
					
						
						|  |  | 
					
						
						|  | def infer( | 
					
						
						|  | self, | 
					
						
						|  | image: Any, | 
					
						
						|  | disable_preproc_auto_orient: bool = False, | 
					
						
						|  | disable_preproc_contrast: bool = False, | 
					
						
						|  | disable_preproc_grayscale: bool = False, | 
					
						
						|  | disable_preproc_static_crop: bool = False, | 
					
						
						|  | return_image_dims: bool = False, | 
					
						
						|  | **kwargs, | 
					
						
						|  | ): | 
					
						
						|  | """ | 
					
						
						|  | Perform inference on the provided image(s) and return the predictions. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | image (Any): The image or list of images to be processed. | 
					
						
						|  | disable_preproc_auto_orient (bool, optional): If true, the auto orient preprocessing step is disabled for this call. Default is False. | 
					
						
						|  | disable_preproc_contrast (bool, optional): If true, the auto contrast preprocessing step is disabled for this call. Default is False. | 
					
						
						|  | disable_preproc_grayscale (bool, optional): If true, the grayscale preprocessing step is disabled for this call. Default is False. | 
					
						
						|  | disable_preproc_static_crop (bool, optional): If true, the static crop preprocessing step is disabled for this call. Default is False. | 
					
						
						|  | return_image_dims (bool, optional): If set to True, the function will also return the dimensions of the image. Defaults to False. | 
					
						
						|  | **kwargs: Additional parameters to customize the inference process. | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | Union[List[np.array], np.array, Tuple[List[np.array], List[Tuple[int, int]]], Tuple[np.array, Tuple[int, int]]]: | 
					
						
						|  | If `return_image_dims` is True and a list of images is provided, a tuple containing a list of prediction arrays and a list of image dimensions (width, height) is returned. | 
					
						
						|  | If `return_image_dims` is True and a single image is provided, a tuple containing the prediction array and image dimensions (width, height) is returned. | 
					
						
						|  | If `return_image_dims` is False and a list of images is provided, only the list of prediction arrays is returned. | 
					
						
						|  | If `return_image_dims` is False and a single image is provided, only the prediction array is returned. | 
					
						
						|  |  | 
					
						
						|  | Notes: | 
					
						
						|  | - The input image(s) will be preprocessed (normalized and reshaped) before inference. | 
					
						
						|  | - This function uses an ONNX session to perform inference on the input image(s). | 
					
						
						|  | """ | 
					
						
						|  | return super().infer( | 
					
						
						|  | image, | 
					
						
						|  | disable_preproc_auto_orient=disable_preproc_auto_orient, | 
					
						
						|  | disable_preproc_contrast=disable_preproc_contrast, | 
					
						
						|  | disable_preproc_grayscale=disable_preproc_grayscale, | 
					
						
						|  | disable_preproc_static_crop=disable_preproc_static_crop, | 
					
						
						|  | return_image_dims=return_image_dims, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | def postprocess( | 
					
						
						|  | self, | 
					
						
						|  | predictions: Tuple[np.ndarray], | 
					
						
						|  | preprocess_return_metadata: PreprocessReturnMetadata, | 
					
						
						|  | return_image_dims=False, | 
					
						
						|  | **kwargs, | 
					
						
						|  | ) -> Union[ClassificationInferenceResponse, List[ClassificationInferenceResponse]]: | 
					
						
						|  | predictions = predictions[0] | 
					
						
						|  | return self.make_response( | 
					
						
						|  | predictions, preprocess_return_metadata["img_dims"], **kwargs | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | def predict(self, img_in: np.ndarray, **kwargs) -> Tuple[np.ndarray]: | 
					
						
						|  | predictions = self.onnx_session.run(None, {self.input_name: img_in}) | 
					
						
						|  | return (predictions,) | 
					
						
						|  |  | 
					
						
						|  | def preprocess( | 
					
						
						|  | self, image: Any, **kwargs | 
					
						
						|  | ) -> Tuple[np.ndarray, PreprocessReturnMetadata]: | 
					
						
						|  | if isinstance(image, list): | 
					
						
						|  | imgs_with_dims = [ | 
					
						
						|  | self.preproc_image( | 
					
						
						|  | i, | 
					
						
						|  | disable_preproc_auto_orient=kwargs.get( | 
					
						
						|  | "disable_preproc_auto_orient", False | 
					
						
						|  | ), | 
					
						
						|  | disable_preproc_contrast=kwargs.get( | 
					
						
						|  | "disable_preproc_contrast", False | 
					
						
						|  | ), | 
					
						
						|  | disable_preproc_grayscale=kwargs.get( | 
					
						
						|  | "disable_preproc_grayscale", False | 
					
						
						|  | ), | 
					
						
						|  | disable_preproc_static_crop=kwargs.get( | 
					
						
						|  | "disable_preproc_static_crop", False | 
					
						
						|  | ), | 
					
						
						|  | ) | 
					
						
						|  | for i in image | 
					
						
						|  | ] | 
					
						
						|  | imgs, img_dims = zip(*imgs_with_dims) | 
					
						
						|  | img_in = np.concatenate(imgs, axis=0) | 
					
						
						|  | else: | 
					
						
						|  | img_in, img_dims = self.preproc_image( | 
					
						
						|  | image, | 
					
						
						|  | disable_preproc_auto_orient=kwargs.get( | 
					
						
						|  | "disable_preproc_auto_orient", False | 
					
						
						|  | ), | 
					
						
						|  | disable_preproc_contrast=kwargs.get("disable_preproc_contrast", False), | 
					
						
						|  | disable_preproc_grayscale=kwargs.get( | 
					
						
						|  | "disable_preproc_grayscale", False | 
					
						
						|  | ), | 
					
						
						|  | disable_preproc_static_crop=kwargs.get( | 
					
						
						|  | "disable_preproc_static_crop", False | 
					
						
						|  | ), | 
					
						
						|  | ) | 
					
						
						|  | img_dims = [img_dims] | 
					
						
						|  |  | 
					
						
						|  | img_in /= 255.0 | 
					
						
						|  |  | 
					
						
						|  | mean = (0.5, 0.5, 0.5) | 
					
						
						|  | std = (0.5, 0.5, 0.5) | 
					
						
						|  |  | 
					
						
						|  | img_in = img_in.astype(np.float32) | 
					
						
						|  |  | 
					
						
						|  | img_in[:, 0, :, :] = (img_in[:, 0, :, :] - mean[0]) / std[0] | 
					
						
						|  | img_in[:, 1, :, :] = (img_in[:, 1, :, :] - mean[1]) / std[1] | 
					
						
						|  | img_in[:, 2, :, :] = (img_in[:, 2, :, :] - mean[2]) / std[2] | 
					
						
						|  | return img_in, PreprocessReturnMetadata({"img_dims": img_dims}) | 
					
						
						|  |  | 
					
						
						|  | def infer_from_request( | 
					
						
						|  | self, | 
					
						
						|  | request: ClassificationInferenceRequest, | 
					
						
						|  | ) -> Union[List[InferenceResponse], InferenceResponse]: | 
					
						
						|  | """ | 
					
						
						|  | Handle an inference request to produce an appropriate response. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | request (ClassificationInferenceRequest): The request object encapsulating the image(s) and relevant parameters. | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | Union[List[InferenceResponse], InferenceResponse]: The response object(s) containing the predictions, visualization, and other pertinent details. If a list of images was provided, a list of responses is returned. Otherwise, a single response is returned. | 
					
						
						|  |  | 
					
						
						|  | Notes: | 
					
						
						|  | - Starts a timer at the beginning to calculate inference time. | 
					
						
						|  | - Processes the image(s) through the `infer` method. | 
					
						
						|  | - Generates the appropriate response object(s) using `make_response`. | 
					
						
						|  | - Calculates and sets the time taken for inference. | 
					
						
						|  | - If visualization is requested, the predictions are drawn on the image. | 
					
						
						|  | """ | 
					
						
						|  | t1 = perf_counter() | 
					
						
						|  | responses = self.infer(**request.dict(), return_image_dims=True) | 
					
						
						|  | for response in responses: | 
					
						
						|  | response.time = perf_counter() - t1 | 
					
						
						|  |  | 
					
						
						|  | if request.visualize_predictions: | 
					
						
						|  | for response in responses: | 
					
						
						|  | response.visualization = self.draw_predictions(request, response) | 
					
						
						|  |  | 
					
						
						|  | if not isinstance(request.image, list): | 
					
						
						|  | responses = responses[0] | 
					
						
						|  |  | 
					
						
						|  | return responses | 
					
						
						|  |  | 
					
						
						|  | def make_response( | 
					
						
						|  | self, | 
					
						
						|  | predictions, | 
					
						
						|  | img_dims, | 
					
						
						|  | confidence: float = 0.5, | 
					
						
						|  | **kwargs, | 
					
						
						|  | ) -> Union[ClassificationInferenceResponse, List[ClassificationInferenceResponse]]: | 
					
						
						|  | """ | 
					
						
						|  | Create response objects for the given predictions and image dimensions. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | predictions (list): List of prediction arrays from the inference process. | 
					
						
						|  | img_dims (list): List of tuples indicating the dimensions (width, height) of each image. | 
					
						
						|  | confidence (float, optional): Confidence threshold for filtering predictions. Defaults to 0.5. | 
					
						
						|  | **kwargs: Additional parameters to influence the response creation process. | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | Union[ClassificationInferenceResponse, List[ClassificationInferenceResponse]]: A response object or a list of response objects encapsulating the prediction details. | 
					
						
						|  |  | 
					
						
						|  | Notes: | 
					
						
						|  | - If the model is multiclass, a `MultiLabelClassificationInferenceResponse` is generated for each image. | 
					
						
						|  | - If the model is not multiclass, a `ClassificationInferenceResponse` is generated for each image. | 
					
						
						|  | - Predictions below the confidence threshold are filtered out. | 
					
						
						|  | """ | 
					
						
						|  | responses = [] | 
					
						
						|  | confidence_threshold = float(confidence) | 
					
						
						|  | for ind, prediction in enumerate(predictions): | 
					
						
						|  | if self.multiclass: | 
					
						
						|  | preds = prediction[0] | 
					
						
						|  | results = dict() | 
					
						
						|  | predicted_classes = [] | 
					
						
						|  | for i, o in enumerate(preds): | 
					
						
						|  | cls_name = self.class_names[i] | 
					
						
						|  | score = float(o) | 
					
						
						|  | results[cls_name] = {"confidence": score, "class_id": i} | 
					
						
						|  | if score > confidence_threshold: | 
					
						
						|  | predicted_classes.append(cls_name) | 
					
						
						|  | response = MultiLabelClassificationInferenceResponse( | 
					
						
						|  | image=InferenceResponseImage( | 
					
						
						|  | width=img_dims[ind][0], height=img_dims[ind][1] | 
					
						
						|  | ), | 
					
						
						|  | predicted_classes=predicted_classes, | 
					
						
						|  | predictions=results, | 
					
						
						|  | ) | 
					
						
						|  | else: | 
					
						
						|  | preds = prediction[0] | 
					
						
						|  | preds = self.softmax(preds) | 
					
						
						|  | results = [] | 
					
						
						|  | for i, cls_name in enumerate(self.class_names): | 
					
						
						|  | score = float(preds[i]) | 
					
						
						|  | pred = { | 
					
						
						|  | "class_id": i, | 
					
						
						|  | "class": cls_name, | 
					
						
						|  | "confidence": round(score, 4), | 
					
						
						|  | } | 
					
						
						|  | results.append(pred) | 
					
						
						|  | results = sorted(results, key=lambda x: x["confidence"], reverse=True) | 
					
						
						|  |  | 
					
						
						|  | response = ClassificationInferenceResponse( | 
					
						
						|  | image=InferenceResponseImage( | 
					
						
						|  | width=img_dims[ind][1], height=img_dims[ind][0] | 
					
						
						|  | ), | 
					
						
						|  | predictions=results, | 
					
						
						|  | top=results[0]["class"], | 
					
						
						|  | confidence=results[0]["confidence"], | 
					
						
						|  | ) | 
					
						
						|  | responses.append(response) | 
					
						
						|  |  | 
					
						
						|  | return responses | 
					
						
						|  |  | 
					
						
						|  | @staticmethod | 
					
						
						|  | def softmax(x): | 
					
						
						|  | """Compute softmax values for each set of scores in x. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | x (np.array): The input array containing the scores. | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | np.array: The softmax values for each set of scores. | 
					
						
						|  | """ | 
					
						
						|  | e_x = np.exp(x - np.max(x)) | 
					
						
						|  | return e_x / e_x.sum() | 
					
						
						|  |  | 
					
						
						|  | def get_model_output_shape(self) -> Tuple[int, int, int]: | 
					
						
						|  | test_image = (np.random.rand(1024, 1024, 3) * 255).astype(np.uint8) | 
					
						
						|  | test_image, _ = self.preprocess(test_image) | 
					
						
						|  | output = np.array(self.predict(test_image)) | 
					
						
						|  | return output.shape | 
					
						
						|  |  | 
					
						
						|  | def validate_model_classes(self) -> None: | 
					
						
						|  | output_shape = self.get_model_output_shape() | 
					
						
						|  | num_classes = output_shape[3] | 
					
						
						|  | try: | 
					
						
						|  | assert num_classes == self.num_classes | 
					
						
						|  | except AssertionError: | 
					
						
						|  | raise ValueError( | 
					
						
						|  | f"Number of classes in model ({num_classes}) does not match the number of classes in the environment ({self.num_classes})" | 
					
						
						|  | ) | 
					
						
						|  |  |