import cv2 import torch from transformers import AutoImageProcessor, AutoModelForImageClassification from collections import Counter from PIL import Image import os class SceneClassifier: def __init__(self, model_path: str = "2nzi/Image_Surf_NotSurf"): # print(f"[DEBUG] Initializing SceneClassifier with model: {model_path}") try: # Initialiser le processeur et le modèle self.processor = AutoImageProcessor.from_pretrained( "google/vit-base-patch16-224", use_fast=True ) self.model = AutoModelForImageClassification.from_pretrained( model_path, trust_remote_code=True ) self.id_to_label = self.model.config.id2label # print("[DEBUG] Model loaded successfully") except Exception as e: # print(f"[ERROR] Failed to load model: {str(e)}") raise def _time_to_seconds(self, time_str: str) -> float: h, m, s = time_str.split(':') return int(h) * 3600 + int(m) * 60 + float(s) def _extract_frames(self, video_path: str, start_time: str, end_time: str, num_frames: int = 5) -> list: cap = cv2.VideoCapture(video_path) start_sec = self._time_to_seconds(start_time) end_sec = self._time_to_seconds(end_time) scene_duration = end_sec - start_sec frame_interval = scene_duration / (num_frames + 1) frames = [] for i in range(num_frames): timestamp = start_sec + frame_interval * (i + 1) cap.set(cv2.CAP_PROP_POS_MSEC, timestamp * 1000) success, frame = cap.read() if success: image_pil = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) frames.append(image_pil) else: print(f"[WARNING] Failed to extract frame at {timestamp} seconds") cap.release() return frames def _classify_frame(self, frame: Image) -> dict: inputs = self.processor(images=frame, return_tensors="pt") with torch.no_grad(): outputs = self.model(**inputs) probs = torch.nn.functional.softmax(outputs.logits, dim=-1) confidence, predicted_class = torch.max(probs, dim=-1) return { "label": self.id_to_label[predicted_class.item()], "confidence": float(confidence.item()) } def classify_scene(self, video_path: str, scene: dict) -> dict: print(f"[DEBUG] Classifying scene: {scene['start']} -> {scene['end']}") frames = self._extract_frames(video_path, scene["start"], scene["end"]) if not frames: print("[WARNING] No frames extracted for classification") return {"recognized_sport": "Unknown", "confidence": 0.0} classifications = [self._classify_frame(frame) for frame in frames] labels = [c["label"] for c in classifications] label_counts = Counter(labels) predominant_label, count = label_counts.most_common(1)[0] confidence_avg = sum( c["confidence"] for c in classifications if c["label"] == predominant_label ) / count result = { "recognized_sport": predominant_label, "confidence": confidence_avg } print(f"[DEBUG] Classification result: {result}") return result