[classifier] more interactions on consistencytracker

This commit is contained in:
Jan Grewe 2025-02-17 22:26:22 +01:00
parent 1c65296008
commit 74fc43b586

View File

@ -1,7 +1,8 @@
import logging import logging
import numpy as np import numpy as np
from PySide6.QtWidgets import QWidget, QVBoxLayout, QTabWidget, QPushButton, QGraphicsView, QSpinBox, QProgressBar, QGridLayout, QLabel from PySide6.QtWidgets import QWidget, QVBoxLayout, QTabWidget, QPushButton, QGraphicsView
from PySide6.QtWidgets import QSpinBox, QProgressBar, QGridLayout, QLabel, QCheckBox
from PySide6.QtCore import Signal, Slot, QRunnable, QObject, QThreadPool from PySide6.QtCore import Signal, Slot, QRunnable, QObject, QThreadPool
from PySide6.QtGui import QBrush, QColor from PySide6.QtGui import QBrush, QColor
import pyqtgraph as pg # needs to be imported after pyside to not import pyqt import pyqtgraph as pg # needs to be imported after pyside to not import pyqt
@ -13,12 +14,13 @@ class WorkerSignals(QObject):
error = Signal(str) error = Signal(str)
running = Signal(bool) running = Signal(bool)
progress = Signal(int, int, int) progress = Signal(int, int, int)
finished = Signal(bool) stopped = Signal(int)
class ConsistencyWorker(QRunnable): class ConsistencyWorker(QRunnable):
signals = WorkerSignals() signals = WorkerSignals()
def __init__(self, positions, orientations, lengths, bendedness, frames, tracks, startframe=0) -> None: def __init__(self, positions, orientations, lengths, bendedness, frames, tracks,
startframe=0, stoponerror=False) -> None:
super().__init__() super().__init__()
self.positions = positions self.positions = positions
self.orientations = orientations self.orientations = orientations
@ -28,25 +30,29 @@ class ConsistencyWorker(QRunnable):
self.tracks = tracks self.tracks = tracks
self._startframe = startframe self._startframe = startframe
self._stoprequest = False self._stoprequest = False
self._stoponerror = stoponerror
@Slot() @Slot()
def cancel(self): def stop(self):
self._stoprequest = True self._stoprequest = True
@Slot() @Slot()
def run(self): def run(self):
last_pos = [self.positions[self.tracks == 1][0], self.positions[self.tracks == 2][0]] last_pos = [self.positions[(self.tracks == 1) & (self.frames <= self._startframe)][-1],
last_frame = [self.frames[self.tracks == 1][0], self.frames[self.tracks == 2][0]] self.positions[(self.tracks == 2) & (self.frames <= self._startframe)][-1]]
last_angle = [self.orientations[self.tracks == 1][0], self.orientations[self.tracks == 2][0]] last_frame = [self.frames[(self.tracks == 1) & (self.frames <= self._startframe)][-1],
self.frames[(self.tracks == 2) & (self.frames <= self._startframe)][-1]]
# last_angle = [self.orientations[self.tracks == 1][0], self.orientations[self.tracks == 2][0]]
errors = 0 errors = 0
processed = 0 processed = 1
self._stoprequest = False
maxframes = np.max(self.frames)
steps = int((maxframes - self._startframe) // 100)
progress = 0 progress = 0
assignment_error = False assignment_error = False
for f in range(self._startframe, np.max(self.frames), 1): self._stoprequest = False
processed += 1 maxframes = np.max(self.frames)
startframe = np.max(last_frame)
steps = int((maxframes - startframe) // 200)
for f in range(startframe + 1, maxframes, 1):
if self._stoprequest: if self._stoprequest:
break break
indices = np.where(self.frames == f)[0] indices = np.where(self.frames == f)[0]
@ -59,15 +65,17 @@ class ConsistencyWorker(QRunnable):
self.tracks[idx] = 2 self.tracks[idx] = 2
last_frame[1] = f last_frame[1] = f
last_pos[1] = p last_pos[1] = p
last_angle[1] = self.orientations[idx] # last_angle[1] = self.orientations[idx]
continue continue
if f < last_frame[1]: if f < last_frame[1]:
last_frame[0] = f last_frame[0] = f
last_pos[0] = p last_pos[0] = p
last_angle[0] = self.orientations[idx] # last_angle[0] = self.orientations[idx]
self.tracks[idx] = 1 self.tracks[idx] = 1
continue continue
# else, we have already seen track one and track two entries # else, we have already seen track one and track two entries
if f - last_frame[0] == 0 or f - last_frame[1] == 0:
print(f"framecount is zero! current frame {f}, last frame {last_frame[0]} and {last_frame[1]}")
distance_to_trackone = np.linalg.norm(p - last_pos[0])/(f - last_frame[0]) distance_to_trackone = np.linalg.norm(p - last_pos[0])/(f - last_frame[0])
distance_to_tracktwo = np.linalg.norm(p - last_pos[1])/(f - last_frame[1]) distance_to_tracktwo = np.linalg.norm(p - last_pos[1])/(f - last_frame[1])
most_likely_track = np.argmin([distance_to_trackone, distance_to_tracktwo]) + 1 most_likely_track = np.argmin([distance_to_trackone, distance_to_tracktwo]) + 1
@ -79,6 +87,8 @@ class ConsistencyWorker(QRunnable):
logging.warning("frame %i: Issues assigning based on distances %s", f, str(distances)) logging.warning("frame %i: Issues assigning based on distances %s", f, str(distances))
assignment_error = True assignment_error = True
errors += 1 errors += 1
if self._stoponerror:
break
else: else:
processed += 1 processed += 1
for i, idx in enumerate(indices): for i, idx in enumerate(indices):
@ -89,17 +99,16 @@ class ConsistencyWorker(QRunnable):
last_pos[assignments[i]-1] = pp[i] last_pos[assignments[i]-1] = pp[i]
last_frame[assignments[i]-1] = f last_frame[assignments[i]-1] = f
assignment_error = False assignment_error = False
if f % steps == 0: if steps > 0 and f % steps == 0:
progress += 1 progress += 1
self.signals.progress.emit(progress, processed, errors) self.signals.progress.emit(progress, processed, errors)
self.signals.finished.emit(True) self.signals.stopped.emit(f)
class SizeClassifier(QWidget): class SizeClassifier(QWidget):
apply = Signal() apply = Signal()
name = "SizeClassifier" name = "Size classifier"
def __init__(self, parent=None): def __init__(self, parent=None):
super().__init__(parent) super().__init__(parent)
@ -286,12 +295,12 @@ class NeighborhoodValidator(QWidget):
class ConsistencyClassifier(QWidget): class ConsistencyClassifier(QWidget):
apply = Signal() apply = Signal()
name = "Consistency classifier" name = "Consistency tracker"
def __init__(self, parent=None): def __init__(self, parent=None):
super().__init__(parent) super().__init__(parent)
self._data = None self._data = None
self._all_cogs = None self._all_pos = None
self._all_orientations = None self._all_orientations = None
self._all_lengths = None self._all_lengths = None
self._all_bendedness = None self._all_bendedness = None
@ -299,41 +308,61 @@ class ConsistencyClassifier(QWidget):
self._frames = None self._frames = None
self._tracks = None self._tracks = None
self._worker = None self._worker = None
self._processed_frames = 0
self._errorlabel = QLabel() self._errorlabel = QLabel()
self._errorlabel.setStyleSheet("QLabel { color : red; }") self._errorlabel.setStyleSheet("QLabel { color : red; }")
self._assignedlabel = QLabel() self._assignedlabel = QLabel()
self._maxframeslabel = QLabel()
self._startframe_spinner = QSpinBox() self._startframe_spinner = QSpinBox()
self._startbtn = QPushButton("run") self._startbtn = QPushButton("start")
self._startbtn.clicked.connect(self.run) self._startbtn.clicked.connect(self.start)
self._startbtn.setEnabled(False) self._startbtn.setEnabled(False)
self._cancelbtn = QPushButton("cancel") self._stopbtn = QPushButton("stop")
self._cancelbtn.clicked.connect(self.cancel) self._stopbtn.clicked.connect(self.stop)
self._cancelbtn.setEnabled(False) self._stopbtn.setEnabled(False)
self._proceedbtn = QPushButton("proceed")
self._proceedbtn.clicked.connect(self.proceed)
self._proceedbtn.setEnabled(False)
self._refreshbtn = QPushButton("refresh")
self._refreshbtn.clicked.connect(self.refresh)
self._refreshbtn.setEnabled(True)
self._apply_btn = QPushButton("apply") self._apply_btn = QPushButton("apply")
self._apply_btn.clicked.connect(lambda: self.apply.emit())
self._apply_btn.setEnabled(False)
self._progressbar = QProgressBar() self._progressbar = QProgressBar()
self._progressbar.setMinimum(0) self._progressbar.setMinimum(0)
self._progressbar.setMaximum(100) self._progressbar.setMaximum(100)
self._apply_btn.clicked.connect(lambda: self.apply.emit()) self._stoponerror = QCheckBox("Stop processing whenever an error is encountered")
self._apply_btn.setEnabled(False) self._stoponerror.setToolTip("Stop process whenever ")
self._stoponerror.setCheckable(True)
self._stoponerror.setChecked(True)
self.threadpool = QThreadPool() self.threadpool = QThreadPool()
lyt = QGridLayout() lyt = QGridLayout()
lyt.addWidget(QLabel("Start frame:"), 0, 0 ) lyt.addWidget(QLabel("Start frame:"), 0, 0 )
lyt.addWidget(self._startframe_spinner, 0, 1 ) lyt.addWidget(self._startframe_spinner, 0, 1, 1, 2)
lyt.addWidget(QLabel("assigned"), 1, 0) lyt.addWidget(QLabel("of"), 1, 1, 1, 1)
lyt.addWidget(self._assignedlabel, 1, 1) lyt.addWidget(self._maxframeslabel, 1, 2, 1, 1)
lyt.addWidget(QLabel("errors/issues"), 2, 0) lyt.addWidget(self._stoponerror, 2, 0, 1, 3)
lyt.addWidget(self._errorlabel, 2, 1) lyt.addWidget(QLabel("assigned"), 3, 0)
lyt.addWidget(self._assignedlabel, 3, 1)
lyt.addWidget(self._startbtn, 3, 0) lyt.addWidget(QLabel("errors/issues"), 4, 0)
lyt.addWidget(self._cancelbtn, 3, 1) lyt.addWidget(self._errorlabel, 4, 1)
lyt.addWidget(self._progressbar, 4, 0, 1, 2)
lyt.addWidget(self._apply_btn, 5, 0, 1, 2) lyt.addWidget(self._startbtn, 5, 0)
lyt.addWidget(self._stopbtn, 5, 1)
lyt.addWidget(self._proceedbtn, 5, 2)
lyt.addWidget(self._apply_btn, 6, 0, 1, 2)
lyt.addWidget(self._refreshbtn, 6, 2, 1, 1)
lyt.addWidget(self._progressbar, 7, 0, 1, 3)
self.setLayout(lyt) self.setLayout(lyt)
def setData(self, data:TrackingData): def setData(self, data:TrackingData):
@ -344,19 +373,23 @@ class ConsistencyClassifier(QWidget):
data : Trackingdata data : Trackingdata
The tracking data. The tracking data.
""" """
self._data = data
self._all_cogs = data.centerOfGravity() self._all_pos = data.centerOfGravity()
self._all_orientations = data.orientation() self._all_orientations = data.orientation()
self._all_lengths = data.animalLength() self._all_lengths = data.animalLength()
self._all_bendedness = data.bendedness() self._all_bendedness = data.bendedness()
self._all_scores = data["confidence"] # ignore for now, let's see how far this carries. self._all_scores = data["confidence"] # ignore for now, let's see how far this carries.
self._frames = data["frame"] self._frames = data["frame"]
self._tracks = data["track"] self._tracks = data["track"]
self._maxframes = np.max(self._frames)
min_frame = max([self._frames[self._tracks == 1][0], self._frames[self._tracks == 2][0]]) + 1 min_frame = max([self._frames[self._tracks == 1][0], self._frames[self._tracks == 2][0]]) + 1
self._maxframeslabel.setText(str(self._maxframes))
self._startframe_spinner.setMinimum(min_frame) self._startframe_spinner.setMinimum(min_frame)
self._startframe_spinner.setMaximum(self._frames[-1]) self._startframe_spinner.setMaximum(self._frames[-1])
self._startframe_spinner.setValue(self._frames[0] + 1) self._startframe_spinner.setValue(self._frames[0] + 1)
self._startbtn.setEnabled(True) self._startbtn.setEnabled(True)
self._assignedlabel.setText("0")
self._errorlabel.setText("0")
self._worker = None self._worker = None
@Slot(float) @Slot(float)
@ -364,30 +397,44 @@ class ConsistencyClassifier(QWidget):
if self._progressbar is not None: if self._progressbar is not None:
self._progressDialog.setValue(int(value * 100)) self._progressDialog.setValue(int(value * 100))
def cancel(self): def stop(self):
if self._worker is not None: if self._worker is not None:
self._worker.cancel() self._worker.stop()
self._startbtn.setEnabled(True) self._startbtn.setEnabled(True)
self._cancelbtn.setEnabled(False) self._proceedbtn.setEnabled(True)
self._stopbtn.setEnabled(False)
self._refreshbtn.setEnabled(True)
def run(self): def start(self):
self._startbtn.setEnabled(False) self._startbtn.setEnabled(False)
self._cancelbtn.setEnabled(True) self._refreshbtn.setEnabled(False)
self._worker = ConsistencyWorker(self._all_cogs, self._all_orientations, self._all_lengths, self._stopbtn.setEnabled(True)
self._all_bendedness, self._frames, self._tracks, self._startframe_spinner.value()) self._worker = ConsistencyWorker(self._all_pos, self._all_orientations, self._all_lengths,
self._worker.signals.finished.connect(self.worker_done) self._all_bendedness, self._frames, self._tracks,
self._startframe_spinner.value(), self._stoponerror.isChecked())
self._worker.signals.stopped.connect(self.worker_stopped)
self._worker.signals.progress.connect(self.worker_progress) self._worker.signals.progress.connect(self.worker_progress)
self.threadpool.start(self._worker) self.threadpool.start(self._worker)
def proceed(self):
self.start()
def refresh(self):
self.setData(self._data)
def worker_progress(self, progress, processed, errors): def worker_progress(self, progress, processed, errors):
self._progressbar.setValue(progress) self._progressbar.setValue(progress)
self._errorlabel.setText(str(errors)) self._errorlabel.setText(str(errors))
self._assignedlabel.setText(str(processed)) self._assignedlabel.setText(str(processed))
def worker_done(self): def worker_stopped(self, frame):
self._apply_btn.setEnabled(True) self._apply_btn.setEnabled(True)
self._startbtn.setEnabled(True) self._startbtn.setEnabled(True)
self._cancelbtn.setEnabled(False) self._stopbtn.setEnabled(False)
self._startframe_spinner.setValue(frame-1)
self._proceedbtn.setEnabled(bool(frame < self._maxframes-1))
self._refreshbtn.setEnabled(True)
self._processed_frames = frame
def assignedTracks(self): def assignedTracks(self):
return self._tracks return self._tracks
@ -441,7 +488,7 @@ def main():
from IPython import embed from IPython import embed
from fixtracks.info import PACKAGE_ROOT from fixtracks.info import PACKAGE_ROOT
datafile = PACKAGE_ROOT / "data/merged_small_tracked.pkl" datafile = PACKAGE_ROOT / "data/merged_small.pkl"
with open(datafile, "rb") as f: with open(datafile, "rb") as f:
df = pickle.load(f) df = pickle.load(f)