JohnJoelMota commited on
Commit
d43dcb1
·
verified ·
1 Parent(s): 833766c

Fix example paths

Browse files
Files changed (1) hide show
  1. app.py +85 -79
app.py CHANGED
@@ -1,83 +1,89 @@
1
- import torch
2
- import torchvision
3
- from torchvision.models.detection import FasterRCNN_ResNet50_FPN_Weights
4
- from PIL import Image
5
- import numpy as np
6
- import matplotlib.pyplot as plt
7
- import gradio as gr
8
- import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- # Load the pre-trained model once
11
- model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=FasterRCNN_ResNet50_FPN_Weights.DEFAULT)
12
- model.eval()
13
-
14
- # COCO class names
15
- COCO_INSTANCE_CATEGORY_NAMES = [
16
- '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
17
- 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
18
- 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
19
- 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
20
- 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
21
- 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
22
- 'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
23
- 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
24
- 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
25
- 'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
26
- 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
27
- 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
28
- ]
29
-
30
- # Gradio-compatible detection function
31
- def detect_objects(image, threshold=0.5):
32
- transform = FasterRCNN_ResNet50_FPN_Weights.DEFAULT.transforms()
33
- image_tensor = transform(image).unsqueeze(0)
34
-
35
- with torch.no_grad():
36
- prediction = model(image_tensor)[0]
37
-
38
- boxes = prediction['boxes'].cpu().numpy()
39
- labels = prediction['labels'].cpu().numpy()
40
- scores = prediction['scores'].cpu().numpy()
41
-
42
- image_np = np.array(image)
43
- plt.figure(figsize=(10, 10))
44
- plt.imshow(image_np)
45
- ax = plt.gca()
46
-
47
- for box, label, score in zip(boxes, labels, scores):
48
- if score >= threshold:
49
- x1, y1, x2, y2 = box
50
- ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1,
51
- fill=False, color='red', linewidth=2))
52
- class_name = COCO_INSTANCE_CATEGORY_NAMES[label]
53
- ax.text(x1, y1, f'{class_name}: {score:.2f}', bbox=dict(facecolor='yellow', alpha=0.5),
54
- fontsize=12, color='black')
55
-
56
- plt.axis('off')
57
- plt.tight_layout()
58
-
59
- # Save the figure to return
60
- plt.savefig("output.png")
61
- plt.close()
62
- return "output.png"
63
-
64
- # List the example images
65
  example_images = [
66
- ["TEST_IMG_1.jpg"],
67
- ["TEST_IMG_2.jpg"],
68
- ["TEST_IMG_3.jpg"],
69
- ["TEST_IMG_4.jpg"]
70
  ]
71
 
72
- # Create Gradio interface
73
- gr.Interface(
74
- fn=detect_objects,
75
- inputs=[
76
- gr.Image(type="pil"),
77
- gr.Slider(0, 1, value=0.5, label="Confidence Threshold")
78
- ],
79
- outputs=gr.Image(type="filepath"),
80
- examples=example_images,
81
- title="Faster R-CNN Object Detection",
82
- description="Upload an image to detect objects using a pretrained Faster R-CNN model."
83
- ).launch()
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+ from torchvision.models.detection import FasterRCNN_ResNet50_FPN_Weights
4
+ from PIL import Image
5
+ import numpy as np
6
+ import matplotlib.pyplot as plt
7
+ import gradio as gr
8
+ import os
9
+
10
+ # Load the pre-trained model once
11
+ model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=FasterRCNN_ResNet50_FPN_Weights.DEFAULT)
12
+ model.eval()
13
+
14
+ # COCO class names
15
+ COCO_INSTANCE_CATEGORY_NAMES = [
16
+ '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
17
+ 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
18
+ 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
19
+ 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
20
+ 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
21
+ 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
22
+ 'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
23
+ 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
24
+ 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
25
+ 'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
26
+ 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
27
+ 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
28
+ ]
29
+
30
+ # Gradio-compatible detection function
31
+ def detect_objects(image, threshold=0.5):
32
+ transform = FasterRCNN_ResNet50_FPN_Weights.DEFAULT.transforms()
33
+ image_tensor = transform(image).unsqueeze(0)
34
+
35
+ with torch.no_grad():
36
+ prediction = model(image_tensor)[0]
37
+
38
+ boxes = prediction['boxes'].cpu().numpy()
39
+ labels = prediction['labels'].cpu().numpy()
40
+ scores = prediction['scores'].cpu().numpy()
41
+
42
+ image_np = np.array(image)
43
+ plt.figure(figsize=(10, 10))
44
+ plt.imshow(image_np)
45
+ ax = plt.gca()
46
+
47
+ for box, label, score in zip(boxes, labels, scores):
48
+ if score >= threshold:
49
+ x1, y1, x2, y2 = box
50
+ ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1,
51
+ fill=False, color='red', linewidth=2))
52
+ class_name = COCO_INSTANCE_CATEGORY_NAMES[label]
53
+ ax.text(x1, y1, f'{class_name}: {score:.2f}', bbox=dict(facecolor='yellow', alpha=0.5),
54
+ fontsize=12, color='black')
55
+
56
+ plt.axis('off')
57
+ plt.tight_layout()
58
+
59
+ # Save the figure to return
60
+ plt.savefig("output.png")
61
+ plt.close()
62
+ return "output.png"
63
 
64
+ # Define the example image paths correctly for Hugging Face Spaces
65
+ # Images are in the Object-Detection subdirectory
66
+ examples_dir = os.path.join(os.path.dirname(__file__), "Object-Detection")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  example_images = [
68
+ [os.path.join(examples_dir, "TEST_IMG_1.jpg")],
69
+ [os.path.join(examples_dir, "TEST_IMG_2.JPG")], # Note: preserving the uppercase JPG extension
70
+ [os.path.join(examples_dir, "TEST_IMG_3.jpg")],
71
+ [os.path.join(examples_dir, "TEST_IMG_4.jpg")]
72
  ]
73
 
74
+ # Create Gradio interface
75
+ interface = gr.Interface(
76
+ fn=detect_objects,
77
+ inputs=[
78
+ gr.Image(type="pil"),
79
+ gr.Slider(0, 1, value=0.5, label="Confidence Threshold")
80
+ ],
81
+ outputs=gr.Image(type="filepath"),
82
+ examples=example_images,
83
+ title="Faster R-CNN Object Detection",
84
+ description="Upload an image to detect objects using a pretrained Faster R-CNN model."
85
+ )
86
+
87
+ # Launch with specific configuration for Hugging Face
88
+ if __name__ == "__main__":
89
+ interface.launch()