2nzi's picture
update app
16c970a verified
raw
history blame
3.53 kB
import cv2
import torch
from transformers import AutoImageProcessor, AutoModelForImageClassification
from collections import Counter
from PIL import Image
import os
class SceneClassifier:
def __init__(self, model_path: str = "2nzi/Image_Surf_NotSurf"):
# print(f"[DEBUG] Initializing SceneClassifier with model: {model_path}")
try:
# Initialiser le processeur et le modèle
self.processor = AutoImageProcessor.from_pretrained(
"google/vit-base-patch16-224",
use_fast=True
)
self.model = AutoModelForImageClassification.from_pretrained(
model_path,
trust_remote_code=True
)
self.id_to_label = self.model.config.id2label
# print("[DEBUG] Model loaded successfully")
except Exception as e:
# print(f"[ERROR] Failed to load model: {str(e)}")
raise
def _time_to_seconds(self, time_str: str) -> float:
h, m, s = time_str.split(':')
return int(h) * 3600 + int(m) * 60 + float(s)
def _extract_frames(self, video_path: str, start_time: str, end_time: str, num_frames: int = 5) -> list:
cap = cv2.VideoCapture(video_path)
start_sec = self._time_to_seconds(start_time)
end_sec = self._time_to_seconds(end_time)
scene_duration = end_sec - start_sec
frame_interval = scene_duration / (num_frames + 1)
frames = []
for i in range(num_frames):
timestamp = start_sec + frame_interval * (i + 1)
cap.set(cv2.CAP_PROP_POS_MSEC, timestamp * 1000)
success, frame = cap.read()
if success:
image_pil = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
frames.append(image_pil)
else:
print(f"[WARNING] Failed to extract frame at {timestamp} seconds")
cap.release()
return frames
def _classify_frame(self, frame: Image) -> dict:
inputs = self.processor(images=frame, return_tensors="pt")
with torch.no_grad():
outputs = self.model(**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
confidence, predicted_class = torch.max(probs, dim=-1)
return {
"label": self.id_to_label[predicted_class.item()],
"confidence": float(confidence.item())
}
def classify_scene(self, video_path: str, scene: dict) -> dict:
print(f"[DEBUG] Classifying scene: {scene['start']} -> {scene['end']}")
frames = self._extract_frames(video_path, scene["start"], scene["end"])
if not frames:
print("[WARNING] No frames extracted for classification")
return {"recognized_sport": "Unknown", "confidence": 0.0}
classifications = [self._classify_frame(frame) for frame in frames]
labels = [c["label"] for c in classifications]
label_counts = Counter(labels)
predominant_label, count = label_counts.most_common(1)[0]
confidence_avg = sum(
c["confidence"] for c in classifications
if c["label"] == predominant_label
) / count
result = {
"recognized_sport": predominant_label,
"confidence": confidence_avg
}
print(f"[DEBUG] Classification result: {result}")
return result