diff --git a/fixtracks/widgets/classifier.py b/fixtracks/widgets/classifier.py index 8ed57fc..758c228 100644 --- a/fixtracks/widgets/classifier.py +++ b/fixtracks/widgets/classifier.py @@ -2,7 +2,7 @@ import logging import numpy as np from PySide6.QtWidgets import QWidget, QVBoxLayout, QTabWidget, QPushButton, QGraphicsView, QTextEdit -from PySide6.QtWidgets import QSpinBox, QProgressBar, QGridLayout, QLabel, QCheckBox, QProgressDialog +from PySide6.QtWidgets import QSpinBox, QProgressBar, QGridLayout, QLabel, QCheckBox, QDoubleSpinBox from PySide6.QtCore import Qt, Signal, Slot, QRunnable, QObject, QThreadPool from PySide6.QtGui import QBrush, QColor @@ -13,16 +13,17 @@ from fixtracks.utils.trackingdata import TrackingData from IPython import embed class Detection(): - def __init__(self, id, frame, track, position, orientation, length, userlabeled): + def __init__(self, id, frame, track, position, orientation, length, userlabeled, confidence): self.id = id self.frame = frame self.track = track self.position = position - self.score = 0.0 + self.confidence = confidence self.angle = orientation self.length = length self.userlabeled = userlabeled + class WorkerSignals(QObject): message = Signal(str) running = Signal(bool) @@ -30,7 +31,8 @@ class WorkerSignals(QObject): currentframe = Signal(int) stopped = Signal(int) -class ConsitencyDataLoader(QRunnable): + +class ConsistencyDataLoader(QRunnable): def __init__(self, data): super().__init__() self.signals = WorkerSignals() @@ -40,7 +42,7 @@ class ConsitencyDataLoader(QRunnable): self.lengths = None self.orientations = None self.userlabeled = None - self.scores = None + self.confidence = None self.frames = None self.tracks = None @@ -54,15 +56,16 @@ class ConsitencyDataLoader(QRunnable): self.lengths = self.data.animalLength() # self.bendedness = self.data.bendedness() self.userlabeled = self.data["userlabeled"] - self.scores = self.data["confidence"] # ignore for now, let's see how far this carries. + self.confidence = self.data["confidence"] # ignore for now, let's see how far this carries. self.frames = self.data["frame"] self.tracks = self.data["track"] self.signals.stopped.emit(0) + class ConsistencyWorker(QRunnable): def __init__(self, positions, orientations, lengths, bendedness, frames, tracks, - userlabeled, startframe=0, stoponerror=False) -> None: + userlabeled, confidence, startframe=0, stoponerror=False, min_confidence=0.0) -> None: super().__init__() self.signals = WorkerSignals() self.positions = positions @@ -70,6 +73,8 @@ class ConsistencyWorker(QRunnable): self.lengths = lengths self.bendedness = bendedness self.userlabeled = userlabeled + self.confidence = confidence + self._min_confidence = min_confidence self.frames = frames self.tracks = tracks self._startframe = startframe @@ -88,9 +93,11 @@ class ConsistencyWorker(QRunnable): if np.any(self.positions[i] < 0.1): logging.debug("Encountered probably invalid position %s", str(self.positions[i])) continue + if self._min_confidence > 0.0 and self.confidence[i] < self._min_confidence: + continue d = Detection(i, frame, self.tracks[i], self.positions[i], self.orientations[i], self.lengths[i], - self.userlabeled[i]) + self.userlabeled[i], self.confidence[i]) detections.append(d) return detections @@ -127,6 +134,10 @@ class ConsistencyWorker(QRunnable): return most_likely_track, length_differences def check_multiple_detections(detections): + if self._min_confidence > 0.0: + for i, d in detections: + if d.confidence < self._min_confidence: + del detections[i] distances = np.zeros((len(detections), len(detections))) for i, d1 in enumerate(detections): for j, d2 in enumerate(detections): @@ -139,9 +150,11 @@ class ConsistencyWorker(QRunnable): t1index = np.where((self.frames < startframe) & (self.userlabeled) & (self.tracks == 1))[0][-1] t2index = np.where((self.frames < startframe) & (self.userlabeled) & (self.tracks == 2))[0][-1] d1 = Detection(t1index, self.frames[t1index], self.tracks[t1index], self.positions[t1index], - self.orientations[t1index], self.lengths[t1index], self.userlabeled[t1index]) + self.orientations[t1index], self.lengths[t1index], self.userlabeled[t1index], + self.confidence[t1index]) d2 = Detection(t1index, self.frames[t2index], self.tracks[t2index], self.positions[t2index], - self.orientations[t2index], self.lengths[t2index], self.userlabeled[t2index]) + self.orientations[t2index], self.lengths[t2index], self.userlabeled[t2index], + self.confidence[t1index]) last_detections[1] = d1 last_detections[2] = d2 @@ -337,6 +350,7 @@ class SizeClassifier(QWidget): tracks[(self._sizes >= t2lower) & (self._sizes < t2upper)] = 2 return tracks + class NeighborhoodValidator(QWidget): apply = Signal() name = "Neighborhood Validator" @@ -506,30 +520,38 @@ class ConsistencyClassifier(QWidget): self._stoponerror.setChecked(True) self.threadpool = QThreadPool() + self._ignore_confidence = QCheckBox("Ignore detections widh confidence less than") + self._confidence_spinner = QDoubleSpinBox() + self._confidence_spinner.setRange(0.0, 1.0) + self._confidence_spinner.setSingleStep(0.01) + self._confidence_spinner.setDecimals(2) + self._confidence_spinner.setValue(0.5) self._messagebox = QTextEdit() self._messagebox.setFocusPolicy(Qt.NoFocus) self._messagebox.setReadOnly(True) lyt = QGridLayout() lyt.addWidget(QLabel("Start frame:"), 0, 0 ) - lyt.addWidget(self._startframe_spinner, 0, 1, 1, 2) - lyt.addWidget(QLabel("of"), 1, 1, 1, 1) - lyt.addWidget(self._maxframeslabel, 1, 2, 1, 1) - lyt.addWidget(self._stoponerror, 2, 0, 1, 3) - lyt.addWidget(QLabel("Current frame"), 3,0) - lyt.addWidget(self._framelabel, 3,1) - lyt.addWidget(QLabel("assigned"), 4, 0) - lyt.addWidget(self._assignedlabel, 4, 1) - lyt.addWidget(QLabel("errors/issues"), 5, 0) - lyt.addWidget(self._errorlabel, 5, 1) - lyt.addWidget(self._messagebox, 6, 0, 2, 3) - - lyt.addWidget(self._startbtn, 8, 0) - lyt.addWidget(self._stopbtn, 8, 1) - lyt.addWidget(self._proceedbtn, 8, 2) - lyt.addWidget(self._apply_btn, 9, 0, 1, 2) - lyt.addWidget(self._refreshbtn, 9, 2, 1, 1) - lyt.addWidget(self._progressbar, 10, 0, 1, 3) + lyt.addWidget(self._startframe_spinner, 0, 1, 1, 1) + lyt.addWidget(QLabel("of"), 0, 2, 1, 1) + lyt.addWidget(self._maxframeslabel, 0, 3, 1, 1) + lyt.addWidget(self._stoponerror, 1, 0, 1, 3) + lyt.addWidget(self._ignore_confidence, 3, 0, 1, 3) + lyt.addWidget(self._confidence_spinner, 3, 3, 1, 1) + lyt.addWidget(QLabel("Current frame"), 4, 0) + lyt.addWidget(self._framelabel, 4, 1) + lyt.addWidget(QLabel("(Re-)Assigned"), 5, 0) + lyt.addWidget(self._assignedlabel, 5, 1) + lyt.addWidget(QLabel("Errors/issues"), 5, 2) + lyt.addWidget(self._errorlabel, 5, 3, 1, 1) + lyt.addWidget(self._messagebox, 6, 0, 2, 4) + + lyt.addWidget(self._startbtn, 8, 0, 1, 2) + lyt.addWidget(self._stopbtn, 8, 2) + # lyt.addWidget(self._proceedbtn, 8, 2) + lyt.addWidget(self._refreshbtn, 8, 3, 1, 1) + lyt.addWidget(self._apply_btn, 9, 0, 1, 4) + lyt.addWidget(self._progressbar, 10, 0, 1, 4) self.setLayout(lyt) def setData(self, data:TrackingData): @@ -554,7 +576,7 @@ class ConsistencyClassifier(QWidget): self._all_lengths = self._dataworker.lengths self._all_bendedness = self._dataworker.bendedness self._userlabeled = self._dataworker.userlabeled - self._all_scores = self._dataworker.scores + self._confidence = self._dataworker.confidence self._frames = self._dataworker.frames self._tracks = self._dataworker.tracks self._dataworker = None @@ -597,12 +619,14 @@ class ConsistencyClassifier(QWidget): self._messagebox.append("Stopping tracking.") def start(self): + confidence_level = self._confidence_spinner.value() if self._ignore_confidence.isChecked() else 0.0 self._startbtn.setEnabled(False) self._refreshbtn.setEnabled(False) self._stopbtn.setEnabled(True) self._worker = ConsistencyWorker(self._all_pos, self._all_orientations, self._all_lengths, self._all_bendedness, self._frames, self._tracks, self._userlabeled, - self._startframe_spinner.value(), self._stoponerror.isChecked()) + self._confidence, self._startframe_spinner.value(), self._stoponerror.isChecked(), + min_confidence=confidence_level) self._worker.signals.stopped.connect(self.worker_stopped) self._worker.signals.progress.connect(self.worker_progress) self._worker.signals.message.connect(self.worker_error) @@ -621,7 +645,7 @@ class ConsistencyClassifier(QWidget): def refresh(self): self.setEnabled(False) - self._dataworker = ConsitencyDataLoader(self._data) + self._dataworker = ConsistencyDataLoader(self._data) self._dataworker.signals.stopped.connect(self.data_processed) self._messagebox.clear() self._messagebox.append("Refreshing...") diff --git a/fixtracks/widgets/detectiontimeline.py b/fixtracks/widgets/detectiontimeline.py index 758c9a6..2f785ee 100644 --- a/fixtracks/widgets/detectiontimeline.py +++ b/fixtracks/widgets/detectiontimeline.py @@ -151,6 +151,8 @@ class DetectionTimeline(QWidget): self._position_label.setFont(f) layout = QVBoxLayout() + layout.setSpacing(0) + layout.setContentsMargins(5, 2, 5, 2) layout.addWidget(self._view) layout.addWidget(self._position_label, Qt.AlignmentFlag.AlignRight) self.setLayout(layout) @@ -310,8 +312,7 @@ def main(): datafile = PACKAGE_ROOT / "data/merged_small.pkl" with open(datafile, "rb") as f: df = pickle.load(f) - data = TrackingData() - data.setData(as_dict(df)) + data = TrackingData(as_dict(df)) data.setSelection(np.arange(0,100, 1)) data.setUserLabeledStatus(True) start_x = 0.1 @@ -330,12 +331,14 @@ def main(): backBtn.clicked.connect(lambda: back(0.2)) btnLyt = QHBoxLayout() + btnLyt.setSpacing(1) btnLyt.addWidget(backBtn) btnLyt.addWidget(zeroBtn) btnLyt.addWidget(fwdBtn) view.setWindowPos(start_x) layout = QVBoxLayout() + layout.setSpacing(1) layout.addWidget(view) layout.addLayout(btnLyt) window.setLayout(layout) diff --git a/fixtracks/widgets/detectionview.py b/fixtracks/widgets/detectionview.py index 059be63..4ee7e61 100644 --- a/fixtracks/widgets/detectionview.py +++ b/fixtracks/widgets/detectionview.py @@ -157,10 +157,8 @@ class DetectionView(QWidget): item.setData(DetectionData.ID.value, id) item.setData(DetectionData.COORDINATES.value, coordinates[i, :, :]) item.setData(DetectionData.FRAME.value, f) - # item.setData(DetectionData.USERLABELED.value, l) + item.setData(DetectionData.USERLABELED.value, l) item.setData(DetectionData.SCORE.value, s) - print(s) - print(item.data(DetectionData.SCORE.value)) item = self._scene.addItem(item) def fit_image_to_view(self): diff --git a/fixtracks/widgets/tracks.py b/fixtracks/widgets/tracks.py index 40badb2..b22c6cf 100644 --- a/fixtracks/widgets/tracks.py +++ b/fixtracks/widgets/tracks.py @@ -73,6 +73,7 @@ class FixTracks(QWidget): combo_layout.addWidget(self._gotoframe) combo_layout.addWidget(self._gotobtn) combo_layout.addItem(QSpacerItem(100, 10, QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Fixed)) + combo_layout.setSpacing(1) timelinebox = QVBoxLayout() timelinebox.setSpacing(2) @@ -112,6 +113,7 @@ class FixTracks(QWidget): data_selection_box.addWidget(QLabel("Select data file")) data_selection_box.addWidget(self._data_combo) data_selection_box.addItem(QSpacerItem(100, 10, QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Fixed)) + data_selection_box.setSpacing(0) btnBox = QHBoxLayout() btnBox.setAlignment(Qt.AlignmentFlag.AlignLeft) @@ -129,8 +131,12 @@ class FixTracks(QWidget): cntrlBox.addWidget(self._controls_widget, alignment=Qt.AlignmentFlag.AlignCenter) cntrlBox.addWidget(self._skeleton) cntrlBox.addItem(QSpacerItem(50, 100, QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Expanding)) + cntrlBox.setSpacing(0) + cntrlBox.setContentsMargins(0,0,0,0) vbox = QVBoxLayout() + vbox.setSpacing(0) + vbox.setContentsMargins(0,0,0,0) vbox.addLayout(timelinebox) vbox.addLayout(cntrlBox) vbox.addLayout(btnBox) @@ -142,9 +148,12 @@ class FixTracks(QWidget): splitter.addWidget(container) splitter.setStretchFactor(0, 3) splitter.setStretchFactor(1, 1) + layout = QVBoxLayout() layout.addLayout(data_selection_box) layout.addWidget(splitter) + layout.setSpacing(0) + layout.setContentsMargins(5,2,2,5) self.setLayout(layout) def on_autoClassify(self, tracks): @@ -364,7 +373,7 @@ class FixTracks(QWidget): tracks = np.zeros(len(detections), dtype=int) ids = np.zeros_like(tracks) frames = np.zeros_like(tracks) - scores = np.zeros_like(tracks) + scores = np.zeros(tracks.shape, dtype=float) coordinates = None if len(detections) > 0: c = detections[0].data(DetectionData.COORDINATES.value) @@ -376,7 +385,6 @@ class FixTracks(QWidget): frames[i] = d.data(DetectionData.FRAME.value) coordinates[i, :, :] = d.data(DetectionData.COORDINATES.value) scores[i] = d.data(DetectionData.SCORE.value) - print(scores[i]) self._data.setSelection(ids) self._controls_widget.setSelectedTracks(tracks) self._skeleton.clear()