[classifier] kind of handling mulitple detections in one frame

This commit is contained in:
Jan Grewe 2025-02-26 08:19:59 +01:00
parent 430ee4fac7
commit d6b91c25d2

View File

@ -1,7 +1,7 @@
import logging
import numpy as np
from PySide6.QtWidgets import QWidget, QVBoxLayout, QTabWidget, QPushButton, QGraphicsView
from PySide6.QtWidgets import QWidget, QVBoxLayout, QTabWidget, QPushButton, QGraphicsView, QTextEdit
from PySide6.QtWidgets import QSpinBox, QProgressBar, QGridLayout, QLabel, QCheckBox, QProgressDialog
from PySide6.QtCore import Qt, Signal, Slot, QRunnable, QObject, QThreadPool
from PySide6.QtGui import QBrush, QColor
@ -24,7 +24,7 @@ class Detection():
self.userlabeled = userlabeled
class WorkerSignals(QObject):
error = Signal(str)
message = Signal(str)
running = Signal(bool)
progress = Signal(int, int, int)
currentframe = Signal(int)
@ -52,7 +52,7 @@ class ConsitencyDataLoader(QRunnable):
self.positions = self.data.centerOfGravity()
self.orientations = self.data.orientation()
self.lengths = self.data.animalLength()
self.bendedness = self.data.bendedness()
# self.bendedness = self.data.bendedness()
self.userlabeled = self.data["userlabeled"]
self.scores = self.data["confidence"] # ignore for now, let's see how far this carries.
self.frames = self.data["frame"]
@ -94,25 +94,13 @@ class ConsistencyWorker(QRunnable):
detections.append(d)
return detections
def needs_checking(original, new):
res = False
for n, o in zip(new, original):
res = (o == 1 or o == 2) and n != o
if res:
print("inverted assignment, needs cross-checking?")
if not res:
res = len(new) > 1 and (np.all(new == 1) or np.all(new == 2))
if res:
print("all detections would be assigned to one track!")
return res
def assign_by_distance(d):
t1_step = d.frame - last_detections[1].frame
t2_step = d.frame - last_detections[2].frame
if t1_step == 0 or t2_step == 0:
print(f"framecount is zero! current frame {f}, last frame {last_detections[1].frame} and {last_detections[2].frame}")
distance_to_trackone = np.linalg.norm(d.position - last_detections[1].position)/t1_step
distance_to_tracktwo = np.linalg.norm(d.position - last_detections[2].position)/t2_step
distance_to_trackone = np.linalg.norm(d.position - last_detections[1].position) /t1_step
distance_to_tracktwo = np.linalg.norm(d.position - last_detections[2].position) /t2_step
most_likely_track = np.argmin([distance_to_trackone, distance_to_tracktwo]) + 1
distances = np.zeros(2)
distances[0] = distance_to_trackone
@ -138,6 +126,15 @@ class ConsistencyWorker(QRunnable):
most_likely_track = np.argmin(length_differences) + 1
return most_likely_track, length_differences
def check_multiple_detections(detections):
distances = np.zeros((len(detections), len(detections)))
for i, d1 in enumerate(detections):
for j, d2 in enumerate(detections):
distances[i, j] = np.abs(np.linalg.norm(d2.position - d1.position))
lowest_dist = np.argmin(np.sum(distances, axis=1))
del detections[lowest_dist]
return detections
unique_frames = np.unique(self.frames)
steps = int((len(unique_frames) - self._startframe) // 100)
errors = 0
@ -150,19 +147,28 @@ class ConsistencyWorker(QRunnable):
if self._stoprequest:
break
error = False
message = ""
self.signals.currentframe.emit(f)
indices = np.where(self.frames == f)[0]
detections = get_detections(f, indices)
done = [False, False]
if len(detections) == 0:
continue
if len(detections) > 2:
message = f"Frame {f}: More than 2 detections ({len(detections)}) in the same frame!"
logging.info("ConsistencyTracker: %s", message)
self.signals.message.emit(message)
while len(detections) > 2:
detections = check_multiple_detections(detections)
if len(detections) > 1 and np.any([detections[0].userlabeled, detections[1].userlabeled]):
# more than one detection
if detections[0].userlabeled and detections[1].userlabeled:
if detections[0].track == detections[1].track:
error = True
logging.info("Classification error both detections in the same frame are assigned to the same track!")
message = f"Frame {f}: Classification error both detections in the same frame are assigned to the same track!"
logging.info("ConsistencyTracker: %s", message)
self.signals.message.emit(message)
elif detections[0].userlabeled and not detections[1].userlabeled:
detections[1].track = 1 if detections[0].track == 2 else 2
elif not detections[0].userlabeled and detections[1].userlabeled:
@ -178,50 +184,52 @@ class ConsistencyWorker(QRunnable):
elif len(detections) == 1 and detections[0].userlabeled: # ony one detection and labeled
last_detections[detections[0].track] = detections[0]
done[0] = True
if np.sum(done) == len(detections):
continue
# if f == 2088:
# embed()
# return
if error and self._stoponerror:
self.signals.error.emit("Classification error both detections in the same frame are assigned to the same track!")
self.signals.message.emit("Tracking stopped at frame %i.", f)
break
elif error:
continue
dist_assignments = np.zeros(2, dtype=int)
orientation_assignments = np.zeros_like(dist_assignments)
length_assignments = np.zeros_like(dist_assignments)
distances = np.zeros((2, 2))
orientations = np.zeros_like(distances)
lengths = np.zeros_like(distances)
assignments = np.zeros((2, 2))
assignments = np.zeros(2)
for i, d in enumerate(detections):
dist_assignments[i], distances[i, :] = assign_by_distance(d)
orientation_assignments[i], orientations[i,:] = assign_by_orientation(d)
length_assignments[i], lengths[i, :] = assign_by_length(d)
assignments[i, :] = dist_assignments # (dist_assignments * 10 + orientation_assignments + length_assignments) / 3
assignments = dist_assignments # (dist_assignments * 10 + orientation_assignments + length_assignments) / 3
diffs = np.diff(assignments, axis=1)
error = False
temp = {}
message = ""
for i, d in enumerate(detections):
temp = {}
if diffs[i] == 0: # both are equally likely
if len(detections) > 1:
if assignments[0] == assignments[1]:
d.track = -1
error = True
message = "Classification error both detections in the same frame are assigned to the same track!"
message = f"Frame {f}: Classification error: both detections in the same frame are assigned to the same track!"
break
if diffs[i] < 0:
d.track = 1
else:
d.track = 2
self.tracks[d.id] = d.track
if d.track not in temp:
temp[d.track] = d
elif assignments[0] != assignments[1]:
detections[0].track = assignments[0]
detections[1].track = assignments[1]
temp[detections[0].track] = detections[0]
temp[detections[1].track] = detections[1]
self.tracks[detections[0].id] = detections[0].track
self.tracks[detections[1].id] = detections[1].track
else:
if np.abs(np.diff(distances[0,:])) > 50: # maybe include the time difference into this?
detections[0].track = assignments[0]
temp[detections[0].track] = detections[0]
self.tracks[detections[0].id] = detections[0].track
else:
self.tracks[detections[0].id] = -1
message = f"Frame: {f}: Decision based on distance not safe. Track set to unassigned."
error = True
message = "Double assignment to the same track!"
break
if not error:
for k in temp:
@ -232,14 +240,14 @@ class ConsistencyWorker(QRunnable):
self.tracks[idx] = -1
errors += 1
if self._stoponerror:
self.signals.error.emit(message)
self.signals.message.emit(message)
break
processed += 1
if steps > 0 and f % steps == 0:
progress += 1
self.signals.progress.emit(progress, processed, errors)
self.signals.message.emit("Tracking stopped at frame %i.", f)
self.signals.stopped.emit(f)
@ -487,6 +495,10 @@ class ConsistencyClassifier(QWidget):
self._stoponerror.setChecked(True)
self.threadpool = QThreadPool()
self._messagebox = QTextEdit()
self._messagebox.setFocusPolicy(Qt.NoFocus)
self._messagebox.setReadOnly(True)
lyt = QGridLayout()
lyt.addWidget(QLabel("Start frame:"), 0, 0 )
lyt.addWidget(self._startframe_spinner, 0, 1, 1, 2)
@ -499,13 +511,14 @@ class ConsistencyClassifier(QWidget):
lyt.addWidget(self._assignedlabel, 4, 1)
lyt.addWidget(QLabel("errors/issues"), 5, 0)
lyt.addWidget(self._errorlabel, 5, 1)
lyt.addWidget(self._startbtn, 6, 0)
lyt.addWidget(self._stopbtn, 6, 1)
lyt.addWidget(self._proceedbtn, 6, 2)
lyt.addWidget(self._apply_btn, 7, 0, 1, 2)
lyt.addWidget(self._refreshbtn, 7, 2, 1, 1)
lyt.addWidget(self._progressbar, 8, 0, 1, 3)
lyt.addWidget(self._messagebox, 6, 0, 2, 3)
lyt.addWidget(self._startbtn, 8, 0)
lyt.addWidget(self._stopbtn, 8, 1)
lyt.addWidget(self._proceedbtn, 8, 2)
lyt.addWidget(self._apply_btn, 9, 0, 1, 2)
lyt.addWidget(self._refreshbtn, 9, 2, 1, 1)
lyt.addWidget(self._progressbar, 10, 0, 1, 3)
self.setLayout(lyt)
def setData(self, data:TrackingData):
@ -575,12 +588,16 @@ class ConsistencyClassifier(QWidget):
self._startframe_spinner.value(), self._stoponerror.isChecked())
self._worker.signals.stopped.connect(self.worker_stopped)
self._worker.signals.progress.connect(self.worker_progress)
self._worker.signals.message.connect(self.worker_error)
self._worker.signals.currentframe.connect(self.worker_frame)
self.threadpool.start(self._worker)
def worker_frame(self, frame):
self._framelabel.setText(str(frame))
def worker_error(self, msg):
self._messagebox.append(msg)
def proceed(self):
self.start()
@ -666,7 +683,7 @@ def main():
import pickle
from fixtracks.info import PACKAGE_ROOT
datafile = PACKAGE_ROOT / "data/merged_small_beginning.pkl"
datafile = PACKAGE_ROOT / "data/merged_small_starter.pkl"
with open(datafile, "rb") as f:
df = pickle.load(f)