File size: 1,936 Bytes
851751e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from pathlib import Path
from typing import Optional, List, Dict, Union, Any
import warnings

from torch.utils.data import Dataset

from .conditional_builder.objects_bbox import ObjectsBoundingBoxConditionalBuilder
from .conditional_builder.objects_center_points import ObjectsCenterPointsConditionalBuilder


class Annotated3DObjectsDataset(Dataset):
    def __init__(self, min_objects_per_image: int,
                 max_objects_per_image: int, no_tokens: int, num_beams: int, cats: List[str],
                 cat_blacklist: Optional[List[str]] = None, **kwargs):
        self.min_objects_per_image = min_objects_per_image
        self.max_objects_per_image = max_objects_per_image
        self.no_tokens = no_tokens
        self.num_beams = num_beams

        self.categories = [c for c in cats if c not in cat_blacklist] if cat_blacklist is not None else cats
        self._conditional_builders = None

    @property
    def no_classes(self) -> int:
        return len(self.categories)

    @property
    def conditional_builders(self) -> ObjectsCenterPointsConditionalBuilder:
        # cannot set this up in init because no_classes is only known after loading data in init of superclass
        if self._conditional_builders is None:
            self._conditional_builders = {
                'center': ObjectsCenterPointsConditionalBuilder(
                    self.no_classes,
                    self.max_objects_per_image,
                    self.no_tokens,
                    self.num_beams
                ),
                'bbox': ObjectsBoundingBoxConditionalBuilder(
                    self.no_classes,
                    self.max_objects_per_image,
                    self.no_tokens,
                    self.num_beams
                )
            }
        return self._conditional_builders

    def get_textual_label_for_category_id(self, category_id: int) -> str:
        return self.categories[category_id]