import logging import numpy as np from PySide6.QtCore import Qt, QThreadPool, Signal from PySide6.QtGui import QImage, QBrush, QColor from PySide6.QtWidgets import QWidget, QVBoxLayout, QHBoxLayout, QLabel, QPushButton, QSizePolicy, QComboBox from PySide6.QtWidgets import QSpinBox, QSpacerItem, QProgressBar, QSplitter, QFileDialog, QMessageBox from fixtracks.utils.reader import PickleLoader from fixtracks.utils.writer import PickleWriter from fixtracks.utils.trackingdata import TrackingData from fixtracks.widgets.detectionview import DetectionView, DetectionData from fixtracks.widgets.detectiontimeline import DetectionTimeline from fixtracks.widgets.skeleton import SkeletonWidget from fixtracks.widgets.classifier import ClassifierWidget from fixtracks.widgets.selection_control import SelectionControls class FixTracks(QWidget): back = Signal() trackone_id = 1 tracktwo_id = 2 trackother_id = -1 def __init__(self, parent=None): super().__init__(parent) self._files = [] self._threadpool = QThreadPool() self._reader = None self._image = None self._currentWindowPos = 0 # in frames self._currentWindowWidth = 0 # in frames self._maxframes = 0 self._manualmove = False self._data = None self._detectionView = DetectionView() self._detectionView.signals.itemsSelected.connect(self.on_detectionsSelected) self._skeleton = SkeletonWidget() self._progress_bar = QProgressBar(self) self._progress_bar.setMaximumHeight(20) self._progress_bar.setValue(0) self._tasklabel = QLabel() self._timeline = DetectionTimeline() self._timeline.signals.windowMoved.connect(self.on_windowChanged) self._timeline.signals.moveRequest.connect(self.on_moveRequest) self._windowspinner = QSpinBox() self._windowspinner.setRange(10, 10000) self._windowspinner.setSingleStep(50) self._windowspinner.setValue(500) self._windowspinner.valueChanged.connect(self.on_windowSizeChanged) self._keypointcombo = QComboBox() self._keypointcombo.currentIndexChanged.connect(self.on_keypointSelected) self._goto_spinner = QSpinBox() self._goto_spinner.setSingleStep(1) self._gotobtn = QPushButton("go!") self._gotobtn.setToolTip("Jump to a given frame") self._gotobtn.clicked.connect(self.on_goto) combo_layout = QHBoxLayout() combo_layout.addWidget(QLabel("Window width:")) combo_layout.addWidget(self._windowspinner) combo_layout.addWidget(QLabel("frames")) combo_layout.addItem(QSpacerItem(10, 10, QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Fixed)) combo_layout.addWidget(QLabel("Keypoint:")) combo_layout.addWidget(self._keypointcombo) combo_layout.addItem(QSpacerItem(10, 10, QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Fixed)) combo_layout.addWidget(QLabel("Jump to frame:")) combo_layout.addWidget(self._goto_spinner) combo_layout.addWidget(self._gotobtn) combo_layout.addItem(QSpacerItem(100, 10, QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Fixed)) combo_layout.setSpacing(1) timelinebox = QVBoxLayout() timelinebox.setSpacing(2) timelinebox.addLayout(combo_layout) timelinebox.addWidget(self._timeline) self._controls_widget = SelectionControls() self._controls_widget.assignOne.connect(self.on_assignOne) self._controls_widget.assignTwo.connect(self.on_assignTwo) self._controls_widget.assignOther.connect(self.on_assignOther) self._controls_widget.fwd.connect(self.on_forward) self._controls_widget.back.connect(self.on_backward) self._controls_widget.accept.connect(self.on_setUserFlag) self._controls_widget.accept_until.connect(self.on_setUserFlagsUntil) self._controls_widget.unaccept.connect(self.on_unsetUserFlag) self._controls_widget.delete.connect(self.on_deleteDetection) self._controls_widget.revertall.connect(self.on_revertUserFlags) self._saveBtn = QPushButton("Save") self._saveBtn.setShortcut("Ctrl+S") self._saveBtn.setEnabled(False) self._saveBtn.clicked.connect(self.on_save) self._backBtn = QPushButton("Back") self._backBtn.setShortcut("ESC") self._backBtn.clicked.connect(self.on_back) self._data_combo = QComboBox() self._data_combo.addItems(self._files) self._data_combo.currentIndexChanged.connect(self.on_dataSelection) self._image_combo = QComboBox() self._image_combo.addItems(self._files) self._image_combo.currentIndexChanged.connect(self.on_imageSelection) data_selection_box = QHBoxLayout() data_selection_box.addWidget(QLabel("Select image file")) data_selection_box.addWidget(self._image_combo) data_selection_box.addWidget(QLabel("Select data file")) data_selection_box.addWidget(self._data_combo) data_selection_box.addItem(QSpacerItem(100, 10, QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Fixed)) data_selection_box.setSpacing(0) btnBox = QHBoxLayout() btnBox.setAlignment(Qt.AlignmentFlag.AlignLeft) btnBox.addWidget(self._backBtn) btnBox.addItem(QSpacerItem(100, 10, QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Fixed)) btnBox.addWidget(self._tasklabel) btnBox.addWidget(self._progress_bar) btnBox.addWidget(self._saveBtn) self._classifier = ClassifierWidget() self._classifier.apply_classifier.connect(self.on_autoClassify) self._classifier.setMaximumWidth(500) cntrlBox = QHBoxLayout() cntrlBox.addWidget(self._classifier) cntrlBox.addWidget(self._controls_widget, alignment=Qt.AlignmentFlag.AlignCenter) cntrlBox.addWidget(self._skeleton) cntrlBox.addItem(QSpacerItem(50, 100, QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Expanding)) cntrlBox.setSpacing(0) cntrlBox.setContentsMargins(0,0,0,0) vbox = QVBoxLayout() vbox.setSpacing(0) vbox.setContentsMargins(0,0,0,0) vbox.addLayout(timelinebox) vbox.addLayout(cntrlBox) vbox.addLayout(btnBox) container = QWidget() container.setLayout(vbox) splitter = QSplitter(Qt.Orientation.Vertical) splitter.addWidget(self._detectionView) splitter.addWidget(container) splitter.setStretchFactor(0, 3) splitter.setStretchFactor(1, 1) layout = QVBoxLayout() layout.addLayout(data_selection_box) layout.addWidget(splitter) layout.setSpacing(0) layout.setContentsMargins(5,2,2,5) self.setLayout(layout) def on_autoClassify(self, tracks): self._data.setSelectionRange("index", 0, self._data.numDetections) self._data.assignTracks(tracks) self._timeline.update() self.update() def on_dataSelection(self): filename = self._data_combo.currentText() if "please select" in filename.lower() or len(filename.strip()) == 0: return self._progress_bar.setRange(0,0) self._reader = PickleLoader(filename) self._reader.signals.finished.connect(self._on_dataOpenend) self._threadpool.start(self._reader) def on_imageSelection(self): filename = self._image_combo.currentText() if "please select" in filename.lower() or len(filename.strip()) == 0: return img = QImage(filename) self._detectionView.setImage(img) def update(self): kp = self._keypointcombo.currentText().lower() if len(kp) == 0: return kpi = -1 if "center" in kp else int(kp) start_frame = self._currentWindowPos stop_frame = start_frame + self._currentWindowWidth self._timeline.setWindow(start_frame / self._maxframes, self._currentWindowWidth/self._maxframes) logging.debug("Tracks:update: Updating View for detection range %i, %i frames", start_frame, stop_frame) self._data.setSelectionRange("frame", start_frame, stop_frame) self._controls_widget.setWindow(start_frame, stop_frame) self._detectionView.updateDetections(kpi) @property def fileList(self): return self._files @fileList.setter def fileList(self, file_list): logging.debug("FixTracks.fileList: set new file list") self._files = [] self._image_combo.clear() self._data_combo.clear() img_formats = [".jpg", ".png"] self._files = [str(f) for f in file_list if f.suffix in img_formats] self._image_combo.addItem("Please select") self._image_combo.addItems(self.fileList) self._image_combo.setCurrentIndex(0) dataformats = [".pkl"] self._files = [str(f) for f in file_list if f.suffix in dataformats] self._data_combo.addItem("Please select") self._data_combo.addItems(self.fileList) self._data_combo.setCurrentIndex(0) def populateKeypointCombo(self, num_keypoints): self._keypointcombo.clear() self._keypointcombo.addItem("Center") for i in range(num_keypoints): self._keypointcombo.addItem(str(i)) self._keypointcombo.setCurrentIndex(0) def _on_dataOpenend(self, state): self._tasklabel.setText("") self._progress_bar.setRange(0, 100) self._progress_bar.setValue(0) if state and self._reader is not None: self._data = TrackingData(self._reader.asdict) self._saveBtn.setEnabled(True) self._currentWindowPos = 0 self._currentWindowWidth = self._windowspinner.value() self._maxframes = np.max(self._data["frame"]) self._goto_spinner.setMaximum(self._maxframes) self.populateKeypointCombo(self._data.numKeypoints()) self._timeline.setData(self._data) # self._timeline.setWindow(self._currentWindowPos / self._maxframes, # self._currentWindowWidth / self._maxframes) self._detectionView.setData(self._data) self._classifier.setData(self._data) self.update() logging.info("Finished loading data: %i frames", self._maxframes) def on_keypointSelected(self): self.update() def on_save(self): logging.debug("Saving fixtracks results") self._tasklabel.setText("Saving results to file...") file_dialog = QFileDialog(self) file_dialog.setAcceptMode(QFileDialog.AcceptMode.AcceptSave) file_dialog.setNameFilter("Pickle Files (*.pkl)") if file_dialog.exec(): file_path = file_dialog.selectedFiles()[0] if not file_path.endswith(".pkl"): file_path += ".pkl" self._progress_bar.setRange(0,0) save_task = PickleWriter(self._data, file_path) save_task.signals.finished.connect(self.on_dataSaved) self._threadpool.start(save_task) def on_dataSaved(self): self._tasklabel.setText("") self._progress_bar.setRange(0, 100) self._progress_bar.setValue(0) def on_back(self): logging.debug("Back button pressed!") self.back.emit() def on_assignOne(self): logging.debug("Assigning user selection to track One") self._data.setTrack(self.trackone_id) self._timeline.update() self.update() def on_assignTwo(self): logging.debug("Assigning user selection to track Two") self._data.setTrack(self.tracktwo_id) self._timeline.update() self.update() def on_assignOther(self): logging.debug("Assigning user selection to track Other") self._data.setTrack(self.trackother_id, False) self._timeline.update() self.update() def on_setUserFlag(self): self._data.setUserLabeledStatus(True) self._timeline.update() self.update() def on_setUserFlagsUntil(self): self._data.setSelectionRange("frame", 0, self._currentWindowPos + self._currentWindowWidth) self._data.setUserLabeledStatus(True) self._timeline.update() self.update() def on_unsetUserFlag(self): logging.debug("Tracks:unsetUserFlag") self._data.setUserLabeledStatus(False) self._timeline.update() self.update() def on_revertUserFlags(self): logging.debug("Tracks:revert ALL UserFlags and track assignments") msg_box = QMessageBox() msg_box.setIcon(QMessageBox.Icon.Warning) msg_box.setText(f"Are you sure you want to revert ALL track assignments?") msg_box.setStandardButtons(QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No) msg_box.setDefaultButton(QMessageBox.StandardButton.No) ret = msg_box.exec() if ret == QMessageBox.StandardButton.Yes: self._data.revertUserLabeledStatus() self._data.revertTrackAssignments() self._timeline.update() self.update() def on_deleteDetection(self): logging.info("Tracks:deleting detections!") msg_box = QMessageBox() msg_box.setIcon(QMessageBox.Icon.Warning) msg_box.setText(f"Are you sure you want to delete the selected ({len(self._data.selectionIndices)})detections?") msg_box.setStandardButtons(QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No) msg_box.setDefaultButton(QMessageBox.StandardButton.No) ret = msg_box.exec() if ret == QMessageBox.StandardButton.Yes: self._data.deleteDetections() self._timeline.update() self.update() def on_windowChanged(self): logging.debug("Tracks:Timeline reports window change ") if not self._manualmove: self._currentWindowPos = np.round(self._timeline.rangeStart * self._maxframes) self.update() self._manualmove = False def on_moveRequest(self, pos): new_pos = int(np.round(pos * self._maxframes)) self._currentWindowPos = new_pos self.update() def on_windowSizeChanged(self, value): """Reacts on the user window-width selection. Selection is done in the unit of frames. Parameters ---------- value : int The width of the observation window in frames. """ self._currentWindowWidth = value logging.debug("Tracks:OnWindowSizeChanged %i franes", value) # if self._maxframes == 0: # self._timeline.setWindowWidth(self._currentWindowWidth / 2000) # else: # self._timeline.setWindowWidth(self._currentWindowWidth / self._maxframes) # self._controls_widget.setSelectedTracks(None) self.update() def on_goto(self): target = self._goto_spinner.value() if target > self._maxframes - self._currentWindowWidth: target = self._maxframes - self._currentWindowWidth logging.info("Jump to frame %i", target) self._currentWindowPos = target self._timeline.setWindow(self._currentWindowPos / self._maxframes, self._currentWindowWidth / self._maxframes) self.update() def on_detectionsSelected(self, detections): logging.debug("Tracks: %i Detections selected", len(detections)) tracks = np.zeros(len(detections), dtype=int) ids = np.zeros_like(tracks) frames = np.zeros_like(tracks) scores = np.zeros(tracks.shape, dtype=float) coordinates = None if len(detections) > 0: c = detections[0].data(DetectionData.COORDINATES.value) coordinates = np.zeros((len(detections), c.shape[0], c.shape[1])) for i, d in enumerate(detections): tracks[i] = d.data(DetectionData.TRACK_ID.value) ids[i] = d.data(DetectionData.ID.value) frames[i] = d.data(DetectionData.FRAME.value) coordinates[i, :, :] = d.data(DetectionData.COORDINATES.value) scores[i] = d.data(DetectionData.SCORE.value) self._data.setSelection(ids) self._controls_widget.setSelectedTracks(tracks) self._skeleton.clear() self._skeleton.addSkeletons(coordinates, ids, frames, tracks, scores, QBrush(QColor(10, 255, 65, 255))) def moveWindow(self, stepsize): logging.info("Tracks.moveWindow: move window with stepsize %.2f", stepsize) self._manualmove = True new_start_frame = self._currentWindowPos + np.round(stepsize * self._currentWindowWidth) if new_start_frame < 0: new_start_frame = 0 elif new_start_frame + self._currentWindowWidth > self._maxframes: new_start_frame = self._maxframes - self._currentWindowWidth self._currentWindowPos = new_start_frame self._controls_widget.setSelectedTracks(None) self.update() def on_forward(self, stepsize): logging.debug("Tracks: receive forward command with step-size: %.2f", stepsize) self.moveWindow(stepsize) def on_backward(self, stepsize): logging.debug("Tracks: receive backward command with step-size: %.2f", stepsize) self.moveWindow(-stepsize) def main(): from PySide6.QtWidgets import QApplication app = QApplication([]) window = QWidget() window.setMinimumSize(200, 200) layout = QVBoxLayout() controls = SelectionControls() layout.addWidget(controls) window.setLayout(layout) window.show() app.exec() if __name__ == "__main__": main()