File size: 3,527 Bytes
16c970a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import cv2
import torch
from transformers import AutoImageProcessor, AutoModelForImageClassification
from collections import Counter
from PIL import Image
import os

class SceneClassifier:
    def __init__(self, model_path: str = "2nzi/Image_Surf_NotSurf"):
        # print(f"[DEBUG] Initializing SceneClassifier with model: {model_path}")
        try:
            # Initialiser le processeur et le modèle
            self.processor = AutoImageProcessor.from_pretrained(
                "google/vit-base-patch16-224",
                use_fast=True
            )
            self.model = AutoModelForImageClassification.from_pretrained(
                model_path,
                trust_remote_code=True
            )
            self.id_to_label = self.model.config.id2label
            # print("[DEBUG] Model loaded successfully")
        except Exception as e:
            # print(f"[ERROR] Failed to load model: {str(e)}")
            raise

    def _time_to_seconds(self, time_str: str) -> float:
        h, m, s = time_str.split(':')
        return int(h) * 3600 + int(m) * 60 + float(s)

    def _extract_frames(self, video_path: str, start_time: str, end_time: str, num_frames: int = 5) -> list:
        cap = cv2.VideoCapture(video_path)
        start_sec = self._time_to_seconds(start_time)
        end_sec = self._time_to_seconds(end_time)
        scene_duration = end_sec - start_sec
        frame_interval = scene_duration / (num_frames + 1)

        frames = []
        for i in range(num_frames):
            timestamp = start_sec + frame_interval * (i + 1)
            cap.set(cv2.CAP_PROP_POS_MSEC, timestamp * 1000)
            success, frame = cap.read()
            if success:
                image_pil = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
                frames.append(image_pil)
            else:
                print(f"[WARNING] Failed to extract frame at {timestamp} seconds")

        cap.release()
        return frames

    def _classify_frame(self, frame: Image) -> dict:
        inputs = self.processor(images=frame, return_tensors="pt")
        with torch.no_grad():
            outputs = self.model(**inputs)
            probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
            confidence, predicted_class = torch.max(probs, dim=-1)

        return {
            "label": self.id_to_label[predicted_class.item()],
            "confidence": float(confidence.item())
        }

    def classify_scene(self, video_path: str, scene: dict) -> dict:
        print(f"[DEBUG] Classifying scene: {scene['start']} -> {scene['end']}")
        frames = self._extract_frames(video_path, scene["start"], scene["end"])
        if not frames:
            print("[WARNING] No frames extracted for classification")
            return {"recognized_sport": "Unknown", "confidence": 0.0}

        classifications = [self._classify_frame(frame) for frame in frames]
        labels = [c["label"] for c in classifications]
        
        label_counts = Counter(labels)
        predominant_label, count = label_counts.most_common(1)[0]
        
        confidence_avg = sum(
            c["confidence"] for c in classifications 
            if c["label"] == predominant_label
        ) / count

        result = {
            "recognized_sport": predominant_label,
            "confidence": confidence_avg
        }
        print(f"[DEBUG] Classification result: {result}")
        return result