[classifier] refresh in the background
This commit is contained in:
parent
ef6ff0d2b4
commit
881194ac66
@ -5,6 +5,7 @@ from PySide6.QtWidgets import QWidget, QVBoxLayout, QTabWidget, QPushButton, QGr
|
||||
from PySide6.QtWidgets import QSpinBox, QProgressBar, QGridLayout, QLabel, QCheckBox, QProgressDialog
|
||||
from PySide6.QtCore import Qt, Signal, Slot, QRunnable, QObject, QThreadPool
|
||||
from PySide6.QtGui import QBrush, QColor
|
||||
|
||||
import pyqtgraph as pg # needs to be imported after pyside to not import pyqt
|
||||
|
||||
from fixtracks.utils.trackingdata import TrackingData
|
||||
@ -15,16 +16,42 @@ class WorkerSignals(QObject):
|
||||
progress = Signal(int, int, int)
|
||||
stopped = Signal(int)
|
||||
|
||||
class ConsitencyDataLoader(QRunnable):
|
||||
def __init__(self, data):
|
||||
super().__init__()
|
||||
self.signals = WorkerSignals()
|
||||
self.data = data
|
||||
self.bendedness = self.positions = None
|
||||
self.lengths = None
|
||||
self.orientations = None
|
||||
self.userlabeled = None
|
||||
self.scores = None
|
||||
self.frames = None
|
||||
self.tracks = None
|
||||
|
||||
@Slot()
|
||||
def run(self):
|
||||
self.positions = self.data.centerOfGravity()
|
||||
self.orientations = self.data.orientation()
|
||||
self.lengths = self.data.animalLength()
|
||||
self.bendedness = self.data.bendedness()
|
||||
self.userlabeled = self.data["userlabeled"]
|
||||
self.scores = self.data["confidence"] # ignore for now, let's see how far this carries.
|
||||
self.frames = self.data["frame"]
|
||||
self.tracks = self.data["track"]
|
||||
self.signals.stopped.emit(0)
|
||||
|
||||
class ConsistencyWorker(QRunnable):
|
||||
signals = WorkerSignals()
|
||||
|
||||
def __init__(self, positions, orientations, lengths, bendedness, frames, tracks,
|
||||
startframe=0, stoponerror=False) -> None:
|
||||
userlabeled, startframe=0, stoponerror=False) -> None:
|
||||
super().__init__()
|
||||
self.signals = WorkerSignals()
|
||||
self.positions = positions
|
||||
self.orientations = orientations
|
||||
self.lengths = lengths
|
||||
self._bendedness = bendedness
|
||||
self.bendedness = bendedness
|
||||
self.userlabeled = userlabeled
|
||||
self.frames = frames
|
||||
self.tracks = tracks
|
||||
self._startframe = startframe
|
||||
@ -41,7 +68,8 @@ class ConsistencyWorker(QRunnable):
|
||||
self.positions[(self.tracks == 2) & (self.frames <= self._startframe)][-1]]
|
||||
last_frame = [self.frames[(self.tracks == 1) & (self.frames <= self._startframe)][-1],
|
||||
self.frames[(self.tracks == 2) & (self.frames <= self._startframe)][-1]]
|
||||
# last_angle = [self.orientations[self.tracks == 1][0], self.orientations[self.tracks == 2][0]]
|
||||
last_angle = [self.orientations[(self.tracks == 1) & (self.frames <= self._startframe)][-1],
|
||||
self.orientations[(self.tracks == 2) & (self.frames <= self._startframe)][-1]]
|
||||
errors = 0
|
||||
processed = 1
|
||||
progress = 0
|
||||
@ -59,14 +87,24 @@ class ConsistencyWorker(QRunnable):
|
||||
originaltracks = self.tracks[indices]
|
||||
assignments = np.zeros_like(originaltracks)
|
||||
distances = np.zeros((len(originaltracks), 2))
|
||||
distances = np.zeros((len(originaltracks), 2))
|
||||
for i, (idx, p) in enumerate(zip(indices, pp)):
|
||||
if self.userlabeled[idx]:
|
||||
print("userlabeled")
|
||||
processed += 1
|
||||
last_pos[originaltracks[i]-1] = pp[i]
|
||||
last_frame[originaltracks[i]-1] = f
|
||||
last_angle[originaltracks[i]-1] = self.orientations[idx]
|
||||
continue
|
||||
if f < last_frame[0]:
|
||||
print("ping")
|
||||
self.tracks[idx] = 2
|
||||
last_frame[1] = f
|
||||
last_pos[1] = p
|
||||
# last_angle[1] = self.orientations[idx]
|
||||
continue
|
||||
if f < last_frame[1]:
|
||||
print("pang")
|
||||
last_frame[0] = f
|
||||
last_pos[0] = p
|
||||
# last_angle[0] = self.orientations[idx]
|
||||
@ -87,6 +125,8 @@ class ConsistencyWorker(QRunnable):
|
||||
assignment_error = True
|
||||
errors += 1
|
||||
if self._stoponerror:
|
||||
from IPython import embed
|
||||
embed()
|
||||
break
|
||||
else:
|
||||
processed += 1
|
||||
@ -97,6 +137,7 @@ class ConsistencyWorker(QRunnable):
|
||||
self.tracks[idx] = assignments[i]
|
||||
last_pos[assignments[i]-1] = pp[i]
|
||||
last_frame[assignments[i]-1] = f
|
||||
last_angle[assignments[i]-1] = self.orientations[idx]
|
||||
assignment_error = False
|
||||
if steps > 0 and f % steps == 0:
|
||||
progress += 1
|
||||
@ -304,9 +345,12 @@ class ConsistencyClassifier(QWidget):
|
||||
self._all_lengths = None
|
||||
self._all_bendedness = None
|
||||
self._all_scores = None
|
||||
self._userlabeled = None
|
||||
self._maxframes = 0
|
||||
self._frames = None
|
||||
self._tracks = None
|
||||
self._worker = None
|
||||
self._dataworker = None
|
||||
self._processed_frames = 0
|
||||
|
||||
self._errorlabel = QLabel()
|
||||
@ -340,7 +384,7 @@ class ConsistencyClassifier(QWidget):
|
||||
self._progressbar.setMaximum(100)
|
||||
|
||||
self._stoponerror = QCheckBox("Stop processing whenever an error is encountered")
|
||||
self._stoponerror.setToolTip("Stop process whenever ")
|
||||
self._stoponerror.setToolTip("Stop process upon errors")
|
||||
self._stoponerror.setCheckable(True)
|
||||
self._stoponerror.setChecked(True)
|
||||
self.threadpool = QThreadPool()
|
||||
@ -372,30 +416,37 @@ class ConsistencyClassifier(QWidget):
|
||||
data : Trackingdata
|
||||
The tracking data.
|
||||
"""
|
||||
self._progressDialog = QProgressDialog("Updating...", "Cancel", 0, 0, self)
|
||||
self._progressDialog.setWindowModality(Qt.WindowModal)
|
||||
self._progressDialog.setMinimumDuration(0)
|
||||
self._progressDialog.setValue(0)
|
||||
self._progressDialog.show()
|
||||
self.setEnabled(False)
|
||||
self._progressbar.setRange(0,0)
|
||||
self._data = data
|
||||
self._all_pos = data.centerOfGravity()
|
||||
self._all_orientations = data.orientation()
|
||||
self._all_lengths = data.animalLength()
|
||||
self._all_bendedness = data.bendedness()
|
||||
self._all_scores = data["confidence"] # ignore for now, let's see how far this carries.
|
||||
self._frames = data["frame"]
|
||||
self._tracks = data["track"]
|
||||
self._maxframes = np.max(self._frames)
|
||||
min_frame = max([self._frames[self._tracks == 1][0], self._frames[self._tracks == 2][0]]) + 1
|
||||
self._maxframeslabel.setText(str(self._maxframes))
|
||||
self._startframe_spinner.setMinimum(min_frame)
|
||||
self._startframe_spinner.setMaximum(self._frames[-1])
|
||||
self._startframe_spinner.setValue(self._frames[0] + 1)
|
||||
self._startbtn.setEnabled(True)
|
||||
self._assignedlabel.setText("0")
|
||||
self._errorlabel.setText("0")
|
||||
self._worker = None
|
||||
self._progressDialog.close()
|
||||
self._dataworker = ConsitencyDataLoader(self._data)
|
||||
self._dataworker.signals.stopped.connect(self.data_processed)
|
||||
self.threadpool.start(self._dataworker)
|
||||
|
||||
@Slot()
|
||||
def data_processed(self):
|
||||
if self._dataworker is not None:
|
||||
self._progressbar.setRange(0,100)
|
||||
self._progressbar.setValue(0)
|
||||
self._all_pos = self._dataworker.positions
|
||||
self._all_orientations = self._dataworker.orientations
|
||||
self._all_lengths = self._dataworker.lengths
|
||||
self._all_bendedness = self._dataworker.bendedness
|
||||
self._userlabeled = self._dataworker.userlabeled
|
||||
self._all_scores = self._dataworker.scores
|
||||
self._frames = self._dataworker.frames
|
||||
self._tracks = self._dataworker.tracks
|
||||
self._maxframes = np.max(self._frames)
|
||||
min_frame = max([self._frames[self._tracks == 1][0], self._frames[self._tracks == 2][0]]) + 1
|
||||
self._maxframeslabel.setText(str(self._maxframes))
|
||||
self._startframe_spinner.setMinimum(min_frame)
|
||||
self._startframe_spinner.setMaximum(self._frames[-1])
|
||||
self._startframe_spinner.setValue(self._frames[0] + 1)
|
||||
self._startbtn.setEnabled(True)
|
||||
self._assignedlabel.setText("0")
|
||||
self._errorlabel.setText("0")
|
||||
self._dataworker = None
|
||||
self.setEnabled(True)
|
||||
|
||||
@Slot(float)
|
||||
def on_progress(self, value):
|
||||
@ -415,7 +466,7 @@ class ConsistencyClassifier(QWidget):
|
||||
self._refreshbtn.setEnabled(False)
|
||||
self._stopbtn.setEnabled(True)
|
||||
self._worker = ConsistencyWorker(self._all_pos, self._all_orientations, self._all_lengths,
|
||||
self._all_bendedness, self._frames, self._tracks,
|
||||
self._all_bendedness, self._frames, self._tracks, self._userlabeled,
|
||||
self._startframe_spinner.value(), self._stoponerror.isChecked())
|
||||
self._worker.signals.stopped.connect(self.worker_stopped)
|
||||
self._worker.signals.progress.connect(self.worker_progress)
|
||||
@ -493,7 +544,7 @@ def main():
|
||||
import pickle
|
||||
from fixtracks.info import PACKAGE_ROOT
|
||||
|
||||
datafile = PACKAGE_ROOT / "data/merged_small.pkl"
|
||||
datafile = PACKAGE_ROOT / "data/merged.pkl"
|
||||
|
||||
with open(datafile, "rb") as f:
|
||||
df = pickle.load(f)
|
||||
|
Loading…
Reference in New Issue
Block a user