File size: 2,646 Bytes
0534753
 
 
 
 
 
a9e4caf
0534753
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26178eb
 
 
2e3ad82
26178eb
 
 
 
00d9ed8
 
 
 
5d8e38e
0534753
 
 
26178eb
0534753
 
 
d7f0de3
26178eb
0534753
 
26178eb
d7f0de3
 
00d9ed8
0534753
 
 
26178eb
d7f0de3
 
0534753
 
 
 
 
26178eb
 
0534753
 
 
 
 
 
 
26178eb
0534753
 
 
d7f0de3
0534753
 
 
26178eb
0534753
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
from ultralytics import YOLO
from PIL import Image
import gradio as gr
from huggingface_hub import snapshot_download
import os
import cv2
import tempfile


def load_model(repo_id):
    download_dir = snapshot_download(repo_id)
    print(download_dir)
    path  = os.path.join(download_dir, "best.pt")
    print(path)
    detection_model = YOLO(path, task='detect')
    return detection_model


def predict(pilimg):
    
    if pilimg is None:
        return None
    source = pilimg
    # x = np.asarray(pilimg)
    # print(x.shape)
    result = detection_model.predict(source, conf=0.5, iou=0.6)
    img_bgr = result[0].plot()
    out_pilimg = Image.fromarray(img_bgr[..., ::-1])  # RGB-order PIL image
    
    return out_pilimg



def predict_video(video):
    if video is None:
        return None  # Return None if no video was uploaded
    
    # Read video file using OpenCV (video is now a string, so we can directly pass it as a path)
    cap = cv2.VideoCapture(video)
    
    frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = cap.get(cv2.CAP_PROP_FPS)

    temp_output_path = tempfile.mktemp(suffix=".mp4")
    fourcc = cv2.VideoWriter_fourcc(*'mp4v') 
    out = cv2.VideoWriter(temp_output_path, fourcc, fps, (frame_width, frame_height))

    while True:
        ret, frame = cap.read()
        if not ret:
            break 
        
        result = detection_model.predict(frame, conf=0.5, iou=0.6)
        img_bgr = result[0].plot()
        out_frame = img_bgr[..., ::-1]  # Convert BGR to RGB
        out.write(out_frame)
    
    cap.release()
    out.release()
    
    # Return the path to the processed video
    return temp_output_path  # Return the path to the processed video

def enable_button(image_input, video_input):
    if image_input is None and video_input is None:
        return gr.Button.update(interactive=False) 
    return gr.Button.update(interactive=True)



REPO_ID = "dexpyw/model"
detection_model = load_model(REPO_ID)



image_interface = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="pil", label="Upload Image"),
    outputs=gr.Image(type="pil", label="Predicted Image"),
    live=False
)

 
video_interface = gr.Interface(
    fn=predict_video,
    inputs=gr.Video(label="Upload Video"),
    outputs=gr.Video(label="Predicted Video"),
    live=False
)

 
gr.TabbedInterface([image_interface, video_interface], ["Image", "Video"]).launch(share=True)

# image_interface.launch(share=True)
# video_interface.launch(share=True)