[classifier] kind of handling mulitple detections in one frame

This commit is contained in:
Jan Grewe 2025-02-26 08:19:59 +01:00
parent 430ee4fac7
commit d6b91c25d2

View File

@ -1,7 +1,7 @@
import logging import logging
import numpy as np import numpy as np
from PySide6.QtWidgets import QWidget, QVBoxLayout, QTabWidget, QPushButton, QGraphicsView from PySide6.QtWidgets import QWidget, QVBoxLayout, QTabWidget, QPushButton, QGraphicsView, QTextEdit
from PySide6.QtWidgets import QSpinBox, QProgressBar, QGridLayout, QLabel, QCheckBox, QProgressDialog from PySide6.QtWidgets import QSpinBox, QProgressBar, QGridLayout, QLabel, QCheckBox, QProgressDialog
from PySide6.QtCore import Qt, Signal, Slot, QRunnable, QObject, QThreadPool from PySide6.QtCore import Qt, Signal, Slot, QRunnable, QObject, QThreadPool
from PySide6.QtGui import QBrush, QColor from PySide6.QtGui import QBrush, QColor
@ -24,7 +24,7 @@ class Detection():
self.userlabeled = userlabeled self.userlabeled = userlabeled
class WorkerSignals(QObject): class WorkerSignals(QObject):
error = Signal(str) message = Signal(str)
running = Signal(bool) running = Signal(bool)
progress = Signal(int, int, int) progress = Signal(int, int, int)
currentframe = Signal(int) currentframe = Signal(int)
@ -52,7 +52,7 @@ class ConsitencyDataLoader(QRunnable):
self.positions = self.data.centerOfGravity() self.positions = self.data.centerOfGravity()
self.orientations = self.data.orientation() self.orientations = self.data.orientation()
self.lengths = self.data.animalLength() self.lengths = self.data.animalLength()
self.bendedness = self.data.bendedness() # self.bendedness = self.data.bendedness()
self.userlabeled = self.data["userlabeled"] self.userlabeled = self.data["userlabeled"]
self.scores = self.data["confidence"] # ignore for now, let's see how far this carries. self.scores = self.data["confidence"] # ignore for now, let's see how far this carries.
self.frames = self.data["frame"] self.frames = self.data["frame"]
@ -94,25 +94,13 @@ class ConsistencyWorker(QRunnable):
detections.append(d) detections.append(d)
return detections return detections
def needs_checking(original, new):
res = False
for n, o in zip(new, original):
res = (o == 1 or o == 2) and n != o
if res:
print("inverted assignment, needs cross-checking?")
if not res:
res = len(new) > 1 and (np.all(new == 1) or np.all(new == 2))
if res:
print("all detections would be assigned to one track!")
return res
def assign_by_distance(d): def assign_by_distance(d):
t1_step = d.frame - last_detections[1].frame t1_step = d.frame - last_detections[1].frame
t2_step = d.frame - last_detections[2].frame t2_step = d.frame - last_detections[2].frame
if t1_step == 0 or t2_step == 0: if t1_step == 0 or t2_step == 0:
print(f"framecount is zero! current frame {f}, last frame {last_detections[1].frame} and {last_detections[2].frame}") print(f"framecount is zero! current frame {f}, last frame {last_detections[1].frame} and {last_detections[2].frame}")
distance_to_trackone = np.linalg.norm(d.position - last_detections[1].position)/t1_step distance_to_trackone = np.linalg.norm(d.position - last_detections[1].position) /t1_step
distance_to_tracktwo = np.linalg.norm(d.position - last_detections[2].position)/t2_step distance_to_tracktwo = np.linalg.norm(d.position - last_detections[2].position) /t2_step
most_likely_track = np.argmin([distance_to_trackone, distance_to_tracktwo]) + 1 most_likely_track = np.argmin([distance_to_trackone, distance_to_tracktwo]) + 1
distances = np.zeros(2) distances = np.zeros(2)
distances[0] = distance_to_trackone distances[0] = distance_to_trackone
@ -138,6 +126,15 @@ class ConsistencyWorker(QRunnable):
most_likely_track = np.argmin(length_differences) + 1 most_likely_track = np.argmin(length_differences) + 1
return most_likely_track, length_differences return most_likely_track, length_differences
def check_multiple_detections(detections):
distances = np.zeros((len(detections), len(detections)))
for i, d1 in enumerate(detections):
for j, d2 in enumerate(detections):
distances[i, j] = np.abs(np.linalg.norm(d2.position - d1.position))
lowest_dist = np.argmin(np.sum(distances, axis=1))
del detections[lowest_dist]
return detections
unique_frames = np.unique(self.frames) unique_frames = np.unique(self.frames)
steps = int((len(unique_frames) - self._startframe) // 100) steps = int((len(unique_frames) - self._startframe) // 100)
errors = 0 errors = 0
@ -150,19 +147,28 @@ class ConsistencyWorker(QRunnable):
if self._stoprequest: if self._stoprequest:
break break
error = False error = False
message = ""
self.signals.currentframe.emit(f) self.signals.currentframe.emit(f)
indices = np.where(self.frames == f)[0] indices = np.where(self.frames == f)[0]
detections = get_detections(f, indices) detections = get_detections(f, indices)
done = [False, False] done = [False, False]
if len(detections) == 0: if len(detections) == 0:
continue continue
if len(detections) > 2:
message = f"Frame {f}: More than 2 detections ({len(detections)}) in the same frame!"
logging.info("ConsistencyTracker: %s", message)
self.signals.message.emit(message)
while len(detections) > 2:
detections = check_multiple_detections(detections)
if len(detections) > 1 and np.any([detections[0].userlabeled, detections[1].userlabeled]): if len(detections) > 1 and np.any([detections[0].userlabeled, detections[1].userlabeled]):
# more than one detection # more than one detection
if detections[0].userlabeled and detections[1].userlabeled: if detections[0].userlabeled and detections[1].userlabeled:
if detections[0].track == detections[1].track: if detections[0].track == detections[1].track:
error = True error = True
logging.info("Classification error both detections in the same frame are assigned to the same track!") message = f"Frame {f}: Classification error both detections in the same frame are assigned to the same track!"
logging.info("ConsistencyTracker: %s", message)
self.signals.message.emit(message)
elif detections[0].userlabeled and not detections[1].userlabeled: elif detections[0].userlabeled and not detections[1].userlabeled:
detections[1].track = 1 if detections[0].track == 2 else 2 detections[1].track = 1 if detections[0].track == 2 else 2
elif not detections[0].userlabeled and detections[1].userlabeled: elif not detections[0].userlabeled and detections[1].userlabeled:
@ -178,50 +184,52 @@ class ConsistencyWorker(QRunnable):
elif len(detections) == 1 and detections[0].userlabeled: # ony one detection and labeled elif len(detections) == 1 and detections[0].userlabeled: # ony one detection and labeled
last_detections[detections[0].track] = detections[0] last_detections[detections[0].track] = detections[0]
done[0] = True done[0] = True
if np.sum(done) == len(detections): if np.sum(done) == len(detections):
continue continue
# if f == 2088:
# embed()
# return
if error and self._stoponerror: if error and self._stoponerror:
self.signals.error.emit("Classification error both detections in the same frame are assigned to the same track!") self.signals.message.emit("Tracking stopped at frame %i.", f)
break break
elif error:
continue
dist_assignments = np.zeros(2, dtype=int) dist_assignments = np.zeros(2, dtype=int)
orientation_assignments = np.zeros_like(dist_assignments) orientation_assignments = np.zeros_like(dist_assignments)
length_assignments = np.zeros_like(dist_assignments) length_assignments = np.zeros_like(dist_assignments)
distances = np.zeros((2, 2)) distances = np.zeros((2, 2))
orientations = np.zeros_like(distances) orientations = np.zeros_like(distances)
lengths = np.zeros_like(distances) lengths = np.zeros_like(distances)
assignments = np.zeros((2, 2)) assignments = np.zeros(2)
for i, d in enumerate(detections): for i, d in enumerate(detections):
dist_assignments[i], distances[i, :] = assign_by_distance(d) dist_assignments[i], distances[i, :] = assign_by_distance(d)
orientation_assignments[i], orientations[i,:] = assign_by_orientation(d) orientation_assignments[i], orientations[i,:] = assign_by_orientation(d)
length_assignments[i], lengths[i, :] = assign_by_length(d) length_assignments[i], lengths[i, :] = assign_by_length(d)
assignments[i, :] = dist_assignments # (dist_assignments * 10 + orientation_assignments + length_assignments) / 3 assignments = dist_assignments # (dist_assignments * 10 + orientation_assignments + length_assignments) / 3
diffs = np.diff(assignments, axis=1)
error = False error = False
temp = {} temp = {}
message = "" message = ""
for i, d in enumerate(detections): if len(detections) > 1:
temp = {} if assignments[0] == assignments[1]:
if diffs[i] == 0: # both are equally likely
d.track = -1 d.track = -1
error = True error = True
message = "Classification error both detections in the same frame are assigned to the same track!" message = f"Frame {f}: Classification error: both detections in the same frame are assigned to the same track!"
break break
if diffs[i] < 0: elif assignments[0] != assignments[1]:
d.track = 1 detections[0].track = assignments[0]
else: detections[1].track = assignments[1]
d.track = 2 temp[detections[0].track] = detections[0]
self.tracks[d.id] = d.track temp[detections[1].track] = detections[1]
if d.track not in temp: self.tracks[detections[0].id] = detections[0].track
temp[d.track] = d self.tracks[detections[1].id] = detections[1].track
else:
if np.abs(np.diff(distances[0,:])) > 50: # maybe include the time difference into this?
detections[0].track = assignments[0]
temp[detections[0].track] = detections[0]
self.tracks[detections[0].id] = detections[0].track
else: else:
self.tracks[detections[0].id] = -1
message = f"Frame: {f}: Decision based on distance not safe. Track set to unassigned."
error = True error = True
message = "Double assignment to the same track!"
break
if not error: if not error:
for k in temp: for k in temp:
@ -232,14 +240,14 @@ class ConsistencyWorker(QRunnable):
self.tracks[idx] = -1 self.tracks[idx] = -1
errors += 1 errors += 1
if self._stoponerror: if self._stoponerror:
self.signals.error.emit(message) self.signals.message.emit(message)
break break
processed += 1 processed += 1
if steps > 0 and f % steps == 0: if steps > 0 and f % steps == 0:
progress += 1 progress += 1
self.signals.progress.emit(progress, processed, errors) self.signals.progress.emit(progress, processed, errors)
self.signals.message.emit("Tracking stopped at frame %i.", f)
self.signals.stopped.emit(f) self.signals.stopped.emit(f)
@ -487,6 +495,10 @@ class ConsistencyClassifier(QWidget):
self._stoponerror.setChecked(True) self._stoponerror.setChecked(True)
self.threadpool = QThreadPool() self.threadpool = QThreadPool()
self._messagebox = QTextEdit()
self._messagebox.setFocusPolicy(Qt.NoFocus)
self._messagebox.setReadOnly(True)
lyt = QGridLayout() lyt = QGridLayout()
lyt.addWidget(QLabel("Start frame:"), 0, 0 ) lyt.addWidget(QLabel("Start frame:"), 0, 0 )
lyt.addWidget(self._startframe_spinner, 0, 1, 1, 2) lyt.addWidget(self._startframe_spinner, 0, 1, 1, 2)
@ -499,13 +511,14 @@ class ConsistencyClassifier(QWidget):
lyt.addWidget(self._assignedlabel, 4, 1) lyt.addWidget(self._assignedlabel, 4, 1)
lyt.addWidget(QLabel("errors/issues"), 5, 0) lyt.addWidget(QLabel("errors/issues"), 5, 0)
lyt.addWidget(self._errorlabel, 5, 1) lyt.addWidget(self._errorlabel, 5, 1)
lyt.addWidget(self._messagebox, 6, 0, 2, 3)
lyt.addWidget(self._startbtn, 6, 0)
lyt.addWidget(self._stopbtn, 6, 1) lyt.addWidget(self._startbtn, 8, 0)
lyt.addWidget(self._proceedbtn, 6, 2) lyt.addWidget(self._stopbtn, 8, 1)
lyt.addWidget(self._apply_btn, 7, 0, 1, 2) lyt.addWidget(self._proceedbtn, 8, 2)
lyt.addWidget(self._refreshbtn, 7, 2, 1, 1) lyt.addWidget(self._apply_btn, 9, 0, 1, 2)
lyt.addWidget(self._progressbar, 8, 0, 1, 3) lyt.addWidget(self._refreshbtn, 9, 2, 1, 1)
lyt.addWidget(self._progressbar, 10, 0, 1, 3)
self.setLayout(lyt) self.setLayout(lyt)
def setData(self, data:TrackingData): def setData(self, data:TrackingData):
@ -575,12 +588,16 @@ class ConsistencyClassifier(QWidget):
self._startframe_spinner.value(), self._stoponerror.isChecked()) self._startframe_spinner.value(), self._stoponerror.isChecked())
self._worker.signals.stopped.connect(self.worker_stopped) self._worker.signals.stopped.connect(self.worker_stopped)
self._worker.signals.progress.connect(self.worker_progress) self._worker.signals.progress.connect(self.worker_progress)
self._worker.signals.message.connect(self.worker_error)
self._worker.signals.currentframe.connect(self.worker_frame) self._worker.signals.currentframe.connect(self.worker_frame)
self.threadpool.start(self._worker) self.threadpool.start(self._worker)
def worker_frame(self, frame): def worker_frame(self, frame):
self._framelabel.setText(str(frame)) self._framelabel.setText(str(frame))
def worker_error(self, msg):
self._messagebox.append(msg)
def proceed(self): def proceed(self):
self.start() self.start()
@ -666,7 +683,7 @@ def main():
import pickle import pickle
from fixtracks.info import PACKAGE_ROOT from fixtracks.info import PACKAGE_ROOT
datafile = PACKAGE_ROOT / "data/merged_small_beginning.pkl" datafile = PACKAGE_ROOT / "data/merged_small_starter.pkl"
with open(datafile, "rb") as f: with open(datafile, "rb") as f:
df = pickle.load(f) df = pickle.load(f)