[tracks] include classifier widget

This commit is contained in:
Jan Grewe 2025-02-07 15:57:35 +01:00
parent 796e03ae7e
commit 15cee494f6

View File

@ -14,6 +14,8 @@ from fixtracks.utils.writer import PickleWriter
from fixtracks.widgets.detectionview import DetectionView, DetectionData
from fixtracks.widgets.detectiontimeline import DetectionTimeline
from fixtracks.widgets.skeleton import SkeletonWidget
from fixtracks.widgets.classifier import ClassifierWidget
class PoseTableModel(QAbstractTableModel):
column_header = ["frame", "track"]
@ -259,6 +261,10 @@ class DataController(QObject):
logging.error("Column %s not in dictionary", col)
return np.nan
@property
def numDetections(self):
return self._data["track"].shape[0]
@property
def selectionRange(self):
return self._start, self._stop
@ -286,6 +292,12 @@ class DataController(QObject):
def assignUserSelection(self, track_id):
self._data["track"][self._user_selections] = track_id
def assignTracks(self, tracks):
if len(tracks) != self.numDetections:
logging.error("DataController: Size of passed tracks does not match data!")
return
self._data["track"] = tracks
def save(self, filename):
export_columns = self._columns.copy()
export_columns.remove("index")
@ -299,6 +311,8 @@ class DataController(QObject):
return 0
return self._data["keypoints"][0].shape[0]
def coordinates(self):
return np.stack(self._data["keypoints"]).astype(np.float32)
class FixTracks(QWidget):
back = Signal()
@ -391,9 +405,13 @@ class FixTracks(QWidget):
btnBox.addWidget(self._progress_bar)
btnBox.addWidget(self._saveBtn)
self._classifier = ClassifierWidget()
self._classifier.apply_sizeclassifier.connect(self.on_classifyBySize)
self._classifier.setMaximumWidth(500)
cntrlBox = QHBoxLayout()
cntrlBox.addItem(QSpacerItem(200, 100, QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Expanding))
cntrlBox.addWidget(self._classifier)
cntrlBox.addWidget(self._controls_widget, alignment=Qt.AlignmentFlag.AlignCenter)
cntrlBox.addItem(QSpacerItem(300, 100, QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Expanding))
vbox = QVBoxLayout()
vbox.addLayout(timelinebox)
@ -412,6 +430,12 @@ class FixTracks(QWidget):
layout.addWidget(splitter)
self.setLayout(layout)
def on_classifyBySize(self, tracks):
self._data.setSelectionRange("index", 0, self._data.numDetections)
self._data.assignTracks(tracks)
self._timeline.setDetectionData(self._data.data)
self.update()
def on_dataSelection(self):
filename = self._data_combo.currentText()
if "please select" in filename.lower() or len(filename.strip()) == 0:
@ -509,6 +533,8 @@ class FixTracks(QWidget):
maxframes = self._data.max("frame")
rel_width = self._windowspinner.value() / maxframes
self._timeline.setWindowWidth(rel_width)
coordinates = self._data.coordinates()
self._classifier.size_classifier.setCoordinates(coordinates)
self.update()
self._saveBtn.setEnabled(True)