diff --git a/fixtracks/widgets/classifier.py b/fixtracks/widgets/classifier.py index 5d02c94..e6569bc 100644 --- a/fixtracks/widgets/classifier.py +++ b/fixtracks/widgets/classifier.py @@ -5,6 +5,7 @@ from PySide6.QtWidgets import QWidget, QVBoxLayout, QTabWidget, QPushButton, QGr from PySide6.QtWidgets import QSpinBox, QProgressBar, QGridLayout, QLabel, QCheckBox, QProgressDialog from PySide6.QtCore import Qt, Signal, Slot, QRunnable, QObject, QThreadPool from PySide6.QtGui import QBrush, QColor + import pyqtgraph as pg # needs to be imported after pyside to not import pyqt from fixtracks.utils.trackingdata import TrackingData @@ -15,16 +16,42 @@ class WorkerSignals(QObject): progress = Signal(int, int, int) stopped = Signal(int) +class ConsitencyDataLoader(QRunnable): + def __init__(self, data): + super().__init__() + self.signals = WorkerSignals() + self.data = data + self.bendedness = self.positions = None + self.lengths = None + self.orientations = None + self.userlabeled = None + self.scores = None + self.frames = None + self.tracks = None + + @Slot() + def run(self): + self.positions = self.data.centerOfGravity() + self.orientations = self.data.orientation() + self.lengths = self.data.animalLength() + self.bendedness = self.data.bendedness() + self.userlabeled = self.data["userlabeled"] + self.scores = self.data["confidence"] # ignore for now, let's see how far this carries. + self.frames = self.data["frame"] + self.tracks = self.data["track"] + self.signals.stopped.emit(0) + class ConsistencyWorker(QRunnable): - signals = WorkerSignals() def __init__(self, positions, orientations, lengths, bendedness, frames, tracks, - startframe=0, stoponerror=False) -> None: + userlabeled, startframe=0, stoponerror=False) -> None: super().__init__() + self.signals = WorkerSignals() self.positions = positions self.orientations = orientations self.lengths = lengths - self._bendedness = bendedness + self.bendedness = bendedness + self.userlabeled = userlabeled self.frames = frames self.tracks = tracks self._startframe = startframe @@ -41,7 +68,8 @@ class ConsistencyWorker(QRunnable): self.positions[(self.tracks == 2) & (self.frames <= self._startframe)][-1]] last_frame = [self.frames[(self.tracks == 1) & (self.frames <= self._startframe)][-1], self.frames[(self.tracks == 2) & (self.frames <= self._startframe)][-1]] - # last_angle = [self.orientations[self.tracks == 1][0], self.orientations[self.tracks == 2][0]] + last_angle = [self.orientations[(self.tracks == 1) & (self.frames <= self._startframe)][-1], + self.orientations[(self.tracks == 2) & (self.frames <= self._startframe)][-1]] errors = 0 processed = 1 progress = 0 @@ -59,14 +87,24 @@ class ConsistencyWorker(QRunnable): originaltracks = self.tracks[indices] assignments = np.zeros_like(originaltracks) distances = np.zeros((len(originaltracks), 2)) + distances = np.zeros((len(originaltracks), 2)) for i, (idx, p) in enumerate(zip(indices, pp)): + if self.userlabeled[idx]: + print("userlabeled") + processed += 1 + last_pos[originaltracks[i]-1] = pp[i] + last_frame[originaltracks[i]-1] = f + last_angle[originaltracks[i]-1] = self.orientations[idx] + continue if f < last_frame[0]: + print("ping") self.tracks[idx] = 2 last_frame[1] = f last_pos[1] = p # last_angle[1] = self.orientations[idx] continue if f < last_frame[1]: + print("pang") last_frame[0] = f last_pos[0] = p # last_angle[0] = self.orientations[idx] @@ -87,6 +125,8 @@ class ConsistencyWorker(QRunnable): assignment_error = True errors += 1 if self._stoponerror: + from IPython import embed + embed() break else: processed += 1 @@ -97,6 +137,7 @@ class ConsistencyWorker(QRunnable): self.tracks[idx] = assignments[i] last_pos[assignments[i]-1] = pp[i] last_frame[assignments[i]-1] = f + last_angle[assignments[i]-1] = self.orientations[idx] assignment_error = False if steps > 0 and f % steps == 0: progress += 1 @@ -304,9 +345,12 @@ class ConsistencyClassifier(QWidget): self._all_lengths = None self._all_bendedness = None self._all_scores = None + self._userlabeled = None + self._maxframes = 0 self._frames = None self._tracks = None self._worker = None + self._dataworker = None self._processed_frames = 0 self._errorlabel = QLabel() @@ -340,7 +384,7 @@ class ConsistencyClassifier(QWidget): self._progressbar.setMaximum(100) self._stoponerror = QCheckBox("Stop processing whenever an error is encountered") - self._stoponerror.setToolTip("Stop process whenever ") + self._stoponerror.setToolTip("Stop process upon errors") self._stoponerror.setCheckable(True) self._stoponerror.setChecked(True) self.threadpool = QThreadPool() @@ -372,30 +416,37 @@ class ConsistencyClassifier(QWidget): data : Trackingdata The tracking data. """ - self._progressDialog = QProgressDialog("Updating...", "Cancel", 0, 0, self) - self._progressDialog.setWindowModality(Qt.WindowModal) - self._progressDialog.setMinimumDuration(0) - self._progressDialog.setValue(0) - self._progressDialog.show() + self.setEnabled(False) + self._progressbar.setRange(0,0) self._data = data - self._all_pos = data.centerOfGravity() - self._all_orientations = data.orientation() - self._all_lengths = data.animalLength() - self._all_bendedness = data.bendedness() - self._all_scores = data["confidence"] # ignore for now, let's see how far this carries. - self._frames = data["frame"] - self._tracks = data["track"] - self._maxframes = np.max(self._frames) - min_frame = max([self._frames[self._tracks == 1][0], self._frames[self._tracks == 2][0]]) + 1 - self._maxframeslabel.setText(str(self._maxframes)) - self._startframe_spinner.setMinimum(min_frame) - self._startframe_spinner.setMaximum(self._frames[-1]) - self._startframe_spinner.setValue(self._frames[0] + 1) - self._startbtn.setEnabled(True) - self._assignedlabel.setText("0") - self._errorlabel.setText("0") - self._worker = None - self._progressDialog.close() + self._dataworker = ConsitencyDataLoader(self._data) + self._dataworker.signals.stopped.connect(self.data_processed) + self.threadpool.start(self._dataworker) + + @Slot() + def data_processed(self): + if self._dataworker is not None: + self._progressbar.setRange(0,100) + self._progressbar.setValue(0) + self._all_pos = self._dataworker.positions + self._all_orientations = self._dataworker.orientations + self._all_lengths = self._dataworker.lengths + self._all_bendedness = self._dataworker.bendedness + self._userlabeled = self._dataworker.userlabeled + self._all_scores = self._dataworker.scores + self._frames = self._dataworker.frames + self._tracks = self._dataworker.tracks + self._maxframes = np.max(self._frames) + min_frame = max([self._frames[self._tracks == 1][0], self._frames[self._tracks == 2][0]]) + 1 + self._maxframeslabel.setText(str(self._maxframes)) + self._startframe_spinner.setMinimum(min_frame) + self._startframe_spinner.setMaximum(self._frames[-1]) + self._startframe_spinner.setValue(self._frames[0] + 1) + self._startbtn.setEnabled(True) + self._assignedlabel.setText("0") + self._errorlabel.setText("0") + self._dataworker = None + self.setEnabled(True) @Slot(float) def on_progress(self, value): @@ -415,7 +466,7 @@ class ConsistencyClassifier(QWidget): self._refreshbtn.setEnabled(False) self._stopbtn.setEnabled(True) self._worker = ConsistencyWorker(self._all_pos, self._all_orientations, self._all_lengths, - self._all_bendedness, self._frames, self._tracks, + self._all_bendedness, self._frames, self._tracks, self._userlabeled, self._startframe_spinner.value(), self._stoponerror.isChecked()) self._worker.signals.stopped.connect(self.worker_stopped) self._worker.signals.progress.connect(self.worker_progress) @@ -493,7 +544,7 @@ def main(): import pickle from fixtracks.info import PACKAGE_ROOT - datafile = PACKAGE_ROOT / "data/merged_small.pkl" + datafile = PACKAGE_ROOT / "data/merged.pkl" with open(datafile, "rb") as f: df = pickle.load(f)