Thastp commited on
Commit
ac1a6d7
·
verified ·
1 Parent(s): 0216f11

Upload model

Browse files
Files changed (2) hide show
  1. configuration_retinanet.py +35 -8
  2. modeling_retinanet.py +125 -103
configuration_retinanet.py CHANGED
@@ -1,6 +1,7 @@
1
  from transformers.configuration_utils import PretrainedConfig
2
  from optimum.exporters.onnx.model_configs import ViTOnnxConfig
3
- from typing import Optional, Dict
 
4
 
5
  class RetinaNetConfig(PretrainedConfig):
6
  model_type = 'retinanet'
@@ -19,23 +20,49 @@ class RetinaNetConfig(PretrainedConfig):
19
 
20
  super().__init__(**kwargs)
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  class RetinaNetOnnxConfig(ViTOnnxConfig):
 
 
23
  @property
24
  def inputs(self) -> Dict[str, Dict[int, str]]:
25
  return {
26
  "pixel_values": {0: "batch_size", 1: "num_channels", 2: "height", 3: "width"},
27
- "image_sizes": {0: "batch_size", 1: "image_size"}
28
  }
29
 
30
  @property
31
  def outputs(self) -> Dict[str, Dict[int, str]]:
32
- common_outputs = super().outputs
33
-
34
- if self.task == "object-detection":
35
- common_outputs["logits"] = {0: "batch_size", 1: "num_queries", 2: "num_classes"}
36
- common_outputs["pred_boxes"] = {0: "batch_size", 1: "num_queries", 2: "coordinates"}
37
 
38
- return common_outputs
 
 
 
 
 
 
39
 
40
  __all__ = [
41
  'RetinaNetConfig',
 
1
  from transformers.configuration_utils import PretrainedConfig
2
  from optimum.exporters.onnx.model_configs import ViTOnnxConfig
3
+ from optimum.utils import DummyVisionInputGenerator
4
+ from typing import Optional, Dict, OrderedDict
5
 
6
  class RetinaNetConfig(PretrainedConfig):
7
  model_type = 'retinanet'
 
20
 
21
  super().__init__(**kwargs)
22
 
23
+ class RetinaNetObjectDetectionInputGenerator(DummyVisionInputGenerator):
24
+
25
+ SUPPORTED_INPUT_NAMES = (
26
+ "pixel_values",
27
+ "image_sizes"
28
+ )
29
+
30
+ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
31
+ if input_name == "image_sizes":
32
+ return self.random_int_tensor(
33
+ shape=[self.batch_size, 2],
34
+ min_value=1,
35
+ max_value=max(self.height, self.width),
36
+ framework=framework,
37
+ dtype=int_dtype,
38
+ )
39
+
40
+ elif input_name == "pixel_values":
41
+ return self.random_float_tensor(
42
+ shape=[self.batch_size, self.num_channels, self.height, self.width],
43
+ framework=framework,
44
+ dtype=float_dtype,
45
+ )
46
+
47
  class RetinaNetOnnxConfig(ViTOnnxConfig):
48
+ DUMMY_INPUT_GENERATOR_CLASSES = (RetinaNetObjectDetectionInputGenerator,)
49
+
50
  @property
51
  def inputs(self) -> Dict[str, Dict[int, str]]:
52
  return {
53
  "pixel_values": {0: "batch_size", 1: "num_channels", 2: "height", 3: "width"},
 
54
  }
55
 
56
  @property
57
  def outputs(self) -> Dict[str, Dict[int, str]]:
 
 
 
 
 
58
 
59
+ return OrderedDict(
60
+ {
61
+ "boxes": {0: "batch_size", 1: "num_predictions", 2: "bbox_coordinates"},
62
+ "labels": {0: "batch_size", 1: "num_predictions"},
63
+ "scores": {0: "batch_size", 1: "num_predictions"},
64
+ }
65
+ )
66
 
67
  __all__ = [
68
  'RetinaNetConfig',
modeling_retinanet.py CHANGED
@@ -1,104 +1,126 @@
1
- import torch
2
- from dataclasses import dataclass
3
- from torchvision.models import ResNet50_Weights
4
- from torchvision.models.detection import retinanet_resnet50_fpn, RetinaNet_ResNet50_FPN_Weights
5
- from torchvision.models.detection.anchor_utils import AnchorGenerator
6
-
7
- from transformers import PreTrainedModel
8
- from transformers.utils import ModelOutput
9
- from typing import OrderedDict, List, Tuple
10
-
11
- from .configuration_retinanet import RetinaNetConfig
12
-
13
- def _default_anchorgen():
14
- anchor_sizes = tuple((x, int(x * 2 ** (1.0 / 3)), int(x * 2 ** (2.0 / 3))) for x in [32, 64, 128, 256, 512])
15
- aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
16
- anchor_generator = RetinaNetAnchorGenerator(anchor_sizes, aspect_ratios)
17
- return anchor_generator
18
-
19
- @dataclass
20
- class RetinaNetObjectDetectionOutput(ModelOutput):
21
- logits: torch.FloatTensor = None
22
- pred_boxes: torch.FloatTensor = None
23
- image_sizes: List[Tuple] = None
24
- anchors: List[torch.Tensor] = None
25
- features: List[torch.Tensor] = None
26
-
27
- class RetinaNetAnchorGenerator(AnchorGenerator):
28
- def __init__(
29
- self,
30
- sizes=((128, 256, 512),),
31
- aspect_ratios=((0.5, 1.0, 2.0),)
32
- ):
33
- super().__init__(sizes, aspect_ratios)
34
-
35
- def forward(self, pixel_values: torch.Tensor, feature_maps: List[torch.Tensor]) -> List[torch.Tensor]:
36
- grid_sizes = [feature_map.shape[-2:] for feature_map in feature_maps]
37
- image_size = pixel_values.shape[-2:]
38
- dtype, device = feature_maps[0].dtype, feature_maps[0].device
39
- strides = [
40
- [
41
- torch.empty((), dtype=torch.int64, device=device).fill_(image_size[0] // g[0]),
42
- torch.empty((), dtype=torch.int64, device=device).fill_(image_size[1] // g[1]),
43
- ]
44
- for g in grid_sizes
45
- ]
46
- self.set_cell_anchors(dtype, device)
47
- anchors_over_all_feature_maps = self.grid_anchors(grid_sizes, strides)
48
- anchors: List[List[torch.Tensor]] = []
49
- for _ in range(pixel_values.shape[0]):
50
- anchors_in_image = [anchors_per_feature_map for anchors_per_feature_map in anchors_over_all_feature_maps]
51
- anchors.append(anchors_in_image)
52
- anchors = [torch.cat(anchors_per_image) for anchors_per_image in anchors]
53
- return anchors
54
-
55
- class RetinaNetModelForObjectDetection(PreTrainedModel):
56
- config_class = RetinaNetConfig
57
-
58
- def __init__(self, config):
59
- super().__init__(config)
60
-
61
- self.config = config
62
-
63
- model_config = {
64
- 'weights': None,
65
- 'weights_backbone': None,
66
- 'num_classes': None
67
- }
68
-
69
- if config.pretrained:
70
- model_config['weights'] = RetinaNet_ResNet50_FPN_Weights.DEFAULT
71
- else:
72
- model_config['num_classes'] = config.num_classes
73
- if config.pretrained_backbone:
74
- model_config['weights_backbone'] = ResNet50_Weights.DEFAULT
75
-
76
- self.model = retinanet_resnet50_fpn(**model_config)
77
- self.model.anchor_generator = _default_anchorgen()
78
-
79
- def forward(self, pixel_values, image_sizes, labels=None):
80
- if labels is not None:
81
- raise NotImplementedError
82
-
83
- features = self.model.backbone(pixel_values)
84
- if isinstance(features, torch.Tensor):
85
- features = OrderedDict([("0", features)])
86
- features = list(features.values())
87
-
88
- # compute the retinanet heads outputs using the features
89
- head_outputs = self.model.head(features)
90
-
91
- # create the set of anchors
92
- anchors = self.model.anchor_generator(pixel_values, features)
93
-
94
- return RetinaNetObjectDetectionOutput(
95
- logits=head_outputs['cls_logits'],
96
- pred_boxes=head_outputs['bbox_regression'],
97
- image_sizes=image_sizes,
98
- anchors=anchors,
99
- features=features
100
- )
101
-
102
- __all__ = [
103
- "RetinaNetModelForObjectDetection"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  ]
 
1
+ import torch
2
+ from dataclasses import dataclass
3
+ from torchvision.models import ResNet50_Weights
4
+ from torchvision.models.detection import retinanet_resnet50_fpn, RetinaNet_ResNet50_FPN_Weights
5
+ from torchvision.models.detection.anchor_utils import AnchorGenerator
6
+
7
+ from transformers import PreTrainedModel
8
+ from transformers.utils import ModelOutput
9
+ from typing import OrderedDict, List, Union
10
+
11
+ from configuration_retinanet import RetinaNetConfig
12
+
13
+ def _default_anchorgen():
14
+ anchor_sizes = tuple((x, int(x * 2 ** (1.0 / 3)), int(x * 2 ** (2.0 / 3))) for x in [32, 64, 128, 256, 512])
15
+ aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
16
+ anchor_generator = RetinaNetAnchorGenerator(anchor_sizes, aspect_ratios)
17
+ return anchor_generator
18
+
19
+ @dataclass
20
+ class RetinaNetObjectDetectionOutput(ModelOutput):
21
+ logits: torch.FloatTensor = None
22
+ pred_boxes: torch.FloatTensor = None
23
+ image_sizes: torch.Tensor = None
24
+ anchors: torch.Tensor = None
25
+ num_anchors_per_level: torch.Tensor = None
26
+
27
+ class RetinaNetAnchorGenerator(AnchorGenerator):
28
+ def __init__(
29
+ self,
30
+ sizes=((128, 256, 512),),
31
+ aspect_ratios=((0.5, 1.0, 2.0),)
32
+ ):
33
+ super().__init__(sizes, aspect_ratios)
34
+
35
+ def forward(self, pixel_values: torch.Tensor, feature_maps: List[torch.Tensor]) -> List[torch.Tensor]:
36
+ grid_sizes = [feature_map.shape[-2:] for feature_map in feature_maps]
37
+ image_size = pixel_values.shape[-2:]
38
+ dtype, device = feature_maps[0].dtype, feature_maps[0].device
39
+ strides = [
40
+ [
41
+ torch.empty((), dtype=torch.int64, device=device).fill_(image_size[0] // g[0]),
42
+ torch.empty((), dtype=torch.int64, device=device).fill_(image_size[1] // g[1]),
43
+ ]
44
+ for g in grid_sizes
45
+ ]
46
+ self.set_cell_anchors(dtype, device)
47
+ anchors_over_all_feature_maps = self.grid_anchors(grid_sizes, strides)
48
+ anchors: List[List[torch.Tensor]] = []
49
+ for _ in range(pixel_values.shape[0]):
50
+ anchors_in_image = [anchors_per_feature_map for anchors_per_feature_map in anchors_over_all_feature_maps]
51
+ anchors.append(anchors_in_image)
52
+ anchors = [torch.cat(anchors_per_image) for anchors_per_image in anchors]
53
+ return anchors
54
+
55
+ class RetinaNetModelForObjectDetection(PreTrainedModel):
56
+ config_class = RetinaNetConfig
57
+
58
+ def __init__(self, config):
59
+ super().__init__(config)
60
+
61
+ self.config = config
62
+
63
+ model_config = {
64
+ 'weights': None,
65
+ 'weights_backbone': None,
66
+ 'num_classes': None
67
+ }
68
+
69
+ if config.pretrained:
70
+ model_config['weights'] = RetinaNet_ResNet50_FPN_Weights.DEFAULT
71
+ else:
72
+ model_config['num_classes'] = config.num_classes
73
+ if config.pretrained_backbone:
74
+ model_config['weights_backbone'] = ResNet50_Weights.DEFAULT
75
+
76
+ self.model = retinanet_resnet50_fpn(**model_config)
77
+
78
+
79
+ def forward_without_processing(self, pixel_values, image_sizes=None, labels=None):
80
+ if labels is not None:
81
+ raise NotImplementedError
82
+ else:
83
+ self.model.training = False
84
+
85
+ if image_sizes is None:
86
+ # construct a tensor [batchsize,2] of value pixel_values.shape[-2:]
87
+ raise NotImplementedError
88
+
89
+ features = self.model.backbone(pixel_values)
90
+ if isinstance(features, torch.Tensor):
91
+ features = OrderedDict([("0", features)])
92
+ features = list(features.values())
93
+
94
+ # compute the retinanet heads outputs using the features
95
+ head_outputs = self.model.head(features)
96
+
97
+ # create the set of anchors
98
+ self.model.anchor_generator = _default_anchorgen()
99
+ anchors = self.model.anchor_generator(pixel_values, features)
100
+ num_anchors_per_level = [x.size(2) * x.size(3) for x in features]
101
+
102
+ return RetinaNetObjectDetectionOutput(
103
+ logits=head_outputs['cls_logits'],
104
+ pred_boxes=head_outputs['bbox_regression'],
105
+ image_sizes=image_sizes,
106
+ anchors=torch.stack(anchors, dim=0),
107
+ num_anchors_per_level=torch.tensor(num_anchors_per_level)
108
+ )
109
+
110
+
111
+ def forward(self, pixel_values: Union[torch.Tensor, List[torch.Tensor]], labels=None):
112
+ """
113
+ Don't use preprocessor for calling the main forward function.
114
+ """
115
+ if labels is not None:
116
+ raise NotImplementedError
117
+ else:
118
+ self.model.training = False
119
+
120
+ detections = self.model(pixel_values, labels)
121
+
122
+ return detections
123
+
124
+ __all__ = [
125
+ "RetinaNetModelForObjectDetection"
126
  ]