[classifier] fixes and improvements

This commit is contained in:
Jan Grewe 2025-02-19 08:45:23 +01:00
parent 7a2084e159
commit 256e9caa2f

View File

@ -70,8 +70,12 @@ class ConsistencyWorker(QRunnable):
res = False
for n, o in zip(new, original):
res = (o == 1 or o == 2) and n != o
if res:
print("inverted assignment, needs cross-checking?")
if not res:
res = len(new) > 1 and (np.all(new == 1) or np.all(new == 2))
if res:
print("all detections would be assigned to one track!")
return res
def assign_by_distance(f, p):
@ -94,6 +98,19 @@ class ConsistencyWorker(QRunnable):
orientationchange = np.unwrap((last_angle - o)/np.array([t1_step, t2_step]))
most_likely_track = np.argmin(orientationchange) + 1
return most_likely_track, orientationchange
def assign_by_length(f, o):
length_difference = (last_length - o)
most_likely_track = np.argmin(length_difference) + 1
return most_likely_track, length_difference
def do_assignment(f, indices, assignments):
for i, idx in enumerate(indices):
self.tracks[idx] = assignments[i]
last_pos[assignments[i]-1] = pp[i]
last_frame[assignments[i]-1] = f
last_angle[assignments[i]-1] = self.orientations[idx]
last_length[assignments[i]-1] += ((self.lengths[idx] - last_length[assignments[i]-1])/processed)
last_pos = [self.positions[(self.tracks == 1) & (self.frames <= self._startframe)][-1],
self.positions[(self.tracks == 2) & (self.frames <= self._startframe)][-1]]
@ -101,16 +118,20 @@ class ConsistencyWorker(QRunnable):
self.frames[(self.tracks == 2) & (self.frames <= self._startframe)][-1]]
last_angle = [self.orientations[(self.tracks == 1) & (self.frames <= self._startframe)][-1],
self.orientations[(self.tracks == 2) & (self.frames <= self._startframe)][-1]]
last_length = [self.lengths[(self.tracks == 1) & (self.frames <= self._startframe)][-1],
self.lengths[(self.tracks == 2) & (self.frames <= self._startframe)][-1]]
errors = 0
processed = 1
progress = 0
assignment_error = False
self._stoprequest = False
maxframes = np.max(self.frames)
startframe = np.max(last_frame)
steps = int((maxframes - startframe) // 200)
for f in np.unique(self.frames[self.frames > startframe]):
print(f)
processed += 1
if self._stoprequest:
break
indices = np.where(self.frames == f)[0]
@ -118,42 +139,50 @@ class ConsistencyWorker(QRunnable):
originaltracks = self.tracks[indices]
dist_assignments = np.zeros_like(originaltracks)
angle_assignments = np.zeros_like(originaltracks)
# userlabeld = np.zeros_like(originaltracks)
length_assignments = np.zeros_like(originaltracks)
userlabeled = np.zeros_like(originaltracks)
distances = np.zeros((len(originaltracks), 2))
orientations = np.zeros((len(originaltracks), 2))
orientations = np.zeros_like(distances)
lengths = np.zeros_like(distances)
for i, (idx, p) in enumerate(zip(indices, pp)):
print(i)
if self.userlabeled[idx]:
print("user")
processed += 1
userlabeled[i] = True
last_pos[originaltracks[i]-1] = pp[i]
last_frame[originaltracks[i]-1] = f
last_angle[originaltracks[i]-1] = self.orientations[idx]
last_length[originaltracks[i]-1] += ((self.lengths[idx] - last_length[originaltracks[i]-1]) / processed)
continue
dist_assignments[i], distances[i, :] = assign_by_distance(f, p)
angle_assignments[i], orientations[i,:] = assign_by_orientation(f, self.orientations[idx])
# check (re) assignment update and proceed
print("dist", distances)
print("angle", orientations)
if needs_checking(originaltracks, dist_assignments):
logging.info("frame %i: Issues assigning based on distances %s", f, str(distances))
assignment_error = True
errors += 1
if self._stoponerror:
embed()
break
length_assignments[i], lengths[i, :] = assign_by_length(f, self.lengths[idx])
if np.any(userlabeled):
continue
# check (re) assignment, update, and proceed
if not needs_checking(originaltracks, dist_assignments):
logging.info("frame %i: Decision based on distance")
do_assignment(f, indices, dist_assignments)
else:
processed += 1
for i, idx in enumerate(indices):
if assignment_error:
self.tracks[idx] = -1
print(distances)
print(orientations)
print(lengths)
if not (np.all(length_assignments == 1) or np.all(length_assignments == 2)): # if I find a solution by body length
logging.debug("frame %i: Decision based on body length", f)
do_assignment(f, indices, length_assignments)
elif not (np.all(angle_assignments == 1) or np.all(angle_assignments == 2)): # else there is a solution based on orientation
logging.info("frame %i: Decision based on orientation", f)
do_assignment(f, indices, angle_assignments)
else:
self.tracks[idx] = dist_assignments[i]
last_pos[dist_assignments[i]-1] = pp[i]
last_frame[dist_assignments[i]-1] = f
last_angle[dist_assignments[i]-1] = self.orientations[idx]
assignment_error = False
logging.info("frame %i: Cannot decide who is who")
for idx in indices:
self.tracks[idx] = -1
errors += 1
if self._stoponerror:
embed()
break
if steps > 0 and f % steps == 0:
progress += 1
self.signals.progress.emit(progress, processed, errors)
@ -510,6 +539,7 @@ class ConsistencyClassifier(QWidget):
def assignedTracks(self):
return self._tracks
class ClassifierWidget(QTabWidget):
apply_classifier = Signal(np.ndarray)
@ -559,7 +589,7 @@ def main():
import pickle
from fixtracks.info import PACKAGE_ROOT
datafile = PACKAGE_ROOT / "data/merged.pkl"
datafile = PACKAGE_ROOT / "data/merged2.pkl"
with open(datafile, "rb") as f:
df = pickle.load(f)