[tracks] include classifier widget
This commit is contained in:
parent
796e03ae7e
commit
15cee494f6
@ -14,6 +14,8 @@ from fixtracks.utils.writer import PickleWriter
|
||||
from fixtracks.widgets.detectionview import DetectionView, DetectionData
|
||||
from fixtracks.widgets.detectiontimeline import DetectionTimeline
|
||||
from fixtracks.widgets.skeleton import SkeletonWidget
|
||||
from fixtracks.widgets.classifier import ClassifierWidget
|
||||
|
||||
|
||||
class PoseTableModel(QAbstractTableModel):
|
||||
column_header = ["frame", "track"]
|
||||
@ -259,6 +261,10 @@ class DataController(QObject):
|
||||
logging.error("Column %s not in dictionary", col)
|
||||
return np.nan
|
||||
|
||||
@property
|
||||
def numDetections(self):
|
||||
return self._data["track"].shape[0]
|
||||
|
||||
@property
|
||||
def selectionRange(self):
|
||||
return self._start, self._stop
|
||||
@ -286,6 +292,12 @@ class DataController(QObject):
|
||||
def assignUserSelection(self, track_id):
|
||||
self._data["track"][self._user_selections] = track_id
|
||||
|
||||
def assignTracks(self, tracks):
|
||||
if len(tracks) != self.numDetections:
|
||||
logging.error("DataController: Size of passed tracks does not match data!")
|
||||
return
|
||||
self._data["track"] = tracks
|
||||
|
||||
def save(self, filename):
|
||||
export_columns = self._columns.copy()
|
||||
export_columns.remove("index")
|
||||
@ -299,6 +311,8 @@ class DataController(QObject):
|
||||
return 0
|
||||
return self._data["keypoints"][0].shape[0]
|
||||
|
||||
def coordinates(self):
|
||||
return np.stack(self._data["keypoints"]).astype(np.float32)
|
||||
|
||||
class FixTracks(QWidget):
|
||||
back = Signal()
|
||||
@ -391,9 +405,13 @@ class FixTracks(QWidget):
|
||||
btnBox.addWidget(self._progress_bar)
|
||||
btnBox.addWidget(self._saveBtn)
|
||||
|
||||
self._classifier = ClassifierWidget()
|
||||
self._classifier.apply_sizeclassifier.connect(self.on_classifyBySize)
|
||||
self._classifier.setMaximumWidth(500)
|
||||
cntrlBox = QHBoxLayout()
|
||||
cntrlBox.addItem(QSpacerItem(200, 100, QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Expanding))
|
||||
cntrlBox.addWidget(self._classifier)
|
||||
cntrlBox.addWidget(self._controls_widget, alignment=Qt.AlignmentFlag.AlignCenter)
|
||||
cntrlBox.addItem(QSpacerItem(300, 100, QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Expanding))
|
||||
|
||||
vbox = QVBoxLayout()
|
||||
vbox.addLayout(timelinebox)
|
||||
@ -412,6 +430,12 @@ class FixTracks(QWidget):
|
||||
layout.addWidget(splitter)
|
||||
self.setLayout(layout)
|
||||
|
||||
def on_classifyBySize(self, tracks):
|
||||
self._data.setSelectionRange("index", 0, self._data.numDetections)
|
||||
self._data.assignTracks(tracks)
|
||||
self._timeline.setDetectionData(self._data.data)
|
||||
self.update()
|
||||
|
||||
def on_dataSelection(self):
|
||||
filename = self._data_combo.currentText()
|
||||
if "please select" in filename.lower() or len(filename.strip()) == 0:
|
||||
@ -509,6 +533,8 @@ class FixTracks(QWidget):
|
||||
maxframes = self._data.max("frame")
|
||||
rel_width = self._windowspinner.value() / maxframes
|
||||
self._timeline.setWindowWidth(rel_width)
|
||||
coordinates = self._data.coordinates()
|
||||
self._classifier.size_classifier.setCoordinates(coordinates)
|
||||
self.update()
|
||||
self._saveBtn.setEnabled(True)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user