From a8fd5375f2a65bd0c87e6e64a3e026a3b170537c Mon Sep 17 00:00:00 2001 From: Jan Grewe Date: Mon, 17 Feb 2025 18:20:25 +0100 Subject: [PATCH] [classifier] auto distance classifier --- fixtracks/utils/reader.py | 3 - fixtracks/widgets/classifier.py | 243 +++++++++++++++++++++++++++----- fixtracks/widgets/tracks.py | 9 +- 3 files changed, 209 insertions(+), 46 deletions(-) diff --git a/fixtracks/utils/reader.py b/fixtracks/utils/reader.py index 27be7a2..c4bbe60 100644 --- a/fixtracks/utils/reader.py +++ b/fixtracks/utils/reader.py @@ -20,9 +20,6 @@ class ImageReader(QRunnable): @Slot() def run(self): - ''' - Your code goes in this function - ''' logging.debug("ImageReader: trying to open file %s", self._filename) cap = cv.VideoCapture(self._filename) framecount = int(cap.get(cv.CAP_PROP_FRAME_COUNT)) diff --git a/fixtracks/widgets/classifier.py b/fixtracks/widgets/classifier.py index 348788d..e664515 100644 --- a/fixtracks/widgets/classifier.py +++ b/fixtracks/widgets/classifier.py @@ -1,13 +1,101 @@ import logging import numpy as np -from PySide6.QtWidgets import QWidget, QVBoxLayout, QTabWidget,QPushButton, QGraphicsView -from PySide6.QtCore import Signal +from PySide6.QtWidgets import QWidget, QVBoxLayout, QTabWidget, QPushButton, QGraphicsView, QSpinBox, QProgressBar, QGridLayout, QLabel +from PySide6.QtCore import Signal, Slot, QRunnable, QObject, QThreadPool from PySide6.QtGui import QBrush, QColor import pyqtgraph as pg # needs to be imported after pyside to not import pyqt from fixtracks.utils.trackingdata import TrackingData +from IPython import embed +class WorkerSignals(QObject): + error = Signal(str) + running = Signal(bool) + progress = Signal(int, int, int) + finished = Signal(bool) + +class ConsistencyWorker(QRunnable): + signals = WorkerSignals() + + def __init__(self, positions, orientations, lengths, bendedness, frames, tracks, startframe=0) -> None: + super().__init__() + self.positions = positions + self.orientations = orientations + self.lengths = lengths + self._bendedness = bendedness + self.frames = frames + self.tracks = tracks + self._startframe = startframe + self._stoprequest = False + + @Slot() + def cancel(self): + self._stoprequest = True + + @Slot() + def run(self): + last_pos = [self.positions[self.tracks == 1][0], self.positions[self.tracks == 2][0]] + last_frame = [self.frames[self.tracks == 1][0], self.frames[self.tracks == 2][0]] + last_angle = [self.orientations[self.tracks == 1][0], self.orientations[self.tracks == 2][0]] + errors = 0 + processed = 0 + self._stoprequest = False + maxframes = np.max(self.frames) + steps = int((maxframes - self._startframe) // 100) + progress = 0 + assignment_error = False + for f in range(self._startframe, np.max(self.frames), 1): + processed += 1 + if self._stoprequest: + break + indices = np.where(self.frames == f)[0] + pp = self.positions[indices] + originaltracks = self.tracks[indices] + assignments = np.zeros_like(originaltracks) + distances = np.zeros((len(originaltracks), 2)) + for i, (idx, p) in enumerate(zip(indices, pp)): + if f < last_frame[0]: + self.tracks[idx] = 2 + last_frame[1] = f + last_pos[1] = p + last_angle[1] = self.orientations[idx] + continue + if f < last_frame[1]: + last_frame[0] = f + last_pos[0] = p + last_angle[0] = self.orientations[idx] + self.tracks[idx] = 1 + continue + # else, we have already seen track one and track two entries + distance_to_trackone = np.linalg.norm(p - last_pos[0])/(f - last_frame[0]) + distance_to_tracktwo = np.linalg.norm(p - last_pos[1])/(f - last_frame[1]) + most_likely_track = np.argmin([distance_to_trackone, distance_to_tracktwo]) + 1 + distances[i, 0] = distance_to_trackone + distances[i, 1] = distance_to_tracktwo + assignments[i] = most_likely_track + # check (re) assignment update and proceed + if len(assignments) > 1 and (np.all(assignments == 1) or np.all(assignments == 2)): + logging.warning("frame %i: Issues assigning based on distances %s", f, str(distances)) + assignment_error = True + errors += 1 + else: + processed += 1 + for i, idx in enumerate(indices): + if assignment_error: + self.tracks[idx] = -1 + else: + self.tracks[idx] = assignments[i] + last_pos[assignments[i]-1] = pp[i] + last_frame[assignments[i]-1] = f + assignment_error = False + if f % steps == 0: + progress += 1 + self.signals.progress.emit(progress, processed, errors) + + self.signals.finished.emit(True) + + class SizeClassifier(QWidget): apply = Signal() @@ -202,63 +290,144 @@ class ConsistencyClassifier(QWidget): def __init__(self, parent=None): super().__init__(parent) + self._data = None + self._all_cogs = None + self._all_orientations = None + self._all_lengths = None + self._all_bendedness = None + self._all_scores = None + self._frames = None + self._tracks = None + self._worker = None + + self._errorlabel = QLabel() + self._errorlabel.setStyleSheet("QLabel { color : red; }") + self._assignedlabel = QLabel() + self._startframe_spinner = QSpinBox() + self._startbtn = QPushButton("run") + self._startbtn.clicked.connect(self.run) + self._startbtn.setEnabled(False) - def setData(self, keypoints, tracks, frames): + self._cancelbtn = QPushButton("cancel") + self._cancelbtn.clicked.connect(self.cancel) + self._cancelbtn.setEnabled(False) + self._apply_btn = QPushButton("apply") + + self._progressbar = QProgressBar() + self._progressbar.setMinimum(0) + self._progressbar.setMaximum(100) + + self._apply_btn.clicked.connect(lambda: self.apply.emit()) + self._apply_btn.setEnabled(False) + self.threadpool = QThreadPool() + + lyt = QGridLayout() + lyt.addWidget(QLabel("Start frame:"), 0, 0 ) + lyt.addWidget(self._startframe_spinner, 0, 1 ) + lyt.addWidget(QLabel("assigned"), 1, 0) + lyt.addWidget(self._assignedlabel, 1, 1) + lyt.addWidget(QLabel("errors/issues"), 2, 0) + lyt.addWidget(self._errorlabel, 2, 1) + + lyt.addWidget(self._startbtn, 3, 0) + lyt.addWidget(self._cancelbtn, 3, 1) + lyt.addWidget(self._progressbar, 4, 0, 1, 2) + lyt.addWidget(self._apply_btn, 5, 0, 1, 2) + self.setLayout(lyt) + + def setData(self, data:TrackingData): """Set the data, the classifier/should be working on. Parameters ---------- - positions : np.ndarray - The position estimates, e.g. the center of gravity for each detection - tracks : np.ndarray - The current track assignment. - frames : np.ndarray - respective frame. + data : Trackingdata + The tracking data. """ - def mouseClicked(event): - pos = event.pos() - if self._plot.sceneBoundingRect().contains(pos): - mousePoint = vb.mapSceneToView(pos) - print("mouse clicked at", mousePoint) - vLine.setPos(mousePoint.x()) - track2_brush = QBrush(QColor.fromString("green")) - track1_brush = QBrush(QColor.fromString("orange")) - self._positions = positions - self._tracks = tracks - self._frames = frames - t1_positions = self._positions[self._tracks == 1] - t1_frames = self._frames[self._tracks == 1] - t1_distances = self.neighborDistances(t1_positions, t1_frames, 1, False) - t2_positions = self._positions[self._tracks == 2] - t2_frames = self._frames[self._tracks == 2] - t2_distances = self.neighborDistances(t2_positions, t2_frames, 1, False) + self._all_cogs = data.centerOfGravity() + self._all_orientations = data.orientation() + self._all_lengths = data.animalLength() + self._all_bendedness = data.bendedness() + self._all_scores = data["confidence"] # ignore for now, let's see how far this carries. + self._frames = data["frame"] + self._tracks = data["track"] + min_frame = max([self._frames[self._tracks == 1][0], self._frames[self._tracks == 2][0]]) + 1 + self._startframe_spinner.setMinimum(min_frame) + self._startframe_spinner.setMaximum(self._frames[-1]) + self._startframe_spinner.setValue(self._frames[0] + 1) + self._startbtn.setEnabled(True) + self._worker = None + + @Slot(float) + def on_progress(self, value): + if self._progressbar is not None: + self._progressDialog.setValue(int(value * 100)) + + def cancel(self): + if self._worker is not None: + self._worker.cancel() + self._startbtn.setEnabled(True) + self._cancelbtn.setEnabled(False) + + def run(self): + self._startbtn.setEnabled(False) + self._cancelbtn.setEnabled(True) + self._worker = ConsistencyWorker(self._all_cogs, self._all_orientations, self._all_lengths, + self._all_bendedness, self._frames, self._tracks, self._startframe_spinner.value()) + self._worker.signals.finished.connect(self.worker_done) + self._worker.signals.progress.connect(self.worker_progress) + self.threadpool.start(self._worker) + + def worker_progress(self, progress, processed, errors): + self._progressbar.setValue(progress) + self._errorlabel.setText(str(errors)) + self._assignedlabel.setText(str(processed)) + + def worker_done(self): + self._apply_btn.setEnabled(True) + self._startbtn.setEnabled(True) + self._cancelbtn.setEnabled(False) + def assignedTracks(self): + return self._tracks class ClassifierWidget(QTabWidget): - apply_sizeclassifier = Signal(np.ndarray) + apply_classifier = Signal(np.ndarray) def __init__(self, parent=None): super().__init__(parent) + self._data = None self._size_classifier = SizeClassifier() - self._neigborhood_validator = NeighborhoodValidator() + # self._neigborhood_validator = NeighborhoodValidator() + self._consistency_tracker = ConsistencyClassifier() self.addTab(self._size_classifier, SizeClassifier.name) - self.addTab(self._neigborhood_validator, NeighborhoodValidator.name) + self.addTab(self._consistency_tracker, ConsistencyClassifier.name) + self.tabBarClicked.connect(self.update) self._size_classifier.apply.connect(self._on_applySizeClassifier) + self._consistency_tracker.apply.connect(self._on_applyConsistencyTracker) def _on_applySizeClassifier(self): tracks = self.size_classifier.assignedTracks() - self.apply_sizeclassifier.emit(tracks) + self.apply_classifier.emit(tracks) + + def _on_applyConsistencyTracker(self): + tracks = self._consistency_tracker.assignedTracks() + self.apply_classifier.emit(tracks) @property def size_classifier(self): return self._size_classifier @property - def neighborhood_validator(self): - return self._neigborhood_validator + def consistency_tracker(self): + return self._consistency_tracker + + def update(self): + self.consistency_tracker.setData(self._data) + def setData(self, data:TrackingData): + self._data = data def as_dict(df): d = {c: df[c].values for c in df.columns} @@ -269,8 +438,9 @@ def as_dict(df): def main(): test_size = False import pickle + from IPython import embed from fixtracks.info import PACKAGE_ROOT - + datafile = PACKAGE_ROOT / "data/merged_small_tracked.pkl" with open(datafile, "rb") as f: @@ -278,11 +448,6 @@ def main(): data = TrackingData() data.setData(as_dict(df)) - positions = data.centerOfGravity() - tracks = data["track"] - frames = data["frame"] - coords = data.coordinates() - app = QApplication([]) window = QWidget() window.setMinimumSize(200, 200) @@ -291,7 +456,7 @@ def main(): # win.setCoordinates(coords) # else: w = ClassifierWidget() - w.neighborhood_validator.setData(positions, tracks, frames) + w.setData(data) layout = QVBoxLayout() layout.addWidget(w) diff --git a/fixtracks/widgets/tracks.py b/fixtracks/widgets/tracks.py index a4705b4..7a50fe0 100644 --- a/fixtracks/widgets/tracks.py +++ b/fixtracks/widgets/tracks.py @@ -254,7 +254,7 @@ class FixTracks(QWidget): btnBox.addWidget(self._saveBtn) self._classifier = ClassifierWidget() - self._classifier.apply_sizeclassifier.connect(self.on_classifyBySize) + self._classifier.apply_classifier.connect(self.on_autoClassify) self._classifier.setMaximumWidth(500) cntrlBox = QHBoxLayout() cntrlBox.addWidget(self._classifier) @@ -278,7 +278,7 @@ class FixTracks(QWidget): layout.addWidget(splitter) self.setLayout(layout) - def on_classifyBySize(self, tracks): + def on_autoClassify(self, tracks): self._data.setSelectionRange("index", 0, self._data.numDetections) self._data.assignTracks(tracks) self._timeline.setDetectionData(self._data.data) @@ -333,6 +333,7 @@ class FixTracks(QWidget): update_detectionView(unassigned, "unassigned") update_detectionView(assigned_left, "assigned_left") update_detectionView(assigned_right, "assigned_right") + self._classifier.setData(self._data) @property def fileList(self): @@ -369,6 +370,7 @@ class FixTracks(QWidget): self._progress_bar.setValue(0) if state and self._reader is not None: self._data.setData(self._reader.asdict) + self._saveBtn.setEnabled(True) self._currentWindowPos = 0 self._currentWindowWidth = self._windowspinner.value() self._maxframes = self._data.max("frame") @@ -381,9 +383,8 @@ class FixTracks(QWidget): tracks = self._data["track"] frames = self._data["frame"] self._classifier.size_classifier.setCoordinates(coordinates) - self._classifier.neighborhood_validator.setData(positions, tracks, frames) + self._classifier.consistency_tracker.setData(self._data) self.update() - self._saveBtn.setEnabled(True) logging.info("Finished loading data: %i frames, %i detections", self._maxframes, len(positions)) def on_keypointSelected(self):