scfive commited on
Commit
3c00238
·
verified ·
1 Parent(s): 16181ac

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -23
app.py CHANGED
@@ -4,40 +4,59 @@ from PIL import Image
4
  import numpy as np
5
  from ultralytics import YOLO
6
  from huggingface_hub import hf_hub_download
 
 
 
 
 
7
 
8
  # Download the model from Hugging Face
9
- model_path = hf_hub_download(repo_id="StephanST/WALDO30", filename="WALDO30_yolov8m_640x640.pt")
10
- model = YOLO(model_path) # Load YOLOv8 model
 
 
 
 
 
 
 
 
11
 
12
  # Detection function for images
13
  def detect_on_image(image):
14
- results = model(image) # Perform detection
15
- annotated_frame = results[0].plot() # Get annotated image
16
- return Image.fromarray(annotated_frame)
 
 
 
17
 
18
  # Detection function for videos
19
  def detect_on_video(video):
20
- temp_video_path = "processed_video.mp4"
21
- cap = cv2.VideoCapture(video)
22
- fourcc = cv2.VideoWriter_fourcc(*"mp4v")
23
- out = cv2.VideoWriter(temp_video_path, fourcc, cap.get(cv2.CAP_PROP_FPS),
24
- (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))))
25
-
26
- while cap.isOpened():
27
- ret, frame = cap.read()
28
- if not ret:
29
- break
30
- results = model(frame) # Perform detection
31
- annotated_frame = results[0].plot() # Get annotated frame
32
- out.write(annotated_frame)
33
-
34
- cap.release()
35
- out.release()
36
- return temp_video_path
 
 
 
37
 
38
  # Gradio Interface
39
  image_input = gr.Image(type="pil", label="Upload Image")
40
- video_input = gr.Video(type="file", label="Upload Video")
41
  image_output = gr.Image(type="pil", label="Detected Image")
42
  video_output = gr.Video(label="Detected Video")
43
 
 
4
  import numpy as np
5
  from ultralytics import YOLO
6
  from huggingface_hub import hf_hub_download
7
+ import os
8
+
9
+ # Verify that Hugging Face repo and file paths are correct
10
+ REPO_ID = "StephanST/WALDO30" # Update if the repository ID is different
11
+ MODEL_FILENAME = "WALDO30_yolov8m_640x640.pt"
12
 
13
  # Download the model from Hugging Face
14
+ try:
15
+ model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILENAME)
16
+ except Exception as e:
17
+ raise RuntimeError(f"Failed to download model from Hugging Face. Verify `repo_id` and `filename`. Error: {e}")
18
+
19
+ # Load the YOLOv8 model
20
+ try:
21
+ model = YOLO(model_path) # Ensure the model path is correct
22
+ except Exception as e:
23
+ raise RuntimeError(f"Failed to load the YOLO model. Verify the model file at `{model_path}`. Error: {e}")
24
 
25
  # Detection function for images
26
  def detect_on_image(image):
27
+ try:
28
+ results = model(image) # Perform detection
29
+ annotated_frame = results[0].plot() # Get annotated image
30
+ return Image.fromarray(annotated_frame)
31
+ except Exception as e:
32
+ raise RuntimeError(f"Error during image processing: {e}")
33
 
34
  # Detection function for videos
35
  def detect_on_video(video):
36
+ try:
37
+ temp_video_path = "processed_video.mp4"
38
+ cap = cv2.VideoCapture(video)
39
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
40
+ out = cv2.VideoWriter(temp_video_path, fourcc, cap.get(cv2.CAP_PROP_FPS),
41
+ (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))))
42
+
43
+ while cap.isOpened():
44
+ ret, frame = cap.read()
45
+ if not ret:
46
+ break
47
+ results = model(frame) # Perform detection
48
+ annotated_frame = results[0].plot() # Get annotated frame
49
+ out.write(annotated_frame)
50
+
51
+ cap.release()
52
+ out.release()
53
+ return temp_video_path
54
+ except Exception as e:
55
+ raise RuntimeError(f"Error during video processing: {e}")
56
 
57
  # Gradio Interface
58
  image_input = gr.Image(type="pil", label="Upload Image")
59
+ video_input = gr.Video(label="Upload Video") # Removed invalid `type` argument
60
  image_output = gr.Image(type="pil", label="Detected Image")
61
  video_output = gr.Video(label="Detected Video")
62