CuriousDolphin commited on
Commit
41eedc7
·
1 Parent(s): 1d636f3

first commit (wip)

Browse files
.gitignore CHANGED
@@ -1,160 +1,3 @@
1
- # Byte-compiled / optimized / DLL files
2
- __pycache__/
3
- *.py[cod]
4
- *$py.class
5
-
6
- # C extensions
7
- *.so
8
-
9
- # Distribution / packaging
10
- .Python
11
- build/
12
- develop-eggs/
13
- dist/
14
- downloads/
15
- eggs/
16
- .eggs/
17
- lib/
18
- lib64/
19
- parts/
20
- sdist/
21
- var/
22
- wheels/
23
- share/python-wheels/
24
- *.egg-info/
25
- .installed.cfg
26
- *.egg
27
- MANIFEST
28
-
29
- # PyInstaller
30
- # Usually these files are written by a python script from a template
31
- # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
- *.manifest
33
- *.spec
34
-
35
- # Installer logs
36
- pip-log.txt
37
- pip-delete-this-directory.txt
38
-
39
- # Unit test / coverage reports
40
- htmlcov/
41
- .tox/
42
- .nox/
43
- .coverage
44
- .coverage.*
45
- .cache
46
- nosetests.xml
47
- coverage.xml
48
- *.cover
49
- *.py,cover
50
- .hypothesis/
51
- .pytest_cache/
52
- cover/
53
-
54
- # Translations
55
- *.mo
56
- *.pot
57
-
58
- # Django stuff:
59
- *.log
60
- local_settings.py
61
- db.sqlite3
62
- db.sqlite3-journal
63
-
64
- # Flask stuff:
65
- instance/
66
- .webassets-cache
67
-
68
- # Scrapy stuff:
69
- .scrapy
70
-
71
- # Sphinx documentation
72
- docs/_build/
73
-
74
- # PyBuilder
75
- .pybuilder/
76
- target/
77
-
78
- # Jupyter Notebook
79
- .ipynb_checkpoints
80
-
81
- # IPython
82
- profile_default/
83
- ipython_config.py
84
-
85
- # pyenv
86
- # For a library or package, you might want to ignore these files since the code is
87
- # intended to run in multiple environments; otherwise, check them in:
88
- # .python-version
89
-
90
- # pipenv
91
- # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
- # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
- # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
- # install all needed dependencies.
95
- #Pipfile.lock
96
-
97
- # poetry
98
- # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
- # This is especially recommended for binary packages to ensure reproducibility, and is more
100
- # commonly ignored for libraries.
101
- # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
- #poetry.lock
103
-
104
- # pdm
105
- # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
- #pdm.lock
107
- # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
- # in version control.
109
- # https://pdm.fming.dev/#use-with-ide
110
- .pdm.toml
111
-
112
- # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
- __pypackages__/
114
-
115
- # Celery stuff
116
- celerybeat-schedule
117
- celerybeat.pid
118
-
119
- # SageMath parsed files
120
- *.sage.py
121
-
122
- # Environments
123
- .env
124
  .venv
125
- env/
126
- venv/
127
- ENV/
128
- env.bak/
129
- venv.bak/
130
-
131
- # Spyder project settings
132
- .spyderproject
133
- .spyproject
134
-
135
- # Rope project settings
136
- .ropeproject
137
-
138
- # mkdocs documentation
139
- /site
140
-
141
- # mypy
142
- .mypy_cache/
143
- .dmypy.json
144
- dmypy.json
145
-
146
- # Pyre type checker
147
- .pyre/
148
-
149
- # pytype static type analyzer
150
- .pytype/
151
-
152
- # Cython debug symbols
153
- cython_debug/
154
-
155
- # PyCharm
156
- # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
- # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
- # and can be added to the global gitignore or merged into this file. For a more nuclear
159
- # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
- #.idea/
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  .venv
2
+ __pycache__
3
+ data/cache
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.vscode/settings.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "flake8.args": [
3
+ "--ignore=E24,E128,E201,E202,E225,E231,E252,E265,E302,E303,E401,E402,E501,E731,W504,W605",
4
+ "--verbose"
5
+ ],
6
+ }
Dockerfile ADDED
File without changes
data/assets/000000039769.jpg ADDED
data/assets/dog_bike_car.jpeg ADDED
detr/__init__.py ADDED
File without changes
detr/detr.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import cache
2
+ import torch
3
+ import torchvision.transforms as T
4
+ import os
5
+ import numpy as np
6
+ from torch import nn
7
+ from torchvision.models import resnet50
8
+
9
+ from supervision import Detections, BoxAnnotator
10
+
11
+ torch.set_grad_enabled(False)
12
+
13
+
14
+ # https://colab.research.google.com/github/facebookresearch/detr/blob/colab/notebooks/detr_demo.ipynb#scrollTo=cfCcEYjg7y46
15
+
16
+ DETR_DEMO_WEIGHTS_URI = "https://dl.fbaipublicfiles.com/detr/detr_demo-da2a99e9.pth"
17
+
18
+ TORCH_HOME = os.path.abspath(os.curdir) + "/data/cache"
19
+
20
+ os.environ["TORCH_HOME"] = TORCH_HOME
21
+
22
+ print("Torch home:", TORCH_HOME)
23
+
24
+
25
+ # standard PyTorch mean-std input image normalization
26
+
27
+
28
+ def normalize_img(image):
29
+ transform = T.Compose(
30
+ [
31
+ T.Resize(800),
32
+ T.ToTensor(),
33
+ T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
34
+ ]
35
+ )
36
+ return transform(image).unsqueeze(0)
37
+
38
+
39
+ # for output bounding box post-processing
40
+ def box_cxcywh_to_xyxy(x):
41
+ x_c, y_c, w, h = x.unbind(1)
42
+ b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
43
+ return torch.stack(b, dim=1)
44
+
45
+
46
+ def rescale_bboxes(out_bbox, size):
47
+ img_w, img_h = size
48
+ b = box_cxcywh_to_xyxy(out_bbox)
49
+ b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
50
+ return b
51
+
52
+
53
+ class DETRdemo(nn.Module):
54
+ """
55
+ Demo DETR implementation.
56
+
57
+ Demo implementation of DETR in minimal number of lines, with the
58
+ following differences wrt DETR in the paper:
59
+ * learned positional encoding (instead of sine)
60
+ * positional encoding is passed at input (instead of attention)
61
+ * fc bbox predictor (instead of MLP)
62
+ The model achieves ~40 AP on COCO val5k and runs at ~28 FPS on Tesla V100.
63
+ Only batch size 1 supported.
64
+ """
65
+
66
+ def __init__(
67
+ self,
68
+ num_classes,
69
+ hidden_dim=256,
70
+ nheads=8,
71
+ num_encoder_layers=6,
72
+ num_decoder_layers=6,
73
+ ):
74
+ super().__init__()
75
+
76
+ # create ResNet-50 backbone
77
+ self.backbone = resnet50()
78
+ del self.backbone.fc
79
+
80
+ # create conversion layer
81
+ self.conv = nn.Conv2d(2048, hidden_dim, 1)
82
+
83
+ # create a default PyTorch transformer
84
+ self.transformer = nn.Transformer(
85
+ hidden_dim, nheads, num_encoder_layers, num_decoder_layers
86
+ )
87
+
88
+ # prediction heads, one extra class for predicting non-empty slots
89
+ # note that in baseline DETR linear_bbox layer is 3-layer MLP
90
+ self.linear_class = nn.Linear(hidden_dim, num_classes + 1)
91
+ self.linear_bbox = nn.Linear(hidden_dim, 4)
92
+
93
+ # output positional encodings (object queries)
94
+ self.query_pos = nn.Parameter(torch.rand(100, hidden_dim))
95
+
96
+ # spatial positional encodings
97
+ # note that in baseline DETR we use sine positional encodings
98
+ self.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
99
+ self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
100
+
101
+ def forward(self, inputs):
102
+ # propagate inputs through ResNet-50 up to avg-pool layer
103
+ x = self.backbone.conv1(inputs)
104
+ x = self.backbone.bn1(x)
105
+ x = self.backbone.relu(x)
106
+ x = self.backbone.maxpool(x)
107
+
108
+ x = self.backbone.layer1(x)
109
+ x = self.backbone.layer2(x)
110
+ x = self.backbone.layer3(x)
111
+ x = self.backbone.layer4(x)
112
+
113
+ # convert from 2048 to 256 feature planes for the transformer
114
+ h = self.conv(x)
115
+
116
+ # construct positional encodings
117
+ H, W = h.shape[-2:]
118
+ pos = (
119
+ torch.cat(
120
+ [
121
+ self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),
122
+ self.row_embed[:H].unsqueeze(1).repeat(1, W, 1),
123
+ ],
124
+ dim=-1,
125
+ )
126
+ .flatten(0, 1)
127
+ .unsqueeze(1)
128
+ )
129
+
130
+ # propagate through the transformer
131
+ h = self.transformer(
132
+ pos + 0.1 * h.flatten(2).permute(2, 0, 1), self.query_pos.unsqueeze(1)
133
+ ).transpose(0, 1)
134
+
135
+ # finally project transformer outputs to class labels and bounding boxes
136
+ return {
137
+ "pred_logits": self.linear_class(h),
138
+ "pred_boxes": self.linear_bbox(h).sigmoid(),
139
+ }
140
+
141
+
142
+ class SimpleDetr:
143
+ @cache
144
+ def __init__(self):
145
+ self.model = DETRdemo(num_classes=91)
146
+ state_dict = torch.hub.load_state_dict_from_url(
147
+ url=DETR_DEMO_WEIGHTS_URI,
148
+ map_location="cpu",
149
+ check_hash=True,
150
+ )
151
+ self.model.load_state_dict(state_dict)
152
+ self.model.eval()
153
+ self.box_annotator: BoxAnnotator = BoxAnnotator()
154
+
155
+ def detect(self, image, conf):
156
+ # mean-std normalize the input image (batch-size: 1)
157
+ img = normalize_img(image)
158
+
159
+ # demo model only support by default images with aspect ratio between 0.5 and 2
160
+ # if you want to use images with an aspect ratio outside this range
161
+ # rescale your image so that the maximum size is at most 1333 for best results
162
+ assert (
163
+ img.shape[-2] <= 1600 and img.shape[-1] <= 1600
164
+ ), "demo model only supports images up to 1600 pixels on each side"
165
+
166
+ # propagate through the model
167
+ outputs = self.model(img)
168
+ # keep only predictions with 0.7+ confidence
169
+ scores = outputs["pred_logits"].softmax(-1)[0, :, :-1]
170
+ keep = scores.max(-1).values > conf
171
+ # convert boxes from [0; 1] to image scales
172
+ bboxes_scaled = rescale_bboxes(outputs["pred_boxes"][0, keep], image.size)
173
+ probas = scores[keep]
174
+ class_id = []
175
+ confidence = []
176
+ for prob in probas:
177
+ cls_id = prob.argmax()
178
+ c = prob[cls_id]
179
+ class_id.append(int(cls_id))
180
+ confidence.append(float(c))
181
+ print(class_id, confidence)
182
+ detections = Detections(
183
+ xyxy=bboxes_scaled.cpu().detach().numpy(),
184
+ class_id=np.array(class_id),
185
+ confidence=np.array(confidence),
186
+ )
187
+ annotated = self.box_annotator.annotate(
188
+ scene=np.array(image),
189
+ skip_label=False,
190
+ detections=detections,
191
+ labels=[
192
+ f"{CLASSES[cls_id]} {conf:.2f}"
193
+ for cls_id, conf in zip(detections.class_id, detections.confidence)
194
+ ],
195
+ )
196
+ return annotated
197
+
198
+
199
+ class PanopticDetrResenet101:
200
+ @cache
201
+ def __init__(self):
202
+ model, postprocessor = torch.hub.load(
203
+ "facebookresearch/detr",
204
+ "detr_resnet101_panoptic",
205
+ pretrained=True,
206
+ return_postprocessor=True,
207
+ num_classes=250,
208
+ )
209
+ model.eval()
210
+
211
+ def detect(self, image, conf):
212
+ # mean-std normalize the input image (batch-size: 1)
213
+ img = normalize_img(image)
214
+
215
+ outputs = self.model(img)
216
+ # keep only predictions with 0.7+ confidence
217
+ # compute the scores, excluding the "no-object" class (the last one)
218
+ scores = outputs["pred_logits"].softmax(-1)[..., :-1].max(-1)[0]
219
+ # threshold the confidence
220
+ keep = scores > conf
221
+
222
+
223
+ # COCO classes
224
+ CLASSES = [
225
+ "N/A",
226
+ "person",
227
+ "bicycle",
228
+ "car",
229
+ "motorcycle",
230
+ "airplane",
231
+ "bus",
232
+ "train",
233
+ "truck",
234
+ "boat",
235
+ "traffic light",
236
+ "fire hydrant",
237
+ "N/A",
238
+ "stop sign",
239
+ "parking meter",
240
+ "bench",
241
+ "bird",
242
+ "cat",
243
+ "dog",
244
+ "horse",
245
+ "sheep",
246
+ "cow",
247
+ "elephant",
248
+ "bear",
249
+ "zebra",
250
+ "giraffe",
251
+ "N/A",
252
+ "backpack",
253
+ "umbrella",
254
+ "N/A",
255
+ "N/A",
256
+ "handbag",
257
+ "tie",
258
+ "suitcase",
259
+ "frisbee",
260
+ "skis",
261
+ "snowboard",
262
+ "sports ball",
263
+ "kite",
264
+ "baseball bat",
265
+ "baseball glove",
266
+ "skateboard",
267
+ "surfboard",
268
+ "tennis racket",
269
+ "bottle",
270
+ "N/A",
271
+ "wine glass",
272
+ "cup",
273
+ "fork",
274
+ "knife",
275
+ "spoon",
276
+ "bowl",
277
+ "banana",
278
+ "apple",
279
+ "sandwich",
280
+ "orange",
281
+ "broccoli",
282
+ "carrot",
283
+ "hot dog",
284
+ "pizza",
285
+ "donut",
286
+ "cake",
287
+ "chair",
288
+ "couch",
289
+ "potted plant",
290
+ "bed",
291
+ "N/A",
292
+ "dining table",
293
+ "N/A",
294
+ "N/A",
295
+ "toilet",
296
+ "N/A",
297
+ "tv",
298
+ "laptop",
299
+ "mouse",
300
+ "remote",
301
+ "keyboard",
302
+ "cell phone",
303
+ "microwave",
304
+ "oven",
305
+ "toaster",
306
+ "sink",
307
+ "refrigerator",
308
+ "N/A",
309
+ "book",
310
+ "clock",
311
+ "vase",
312
+ "scissors",
313
+ "teddy bear",
314
+ "hair drier",
315
+ "toothbrush",
316
+ ]
detr/main_gradio.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import supervision as sv
3
+ import os
4
+ from detr import SimpleDetr, PanopticDetrResenet101
5
+
6
+ ASSETS_DIR = os.path.abspath(os.curdir) + "/data/assets"
7
+
8
+ print("Assets:", ASSETS_DIR)
9
+
10
+
11
+ def run_inference(image, confidence, model_name, progress=gr.Progress(track_tqdm=True)):
12
+ progress(0.1, "loading model..")
13
+
14
+ if model_name == "detr_demo_boxes":
15
+ model = SimpleDetr()
16
+ else:
17
+ model = PanopticDetrResenet101()
18
+ progress(0.1, "Inference..")
19
+
20
+ annotated_img = model.detect(image, confidence)
21
+ return annotated_img, None, None
22
+
23
+
24
+ with gr.Blocks() as inference_gradio:
25
+ gr.Markdown("# DETR inference")
26
+ with gr.Row():
27
+ with gr.Column():
28
+ img_file = gr.Image(type="pil")
29
+ # with gr.Row():
30
+ model_name = gr.Dropdown(
31
+ label="Model",
32
+ scale=3,
33
+ choices=["detr_demo_boxes", "detr_resnet101_panoptic"],
34
+ value="detr_demo_boxes",
35
+ )
36
+
37
+ conf = gr.Slider(label="Confidence", minimum=0, maximum=0.99, value=0.5)
38
+
39
+ with gr.Row():
40
+ start_btn = gr.Button("Start", variant="primary")
41
+
42
+ with gr.Column():
43
+ annotated_img = gr.Image(label="Annotated Image")
44
+ speed = gr.JSON(label="speed")
45
+ json_out = gr.JSON(label="output")
46
+ examples = gr.Examples(
47
+ examples=[
48
+ [path]
49
+ for path in sv.list_files_with_extensions(
50
+ directory=ASSETS_DIR, extensions=["jpeg", "jpg"]
51
+ )
52
+ ],
53
+ inputs=[img_file],
54
+ )
55
+ start_btn.click(
56
+ fn=run_inference,
57
+ inputs=[img_file, conf, model_name],
58
+ outputs=[annotated_img, speed, json_out],
59
+ )
60
+
61
+ if __name__ == "__main__":
62
+ inference_gradio.queue(2).launch(
63
+ debug=True,
64
+ server_name="0.0.0.0",
65
+ server_port=7000,
66
+ )
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio==4.8.0
2
+ torch==2.1.1
3
+ numpy
4
+ matplotlib
5
+ torchvision
6
+ supervision==0.17.1