diff --git a/fixtracks/widgets/classifier.py b/fixtracks/widgets/classifier.py index 406d114..407f047 100644 --- a/fixtracks/widgets/classifier.py +++ b/fixtracks/widgets/classifier.py @@ -1,7 +1,7 @@ import logging import numpy as np -from PySide6.QtWidgets import QWidget, QVBoxLayout, QTabWidget, QPushButton, QGraphicsView +from PySide6.QtWidgets import QWidget, QVBoxLayout, QTabWidget, QPushButton, QGraphicsView, QTextEdit from PySide6.QtWidgets import QSpinBox, QProgressBar, QGridLayout, QLabel, QCheckBox, QProgressDialog from PySide6.QtCore import Qt, Signal, Slot, QRunnable, QObject, QThreadPool from PySide6.QtGui import QBrush, QColor @@ -24,7 +24,7 @@ class Detection(): self.userlabeled = userlabeled class WorkerSignals(QObject): - error = Signal(str) + message = Signal(str) running = Signal(bool) progress = Signal(int, int, int) currentframe = Signal(int) @@ -52,7 +52,7 @@ class ConsitencyDataLoader(QRunnable): self.positions = self.data.centerOfGravity() self.orientations = self.data.orientation() self.lengths = self.data.animalLength() - self.bendedness = self.data.bendedness() + # self.bendedness = self.data.bendedness() self.userlabeled = self.data["userlabeled"] self.scores = self.data["confidence"] # ignore for now, let's see how far this carries. self.frames = self.data["frame"] @@ -94,25 +94,13 @@ class ConsistencyWorker(QRunnable): detections.append(d) return detections - def needs_checking(original, new): - res = False - for n, o in zip(new, original): - res = (o == 1 or o == 2) and n != o - if res: - print("inverted assignment, needs cross-checking?") - if not res: - res = len(new) > 1 and (np.all(new == 1) or np.all(new == 2)) - if res: - print("all detections would be assigned to one track!") - return res - def assign_by_distance(d): t1_step = d.frame - last_detections[1].frame t2_step = d.frame - last_detections[2].frame if t1_step == 0 or t2_step == 0: print(f"framecount is zero! current frame {f}, last frame {last_detections[1].frame} and {last_detections[2].frame}") - distance_to_trackone = np.linalg.norm(d.position - last_detections[1].position)/t1_step - distance_to_tracktwo = np.linalg.norm(d.position - last_detections[2].position)/t2_step + distance_to_trackone = np.linalg.norm(d.position - last_detections[1].position) /t1_step + distance_to_tracktwo = np.linalg.norm(d.position - last_detections[2].position) /t2_step most_likely_track = np.argmin([distance_to_trackone, distance_to_tracktwo]) + 1 distances = np.zeros(2) distances[0] = distance_to_trackone @@ -138,6 +126,15 @@ class ConsistencyWorker(QRunnable): most_likely_track = np.argmin(length_differences) + 1 return most_likely_track, length_differences + def check_multiple_detections(detections): + distances = np.zeros((len(detections), len(detections))) + for i, d1 in enumerate(detections): + for j, d2 in enumerate(detections): + distances[i, j] = np.abs(np.linalg.norm(d2.position - d1.position)) + lowest_dist = np.argmin(np.sum(distances, axis=1)) + del detections[lowest_dist] + return detections + unique_frames = np.unique(self.frames) steps = int((len(unique_frames) - self._startframe) // 100) errors = 0 @@ -150,19 +147,28 @@ class ConsistencyWorker(QRunnable): if self._stoprequest: break error = False + message = "" self.signals.currentframe.emit(f) indices = np.where(self.frames == f)[0] detections = get_detections(f, indices) done = [False, False] if len(detections) == 0: continue + if len(detections) > 2: + message = f"Frame {f}: More than 2 detections ({len(detections)}) in the same frame!" + logging.info("ConsistencyTracker: %s", message) + self.signals.message.emit(message) + while len(detections) > 2: + detections = check_multiple_detections(detections) if len(detections) > 1 and np.any([detections[0].userlabeled, detections[1].userlabeled]): # more than one detection if detections[0].userlabeled and detections[1].userlabeled: if detections[0].track == detections[1].track: error = True - logging.info("Classification error both detections in the same frame are assigned to the same track!") + message = f"Frame {f}: Classification error both detections in the same frame are assigned to the same track!" + logging.info("ConsistencyTracker: %s", message) + self.signals.message.emit(message) elif detections[0].userlabeled and not detections[1].userlabeled: detections[1].track = 1 if detections[0].track == 2 else 2 elif not detections[0].userlabeled and detections[1].userlabeled: @@ -178,50 +184,52 @@ class ConsistencyWorker(QRunnable): elif len(detections) == 1 and detections[0].userlabeled: # ony one detection and labeled last_detections[detections[0].track] = detections[0] done[0] = True - if np.sum(done) == len(detections): continue - # if f == 2088: - # embed() - # return + if error and self._stoponerror: - self.signals.error.emit("Classification error both detections in the same frame are assigned to the same track!") + self.signals.message.emit("Tracking stopped at frame %i.", f) break + elif error: + continue dist_assignments = np.zeros(2, dtype=int) orientation_assignments = np.zeros_like(dist_assignments) length_assignments = np.zeros_like(dist_assignments) distances = np.zeros((2, 2)) orientations = np.zeros_like(distances) lengths = np.zeros_like(distances) - assignments = np.zeros((2, 2)) + assignments = np.zeros(2) for i, d in enumerate(detections): dist_assignments[i], distances[i, :] = assign_by_distance(d) orientation_assignments[i], orientations[i,:] = assign_by_orientation(d) length_assignments[i], lengths[i, :] = assign_by_length(d) - assignments[i, :] = dist_assignments # (dist_assignments * 10 + orientation_assignments + length_assignments) / 3 + assignments = dist_assignments # (dist_assignments * 10 + orientation_assignments + length_assignments) / 3 - diffs = np.diff(assignments, axis=1) error = False temp = {} message = "" - for i, d in enumerate(detections): - temp = {} - if diffs[i] == 0: # both are equally likely + if len(detections) > 1: + if assignments[0] == assignments[1]: d.track = -1 error = True - message = "Classification error both detections in the same frame are assigned to the same track!" + message = f"Frame {f}: Classification error: both detections in the same frame are assigned to the same track!" break - if diffs[i] < 0: - d.track = 1 - else: - d.track = 2 - self.tracks[d.id] = d.track - if d.track not in temp: - temp[d.track] = d + elif assignments[0] != assignments[1]: + detections[0].track = assignments[0] + detections[1].track = assignments[1] + temp[detections[0].track] = detections[0] + temp[detections[1].track] = detections[1] + self.tracks[detections[0].id] = detections[0].track + self.tracks[detections[1].id] = detections[1].track + else: + if np.abs(np.diff(distances[0,:])) > 50: # maybe include the time difference into this? + detections[0].track = assignments[0] + temp[detections[0].track] = detections[0] + self.tracks[detections[0].id] = detections[0].track else: + self.tracks[detections[0].id] = -1 + message = f"Frame: {f}: Decision based on distance not safe. Track set to unassigned." error = True - message = "Double assignment to the same track!" - break if not error: for k in temp: @@ -232,14 +240,14 @@ class ConsistencyWorker(QRunnable): self.tracks[idx] = -1 errors += 1 if self._stoponerror: - self.signals.error.emit(message) + self.signals.message.emit(message) break processed += 1 if steps > 0 and f % steps == 0: progress += 1 self.signals.progress.emit(progress, processed, errors) - + self.signals.message.emit("Tracking stopped at frame %i.", f) self.signals.stopped.emit(f) @@ -487,6 +495,10 @@ class ConsistencyClassifier(QWidget): self._stoponerror.setChecked(True) self.threadpool = QThreadPool() + self._messagebox = QTextEdit() + self._messagebox.setFocusPolicy(Qt.NoFocus) + self._messagebox.setReadOnly(True) + lyt = QGridLayout() lyt.addWidget(QLabel("Start frame:"), 0, 0 ) lyt.addWidget(self._startframe_spinner, 0, 1, 1, 2) @@ -499,13 +511,14 @@ class ConsistencyClassifier(QWidget): lyt.addWidget(self._assignedlabel, 4, 1) lyt.addWidget(QLabel("errors/issues"), 5, 0) lyt.addWidget(self._errorlabel, 5, 1) - - lyt.addWidget(self._startbtn, 6, 0) - lyt.addWidget(self._stopbtn, 6, 1) - lyt.addWidget(self._proceedbtn, 6, 2) - lyt.addWidget(self._apply_btn, 7, 0, 1, 2) - lyt.addWidget(self._refreshbtn, 7, 2, 1, 1) - lyt.addWidget(self._progressbar, 8, 0, 1, 3) + lyt.addWidget(self._messagebox, 6, 0, 2, 3) + + lyt.addWidget(self._startbtn, 8, 0) + lyt.addWidget(self._stopbtn, 8, 1) + lyt.addWidget(self._proceedbtn, 8, 2) + lyt.addWidget(self._apply_btn, 9, 0, 1, 2) + lyt.addWidget(self._refreshbtn, 9, 2, 1, 1) + lyt.addWidget(self._progressbar, 10, 0, 1, 3) self.setLayout(lyt) def setData(self, data:TrackingData): @@ -575,12 +588,16 @@ class ConsistencyClassifier(QWidget): self._startframe_spinner.value(), self._stoponerror.isChecked()) self._worker.signals.stopped.connect(self.worker_stopped) self._worker.signals.progress.connect(self.worker_progress) + self._worker.signals.message.connect(self.worker_error) self._worker.signals.currentframe.connect(self.worker_frame) self.threadpool.start(self._worker) def worker_frame(self, frame): self._framelabel.setText(str(frame)) + def worker_error(self, msg): + self._messagebox.append(msg) + def proceed(self): self.start() @@ -666,7 +683,7 @@ def main(): import pickle from fixtracks.info import PACKAGE_ROOT - datafile = PACKAGE_ROOT / "data/merged_small_beginning.pkl" + datafile = PACKAGE_ROOT / "data/merged_small_starter.pkl" with open(datafile, "rb") as f: df = pickle.load(f)