Spaces:
Sleeping
Sleeping
import cv2 | |
import torch | |
from transformers import AutoImageProcessor, AutoModelForImageClassification | |
from collections import Counter | |
from PIL import Image | |
import os | |
class SceneClassifier: | |
def __init__(self, model_path: str = "2nzi/Image_Surf_NotSurf"): | |
# print(f"[DEBUG] Initializing SceneClassifier with model: {model_path}") | |
try: | |
# Initialiser le processeur et le modèle | |
self.processor = AutoImageProcessor.from_pretrained( | |
"google/vit-base-patch16-224", | |
use_fast=True | |
) | |
self.model = AutoModelForImageClassification.from_pretrained( | |
model_path, | |
trust_remote_code=True | |
) | |
self.id_to_label = self.model.config.id2label | |
# print("[DEBUG] Model loaded successfully") | |
except Exception as e: | |
# print(f"[ERROR] Failed to load model: {str(e)}") | |
raise | |
def _time_to_seconds(self, time_str: str) -> float: | |
h, m, s = time_str.split(':') | |
return int(h) * 3600 + int(m) * 60 + float(s) | |
def _extract_frames(self, video_path: str, start_time: str, end_time: str, num_frames: int = 5) -> list: | |
cap = cv2.VideoCapture(video_path) | |
start_sec = self._time_to_seconds(start_time) | |
end_sec = self._time_to_seconds(end_time) | |
scene_duration = end_sec - start_sec | |
frame_interval = scene_duration / (num_frames + 1) | |
frames = [] | |
for i in range(num_frames): | |
timestamp = start_sec + frame_interval * (i + 1) | |
cap.set(cv2.CAP_PROP_POS_MSEC, timestamp * 1000) | |
success, frame = cap.read() | |
if success: | |
image_pil = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) | |
frames.append(image_pil) | |
else: | |
print(f"[WARNING] Failed to extract frame at {timestamp} seconds") | |
cap.release() | |
return frames | |
def _classify_frame(self, frame: Image) -> dict: | |
inputs = self.processor(images=frame, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = self.model(**inputs) | |
probs = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
confidence, predicted_class = torch.max(probs, dim=-1) | |
return { | |
"label": self.id_to_label[predicted_class.item()], | |
"confidence": float(confidence.item()) | |
} | |
def classify_scene(self, video_path: str, scene: dict) -> dict: | |
print(f"[DEBUG] Classifying scene: {scene['start']} -> {scene['end']}") | |
frames = self._extract_frames(video_path, scene["start"], scene["end"]) | |
if not frames: | |
print("[WARNING] No frames extracted for classification") | |
return {"recognized_sport": "Unknown", "confidence": 0.0} | |
classifications = [self._classify_frame(frame) for frame in frames] | |
labels = [c["label"] for c in classifications] | |
label_counts = Counter(labels) | |
predominant_label, count = label_counts.most_common(1)[0] | |
confidence_avg = sum( | |
c["confidence"] for c in classifications | |
if c["label"] == predominant_label | |
) / count | |
result = { | |
"recognized_sport": predominant_label, | |
"confidence": confidence_avg | |
} | |
print(f"[DEBUG] Classification result: {result}") | |
return result |