Spaces:
Sleeping
Sleeping
File size: 3,527 Bytes
16c970a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
import cv2
import torch
from transformers import AutoImageProcessor, AutoModelForImageClassification
from collections import Counter
from PIL import Image
import os
class SceneClassifier:
def __init__(self, model_path: str = "2nzi/Image_Surf_NotSurf"):
# print(f"[DEBUG] Initializing SceneClassifier with model: {model_path}")
try:
# Initialiser le processeur et le modèle
self.processor = AutoImageProcessor.from_pretrained(
"google/vit-base-patch16-224",
use_fast=True
)
self.model = AutoModelForImageClassification.from_pretrained(
model_path,
trust_remote_code=True
)
self.id_to_label = self.model.config.id2label
# print("[DEBUG] Model loaded successfully")
except Exception as e:
# print(f"[ERROR] Failed to load model: {str(e)}")
raise
def _time_to_seconds(self, time_str: str) -> float:
h, m, s = time_str.split(':')
return int(h) * 3600 + int(m) * 60 + float(s)
def _extract_frames(self, video_path: str, start_time: str, end_time: str, num_frames: int = 5) -> list:
cap = cv2.VideoCapture(video_path)
start_sec = self._time_to_seconds(start_time)
end_sec = self._time_to_seconds(end_time)
scene_duration = end_sec - start_sec
frame_interval = scene_duration / (num_frames + 1)
frames = []
for i in range(num_frames):
timestamp = start_sec + frame_interval * (i + 1)
cap.set(cv2.CAP_PROP_POS_MSEC, timestamp * 1000)
success, frame = cap.read()
if success:
image_pil = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
frames.append(image_pil)
else:
print(f"[WARNING] Failed to extract frame at {timestamp} seconds")
cap.release()
return frames
def _classify_frame(self, frame: Image) -> dict:
inputs = self.processor(images=frame, return_tensors="pt")
with torch.no_grad():
outputs = self.model(**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
confidence, predicted_class = torch.max(probs, dim=-1)
return {
"label": self.id_to_label[predicted_class.item()],
"confidence": float(confidence.item())
}
def classify_scene(self, video_path: str, scene: dict) -> dict:
print(f"[DEBUG] Classifying scene: {scene['start']} -> {scene['end']}")
frames = self._extract_frames(video_path, scene["start"], scene["end"])
if not frames:
print("[WARNING] No frames extracted for classification")
return {"recognized_sport": "Unknown", "confidence": 0.0}
classifications = [self._classify_frame(frame) for frame in frames]
labels = [c["label"] for c in classifications]
label_counts = Counter(labels)
predominant_label, count = label_counts.most_common(1)[0]
confidence_avg = sum(
c["confidence"] for c in classifications
if c["label"] == predominant_label
) / count
result = {
"recognized_sport": predominant_label,
"confidence": confidence_avg
}
print(f"[DEBUG] Classification result: {result}")
return result |