[classifier] refresh in the background

This commit is contained in:
Jan Grewe 2025-02-18 11:08:23 +01:00
parent ef6ff0d2b4
commit 881194ac66

View File

@ -5,6 +5,7 @@ from PySide6.QtWidgets import QWidget, QVBoxLayout, QTabWidget, QPushButton, QGr
from PySide6.QtWidgets import QSpinBox, QProgressBar, QGridLayout, QLabel, QCheckBox, QProgressDialog from PySide6.QtWidgets import QSpinBox, QProgressBar, QGridLayout, QLabel, QCheckBox, QProgressDialog
from PySide6.QtCore import Qt, Signal, Slot, QRunnable, QObject, QThreadPool from PySide6.QtCore import Qt, Signal, Slot, QRunnable, QObject, QThreadPool
from PySide6.QtGui import QBrush, QColor from PySide6.QtGui import QBrush, QColor
import pyqtgraph as pg # needs to be imported after pyside to not import pyqt import pyqtgraph as pg # needs to be imported after pyside to not import pyqt
from fixtracks.utils.trackingdata import TrackingData from fixtracks.utils.trackingdata import TrackingData
@ -15,16 +16,42 @@ class WorkerSignals(QObject):
progress = Signal(int, int, int) progress = Signal(int, int, int)
stopped = Signal(int) stopped = Signal(int)
class ConsitencyDataLoader(QRunnable):
def __init__(self, data):
super().__init__()
self.signals = WorkerSignals()
self.data = data
self.bendedness = self.positions = None
self.lengths = None
self.orientations = None
self.userlabeled = None
self.scores = None
self.frames = None
self.tracks = None
@Slot()
def run(self):
self.positions = self.data.centerOfGravity()
self.orientations = self.data.orientation()
self.lengths = self.data.animalLength()
self.bendedness = self.data.bendedness()
self.userlabeled = self.data["userlabeled"]
self.scores = self.data["confidence"] # ignore for now, let's see how far this carries.
self.frames = self.data["frame"]
self.tracks = self.data["track"]
self.signals.stopped.emit(0)
class ConsistencyWorker(QRunnable): class ConsistencyWorker(QRunnable):
signals = WorkerSignals()
def __init__(self, positions, orientations, lengths, bendedness, frames, tracks, def __init__(self, positions, orientations, lengths, bendedness, frames, tracks,
startframe=0, stoponerror=False) -> None: userlabeled, startframe=0, stoponerror=False) -> None:
super().__init__() super().__init__()
self.signals = WorkerSignals()
self.positions = positions self.positions = positions
self.orientations = orientations self.orientations = orientations
self.lengths = lengths self.lengths = lengths
self._bendedness = bendedness self.bendedness = bendedness
self.userlabeled = userlabeled
self.frames = frames self.frames = frames
self.tracks = tracks self.tracks = tracks
self._startframe = startframe self._startframe = startframe
@ -41,7 +68,8 @@ class ConsistencyWorker(QRunnable):
self.positions[(self.tracks == 2) & (self.frames <= self._startframe)][-1]] self.positions[(self.tracks == 2) & (self.frames <= self._startframe)][-1]]
last_frame = [self.frames[(self.tracks == 1) & (self.frames <= self._startframe)][-1], last_frame = [self.frames[(self.tracks == 1) & (self.frames <= self._startframe)][-1],
self.frames[(self.tracks == 2) & (self.frames <= self._startframe)][-1]] self.frames[(self.tracks == 2) & (self.frames <= self._startframe)][-1]]
# last_angle = [self.orientations[self.tracks == 1][0], self.orientations[self.tracks == 2][0]] last_angle = [self.orientations[(self.tracks == 1) & (self.frames <= self._startframe)][-1],
self.orientations[(self.tracks == 2) & (self.frames <= self._startframe)][-1]]
errors = 0 errors = 0
processed = 1 processed = 1
progress = 0 progress = 0
@ -59,14 +87,24 @@ class ConsistencyWorker(QRunnable):
originaltracks = self.tracks[indices] originaltracks = self.tracks[indices]
assignments = np.zeros_like(originaltracks) assignments = np.zeros_like(originaltracks)
distances = np.zeros((len(originaltracks), 2)) distances = np.zeros((len(originaltracks), 2))
distances = np.zeros((len(originaltracks), 2))
for i, (idx, p) in enumerate(zip(indices, pp)): for i, (idx, p) in enumerate(zip(indices, pp)):
if self.userlabeled[idx]:
print("userlabeled")
processed += 1
last_pos[originaltracks[i]-1] = pp[i]
last_frame[originaltracks[i]-1] = f
last_angle[originaltracks[i]-1] = self.orientations[idx]
continue
if f < last_frame[0]: if f < last_frame[0]:
print("ping")
self.tracks[idx] = 2 self.tracks[idx] = 2
last_frame[1] = f last_frame[1] = f
last_pos[1] = p last_pos[1] = p
# last_angle[1] = self.orientations[idx] # last_angle[1] = self.orientations[idx]
continue continue
if f < last_frame[1]: if f < last_frame[1]:
print("pang")
last_frame[0] = f last_frame[0] = f
last_pos[0] = p last_pos[0] = p
# last_angle[0] = self.orientations[idx] # last_angle[0] = self.orientations[idx]
@ -87,6 +125,8 @@ class ConsistencyWorker(QRunnable):
assignment_error = True assignment_error = True
errors += 1 errors += 1
if self._stoponerror: if self._stoponerror:
from IPython import embed
embed()
break break
else: else:
processed += 1 processed += 1
@ -97,6 +137,7 @@ class ConsistencyWorker(QRunnable):
self.tracks[idx] = assignments[i] self.tracks[idx] = assignments[i]
last_pos[assignments[i]-1] = pp[i] last_pos[assignments[i]-1] = pp[i]
last_frame[assignments[i]-1] = f last_frame[assignments[i]-1] = f
last_angle[assignments[i]-1] = self.orientations[idx]
assignment_error = False assignment_error = False
if steps > 0 and f % steps == 0: if steps > 0 and f % steps == 0:
progress += 1 progress += 1
@ -304,9 +345,12 @@ class ConsistencyClassifier(QWidget):
self._all_lengths = None self._all_lengths = None
self._all_bendedness = None self._all_bendedness = None
self._all_scores = None self._all_scores = None
self._userlabeled = None
self._maxframes = 0
self._frames = None self._frames = None
self._tracks = None self._tracks = None
self._worker = None self._worker = None
self._dataworker = None
self._processed_frames = 0 self._processed_frames = 0
self._errorlabel = QLabel() self._errorlabel = QLabel()
@ -340,7 +384,7 @@ class ConsistencyClassifier(QWidget):
self._progressbar.setMaximum(100) self._progressbar.setMaximum(100)
self._stoponerror = QCheckBox("Stop processing whenever an error is encountered") self._stoponerror = QCheckBox("Stop processing whenever an error is encountered")
self._stoponerror.setToolTip("Stop process whenever ") self._stoponerror.setToolTip("Stop process upon errors")
self._stoponerror.setCheckable(True) self._stoponerror.setCheckable(True)
self._stoponerror.setChecked(True) self._stoponerror.setChecked(True)
self.threadpool = QThreadPool() self.threadpool = QThreadPool()
@ -372,30 +416,37 @@ class ConsistencyClassifier(QWidget):
data : Trackingdata data : Trackingdata
The tracking data. The tracking data.
""" """
self._progressDialog = QProgressDialog("Updating...", "Cancel", 0, 0, self) self.setEnabled(False)
self._progressDialog.setWindowModality(Qt.WindowModal) self._progressbar.setRange(0,0)
self._progressDialog.setMinimumDuration(0)
self._progressDialog.setValue(0)
self._progressDialog.show()
self._data = data self._data = data
self._all_pos = data.centerOfGravity() self._dataworker = ConsitencyDataLoader(self._data)
self._all_orientations = data.orientation() self._dataworker.signals.stopped.connect(self.data_processed)
self._all_lengths = data.animalLength() self.threadpool.start(self._dataworker)
self._all_bendedness = data.bendedness()
self._all_scores = data["confidence"] # ignore for now, let's see how far this carries. @Slot()
self._frames = data["frame"] def data_processed(self):
self._tracks = data["track"] if self._dataworker is not None:
self._maxframes = np.max(self._frames) self._progressbar.setRange(0,100)
min_frame = max([self._frames[self._tracks == 1][0], self._frames[self._tracks == 2][0]]) + 1 self._progressbar.setValue(0)
self._maxframeslabel.setText(str(self._maxframes)) self._all_pos = self._dataworker.positions
self._startframe_spinner.setMinimum(min_frame) self._all_orientations = self._dataworker.orientations
self._startframe_spinner.setMaximum(self._frames[-1]) self._all_lengths = self._dataworker.lengths
self._startframe_spinner.setValue(self._frames[0] + 1) self._all_bendedness = self._dataworker.bendedness
self._startbtn.setEnabled(True) self._userlabeled = self._dataworker.userlabeled
self._assignedlabel.setText("0") self._all_scores = self._dataworker.scores
self._errorlabel.setText("0") self._frames = self._dataworker.frames
self._worker = None self._tracks = self._dataworker.tracks
self._progressDialog.close() self._maxframes = np.max(self._frames)
min_frame = max([self._frames[self._tracks == 1][0], self._frames[self._tracks == 2][0]]) + 1
self._maxframeslabel.setText(str(self._maxframes))
self._startframe_spinner.setMinimum(min_frame)
self._startframe_spinner.setMaximum(self._frames[-1])
self._startframe_spinner.setValue(self._frames[0] + 1)
self._startbtn.setEnabled(True)
self._assignedlabel.setText("0")
self._errorlabel.setText("0")
self._dataworker = None
self.setEnabled(True)
@Slot(float) @Slot(float)
def on_progress(self, value): def on_progress(self, value):
@ -415,7 +466,7 @@ class ConsistencyClassifier(QWidget):
self._refreshbtn.setEnabled(False) self._refreshbtn.setEnabled(False)
self._stopbtn.setEnabled(True) self._stopbtn.setEnabled(True)
self._worker = ConsistencyWorker(self._all_pos, self._all_orientations, self._all_lengths, self._worker = ConsistencyWorker(self._all_pos, self._all_orientations, self._all_lengths,
self._all_bendedness, self._frames, self._tracks, self._all_bendedness, self._frames, self._tracks, self._userlabeled,
self._startframe_spinner.value(), self._stoponerror.isChecked()) self._startframe_spinner.value(), self._stoponerror.isChecked())
self._worker.signals.stopped.connect(self.worker_stopped) self._worker.signals.stopped.connect(self.worker_stopped)
self._worker.signals.progress.connect(self.worker_progress) self._worker.signals.progress.connect(self.worker_progress)
@ -493,7 +544,7 @@ def main():
import pickle import pickle
from fixtracks.info import PACKAGE_ROOT from fixtracks.info import PACKAGE_ROOT
datafile = PACKAGE_ROOT / "data/merged_small.pkl" datafile = PACKAGE_ROOT / "data/merged.pkl"
with open(datafile, "rb") as f: with open(datafile, "rb") as f:
df = pickle.load(f) df = pickle.load(f)