[classifier] more interactions on consistencytracker

This commit is contained in:
Jan Grewe 2025-02-17 22:26:22 +01:00
parent 1c65296008
commit 74fc43b586

View File

@ -1,7 +1,8 @@
import logging
import numpy as np
from PySide6.QtWidgets import QWidget, QVBoxLayout, QTabWidget, QPushButton, QGraphicsView, QSpinBox, QProgressBar, QGridLayout, QLabel
from PySide6.QtWidgets import QWidget, QVBoxLayout, QTabWidget, QPushButton, QGraphicsView
from PySide6.QtWidgets import QSpinBox, QProgressBar, QGridLayout, QLabel, QCheckBox
from PySide6.QtCore import Signal, Slot, QRunnable, QObject, QThreadPool
from PySide6.QtGui import QBrush, QColor
import pyqtgraph as pg # needs to be imported after pyside to not import pyqt
@ -13,12 +14,13 @@ class WorkerSignals(QObject):
error = Signal(str)
running = Signal(bool)
progress = Signal(int, int, int)
finished = Signal(bool)
stopped = Signal(int)
class ConsistencyWorker(QRunnable):
signals = WorkerSignals()
def __init__(self, positions, orientations, lengths, bendedness, frames, tracks, startframe=0) -> None:
def __init__(self, positions, orientations, lengths, bendedness, frames, tracks,
startframe=0, stoponerror=False) -> None:
super().__init__()
self.positions = positions
self.orientations = orientations
@ -28,25 +30,29 @@ class ConsistencyWorker(QRunnable):
self.tracks = tracks
self._startframe = startframe
self._stoprequest = False
self._stoponerror = stoponerror
@Slot()
def cancel(self):
def stop(self):
self._stoprequest = True
@Slot()
def run(self):
last_pos = [self.positions[self.tracks == 1][0], self.positions[self.tracks == 2][0]]
last_frame = [self.frames[self.tracks == 1][0], self.frames[self.tracks == 2][0]]
last_angle = [self.orientations[self.tracks == 1][0], self.orientations[self.tracks == 2][0]]
last_pos = [self.positions[(self.tracks == 1) & (self.frames <= self._startframe)][-1],
self.positions[(self.tracks == 2) & (self.frames <= self._startframe)][-1]]
last_frame = [self.frames[(self.tracks == 1) & (self.frames <= self._startframe)][-1],
self.frames[(self.tracks == 2) & (self.frames <= self._startframe)][-1]]
# last_angle = [self.orientations[self.tracks == 1][0], self.orientations[self.tracks == 2][0]]
errors = 0
processed = 0
self._stoprequest = False
maxframes = np.max(self.frames)
steps = int((maxframes - self._startframe) // 100)
processed = 1
progress = 0
assignment_error = False
for f in range(self._startframe, np.max(self.frames), 1):
processed += 1
self._stoprequest = False
maxframes = np.max(self.frames)
startframe = np.max(last_frame)
steps = int((maxframes - startframe) // 200)
for f in range(startframe + 1, maxframes, 1):
if self._stoprequest:
break
indices = np.where(self.frames == f)[0]
@ -59,15 +65,17 @@ class ConsistencyWorker(QRunnable):
self.tracks[idx] = 2
last_frame[1] = f
last_pos[1] = p
last_angle[1] = self.orientations[idx]
# last_angle[1] = self.orientations[idx]
continue
if f < last_frame[1]:
last_frame[0] = f
last_pos[0] = p
last_angle[0] = self.orientations[idx]
# last_angle[0] = self.orientations[idx]
self.tracks[idx] = 1
continue
# else, we have already seen track one and track two entries
if f - last_frame[0] == 0 or f - last_frame[1] == 0:
print(f"framecount is zero! current frame {f}, last frame {last_frame[0]} and {last_frame[1]}")
distance_to_trackone = np.linalg.norm(p - last_pos[0])/(f - last_frame[0])
distance_to_tracktwo = np.linalg.norm(p - last_pos[1])/(f - last_frame[1])
most_likely_track = np.argmin([distance_to_trackone, distance_to_tracktwo]) + 1
@ -79,6 +87,8 @@ class ConsistencyWorker(QRunnable):
logging.warning("frame %i: Issues assigning based on distances %s", f, str(distances))
assignment_error = True
errors += 1
if self._stoponerror:
break
else:
processed += 1
for i, idx in enumerate(indices):
@ -89,17 +99,16 @@ class ConsistencyWorker(QRunnable):
last_pos[assignments[i]-1] = pp[i]
last_frame[assignments[i]-1] = f
assignment_error = False
if f % steps == 0:
if steps > 0 and f % steps == 0:
progress += 1
self.signals.progress.emit(progress, processed, errors)
self.signals.finished.emit(True)
self.signals.stopped.emit(f)
class SizeClassifier(QWidget):
apply = Signal()
name = "SizeClassifier"
name = "Size classifier"
def __init__(self, parent=None):
super().__init__(parent)
@ -286,12 +295,12 @@ class NeighborhoodValidator(QWidget):
class ConsistencyClassifier(QWidget):
apply = Signal()
name = "Consistency classifier"
name = "Consistency tracker"
def __init__(self, parent=None):
super().__init__(parent)
self._data = None
self._all_cogs = None
self._all_pos = None
self._all_orientations = None
self._all_lengths = None
self._all_bendedness = None
@ -299,41 +308,61 @@ class ConsistencyClassifier(QWidget):
self._frames = None
self._tracks = None
self._worker = None
self._processed_frames = 0
self._errorlabel = QLabel()
self._errorlabel.setStyleSheet("QLabel { color : red; }")
self._assignedlabel = QLabel()
self._maxframeslabel = QLabel()
self._startframe_spinner = QSpinBox()
self._startbtn = QPushButton("run")
self._startbtn.clicked.connect(self.run)
self._startbtn = QPushButton("start")
self._startbtn.clicked.connect(self.start)
self._startbtn.setEnabled(False)
self._cancelbtn = QPushButton("cancel")
self._cancelbtn.clicked.connect(self.cancel)
self._cancelbtn.setEnabled(False)
self._stopbtn = QPushButton("stop")
self._stopbtn.clicked.connect(self.stop)
self._stopbtn.setEnabled(False)
self._proceedbtn = QPushButton("proceed")
self._proceedbtn.clicked.connect(self.proceed)
self._proceedbtn.setEnabled(False)
self._refreshbtn = QPushButton("refresh")
self._refreshbtn.clicked.connect(self.refresh)
self._refreshbtn.setEnabled(True)
self._apply_btn = QPushButton("apply")
self._apply_btn.clicked.connect(lambda: self.apply.emit())
self._apply_btn.setEnabled(False)
self._progressbar = QProgressBar()
self._progressbar.setMinimum(0)
self._progressbar.setMaximum(100)
self._apply_btn.clicked.connect(lambda: self.apply.emit())
self._apply_btn.setEnabled(False)
self._stoponerror = QCheckBox("Stop processing whenever an error is encountered")
self._stoponerror.setToolTip("Stop process whenever ")
self._stoponerror.setCheckable(True)
self._stoponerror.setChecked(True)
self.threadpool = QThreadPool()
lyt = QGridLayout()
lyt.addWidget(QLabel("Start frame:"), 0, 0 )
lyt.addWidget(self._startframe_spinner, 0, 1 )
lyt.addWidget(QLabel("assigned"), 1, 0)
lyt.addWidget(self._assignedlabel, 1, 1)
lyt.addWidget(QLabel("errors/issues"), 2, 0)
lyt.addWidget(self._errorlabel, 2, 1)
lyt.addWidget(self._startbtn, 3, 0)
lyt.addWidget(self._cancelbtn, 3, 1)
lyt.addWidget(self._progressbar, 4, 0, 1, 2)
lyt.addWidget(self._apply_btn, 5, 0, 1, 2)
lyt.addWidget(self._startframe_spinner, 0, 1, 1, 2)
lyt.addWidget(QLabel("of"), 1, 1, 1, 1)
lyt.addWidget(self._maxframeslabel, 1, 2, 1, 1)
lyt.addWidget(self._stoponerror, 2, 0, 1, 3)
lyt.addWidget(QLabel("assigned"), 3, 0)
lyt.addWidget(self._assignedlabel, 3, 1)
lyt.addWidget(QLabel("errors/issues"), 4, 0)
lyt.addWidget(self._errorlabel, 4, 1)
lyt.addWidget(self._startbtn, 5, 0)
lyt.addWidget(self._stopbtn, 5, 1)
lyt.addWidget(self._proceedbtn, 5, 2)
lyt.addWidget(self._apply_btn, 6, 0, 1, 2)
lyt.addWidget(self._refreshbtn, 6, 2, 1, 1)
lyt.addWidget(self._progressbar, 7, 0, 1, 3)
self.setLayout(lyt)
def setData(self, data:TrackingData):
@ -344,19 +373,23 @@ class ConsistencyClassifier(QWidget):
data : Trackingdata
The tracking data.
"""
self._all_cogs = data.centerOfGravity()
self._data = data
self._all_pos = data.centerOfGravity()
self._all_orientations = data.orientation()
self._all_lengths = data.animalLength()
self._all_bendedness = data.bendedness()
self._all_scores = data["confidence"] # ignore for now, let's see how far this carries.
self._frames = data["frame"]
self._tracks = data["track"]
self._maxframes = np.max(self._frames)
min_frame = max([self._frames[self._tracks == 1][0], self._frames[self._tracks == 2][0]]) + 1
self._maxframeslabel.setText(str(self._maxframes))
self._startframe_spinner.setMinimum(min_frame)
self._startframe_spinner.setMaximum(self._frames[-1])
self._startframe_spinner.setValue(self._frames[0] + 1)
self._startbtn.setEnabled(True)
self._assignedlabel.setText("0")
self._errorlabel.setText("0")
self._worker = None
@Slot(float)
@ -364,30 +397,44 @@ class ConsistencyClassifier(QWidget):
if self._progressbar is not None:
self._progressDialog.setValue(int(value * 100))
def cancel(self):
def stop(self):
if self._worker is not None:
self._worker.cancel()
self._worker.stop()
self._startbtn.setEnabled(True)
self._cancelbtn.setEnabled(False)
self._proceedbtn.setEnabled(True)
self._stopbtn.setEnabled(False)
self._refreshbtn.setEnabled(True)
def run(self):
def start(self):
self._startbtn.setEnabled(False)
self._cancelbtn.setEnabled(True)
self._worker = ConsistencyWorker(self._all_cogs, self._all_orientations, self._all_lengths,
self._all_bendedness, self._frames, self._tracks, self._startframe_spinner.value())
self._worker.signals.finished.connect(self.worker_done)
self._refreshbtn.setEnabled(False)
self._stopbtn.setEnabled(True)
self._worker = ConsistencyWorker(self._all_pos, self._all_orientations, self._all_lengths,
self._all_bendedness, self._frames, self._tracks,
self._startframe_spinner.value(), self._stoponerror.isChecked())
self._worker.signals.stopped.connect(self.worker_stopped)
self._worker.signals.progress.connect(self.worker_progress)
self.threadpool.start(self._worker)
def proceed(self):
self.start()
def refresh(self):
self.setData(self._data)
def worker_progress(self, progress, processed, errors):
self._progressbar.setValue(progress)
self._errorlabel.setText(str(errors))
self._assignedlabel.setText(str(processed))
def worker_done(self):
def worker_stopped(self, frame):
self._apply_btn.setEnabled(True)
self._startbtn.setEnabled(True)
self._cancelbtn.setEnabled(False)
self._stopbtn.setEnabled(False)
self._startframe_spinner.setValue(frame-1)
self._proceedbtn.setEnabled(bool(frame < self._maxframes-1))
self._refreshbtn.setEnabled(True)
self._processed_frames = frame
def assignedTracks(self):
return self._tracks
@ -441,7 +488,7 @@ def main():
from IPython import embed
from fixtracks.info import PACKAGE_ROOT
datafile = PACKAGE_ROOT / "data/merged_small_tracked.pkl"
datafile = PACKAGE_ROOT / "data/merged_small.pkl"
with open(datafile, "rb") as f:
df = pickle.load(f)