File size: 4,383 Bytes
2a572c2
 
a63e231
 
 
 
 
c376f3c
 
 
 
 
2a572c2
 
 
 
 
 
a63e231
 
 
2a572c2
d9c7dce
2a572c2
c376f3c
a63e231
2a572c2
 
 
 
d8d9ab6
2a572c2
 
 
 
 
 
d8d9ab6
2a572c2
 
a63e231
 
 
 
 
 
 
2a572c2
 
128e4f0
2a572c2
 
 
 
128e4f0
2a572c2
 
 
 
128e4f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c376f3c
cf87b0c
ac53593
 
 
c376f3c
cf87b0c
c376f3c
 
 
2a572c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from enum import Enum

class TrackerType(Enum):
    NONE = 0
    CONF_BOOST = 1
    BYTETRACK = 2

    def toString(val):
        if val == TrackerType.NONE: return "None"
        if val == TrackerType.CONF_BOOST: return "Confidence Boost"
        if val == TrackerType.BYTETRACK: return "ByteTrack"

### Configuration options
WEIGHTS = 'models/v5m_896_300best.pt'
# will need to configure these based on GPU hardware
BATCH_SIZE = 32

CONF_THRES = 0.05 # detection
NMS_IOU  = 0.25 # NMS IOU
MAX_AGE = 20 # time until missing fish get's new id
MIN_HITS = 11 # minimum number of frames with a specific fish for it to count
MIN_LENGTH = 0.3 # minimum fish length, in meters
MAX_LENGTH = 0 # maximum fish length, in meters
IOU_THRES = 0.01 # IOU threshold for tracking
MIN_TRAVEL = 0 # Minimum distance a track has to travel
DEFAULT_TRACKER = TrackerType.BYTETRACK

class InferenceConfig:
    def __init__(self, 
                weights=WEIGHTS, conf_thresh=CONF_THRES, nms_iou=NMS_IOU, 
                min_hits=MIN_HITS, max_age=MAX_AGE, min_length=MIN_LENGTH, max_length=MAX_LENGTH, min_travel=MIN_TRAVEL):
        self.weights = weights
        self.conf_thresh = conf_thresh
        self.nms_iou = nms_iou
        self.min_hits = min_hits
        self.max_age = max_age
        self.min_length = min_length
        self.max_length = max_length
        self.min_travel = min_travel

        self.associative_tracker = DEFAULT_TRACKER
        self.boost_power = 2
        self.boost_decay = 0.1
        self.byte_low_conf = 0.1
        self.byte_high_conf = 0.3

    def enable_sort_track(self):
        self.associative_tracker = TrackerType.NONE

    def enable_conf_boost(self, power=2, decay=0.1):
        self.associative_tracker = TrackerType.CONF_BOOST
        self.boost_power = power
        self.boost_decay = decay

    def enable_byte_track(self, low=0.1, high=0.3):
        self.associative_tracker = TrackerType.BYTETRACK
        self.byte_low_conf = low
        self.byte_high_conf = high

    def enable_tracker_from_string(self, associativity):
        if associativity != "":
            if (associativity.startswith("boost")):
                conf = associativity.split(":")
                if len(conf) == 3:
                    self.enable_conf_boost(power=float(conf[1]), decay=float(conf[2]))
                    return True
                else:
                    print("INVALID PARAMETERS FOR CONFIDENCE BOOST:", associativity)
                return False
                
            elif (associativity.startswith("bytetrack")):
                conf = associativity.split(":")
                if len(conf) == 3:
                    self.enable_byte_track(low=float(conf[1]), high=float(conf[2]))
                    return True
                else:
                    print("INVALID PARAMETERS FOR BYTETRACK:", associativity)
                return False
            else:
                print("INVALID ASSOCIATIVITY TYPE:", associativity)
                return False
        else:
            self.enable_sort_track()
            return True

    def find_model(self, model_list):
        print("weights", self.weights)
        for model_name in model_list:
            print("Path", model_list[model_name], "->", model_name)
            if model_list[model_name] == self.weights:
                return model_name
        print("not found")
        return None


    def to_dict(self):
        dict = {
            'weights': self.weights,
            'nms_iou': self.nms_iou,
            'min_hits': self.min_hits,
            'max_age': self.max_age,
            'min_length': self.min_length,
            'min_travel': self.min_travel,
        }

        # Add tracker specific parameters
        if (self.associative_tracker == TrackerType.BYTETRACK):
            dict['tracker'] = "ByteTrack"
            dict['byte_low_conf'] = self.byte_low_conf
            dict['byte_high_conf'] = self.byte_high_conf
        elif (self.associative_tracker == TrackerType.CONF_BOOST):
            dict['tracker'] = "Confidence Boost"
            dict['conf_thresh'] = self.conf_thresh
            dict['boost_power'] = self.boost_power
            dict['boost_decay'] = self.boost_decay
        elif (self.associative_tracker == TrackerType.NONE):
            dict['tracker'] = "None"
            dict['conf_thresh'] = self.conf_thresh

        return dict