Spaces:
Running
Running
File size: 9,438 Bytes
1999a98 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 |
# Ultralytics π AGPL-3.0 License - https://ultralytics.com/license
import io
from typing import Any
import cv2
from ultralytics import YOLO
from ultralytics.utils import LOGGER
from ultralytics.utils.checks import check_requirements
from ultralytics.utils.downloads import GITHUB_ASSETS_STEMS
class Inference:
"""
A class to perform object detection, image classification, image segmentation and pose estimation inference using
Streamlit and Ultralytics YOLO models. It provides the functionalities such as loading models, configuring settings,
uploading video files, and performing real-time inference.
Attributes:
st (module): Streamlit module for UI creation.
temp_dict (dict): Temporary dictionary to store the model path.
model_path (str): Path to the loaded model.
model (YOLO): The YOLO model instance.
source (str): Selected video source.
enable_trk (str): Enable tracking option.
conf (float): Confidence threshold.
iou (float): IoU threshold for non-max suppression.
vid_file_name (str): Name of the uploaded video file.
selected_ind (list): List of selected class indices.
Methods:
web_ui: Sets up the Streamlit web interface with custom HTML elements.
sidebar: Configures the Streamlit sidebar for model and inference settings.
source_upload: Handles video file uploads through the Streamlit interface.
configure: Configures the model and loads selected classes for inference.
inference: Performs real-time object detection inference.
Examples:
>>> inf = solutions.Inference(model="path/to/model.pt") # Model is not necessary argument.
>>> inf.inference()
"""
def __init__(self, **kwargs: Any):
"""
Initializes the Inference class, checking Streamlit requirements and setting up the model path.
Args:
**kwargs (Any): Additional keyword arguments for model configuration.
"""
check_requirements("streamlit>=1.29.0") # scope imports for faster ultralytics package load speeds
import streamlit as st
self.st = st # Reference to the Streamlit class instance
self.source = None # Placeholder for video or webcam source details
self.enable_trk = False # Flag to toggle object tracking
self.conf = 0.25 # Confidence threshold for detection
self.iou = 0.45 # Intersection-over-Union (IoU) threshold for non-maximum suppression
self.org_frame = None # Container for the original frame to be displayed
self.ann_frame = None # Container for the annotated frame to be displayed
self.vid_file_name = None # Holds the name of the video file
self.selected_ind = [] # List of selected classes for detection or tracking
self.model = None # Container for the loaded model instance
self.temp_dict = {"model": None, **kwargs}
self.model_path = None # Store model file name with path
if self.temp_dict["model"] is not None:
self.model_path = self.temp_dict["model"]
LOGGER.info(f"Ultralytics Solutions: β
{self.temp_dict}")
def web_ui(self):
"""Sets up the Streamlit web interface with custom HTML elements."""
menu_style_cfg = """<style>MainMenu {visibility: hidden;}</style>""" # Hide main menu style
# Main title of streamlit application
main_title_cfg = """<div><h1 style="color:#FF64DA; text-align:center; font-size:40px; margin-top:-50px;
font-family: 'Archivo', sans-serif; margin-bottom:20px;">Ultralytics YOLO Streamlit Application</h1></div>"""
# Subtitle of streamlit application
sub_title_cfg = """<div><h4 style="color:#042AFF; text-align:center; font-family: 'Archivo', sans-serif;
margin-top:-15px; margin-bottom:50px;">Experience real-time object detection on your webcam with the power
of Ultralytics YOLO! π</h4></div>"""
# Set html page configuration and append custom HTML
self.st.set_page_config(page_title="Ultralytics Streamlit App", layout="wide")
self.st.markdown(menu_style_cfg, unsafe_allow_html=True)
self.st.markdown(main_title_cfg, unsafe_allow_html=True)
self.st.markdown(sub_title_cfg, unsafe_allow_html=True)
def sidebar(self):
"""Configures the Streamlit sidebar for model and inference settings."""
with self.st.sidebar: # Add Ultralytics LOGO
logo = "https://raw.githubusercontent.com/ultralytics/assets/main/logo/Ultralytics_Logotype_Original.svg"
self.st.image(logo, width=250)
self.st.sidebar.title("User Configuration") # Add elements to vertical setting menu
self.source = self.st.sidebar.selectbox(
"Video",
("webcam", "video"),
) # Add source selection dropdown
self.enable_trk = self.st.sidebar.radio("Enable Tracking", ("Yes", "No")) # Enable object tracking
self.conf = float(
self.st.sidebar.slider("Confidence Threshold", 0.0, 1.0, self.conf, 0.01)
) # Slider for confidence
self.iou = float(self.st.sidebar.slider("IoU Threshold", 0.0, 1.0, self.iou, 0.01)) # Slider for NMS threshold
col1, col2 = self.st.columns(2)
self.org_frame = col1.empty()
self.ann_frame = col2.empty()
def source_upload(self):
"""Handles video file uploads through the Streamlit interface."""
self.vid_file_name = ""
if self.source == "video":
vid_file = self.st.sidebar.file_uploader("Upload Video File", type=["mp4", "mov", "avi", "mkv"])
if vid_file is not None:
g = io.BytesIO(vid_file.read()) # BytesIO Object
with open("ultralytics.mp4", "wb") as out: # Open temporary file as bytes
out.write(g.read()) # Read bytes into file
self.vid_file_name = "ultralytics.mp4"
elif self.source == "webcam":
self.vid_file_name = 0
def configure(self):
"""Configures the model and loads selected classes for inference."""
# Add dropdown menu for model selection
available_models = [x.replace("yolo", "YOLO") for x in GITHUB_ASSETS_STEMS if x.startswith("yolo11")]
if self.model_path: # If user provided the custom model, insert model without suffix as *.pt is added later
available_models.insert(0, self.model_path.split(".pt")[0])
selected_model = self.st.sidebar.selectbox("Model", available_models)
with self.st.spinner("Model is downloading..."):
self.model = YOLO(f"{selected_model.lower()}.pt") # Load the YOLO model
class_names = list(self.model.names.values()) # Convert dictionary to list of class names
self.st.success("Model loaded successfully!")
# Multiselect box with class names and get indices of selected classes
selected_classes = self.st.sidebar.multiselect("Classes", class_names, default=class_names[:3])
self.selected_ind = [class_names.index(option) for option in selected_classes]
if not isinstance(self.selected_ind, list): # Ensure selected_options is a list
self.selected_ind = list(self.selected_ind)
def inference(self):
"""Performs real-time object detection inference."""
self.web_ui() # Initialize the web interface
self.sidebar() # Create the sidebar
self.source_upload() # Upload the video source
self.configure() # Configure the app
if self.st.sidebar.button("Start"):
stop_button = self.st.button("Stop") # Button to stop the inference
cap = cv2.VideoCapture(self.vid_file_name) # Capture the video
if not cap.isOpened():
self.st.error("Could not open webcam.")
while cap.isOpened():
success, frame = cap.read()
if not success:
self.st.warning("Failed to read frame from webcam. Please verify the webcam is connected properly.")
break
# Store model predictions
if self.enable_trk == "Yes":
results = self.model.track(
frame, conf=self.conf, iou=self.iou, classes=self.selected_ind, persist=True
)
else:
results = self.model(frame, conf=self.conf, iou=self.iou, classes=self.selected_ind)
annotated_frame = results[0].plot() # Add annotations on frame
if stop_button:
cap.release() # Release the capture
self.st.stop() # Stop streamlit app
self.org_frame.image(frame, channels="BGR") # Display original frame
self.ann_frame.image(annotated_frame, channels="BGR") # Display processed frame
cap.release() # Release the capture
cv2.destroyAllWindows() # Destroy window
if __name__ == "__main__":
import sys # Import the sys module for accessing command-line arguments
# Check if a model name is provided as a command-line argument
args = len(sys.argv)
model = sys.argv[1] if args > 1 else None # assign first argument as the model name
# Create an instance of the Inference class and run inference
Inference(model=model).inference()
|