File size: 2,227 Bytes
4901e8d
 
 
 
 
 
 
 
 
 
 
 
85efce4
4901e8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import gradio as gr
import cv2
import tempfile
from ultralytics import YOLO
from huggingface_hub import hf_hub_download

# Download model weights from Hugging Face model hub
violence_model_path = hf_hub_download(
    repo_id="henriquequeirozcunha/YOLOv12-shop-lifiting",
    filename="yolov12-shoplifiting.pt"
)


# Load model (automatically pulls from HF model hub if not local)
model = YOLO(violence_model_path)

def detect_violence(video_file):
    cap = cv2.VideoCapture(video_file)
    fps = cap.get(cv2.CAP_PROP_FPS)
    w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    out_path = tempfile.mktemp(suffix=".mp4")
    out = cv2.VideoWriter(out_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))

    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break

        results = model.predict(frame, conf=0.4, verbose=False)
        for r in results:
            for box in r.boxes:
                cls_id = int(box.cls[0])
                label = model.names[cls_id] if cls_id in model.names else "Unknown"
                conf = float(box.conf[0])

                x1, y1, x2, y2 = box.xyxy[0].int().tolist()
                cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 0, 255), 2)

                # Label formatting and drawing
                label_text = f"{label} {conf:.2f}"
                (text_w, text_h), baseline = cv2.getTextSize(label_text, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 1)
                text_x = x1
                text_y = y1 - 20 if y1 - 20 > text_h else y1 + text_h + 20

                cv2.rectangle(frame, (text_x, text_y - text_h - baseline),
                              (text_x + text_w, text_y + baseline),
                              (0, 0, 255), -1)
                cv2.putText(frame, label_text, (text_x, text_y),
                            cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 1)

        out.write(frame)

    cap.release()
    out.release()
    return out_path

gr.Interface(fn=detect_violence,
             inputs=gr.Video(),
             outputs=gr.Video(),
             title="YOLOv12 Shoplifting",
             description="Upload a video and get annotated violence predictions.").launch()