Spaces:
Runtime error
Runtime error
File size: 4,383 Bytes
2a572c2 a63e231 c376f3c 2a572c2 a63e231 2a572c2 d9c7dce 2a572c2 c376f3c a63e231 2a572c2 d8d9ab6 2a572c2 d8d9ab6 2a572c2 a63e231 2a572c2 128e4f0 2a572c2 128e4f0 2a572c2 128e4f0 c376f3c cf87b0c ac53593 c376f3c cf87b0c c376f3c 2a572c2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
from enum import Enum
class TrackerType(Enum):
NONE = 0
CONF_BOOST = 1
BYTETRACK = 2
def toString(val):
if val == TrackerType.NONE: return "None"
if val == TrackerType.CONF_BOOST: return "Confidence Boost"
if val == TrackerType.BYTETRACK: return "ByteTrack"
### Configuration options
WEIGHTS = 'models/v5m_896_300best.pt'
# will need to configure these based on GPU hardware
BATCH_SIZE = 32
CONF_THRES = 0.05 # detection
NMS_IOU = 0.25 # NMS IOU
MAX_AGE = 20 # time until missing fish get's new id
MIN_HITS = 11 # minimum number of frames with a specific fish for it to count
MIN_LENGTH = 0.3 # minimum fish length, in meters
MAX_LENGTH = 0 # maximum fish length, in meters
IOU_THRES = 0.01 # IOU threshold for tracking
MIN_TRAVEL = 0 # Minimum distance a track has to travel
DEFAULT_TRACKER = TrackerType.BYTETRACK
class InferenceConfig:
def __init__(self,
weights=WEIGHTS, conf_thresh=CONF_THRES, nms_iou=NMS_IOU,
min_hits=MIN_HITS, max_age=MAX_AGE, min_length=MIN_LENGTH, max_length=MAX_LENGTH, min_travel=MIN_TRAVEL):
self.weights = weights
self.conf_thresh = conf_thresh
self.nms_iou = nms_iou
self.min_hits = min_hits
self.max_age = max_age
self.min_length = min_length
self.max_length = max_length
self.min_travel = min_travel
self.associative_tracker = DEFAULT_TRACKER
self.boost_power = 2
self.boost_decay = 0.1
self.byte_low_conf = 0.1
self.byte_high_conf = 0.3
def enable_sort_track(self):
self.associative_tracker = TrackerType.NONE
def enable_conf_boost(self, power=2, decay=0.1):
self.associative_tracker = TrackerType.CONF_BOOST
self.boost_power = power
self.boost_decay = decay
def enable_byte_track(self, low=0.1, high=0.3):
self.associative_tracker = TrackerType.BYTETRACK
self.byte_low_conf = low
self.byte_high_conf = high
def enable_tracker_from_string(self, associativity):
if associativity != "":
if (associativity.startswith("boost")):
conf = associativity.split(":")
if len(conf) == 3:
self.enable_conf_boost(power=float(conf[1]), decay=float(conf[2]))
return True
else:
print("INVALID PARAMETERS FOR CONFIDENCE BOOST:", associativity)
return False
elif (associativity.startswith("bytetrack")):
conf = associativity.split(":")
if len(conf) == 3:
self.enable_byte_track(low=float(conf[1]), high=float(conf[2]))
return True
else:
print("INVALID PARAMETERS FOR BYTETRACK:", associativity)
return False
else:
print("INVALID ASSOCIATIVITY TYPE:", associativity)
return False
else:
self.enable_sort_track()
return True
def find_model(self, model_list):
print("weights", self.weights)
for model_name in model_list:
print("Path", model_list[model_name], "->", model_name)
if model_list[model_name] == self.weights:
return model_name
print("not found")
return None
def to_dict(self):
dict = {
'weights': self.weights,
'nms_iou': self.nms_iou,
'min_hits': self.min_hits,
'max_age': self.max_age,
'min_length': self.min_length,
'min_travel': self.min_travel,
}
# Add tracker specific parameters
if (self.associative_tracker == TrackerType.BYTETRACK):
dict['tracker'] = "ByteTrack"
dict['byte_low_conf'] = self.byte_low_conf
dict['byte_high_conf'] = self.byte_high_conf
elif (self.associative_tracker == TrackerType.CONF_BOOST):
dict['tracker'] = "Confidence Boost"
dict['conf_thresh'] = self.conf_thresh
dict['boost_power'] = self.boost_power
dict['boost_decay'] = self.boost_decay
elif (self.associative_tracker == TrackerType.NONE):
dict['tracker'] = "None"
dict['conf_thresh'] = self.conf_thresh
return dict |