[classifier] make sure, we always start with user labeled detections

This commit is contained in:
Jan Grewe 2025-02-27 17:41:07 +01:00
parent 15264dbe48
commit ae24463be2

View File

@ -135,6 +135,16 @@ class ConsistencyWorker(QRunnable):
del detections[lowest_dist]
return detections
def find_last_userlabeled(startframe):
t1index = np.where((self.frames < startframe) & (self.userlabeled) & (self.tracks == 1))[0][-1]
t2index = np.where((self.frames < startframe) & (self.userlabeled) & (self.tracks == 2))[0][-1]
d1 = Detection(t1index, self.frames[t1index], self.tracks[t1index], self.positions[t1index],
self.orientations[t1index], self.lengths[t1index], self.userlabeled[t1index])
d2 = Detection(t1index, self.frames[t2index], self.tracks[t2index], self.positions[t2index],
self.orientations[t2index], self.lengths[t2index], self.userlabeled[t2index])
last_detections[1] = d1
last_detections[2] = d2
unique_frames = np.unique(self.frames)
steps = int((len(unique_frames) - self._startframe) // 100)
errors = 0
@ -142,6 +152,7 @@ class ConsistencyWorker(QRunnable):
progress = 0
self._stoprequest = False
last_detections = {1: None, 2: None, -1: None}
find_last_userlabeled(self._startframe)
for f in unique_frames[unique_frames >= self._startframe]:
if self._stoprequest:
@ -188,7 +199,7 @@ class ConsistencyWorker(QRunnable):
continue
if error and self._stoponerror:
self.signals.message.emit("Tracking stopped at frame %i.", f)
self.signals.message.emit(f"Tracking stopped at frame {f}.")
break
elif error:
continue
@ -559,8 +570,11 @@ class ConsistencyClassifier(QWidget):
self._messagebox.append("Error preparing data! Make sure that the first user-labeled frames contain both tracks!")
self.setEnabled(False)
return
max_startframe = np.min([t1_userlabeled[-1], t2_userlabeled[-1]])
min_startframe = np.max([t1_userlabeled[0], t2_userlabeled[0]])
max_startframe = np.min([t1_userlabeled[-1], t2_userlabeled[-1]]) -1
first_guess = np.max([t1_userlabeled[0], t2_userlabeled[0]])
while first_guess not in t1_userlabeled or first_guess not in t2_userlabeled:
first_guess += 1
min_startframe = first_guess + 1
self._maxframes = np.max(self._frames)
self._maxframeslabel.setText(str(self._maxframes))
self._startframe_spinner.setMinimum(min_startframe)