diff --git a/fixtracks/widgets/tracks.py b/fixtracks/widgets/tracks.py index feb798f..2e48f58 100644 --- a/fixtracks/widgets/tracks.py +++ b/fixtracks/widgets/tracks.py @@ -14,6 +14,8 @@ from fixtracks.utils.writer import PickleWriter from fixtracks.widgets.detectionview import DetectionView, DetectionData from fixtracks.widgets.detectiontimeline import DetectionTimeline from fixtracks.widgets.skeleton import SkeletonWidget +from fixtracks.widgets.classifier import ClassifierWidget + class PoseTableModel(QAbstractTableModel): column_header = ["frame", "track"] @@ -259,6 +261,10 @@ class DataController(QObject): logging.error("Column %s not in dictionary", col) return np.nan + @property + def numDetections(self): + return self._data["track"].shape[0] + @property def selectionRange(self): return self._start, self._stop @@ -286,6 +292,12 @@ class DataController(QObject): def assignUserSelection(self, track_id): self._data["track"][self._user_selections] = track_id + def assignTracks(self, tracks): + if len(tracks) != self.numDetections: + logging.error("DataController: Size of passed tracks does not match data!") + return + self._data["track"] = tracks + def save(self, filename): export_columns = self._columns.copy() export_columns.remove("index") @@ -299,6 +311,8 @@ class DataController(QObject): return 0 return self._data["keypoints"][0].shape[0] + def coordinates(self): + return np.stack(self._data["keypoints"]).astype(np.float32) class FixTracks(QWidget): back = Signal() @@ -391,9 +405,13 @@ class FixTracks(QWidget): btnBox.addWidget(self._progress_bar) btnBox.addWidget(self._saveBtn) + self._classifier = ClassifierWidget() + self._classifier.apply_sizeclassifier.connect(self.on_classifyBySize) + self._classifier.setMaximumWidth(500) cntrlBox = QHBoxLayout() - cntrlBox.addItem(QSpacerItem(200, 100, QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Expanding)) + cntrlBox.addWidget(self._classifier) cntrlBox.addWidget(self._controls_widget, alignment=Qt.AlignmentFlag.AlignCenter) + cntrlBox.addItem(QSpacerItem(300, 100, QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Expanding)) vbox = QVBoxLayout() vbox.addLayout(timelinebox) @@ -412,6 +430,12 @@ class FixTracks(QWidget): layout.addWidget(splitter) self.setLayout(layout) + def on_classifyBySize(self, tracks): + self._data.setSelectionRange("index", 0, self._data.numDetections) + self._data.assignTracks(tracks) + self._timeline.setDetectionData(self._data.data) + self.update() + def on_dataSelection(self): filename = self._data_combo.currentText() if "please select" in filename.lower() or len(filename.strip()) == 0: @@ -509,6 +533,8 @@ class FixTracks(QWidget): maxframes = self._data.max("frame") rel_width = self._windowspinner.value() / maxframes self._timeline.setWindowWidth(rel_width) + coordinates = self._data.coordinates() + self._classifier.size_classifier.setCoordinates(coordinates) self.update() self._saveBtn.setEnabled(True)