File size: 1,972 Bytes
d526f23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0cd7737
d526f23
 
 
 
 
 
0cd7737
d526f23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3163c7a
d526f23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import sys
import PIL.Image as Image
from ultralytics import YOLO
import gradio as gr

# Local imports
from src.logger import logging
from src.exception import CustomExceptionHandling


def predict_pose(
    img: str,
    conf_threshold: float,
    iou_threshold: float,
    max_detections: int,
    model_name: str,
) -> Image.Image:
    """
    Predicts objects in an image using a YOLO model with adjustable confidence and IOU thresholds.

    Args:
        - img (str or numpy.ndarray): The input image or path to the image file.
        - conf_threshold (float): The confidence threshold for object detection.
        - iou_threshold (float): The Intersection Over Union (IOU) threshold for non-max suppression.
        - max_detections (int): The maximum number of detections allowed.
        - model_name (str): The name or path of the YOLO model to be used for prediction.

    Returns:
        PIL.Image.Image: The image with predicted objects plotted on it.
    """
    try:
        # Check if image is None
        if img is None:
            gr.Warning("Please provide an image.")

        # Load the YOLO model
        model = YOLO(model_name)

        # Predict objects in the image
        results = model.predict(
            source=img,
            conf=conf_threshold,
            iou=iou_threshold,
            max_det=max_detections,
            show_labels=True,
            show_conf=True,
            imgsz=640,
            half=True,
            device="cpu",
        )

        # Plot the predicted objects on the image
        for r in results:
            im_array = r.plot()
            im = Image.fromarray(im_array[..., ::-1])

        # Log the successful prediction
        logging.info("Pose estimated successfully.")

        # Return the image
        return im

    # Handle exceptions that may occur during the process
    except Exception as e:
        # Custom exception handling
        raise CustomExceptionHandling(e, sys) from e