[tracks] include classifier widget

This commit is contained in:
Jan Grewe 2025-02-07 15:57:35 +01:00
parent 796e03ae7e
commit 15cee494f6

View File

@ -14,6 +14,8 @@ from fixtracks.utils.writer import PickleWriter
from fixtracks.widgets.detectionview import DetectionView, DetectionData from fixtracks.widgets.detectionview import DetectionView, DetectionData
from fixtracks.widgets.detectiontimeline import DetectionTimeline from fixtracks.widgets.detectiontimeline import DetectionTimeline
from fixtracks.widgets.skeleton import SkeletonWidget from fixtracks.widgets.skeleton import SkeletonWidget
from fixtracks.widgets.classifier import ClassifierWidget
class PoseTableModel(QAbstractTableModel): class PoseTableModel(QAbstractTableModel):
column_header = ["frame", "track"] column_header = ["frame", "track"]
@ -259,6 +261,10 @@ class DataController(QObject):
logging.error("Column %s not in dictionary", col) logging.error("Column %s not in dictionary", col)
return np.nan return np.nan
@property
def numDetections(self):
return self._data["track"].shape[0]
@property @property
def selectionRange(self): def selectionRange(self):
return self._start, self._stop return self._start, self._stop
@ -286,6 +292,12 @@ class DataController(QObject):
def assignUserSelection(self, track_id): def assignUserSelection(self, track_id):
self._data["track"][self._user_selections] = track_id self._data["track"][self._user_selections] = track_id
def assignTracks(self, tracks):
if len(tracks) != self.numDetections:
logging.error("DataController: Size of passed tracks does not match data!")
return
self._data["track"] = tracks
def save(self, filename): def save(self, filename):
export_columns = self._columns.copy() export_columns = self._columns.copy()
export_columns.remove("index") export_columns.remove("index")
@ -299,6 +311,8 @@ class DataController(QObject):
return 0 return 0
return self._data["keypoints"][0].shape[0] return self._data["keypoints"][0].shape[0]
def coordinates(self):
return np.stack(self._data["keypoints"]).astype(np.float32)
class FixTracks(QWidget): class FixTracks(QWidget):
back = Signal() back = Signal()
@ -391,9 +405,13 @@ class FixTracks(QWidget):
btnBox.addWidget(self._progress_bar) btnBox.addWidget(self._progress_bar)
btnBox.addWidget(self._saveBtn) btnBox.addWidget(self._saveBtn)
self._classifier = ClassifierWidget()
self._classifier.apply_sizeclassifier.connect(self.on_classifyBySize)
self._classifier.setMaximumWidth(500)
cntrlBox = QHBoxLayout() cntrlBox = QHBoxLayout()
cntrlBox.addItem(QSpacerItem(200, 100, QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Expanding)) cntrlBox.addWidget(self._classifier)
cntrlBox.addWidget(self._controls_widget, alignment=Qt.AlignmentFlag.AlignCenter) cntrlBox.addWidget(self._controls_widget, alignment=Qt.AlignmentFlag.AlignCenter)
cntrlBox.addItem(QSpacerItem(300, 100, QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Expanding))
vbox = QVBoxLayout() vbox = QVBoxLayout()
vbox.addLayout(timelinebox) vbox.addLayout(timelinebox)
@ -412,6 +430,12 @@ class FixTracks(QWidget):
layout.addWidget(splitter) layout.addWidget(splitter)
self.setLayout(layout) self.setLayout(layout)
def on_classifyBySize(self, tracks):
self._data.setSelectionRange("index", 0, self._data.numDetections)
self._data.assignTracks(tracks)
self._timeline.setDetectionData(self._data.data)
self.update()
def on_dataSelection(self): def on_dataSelection(self):
filename = self._data_combo.currentText() filename = self._data_combo.currentText()
if "please select" in filename.lower() or len(filename.strip()) == 0: if "please select" in filename.lower() or len(filename.strip()) == 0:
@ -509,6 +533,8 @@ class FixTracks(QWidget):
maxframes = self._data.max("frame") maxframes = self._data.max("frame")
rel_width = self._windowspinner.value() / maxframes rel_width = self._windowspinner.value() / maxframes
self._timeline.setWindowWidth(rel_width) self._timeline.setWindowWidth(rel_width)
coordinates = self._data.coordinates()
self._classifier.size_classifier.setCoordinates(coordinates)
self.update() self.update()
self._saveBtn.setEnabled(True) self._saveBtn.setEnabled(True)