[tracks] include classifier widget
This commit is contained in:
parent
796e03ae7e
commit
15cee494f6
@ -14,6 +14,8 @@ from fixtracks.utils.writer import PickleWriter
|
|||||||
from fixtracks.widgets.detectionview import DetectionView, DetectionData
|
from fixtracks.widgets.detectionview import DetectionView, DetectionData
|
||||||
from fixtracks.widgets.detectiontimeline import DetectionTimeline
|
from fixtracks.widgets.detectiontimeline import DetectionTimeline
|
||||||
from fixtracks.widgets.skeleton import SkeletonWidget
|
from fixtracks.widgets.skeleton import SkeletonWidget
|
||||||
|
from fixtracks.widgets.classifier import ClassifierWidget
|
||||||
|
|
||||||
|
|
||||||
class PoseTableModel(QAbstractTableModel):
|
class PoseTableModel(QAbstractTableModel):
|
||||||
column_header = ["frame", "track"]
|
column_header = ["frame", "track"]
|
||||||
@ -259,6 +261,10 @@ class DataController(QObject):
|
|||||||
logging.error("Column %s not in dictionary", col)
|
logging.error("Column %s not in dictionary", col)
|
||||||
return np.nan
|
return np.nan
|
||||||
|
|
||||||
|
@property
|
||||||
|
def numDetections(self):
|
||||||
|
return self._data["track"].shape[0]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def selectionRange(self):
|
def selectionRange(self):
|
||||||
return self._start, self._stop
|
return self._start, self._stop
|
||||||
@ -286,6 +292,12 @@ class DataController(QObject):
|
|||||||
def assignUserSelection(self, track_id):
|
def assignUserSelection(self, track_id):
|
||||||
self._data["track"][self._user_selections] = track_id
|
self._data["track"][self._user_selections] = track_id
|
||||||
|
|
||||||
|
def assignTracks(self, tracks):
|
||||||
|
if len(tracks) != self.numDetections:
|
||||||
|
logging.error("DataController: Size of passed tracks does not match data!")
|
||||||
|
return
|
||||||
|
self._data["track"] = tracks
|
||||||
|
|
||||||
def save(self, filename):
|
def save(self, filename):
|
||||||
export_columns = self._columns.copy()
|
export_columns = self._columns.copy()
|
||||||
export_columns.remove("index")
|
export_columns.remove("index")
|
||||||
@ -299,6 +311,8 @@ class DataController(QObject):
|
|||||||
return 0
|
return 0
|
||||||
return self._data["keypoints"][0].shape[0]
|
return self._data["keypoints"][0].shape[0]
|
||||||
|
|
||||||
|
def coordinates(self):
|
||||||
|
return np.stack(self._data["keypoints"]).astype(np.float32)
|
||||||
|
|
||||||
class FixTracks(QWidget):
|
class FixTracks(QWidget):
|
||||||
back = Signal()
|
back = Signal()
|
||||||
@ -391,9 +405,13 @@ class FixTracks(QWidget):
|
|||||||
btnBox.addWidget(self._progress_bar)
|
btnBox.addWidget(self._progress_bar)
|
||||||
btnBox.addWidget(self._saveBtn)
|
btnBox.addWidget(self._saveBtn)
|
||||||
|
|
||||||
|
self._classifier = ClassifierWidget()
|
||||||
|
self._classifier.apply_sizeclassifier.connect(self.on_classifyBySize)
|
||||||
|
self._classifier.setMaximumWidth(500)
|
||||||
cntrlBox = QHBoxLayout()
|
cntrlBox = QHBoxLayout()
|
||||||
cntrlBox.addItem(QSpacerItem(200, 100, QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Expanding))
|
cntrlBox.addWidget(self._classifier)
|
||||||
cntrlBox.addWidget(self._controls_widget, alignment=Qt.AlignmentFlag.AlignCenter)
|
cntrlBox.addWidget(self._controls_widget, alignment=Qt.AlignmentFlag.AlignCenter)
|
||||||
|
cntrlBox.addItem(QSpacerItem(300, 100, QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Expanding))
|
||||||
|
|
||||||
vbox = QVBoxLayout()
|
vbox = QVBoxLayout()
|
||||||
vbox.addLayout(timelinebox)
|
vbox.addLayout(timelinebox)
|
||||||
@ -412,6 +430,12 @@ class FixTracks(QWidget):
|
|||||||
layout.addWidget(splitter)
|
layout.addWidget(splitter)
|
||||||
self.setLayout(layout)
|
self.setLayout(layout)
|
||||||
|
|
||||||
|
def on_classifyBySize(self, tracks):
|
||||||
|
self._data.setSelectionRange("index", 0, self._data.numDetections)
|
||||||
|
self._data.assignTracks(tracks)
|
||||||
|
self._timeline.setDetectionData(self._data.data)
|
||||||
|
self.update()
|
||||||
|
|
||||||
def on_dataSelection(self):
|
def on_dataSelection(self):
|
||||||
filename = self._data_combo.currentText()
|
filename = self._data_combo.currentText()
|
||||||
if "please select" in filename.lower() or len(filename.strip()) == 0:
|
if "please select" in filename.lower() or len(filename.strip()) == 0:
|
||||||
@ -509,6 +533,8 @@ class FixTracks(QWidget):
|
|||||||
maxframes = self._data.max("frame")
|
maxframes = self._data.max("frame")
|
||||||
rel_width = self._windowspinner.value() / maxframes
|
rel_width = self._windowspinner.value() / maxframes
|
||||||
self._timeline.setWindowWidth(rel_width)
|
self._timeline.setWindowWidth(rel_width)
|
||||||
|
coordinates = self._data.coordinates()
|
||||||
|
self._classifier.size_classifier.setCoordinates(coordinates)
|
||||||
self.update()
|
self.update()
|
||||||
self._saveBtn.setEnabled(True)
|
self._saveBtn.setEnabled(True)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user