| | |
| | """ |
| | Real-time pose classifier |
| | Uses MediaPipe to capture camera input, perform pose recognition and classification, and display results on screen |
| | |
| | Features: |
| | 1. Use MediaPipe to obtain real-time pose data from camera |
| | 2. Extract joint coordinates and preprocess them |
| | 3. Use trained machine learning models for pose classification |
| | 4. Display classification results and keypoints in real-time on video screen |
| | |
| | Dependencies: |
| | pip install opencv-python mediapipe numpy scikit-learn |
| | |
| | Usage: |
| | python realtime_pose_classifier.py [--model MODEL_PATH] [--camera CAMERA_ID] |
| | """ |
| |
|
| | import cv2 |
| | import mediapipe as mp |
| | import numpy as np |
| | import json |
| | import joblib |
| | import argparse |
| | import time |
| | from pathlib import Path |
| | import traceback |
| |
|
| |
|
| | class RealtimePoseClassifier: |
| | def __init__(self, model_path=None, camera_id=0): |
| | """ |
| | Initialize real-time pose classifier |
| | |
| | Args: |
| | model_path (str): Model file path, auto-detect if None |
| | camera_id (int): Camera ID, default 0 |
| | """ |
| | self.camera_id = camera_id |
| | |
| | |
| | self.mp_pose = mp.solutions.pose |
| | self.mp_drawing = mp.solutions.drawing_utils |
| | self.mp_drawing_styles = mp.solutions.drawing_styles |
| | |
| | |
| | self.pose = self.mp_pose.Pose( |
| | static_image_mode=False, |
| | model_complexity=1, |
| | enable_segmentation=False, |
| | min_detection_confidence=0.7, |
| | min_tracking_confidence=0.5 |
| | ) |
| | |
| | |
| | self.landmark_names = [ |
| | 'nose', 'left_eye_inner', 'left_eye', 'left_eye_outer', |
| | 'right_eye_inner', 'right_eye', 'right_eye_outer', |
| | 'left_ear', 'right_ear', 'mouth_left', 'mouth_right', |
| | 'left_shoulder', 'right_shoulder', 'left_elbow', 'right_elbow', |
| | 'left_wrist', 'right_wrist', 'left_pinky', 'right_pinky', |
| | 'left_index', 'right_index', 'left_thumb', 'right_thumb', |
| | 'left_hip', 'right_hip', 'left_knee', 'right_knee', |
| | 'left_ankle', 'right_ankle', 'left_heel', 'right_heel', |
| | 'left_foot_index', 'right_foot_index' |
| | ] |
| | |
| | |
| | self.model = None |
| | self.scaler = None |
| | self.label_encoder = None |
| | self.target_joints = None |
| | self.model_info = None |
| | |
| | self.load_model(model_path) |
| | |
| | |
| | self.prediction_history = [] |
| | self.history_size = 5 |
| | |
| | |
| | self.fps_counter = 0 |
| | self.fps_start_time = time.time() |
| | self.current_fps = 0 |
| |
|
| | |
| | self.mediapipe_time_total = 0.0 |
| | self.mediapipe_time_count = 0 |
| | self.feature_pred_time_total = 0.0 |
| | self.feature_pred_time_count = 0 |
| |
|
| | |
| | self.show_landmarks = True |
| | self.show_connections = True |
| | |
| | def load_model(self, model_path=None): |
| | """Load trained model""" |
| | if model_path is None: |
| | |
| | possible_models = [ |
| | 'pose_classifier_random_forest.pkl', |
| | 'pose_classifier_logistic.pkl', |
| | 'pose_classifier_distilled_rf.pkl' |
| | ] |
| | |
| | for model_file in possible_models: |
| | if Path(model_file).exists(): |
| | model_path = model_file |
| | break |
| | |
| | if model_path is None: |
| | raise FileNotFoundError("No available model file found, please specify model path") |
| | |
| | try: |
| | print(f"Loading model: {model_path}") |
| | model_data = joblib.load(model_path) |
| | |
| | self.model = model_data['model'] |
| | self.scaler = model_data['scaler'] |
| | self.label_encoder = model_data['label_encoder'] |
| | self.target_joints = model_data['target_joints'] |
| | |
| | |
| | labels_path = model_path.replace('.pkl', '_labels.json') |
| | if Path(labels_path).exists(): |
| | with open(labels_path, 'r') as f: |
| | self.model_info = json.load(f) |
| | print(f"Loaded label information: {labels_path}") |
| | |
| | print("Model loaded successfully!") |
| | print(f"Target joints: {self.target_joints}") |
| | print(f"Classification classes: {self.label_encoder.classes_}") |
| | |
| | except Exception as e: |
| | raise RuntimeError(f"Model loading failed: {e}") |
| | |
| | def extract_pose_features(self, landmarks): |
| | """ |
| | Extract pose features from MediaPipe landmarks (vectorized optimized version) |
| | """ |
| | if landmarks is None: |
| | return None |
| |
|
| | |
| | coords = np.array([[lm.x, lm.y, lm.z] for lm in landmarks.landmark], dtype=np.float32) |
| |
|
| | |
| | try: |
| | head_idx = self.landmark_names.index('nose') |
| | head_pos = coords[head_idx] |
| | except ValueError: |
| | return None |
| |
|
| | |
| | joint_indices = [self.landmark_names.index(j) if j in self.landmark_names else -1 for j in self.target_joints] |
| |
|
| | |
| | joint_coords = np.array([ |
| | coords[idx] if idx >= 0 else np.zeros(3, dtype=np.float32) |
| | for idx in joint_indices |
| | ], dtype=np.float32) |
| |
|
| | |
| | relative_coords = (joint_coords - head_pos) * 100 |
| |
|
| | |
| | features = np.round(relative_coords, 2).flatten() |
| |
|
| | return features |
| | |
| | def predict_pose(self, features): |
| | """ |
| | Use machine learning model to predict pose |
| | |
| | Args: |
| | features: Feature vector |
| | |
| | Returns: |
| | dict: Prediction result containing label, confidence, etc. |
| | """ |
| | if features is None or self.model is None: |
| | return None |
| | |
| | try: |
| | |
| | features_scaled = self.scaler.transform(features.reshape(1, -1)) |
| | |
| | |
| | prediction = self.model.predict(features_scaled)[0] |
| | predicted_label = self.label_encoder.inverse_transform([prediction])[0] |
| | |
| | |
| | confidence = 0.0 |
| | probabilities = None |
| | if hasattr(self.model, 'predict_proba'): |
| | probs = self.model.predict_proba(features_scaled)[0] |
| | confidence = float(np.max(probs)) |
| | probabilities = dict(zip(self.label_encoder.classes_, probs)) |
| | |
| | return { |
| | 'predicted_label': predicted_label, |
| | 'confidence': confidence, |
| | 'probabilities': probabilities |
| | } |
| | |
| | except Exception as e: |
| | print(f"Prediction error: {e}") |
| | return None |
| | |
| | def smooth_predictions(self, current_prediction): |
| | """ |
| | Smooth prediction results |
| | |
| | Args: |
| | current_prediction: Current prediction result |
| | |
| | Returns: |
| | dict: Smoothed prediction result |
| | """ |
| | if current_prediction is None: |
| | return None |
| | |
| | |
| | self.prediction_history.append(current_prediction) |
| | if len(self.prediction_history) > self.history_size: |
| | self.prediction_history.pop(0) |
| | |
| | |
| | if len(self.prediction_history) < 3: |
| | return current_prediction |
| | |
| | |
| | recent_labels = [pred['predicted_label'] for pred in self.prediction_history] |
| | |
| | |
| | from collections import Counter |
| | label_counts = Counter(recent_labels) |
| | most_common_label = label_counts.most_common(1)[0][0] |
| | |
| | |
| | avg_confidence = np.mean([ |
| | pred['confidence'] for pred in self.prediction_history |
| | if pred['predicted_label'] == most_common_label |
| | ]) |
| | |
| | return { |
| | 'predicted_label': most_common_label, |
| | 'confidence': avg_confidence, |
| | 'stability': label_counts[most_common_label] / len(recent_labels) |
| | } |
| | |
| | def draw_pose_info(self, image, landmarks, prediction_result): |
| | """ |
| | Draw pose information on image |
| | |
| | Args: |
| | image: OpenCV image |
| | landmarks: MediaPipe landmarks |
| | prediction_result: Prediction result |
| | """ |
| | height, width = image.shape[:2] |
| | |
| | |
| | if landmarks and self.show_connections: |
| | self.mp_drawing.draw_landmarks( |
| | image, |
| | landmarks, |
| | self.mp_pose.POSE_CONNECTIONS, |
| | landmark_drawing_spec=self.mp_drawing_styles.get_default_pose_landmarks_style() |
| | ) |
| | |
| | |
| | if landmarks and self.show_landmarks: |
| | for i, landmark in enumerate(landmarks.landmark): |
| | if self.landmark_names[i] in self.target_joints: |
| | x = int(landmark.x * width) |
| | y = int(landmark.y * height) |
| | cv2.circle(image, (x, y), 8, (0, 255, 0), -1) |
| | cv2.putText(image, self.landmark_names[i], (x + 10, y - 10), |
| | cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 1) |
| | |
| | |
| | if prediction_result: |
| | label = prediction_result['predicted_label'] |
| | confidence = prediction_result.get('confidence', 0.0) |
| | stability = prediction_result.get('stability', 1.0) |
| | |
| | |
| | if confidence > 0.8: |
| | color = (0, 255, 0) |
| | elif confidence > 0.6: |
| | color = (0, 255, 255) |
| | else: |
| | color = (0, 0, 255) |
| | |
| | |
| | cv2.rectangle(image, (10, 10), (400, 120), (0, 0, 0), -1) |
| | cv2.rectangle(image, (10, 10), (400, 120), color, 2) |
| | |
| | |
| | cv2.putText(image, f"Pose: {label}", (20, 40), |
| | cv2.FONT_HERSHEY_SIMPLEX, 1.0, color, 2) |
| | |
| | |
| | cv2.putText(image, f"Confidence: {confidence:.2f}", (20, 70), |
| | cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2) |
| | |
| | |
| | cv2.putText(image, f"Stability: {stability:.2f}", (20, 95), |
| | cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2) |
| | |
| | |
| | cv2.putText(image, f"FPS: {self.current_fps:.1f}", (width - 150, 30), |
| | cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2) |
| | |
| | |
| | instructions = [ |
| | "Controls:", |
| | "Q - Quit", |
| | "L - Toggle Landmarks", |
| | "C - Toggle Connections", |
| | "R - Reset History" |
| | ] |
| | |
| | for i, instruction in enumerate(instructions): |
| | cv2.putText(image, instruction, (width - 200, height - 120 + i * 25), |
| | cv2.FONT_HERSHEY_SIMPLEX, 0.5, (200, 200, 200), 1) |
| |
|
| | |
| | mp_avg = self.mediapipe_time_total / self.mediapipe_time_count if self.mediapipe_time_count else 0.0 |
| | fp_avg = self.feature_pred_time_total / self.feature_pred_time_count if self.feature_pred_time_count else 0.0 |
| | cv2.putText(image, f"MP avg: {mp_avg*1000:.1f}ms", (width - 150, 55), |
| | cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1) |
| | cv2.putText(image, f"FP avg: {fp_avg*1000:.1f}ms", (width - 150, 75), |
| | cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1) |
| | |
| | total_frames = max(self.mediapipe_time_count, 1) |
| | avg_fps = total_frames / max(self.mediapipe_time_total + self.feature_pred_time_total, 1e-6) |
| | cv2.putText(image, f"Avg FPS: {avg_fps:.1f}", (width - 150, 95), |
| | cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1) |
| | |
| | def update_fps(self): |
| | """Update FPS calculation""" |
| | self.fps_counter += 1 |
| | if self.fps_counter >= 30: |
| | current_time = time.time() |
| | self.current_fps = 30 / (current_time - self.fps_start_time) |
| | self.fps_start_time = current_time |
| | self.fps_counter = 0 |
| | |
| | def run(self): |
| | """Run real-time pose classification""" |
| | print("Starting real-time pose classifier...") |
| | print("Press 'Q' to quit, 'L' to toggle landmark display, 'C' to toggle skeleton connections, 'R' to reset history") |
| | |
| | |
| | cap = cv2.VideoCapture(self.camera_id) |
| | if not cap.isOpened(): |
| | raise RuntimeError(f"Cannot open camera {self.camera_id}") |
| | |
| | |
| | cap.set(cv2.CAP_PROP_FRAME_WIDTH, 1280) |
| | cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 720) |
| | cap.set(cv2.CAP_PROP_FPS, 30) |
| | |
| | try: |
| | while True: |
| | success, frame = cap.read() |
| | if not success: |
| | print("Cannot read camera frame") |
| | break |
| |
|
| | |
| | frame = cv2.flip(frame, 1) |
| |
|
| | |
| | rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
| |
|
| | |
| | mp_start = time.time() |
| | results = self.pose.process(rgb_frame) |
| | mp_end = time.time() |
| | self.mediapipe_time_total += (mp_end - mp_start) |
| | self.mediapipe_time_count += 1 |
| |
|
| | |
| | fp_start = time.time() |
| | prediction_result = None |
| | if results.pose_landmarks: |
| | features = self.extract_pose_features(results.pose_landmarks) |
| | if features is not None: |
| | raw_prediction = self.predict_pose(features) |
| | prediction_result = self.smooth_predictions(raw_prediction) |
| | fp_end = time.time() |
| | self.feature_pred_time_total += (fp_end - fp_start) |
| | self.feature_pred_time_count += 1 |
| |
|
| | |
| | self.draw_pose_info(frame, results.pose_landmarks, prediction_result) |
| | |
| | |
| | self.update_fps() |
| | |
| | |
| | cv2.imshow('Real-time Pose Classification', frame) |
| | |
| | |
| | key = cv2.waitKey(1) & 0xFF |
| | if key == ord('q') or key == ord('Q'): |
| | break |
| | elif key == ord('l') or key == ord('L'): |
| | self.show_landmarks = not self.show_landmarks |
| | print(f"Landmark display: {'On' if self.show_landmarks else 'Off'}") |
| | elif key == ord('c') or key == ord('C'): |
| | self.show_connections = not self.show_connections |
| | print(f"Skeleton connection display: {'On' if self.show_connections else 'Off'}") |
| | elif key == ord('r') or key == ord('R'): |
| | self.prediction_history.clear() |
| | print("Prediction history reset") |
| | |
| | except KeyboardInterrupt: |
| | print("\nUser interrupted program") |
| | except Exception as e: |
| | print(f"Runtime error: {e}") |
| | traceback.print_exc() |
| | finally: |
| | cap.release() |
| | cv2.destroyAllWindows() |
| | print("Program exited") |
| |
|
| |
|
| | def main(): |
| | """Main function""" |
| | parser = argparse.ArgumentParser(description='Real-time pose classifier') |
| | parser.add_argument('--model', '-m', type=str, default=None, |
| | help='Model file path (auto-detect by default)') |
| | parser.add_argument('--camera', '-c', type=int, default=0, |
| | help='Camera ID (default 0)') |
| | |
| | args = parser.parse_args() |
| | |
| | try: |
| | classifier = RealtimePoseClassifier( |
| | model_path=args.model, |
| | camera_id=args.camera |
| | ) |
| | classifier.run() |
| | except Exception as e: |
| | print(f"Program startup failed: {e}") |
| | return 1 |
| | |
| | return 0 |
| |
|
| |
|
| | if __name__ == "__main__": |
| | exit(main()) |
| |
|