[classifier] refresh in the background

This commit is contained in:
Jan Grewe 2025-02-18 11:08:23 +01:00
parent ef6ff0d2b4
commit 881194ac66

View File

@ -5,6 +5,7 @@ from PySide6.QtWidgets import QWidget, QVBoxLayout, QTabWidget, QPushButton, QGr
from PySide6.QtWidgets import QSpinBox, QProgressBar, QGridLayout, QLabel, QCheckBox, QProgressDialog
from PySide6.QtCore import Qt, Signal, Slot, QRunnable, QObject, QThreadPool
from PySide6.QtGui import QBrush, QColor
import pyqtgraph as pg # needs to be imported after pyside to not import pyqt
from fixtracks.utils.trackingdata import TrackingData
@ -15,16 +16,42 @@ class WorkerSignals(QObject):
progress = Signal(int, int, int)
stopped = Signal(int)
class ConsitencyDataLoader(QRunnable):
def __init__(self, data):
super().__init__()
self.signals = WorkerSignals()
self.data = data
self.bendedness = self.positions = None
self.lengths = None
self.orientations = None
self.userlabeled = None
self.scores = None
self.frames = None
self.tracks = None
@Slot()
def run(self):
self.positions = self.data.centerOfGravity()
self.orientations = self.data.orientation()
self.lengths = self.data.animalLength()
self.bendedness = self.data.bendedness()
self.userlabeled = self.data["userlabeled"]
self.scores = self.data["confidence"] # ignore for now, let's see how far this carries.
self.frames = self.data["frame"]
self.tracks = self.data["track"]
self.signals.stopped.emit(0)
class ConsistencyWorker(QRunnable):
signals = WorkerSignals()
def __init__(self, positions, orientations, lengths, bendedness, frames, tracks,
startframe=0, stoponerror=False) -> None:
userlabeled, startframe=0, stoponerror=False) -> None:
super().__init__()
self.signals = WorkerSignals()
self.positions = positions
self.orientations = orientations
self.lengths = lengths
self._bendedness = bendedness
self.bendedness = bendedness
self.userlabeled = userlabeled
self.frames = frames
self.tracks = tracks
self._startframe = startframe
@ -41,7 +68,8 @@ class ConsistencyWorker(QRunnable):
self.positions[(self.tracks == 2) & (self.frames <= self._startframe)][-1]]
last_frame = [self.frames[(self.tracks == 1) & (self.frames <= self._startframe)][-1],
self.frames[(self.tracks == 2) & (self.frames <= self._startframe)][-1]]
# last_angle = [self.orientations[self.tracks == 1][0], self.orientations[self.tracks == 2][0]]
last_angle = [self.orientations[(self.tracks == 1) & (self.frames <= self._startframe)][-1],
self.orientations[(self.tracks == 2) & (self.frames <= self._startframe)][-1]]
errors = 0
processed = 1
progress = 0
@ -59,14 +87,24 @@ class ConsistencyWorker(QRunnable):
originaltracks = self.tracks[indices]
assignments = np.zeros_like(originaltracks)
distances = np.zeros((len(originaltracks), 2))
distances = np.zeros((len(originaltracks), 2))
for i, (idx, p) in enumerate(zip(indices, pp)):
if self.userlabeled[idx]:
print("userlabeled")
processed += 1
last_pos[originaltracks[i]-1] = pp[i]
last_frame[originaltracks[i]-1] = f
last_angle[originaltracks[i]-1] = self.orientations[idx]
continue
if f < last_frame[0]:
print("ping")
self.tracks[idx] = 2
last_frame[1] = f
last_pos[1] = p
# last_angle[1] = self.orientations[idx]
continue
if f < last_frame[1]:
print("pang")
last_frame[0] = f
last_pos[0] = p
# last_angle[0] = self.orientations[idx]
@ -87,6 +125,8 @@ class ConsistencyWorker(QRunnable):
assignment_error = True
errors += 1
if self._stoponerror:
from IPython import embed
embed()
break
else:
processed += 1
@ -97,6 +137,7 @@ class ConsistencyWorker(QRunnable):
self.tracks[idx] = assignments[i]
last_pos[assignments[i]-1] = pp[i]
last_frame[assignments[i]-1] = f
last_angle[assignments[i]-1] = self.orientations[idx]
assignment_error = False
if steps > 0 and f % steps == 0:
progress += 1
@ -304,9 +345,12 @@ class ConsistencyClassifier(QWidget):
self._all_lengths = None
self._all_bendedness = None
self._all_scores = None
self._userlabeled = None
self._maxframes = 0
self._frames = None
self._tracks = None
self._worker = None
self._dataworker = None
self._processed_frames = 0
self._errorlabel = QLabel()
@ -340,7 +384,7 @@ class ConsistencyClassifier(QWidget):
self._progressbar.setMaximum(100)
self._stoponerror = QCheckBox("Stop processing whenever an error is encountered")
self._stoponerror.setToolTip("Stop process whenever ")
self._stoponerror.setToolTip("Stop process upon errors")
self._stoponerror.setCheckable(True)
self._stoponerror.setChecked(True)
self.threadpool = QThreadPool()
@ -372,30 +416,37 @@ class ConsistencyClassifier(QWidget):
data : Trackingdata
The tracking data.
"""
self._progressDialog = QProgressDialog("Updating...", "Cancel", 0, 0, self)
self._progressDialog.setWindowModality(Qt.WindowModal)
self._progressDialog.setMinimumDuration(0)
self._progressDialog.setValue(0)
self._progressDialog.show()
self.setEnabled(False)
self._progressbar.setRange(0,0)
self._data = data
self._all_pos = data.centerOfGravity()
self._all_orientations = data.orientation()
self._all_lengths = data.animalLength()
self._all_bendedness = data.bendedness()
self._all_scores = data["confidence"] # ignore for now, let's see how far this carries.
self._frames = data["frame"]
self._tracks = data["track"]
self._maxframes = np.max(self._frames)
min_frame = max([self._frames[self._tracks == 1][0], self._frames[self._tracks == 2][0]]) + 1
self._maxframeslabel.setText(str(self._maxframes))
self._startframe_spinner.setMinimum(min_frame)
self._startframe_spinner.setMaximum(self._frames[-1])
self._startframe_spinner.setValue(self._frames[0] + 1)
self._startbtn.setEnabled(True)
self._assignedlabel.setText("0")
self._errorlabel.setText("0")
self._worker = None
self._progressDialog.close()
self._dataworker = ConsitencyDataLoader(self._data)
self._dataworker.signals.stopped.connect(self.data_processed)
self.threadpool.start(self._dataworker)
@Slot()
def data_processed(self):
if self._dataworker is not None:
self._progressbar.setRange(0,100)
self._progressbar.setValue(0)
self._all_pos = self._dataworker.positions
self._all_orientations = self._dataworker.orientations
self._all_lengths = self._dataworker.lengths
self._all_bendedness = self._dataworker.bendedness
self._userlabeled = self._dataworker.userlabeled
self._all_scores = self._dataworker.scores
self._frames = self._dataworker.frames
self._tracks = self._dataworker.tracks
self._maxframes = np.max(self._frames)
min_frame = max([self._frames[self._tracks == 1][0], self._frames[self._tracks == 2][0]]) + 1
self._maxframeslabel.setText(str(self._maxframes))
self._startframe_spinner.setMinimum(min_frame)
self._startframe_spinner.setMaximum(self._frames[-1])
self._startframe_spinner.setValue(self._frames[0] + 1)
self._startbtn.setEnabled(True)
self._assignedlabel.setText("0")
self._errorlabel.setText("0")
self._dataworker = None
self.setEnabled(True)
@Slot(float)
def on_progress(self, value):
@ -415,7 +466,7 @@ class ConsistencyClassifier(QWidget):
self._refreshbtn.setEnabled(False)
self._stopbtn.setEnabled(True)
self._worker = ConsistencyWorker(self._all_pos, self._all_orientations, self._all_lengths,
self._all_bendedness, self._frames, self._tracks,
self._all_bendedness, self._frames, self._tracks, self._userlabeled,
self._startframe_spinner.value(), self._stoponerror.isChecked())
self._worker.signals.stopped.connect(self.worker_stopped)
self._worker.signals.progress.connect(self.worker_progress)
@ -493,7 +544,7 @@ def main():
import pickle
from fixtracks.info import PACKAGE_ROOT
datafile = PACKAGE_ROOT / "data/merged_small.pkl"
datafile = PACKAGE_ROOT / "data/merged.pkl"
with open(datafile, "rb") as f:
df = pickle.load(f)