From 74fc43b586c0aa45f8e991d99f200f8b84bd4db7 Mon Sep 17 00:00:00 2001 From: Jan Grewe Date: Mon, 17 Feb 2025 22:26:22 +0100 Subject: [PATCH] [classifier] more interactions on consistencytracker --- fixtracks/widgets/classifier.py | 149 +++++++++++++++++++++----------- 1 file changed, 98 insertions(+), 51 deletions(-) diff --git a/fixtracks/widgets/classifier.py b/fixtracks/widgets/classifier.py index e664515..8f5aa4f 100644 --- a/fixtracks/widgets/classifier.py +++ b/fixtracks/widgets/classifier.py @@ -1,7 +1,8 @@ import logging import numpy as np -from PySide6.QtWidgets import QWidget, QVBoxLayout, QTabWidget, QPushButton, QGraphicsView, QSpinBox, QProgressBar, QGridLayout, QLabel +from PySide6.QtWidgets import QWidget, QVBoxLayout, QTabWidget, QPushButton, QGraphicsView +from PySide6.QtWidgets import QSpinBox, QProgressBar, QGridLayout, QLabel, QCheckBox from PySide6.QtCore import Signal, Slot, QRunnable, QObject, QThreadPool from PySide6.QtGui import QBrush, QColor import pyqtgraph as pg # needs to be imported after pyside to not import pyqt @@ -13,12 +14,13 @@ class WorkerSignals(QObject): error = Signal(str) running = Signal(bool) progress = Signal(int, int, int) - finished = Signal(bool) + stopped = Signal(int) class ConsistencyWorker(QRunnable): signals = WorkerSignals() - def __init__(self, positions, orientations, lengths, bendedness, frames, tracks, startframe=0) -> None: + def __init__(self, positions, orientations, lengths, bendedness, frames, tracks, + startframe=0, stoponerror=False) -> None: super().__init__() self.positions = positions self.orientations = orientations @@ -28,25 +30,29 @@ class ConsistencyWorker(QRunnable): self.tracks = tracks self._startframe = startframe self._stoprequest = False + self._stoponerror = stoponerror @Slot() - def cancel(self): + def stop(self): self._stoprequest = True @Slot() def run(self): - last_pos = [self.positions[self.tracks == 1][0], self.positions[self.tracks == 2][0]] - last_frame = [self.frames[self.tracks == 1][0], self.frames[self.tracks == 2][0]] - last_angle = [self.orientations[self.tracks == 1][0], self.orientations[self.tracks == 2][0]] + last_pos = [self.positions[(self.tracks == 1) & (self.frames <= self._startframe)][-1], + self.positions[(self.tracks == 2) & (self.frames <= self._startframe)][-1]] + last_frame = [self.frames[(self.tracks == 1) & (self.frames <= self._startframe)][-1], + self.frames[(self.tracks == 2) & (self.frames <= self._startframe)][-1]] + # last_angle = [self.orientations[self.tracks == 1][0], self.orientations[self.tracks == 2][0]] errors = 0 - processed = 0 - self._stoprequest = False - maxframes = np.max(self.frames) - steps = int((maxframes - self._startframe) // 100) + processed = 1 progress = 0 assignment_error = False - for f in range(self._startframe, np.max(self.frames), 1): - processed += 1 + self._stoprequest = False + maxframes = np.max(self.frames) + startframe = np.max(last_frame) + steps = int((maxframes - startframe) // 200) + + for f in range(startframe + 1, maxframes, 1): if self._stoprequest: break indices = np.where(self.frames == f)[0] @@ -59,15 +65,17 @@ class ConsistencyWorker(QRunnable): self.tracks[idx] = 2 last_frame[1] = f last_pos[1] = p - last_angle[1] = self.orientations[idx] + # last_angle[1] = self.orientations[idx] continue if f < last_frame[1]: last_frame[0] = f last_pos[0] = p - last_angle[0] = self.orientations[idx] + # last_angle[0] = self.orientations[idx] self.tracks[idx] = 1 continue # else, we have already seen track one and track two entries + if f - last_frame[0] == 0 or f - last_frame[1] == 0: + print(f"framecount is zero! current frame {f}, last frame {last_frame[0]} and {last_frame[1]}") distance_to_trackone = np.linalg.norm(p - last_pos[0])/(f - last_frame[0]) distance_to_tracktwo = np.linalg.norm(p - last_pos[1])/(f - last_frame[1]) most_likely_track = np.argmin([distance_to_trackone, distance_to_tracktwo]) + 1 @@ -79,6 +87,8 @@ class ConsistencyWorker(QRunnable): logging.warning("frame %i: Issues assigning based on distances %s", f, str(distances)) assignment_error = True errors += 1 + if self._stoponerror: + break else: processed += 1 for i, idx in enumerate(indices): @@ -89,17 +99,16 @@ class ConsistencyWorker(QRunnable): last_pos[assignments[i]-1] = pp[i] last_frame[assignments[i]-1] = f assignment_error = False - if f % steps == 0: + if steps > 0 and f % steps == 0: progress += 1 self.signals.progress.emit(progress, processed, errors) - self.signals.finished.emit(True) - + self.signals.stopped.emit(f) class SizeClassifier(QWidget): apply = Signal() - name = "SizeClassifier" + name = "Size classifier" def __init__(self, parent=None): super().__init__(parent) @@ -286,12 +295,12 @@ class NeighborhoodValidator(QWidget): class ConsistencyClassifier(QWidget): apply = Signal() - name = "Consistency classifier" + name = "Consistency tracker" def __init__(self, parent=None): super().__init__(parent) self._data = None - self._all_cogs = None + self._all_pos = None self._all_orientations = None self._all_lengths = None self._all_bendedness = None @@ -299,41 +308,61 @@ class ConsistencyClassifier(QWidget): self._frames = None self._tracks = None self._worker = None + self._processed_frames = 0 self._errorlabel = QLabel() self._errorlabel.setStyleSheet("QLabel { color : red; }") self._assignedlabel = QLabel() + self._maxframeslabel = QLabel() self._startframe_spinner = QSpinBox() - self._startbtn = QPushButton("run") - self._startbtn.clicked.connect(self.run) + self._startbtn = QPushButton("start") + self._startbtn.clicked.connect(self.start) self._startbtn.setEnabled(False) - self._cancelbtn = QPushButton("cancel") - self._cancelbtn.clicked.connect(self.cancel) - self._cancelbtn.setEnabled(False) + self._stopbtn = QPushButton("stop") + self._stopbtn.clicked.connect(self.stop) + self._stopbtn.setEnabled(False) + + self._proceedbtn = QPushButton("proceed") + self._proceedbtn.clicked.connect(self.proceed) + self._proceedbtn.setEnabled(False) + + self._refreshbtn = QPushButton("refresh") + self._refreshbtn.clicked.connect(self.refresh) + self._refreshbtn.setEnabled(True) + self._apply_btn = QPushButton("apply") + self._apply_btn.clicked.connect(lambda: self.apply.emit()) + self._apply_btn.setEnabled(False) self._progressbar = QProgressBar() self._progressbar.setMinimum(0) self._progressbar.setMaximum(100) - self._apply_btn.clicked.connect(lambda: self.apply.emit()) - self._apply_btn.setEnabled(False) + self._stoponerror = QCheckBox("Stop processing whenever an error is encountered") + self._stoponerror.setToolTip("Stop process whenever ") + self._stoponerror.setCheckable(True) + self._stoponerror.setChecked(True) self.threadpool = QThreadPool() lyt = QGridLayout() lyt.addWidget(QLabel("Start frame:"), 0, 0 ) - lyt.addWidget(self._startframe_spinner, 0, 1 ) - lyt.addWidget(QLabel("assigned"), 1, 0) - lyt.addWidget(self._assignedlabel, 1, 1) - lyt.addWidget(QLabel("errors/issues"), 2, 0) - lyt.addWidget(self._errorlabel, 2, 1) - - lyt.addWidget(self._startbtn, 3, 0) - lyt.addWidget(self._cancelbtn, 3, 1) - lyt.addWidget(self._progressbar, 4, 0, 1, 2) - lyt.addWidget(self._apply_btn, 5, 0, 1, 2) + lyt.addWidget(self._startframe_spinner, 0, 1, 1, 2) + lyt.addWidget(QLabel("of"), 1, 1, 1, 1) + lyt.addWidget(self._maxframeslabel, 1, 2, 1, 1) + lyt.addWidget(self._stoponerror, 2, 0, 1, 3) + lyt.addWidget(QLabel("assigned"), 3, 0) + lyt.addWidget(self._assignedlabel, 3, 1) + lyt.addWidget(QLabel("errors/issues"), 4, 0) + lyt.addWidget(self._errorlabel, 4, 1) + + lyt.addWidget(self._startbtn, 5, 0) + lyt.addWidget(self._stopbtn, 5, 1) + lyt.addWidget(self._proceedbtn, 5, 2) + lyt.addWidget(self._apply_btn, 6, 0, 1, 2) + lyt.addWidget(self._refreshbtn, 6, 2, 1, 1) + lyt.addWidget(self._progressbar, 7, 0, 1, 3) self.setLayout(lyt) def setData(self, data:TrackingData): @@ -344,19 +373,23 @@ class ConsistencyClassifier(QWidget): data : Trackingdata The tracking data. """ - - self._all_cogs = data.centerOfGravity() + self._data = data + self._all_pos = data.centerOfGravity() self._all_orientations = data.orientation() self._all_lengths = data.animalLength() self._all_bendedness = data.bendedness() self._all_scores = data["confidence"] # ignore for now, let's see how far this carries. self._frames = data["frame"] self._tracks = data["track"] + self._maxframes = np.max(self._frames) min_frame = max([self._frames[self._tracks == 1][0], self._frames[self._tracks == 2][0]]) + 1 + self._maxframeslabel.setText(str(self._maxframes)) self._startframe_spinner.setMinimum(min_frame) self._startframe_spinner.setMaximum(self._frames[-1]) self._startframe_spinner.setValue(self._frames[0] + 1) self._startbtn.setEnabled(True) + self._assignedlabel.setText("0") + self._errorlabel.setText("0") self._worker = None @Slot(float) @@ -364,30 +397,44 @@ class ConsistencyClassifier(QWidget): if self._progressbar is not None: self._progressDialog.setValue(int(value * 100)) - def cancel(self): + def stop(self): if self._worker is not None: - self._worker.cancel() + self._worker.stop() self._startbtn.setEnabled(True) - self._cancelbtn.setEnabled(False) + self._proceedbtn.setEnabled(True) + self._stopbtn.setEnabled(False) + self._refreshbtn.setEnabled(True) - def run(self): + def start(self): self._startbtn.setEnabled(False) - self._cancelbtn.setEnabled(True) - self._worker = ConsistencyWorker(self._all_cogs, self._all_orientations, self._all_lengths, - self._all_bendedness, self._frames, self._tracks, self._startframe_spinner.value()) - self._worker.signals.finished.connect(self.worker_done) + self._refreshbtn.setEnabled(False) + self._stopbtn.setEnabled(True) + self._worker = ConsistencyWorker(self._all_pos, self._all_orientations, self._all_lengths, + self._all_bendedness, self._frames, self._tracks, + self._startframe_spinner.value(), self._stoponerror.isChecked()) + self._worker.signals.stopped.connect(self.worker_stopped) self._worker.signals.progress.connect(self.worker_progress) self.threadpool.start(self._worker) + def proceed(self): + self.start() + + def refresh(self): + self.setData(self._data) + def worker_progress(self, progress, processed, errors): self._progressbar.setValue(progress) self._errorlabel.setText(str(errors)) self._assignedlabel.setText(str(processed)) - def worker_done(self): + def worker_stopped(self, frame): self._apply_btn.setEnabled(True) self._startbtn.setEnabled(True) - self._cancelbtn.setEnabled(False) + self._stopbtn.setEnabled(False) + self._startframe_spinner.setValue(frame-1) + self._proceedbtn.setEnabled(bool(frame < self._maxframes-1)) + self._refreshbtn.setEnabled(True) + self._processed_frames = frame def assignedTracks(self): return self._tracks @@ -441,7 +488,7 @@ def main(): from IPython import embed from fixtracks.info import PACKAGE_ROOT - datafile = PACKAGE_ROOT / "data/merged_small_tracked.pkl" + datafile = PACKAGE_ROOT / "data/merged_small.pkl" with open(datafile, "rb") as f: df = pickle.load(f)