Spaces:
Runtime error
Runtime error
from enum import Enum | |
class TrackerType(Enum): | |
NONE = 0 | |
CONF_BOOST = 1 | |
BYTETRACK = 2 | |
def toString(val): | |
if val == TrackerType.NONE: return "None" | |
if val == TrackerType.CONF_BOOST: return "Confidence Boost" | |
if val == TrackerType.BYTETRACK: return "ByteTrack" | |
### Configuration options | |
WEIGHTS = 'models/v5m_896_300best.pt' | |
# will need to configure these based on GPU hardware | |
BATCH_SIZE = 32 | |
CONF_THRES = 0.05 # detection | |
NMS_IOU = 0.25 # NMS IOU | |
MAX_AGE = 20 # time until missing fish get's new id | |
MIN_HITS = 11 # minimum number of frames with a specific fish for it to count | |
MIN_LENGTH = 0.3 # minimum fish length, in meters | |
MAX_LENGTH = 0 # maximum fish length, in meters | |
IOU_THRES = 0.01 # IOU threshold for tracking | |
MIN_TRAVEL = 0 # Minimum distance a track has to travel | |
DEFAULT_TRACKER = TrackerType.BYTETRACK | |
class InferenceConfig: | |
def __init__(self, | |
weights=WEIGHTS, conf_thresh=CONF_THRES, nms_iou=NMS_IOU, | |
min_hits=MIN_HITS, max_age=MAX_AGE, min_length=MIN_LENGTH, max_length=MAX_LENGTH, min_travel=MIN_TRAVEL): | |
self.weights = weights | |
self.conf_thresh = conf_thresh | |
self.nms_iou = nms_iou | |
self.min_hits = min_hits | |
self.max_age = max_age | |
self.min_length = min_length | |
self.max_length = max_length | |
self.min_travel = min_travel | |
self.associative_tracker = DEFAULT_TRACKER | |
self.boost_power = 2 | |
self.boost_decay = 0.1 | |
self.byte_low_conf = 0.1 | |
self.byte_high_conf = 0.3 | |
def enable_sort_track(self): | |
self.associative_tracker = TrackerType.NONE | |
def enable_conf_boost(self, power=2, decay=0.1): | |
self.associative_tracker = TrackerType.CONF_BOOST | |
self.boost_power = power | |
self.boost_decay = decay | |
def enable_byte_track(self, low=0.1, high=0.3): | |
self.associative_tracker = TrackerType.BYTETRACK | |
self.byte_low_conf = low | |
self.byte_high_conf = high | |
def enable_tracker_from_string(self, associativity): | |
if associativity != "": | |
if (associativity.startswith("boost")): | |
conf = associativity.split(":") | |
if len(conf) == 3: | |
self.enable_conf_boost(power=float(conf[1]), decay=float(conf[2])) | |
return True | |
else: | |
print("INVALID PARAMETERS FOR CONFIDENCE BOOST:", associativity) | |
return False | |
elif (associativity.startswith("bytetrack")): | |
conf = associativity.split(":") | |
if len(conf) == 3: | |
self.enable_byte_track(low=float(conf[1]), high=float(conf[2])) | |
return True | |
else: | |
print("INVALID PARAMETERS FOR BYTETRACK:", associativity) | |
return False | |
else: | |
print("INVALID ASSOCIATIVITY TYPE:", associativity) | |
return False | |
else: | |
self.enable_sort_track() | |
return True | |
def find_model(self, model_list): | |
print("weights", self.weights) | |
for model_name in model_list: | |
print("Path", model_list[model_name], "->", model_name) | |
if model_list[model_name] == self.weights: | |
return model_name | |
print("not found") | |
return None | |
def to_dict(self): | |
dict = { | |
'weights': self.weights, | |
'nms_iou': self.nms_iou, | |
'min_hits': self.min_hits, | |
'max_age': self.max_age, | |
'min_length': self.min_length, | |
'min_travel': self.min_travel, | |
} | |
# Add tracker specific parameters | |
if (self.associative_tracker == TrackerType.BYTETRACK): | |
dict['tracker'] = "ByteTrack" | |
dict['byte_low_conf'] = self.byte_low_conf | |
dict['byte_high_conf'] = self.byte_high_conf | |
elif (self.associative_tracker == TrackerType.CONF_BOOST): | |
dict['tracker'] = "Confidence Boost" | |
dict['conf_thresh'] = self.conf_thresh | |
dict['boost_power'] = self.boost_power | |
dict['boost_decay'] = self.boost_decay | |
elif (self.associative_tracker == TrackerType.NONE): | |
dict['tracker'] = "None" | |
dict['conf_thresh'] = self.conf_thresh | |
return dict |