From 46d9e3d15bcdde13b2f7c47d2e956cee9f6b9a92 Mon Sep 17 00:00:00 2001 From: Jan Grewe Date: Tue, 28 Jan 2025 14:50:24 +0100 Subject: [PATCH] [tracks] no more tables, much more speed --- fixtracks/widgets/tracks.py | 241 +++++++++++++++++++++++------------- 1 file changed, 156 insertions(+), 85 deletions(-) diff --git a/fixtracks/widgets/tracks.py b/fixtracks/widgets/tracks.py index b4b4501..5c356b2 100644 --- a/fixtracks/widgets/tracks.py +++ b/fixtracks/widgets/tracks.py @@ -3,7 +3,7 @@ import pathlib import numpy as np import pandas as pd -from PySide6.QtCore import Qt, QThreadPool, Signal, QAbstractTableModel, QSortFilterProxyModel, QItemSelectionModel, QSize +from PySide6.QtCore import Qt, QThreadPool, Signal, QAbstractTableModel, QSortFilterProxyModel, QItemSelectionModel, QSize, QObject from PySide6.QtGui import QImage, QBrush, QColor, QFont from PySide6.QtWidgets import QWidget, QVBoxLayout, QHBoxLayout, QLabel, QPushButton, QSizePolicy, QComboBox from PySide6.QtWidgets import QSpinBox, QSpacerItem, QProgressBar, QTableView, QSplitter, QGridLayout @@ -134,21 +134,38 @@ class SelectionControls(QWidget): assignOtherBtn.setShortcut("Ctrl+0") assignOtherBtn.clicked.connect(self.on_TrackOther) + self.tone_selection = QLabel("0") + self.ttwo_selection = QLabel("0") + self.tother_selection = QLabel("0") + self._total = 0 + grid = QGridLayout() grid.addWidget(previousBtn, 0, 0, 4, 2) + grid.addWidget(nextBtn, 0, 6, 4, 2) grid.addWidget(QLabel("Current selection:"), 0, 2, 1, 4) - grid.addWidget(QLabel("Track One:"), 1, 2, 1, 4) - grid.addWidget(QLabel("Track Two:"), 2, 2, 1, 4) - grid.addWidget(QLabel("Unassigned:"), 3, 2, 1, 4) - grid.addWidget(nextBtn, 0, 5, 4, 2) + grid.addWidget(QLabel("Track One:"), 1, 2, 1, 3) + grid.addWidget(self.tone_selection, 1, 5, 1, 1) + grid.addWidget(QLabel("Track Two:"), 2, 2, 1, 3) + grid.addWidget(self.ttwo_selection, 2, 5, 1, 1) + grid.addWidget(QLabel("Unassigned:"), 3, 2, 1, 3) + grid.addWidget(self.tother_selection, 3, 5, 1, 1) grid.addWidget(assignOneBtn, 4, 0, 4, 3) grid.addWidget(assignOtherBtn, 4, 3, 4, 2) grid.addWidget(assignTwoBtn, 4, 5, 4, 3) - + grid.setColumnStretch(0, 1) + grid.setColumnStretch(7, 1) self.setLayout(grid) self.setMaximumSize(QSize(400, 200)) + def _updateNumbers(self, track): + labels = {1: self.tone_selection, 2: self.ttwo_selection, 3: self.tother_selection} + for k in labels: + if k == track: + labels[k].setText(str(self._total)) + else: + labels[k].setText("0") + def on_Next(self): self.next.emit() @@ -157,19 +174,93 @@ class SelectionControls(QWidget): def on_TrackOne(self): self.assignOne.emit() + self._updateNumbers(1) def on_TrackTwo(self): self.assignTwo.emit() + self._updateNumbers(2) def on_TrackOther(self): self.assignOther.emit() + self._updateNumbers(3) + + def setSelectedTracks(self, tracks): + logging.debug("SelectionControl: setSelectedTracks") + tone = np.sum(tracks == 1) + ttwo = np.sum(tracks == 2) + self.tone_selection.setText(str(tone)) + self.ttwo_selection.setText(str(ttwo)) + self.tother_selection.setText(str(len(tracks) - tone - ttwo)) + self._total = len(tracks) + + +class DataController(QObject): + def __init__(self, parent=None): + super().__init__(parent) + self._data = None + self._columns = [] + self._start = 0 + self._stop = 0 + self._indices = None + self._selection_column = None + self._user_selections = None + + def setData(self, datadict): + assert isinstance(datadict, dict) + self._data = datadict + self._columns = [k for k in self._data.keys()] + + @property + def data(self): + return self._data + + @property + def columns(self): + return self._columns + + def max(self, col): + if col in self.columns: + return np.max(self._data[col]) + else: + logging.error("Column %s not in dictionary", col) + return np.nan - def setSelection(self, selection): + @property + def selectionRange(self): + return self._start, self._stop + + @property + def selectionRangeColumn(self): + return self._selection_column + + @property + def selectionIndices(self): + return self._indices + + def setSelectionRange(self, col, start, stop): + self._start = start + self._stop = stop + self._selection_column = col + self._indices = np.where((self._data[col] >= self._start) & (self._data[col] < self._stop))[0] + + def selectedData(self, col): + return self._data[col][self._indices] + + def setUserSelection(self, ids): + self._user_selections = ids.astype(int) + + def assignUserSelection(self, track_id): + self._data["track"][self._user_selections] = track_id + + def save(self, filename): pass class FixTracks(QWidget): back = Signal() + trackone_id = 1 + tracktwo_id = 2 + trackother_id = -1 def __init__(self, parent=None): super().__init__(parent) @@ -177,7 +268,7 @@ class FixTracks(QWidget): self._threadpool = QThreadPool() self._reader = None self._image = None - self._data = None + self._data = DataController() self._unassignedmodel = None self._leftmodel = None self._rightmodel = None @@ -207,7 +298,10 @@ class FixTracks(QWidget): timelinebox.addWidget(QLabel("Window")) timelinebox.addWidget(self._windowspinner) - # self._controls_widget = SelectionControls() + self._controls_widget = SelectionControls() + self._controls_widget.assignOne.connect(self.on_assignOne) + self._controls_widget.assignTwo.connect(self.on_assignTwo) + self._controls_widget.assignOther.connect(self.on_assignOther) self._trackone_table = QTableView() font = QFont() @@ -215,9 +309,9 @@ class FixTracks(QWidget): font.setPointSize(8) self._trackone_table.setFont(font) assign1 = QPushButton("<<") - assign1.clicked.connect(self.on_assignLeft) + assign1.clicked.connect(self.on_assignOne) assign2 = QPushButton(">>") - assign2.clicked.connect(self.on_assignRight) + assign2.clicked.connect(self.on_assignTwo) self._unassigned_table = QTableView() self._unassigned_table.setFont(font) self._unassigned_table.setSelectionMode(QTableView.SelectionMode.ExtendedSelection) @@ -243,12 +337,12 @@ class FixTracks(QWidget): trackother_box.addWidget(trackother_label) trackother_box.addWidget(self._unassigned_table) - tablebox = QHBoxLayout() - tablebox.addLayout(track1_box) - tablebox.addWidget(assign1) - tablebox.addLayout(trackother_box) - tablebox.addWidget(assign2) - tablebox.addLayout(tracktwo_box) + # tablebox = QHBoxLayout() + # tablebox.addLayout(track1_box) + # tablebox.addWidget(assign1) + # tablebox.addLayout(trackother_box) + # tablebox.addWidget(assign2) + # tablebox.addLayout(tracktwo_box) self._saveBtn = QPushButton("Save") self._saveBtn.setEnabled(False) @@ -281,7 +375,8 @@ class FixTracks(QWidget): vbox = QVBoxLayout() vbox.addLayout(timelinebox) - vbox.addLayout(tablebox) + # vbox.addLayout(tablebox) + vbox.addWidget(self._controls_widget, stretch=1, alignment=Qt.AlignmentFlag.AlignCenter) vbox.addLayout(btnBox) container = QWidget() container.setLayout(vbox) @@ -296,7 +391,6 @@ class FixTracks(QWidget): layout.addWidget(splitter) self.setLayout(layout) - def on_dataSelection(self): filename = self._data_combo.currentText() if "please select" in filename.lower(): @@ -313,7 +407,7 @@ class FixTracks(QWidget): img = QImage(filename) self._detectionView.setImage(img) - def populateTables(self): + def update(self): def update_detectionView(df, name): if len(df) == 0: return @@ -322,36 +416,25 @@ class FixTracks(QWidget): ids = df.index.values.astype(int) self._detectionView.addDetections(coords, tracks, ids, self._brushes[name]) - trackone_id = 1 - tracktwo_id = 2 - - max_frames = np.max(self._data["frame"]) + max_frames = self._data.max("frame") start = self._timeline.rangeStart stop = self._timeline.rangeStop - print(start, stop, max_frames) start_frame = int(np.floor(start * max_frames)) stop_frame = int(np.ceil(stop * max_frames)) - logging.debug("Updating TableModel for range %i, %i", start_frame, stop_frame) - indices = np.where((self._data["frame"] >= start_frame) & (self._data["frame"] < stop_frame))[0] - # from IPython import embed - # embed() - # exit() - df = pd.DataFrame({"frame": self._data["index"][indices], - "track": self._data["track"][indices], - "keypoints": self._data["keypoints"][indices]}, - index= self._data["index"][indices]) - assigned_left = df[(df.track == trackone_id)] - assigned_right = df[(df.track == tracktwo_id)] - unassigned = df[(df.track != trackone_id) & (df.track != tracktwo_id)] - logging.debug("Updating TableModel: %i track one, %i unassigned, %i tracktwo", len(assigned_left), - len(unassigned), len(assigned_right)) - - self._unassignedmodel = PoseTableModel(unassigned) - self._unassigned_table.setModel(self._unassignedmodel) - self._leftmodel = PoseTableModel(assigned_left) - self._trackone_table.setModel(self._leftmodel) - self._rightmodel = PoseTableModel(assigned_right) - self._tracktwo_table.setModel(self._rightmodel) + logging.debug("Updating View for detection range %i, %i frames", start_frame, stop_frame) + self._data.setSelectionRange("frame", start_frame, stop_frame) + frames = self._data.selectedData("frame") + tracks = self._data.selectedData("track") + keypoints = self._data.selectedData("keypoints") + index = self._data.selectedData("index") + + df = pd.DataFrame({"frame": frames, + "track": tracks, + "keypoints": keypoints}, + index=index) + assigned_left = df[(df.track == self.trackone_id)] + assigned_right = df[(df.track == self.tracktwo_id)] + unassigned = df[(df.track != self.trackone_id) & (df.track != self.tracktwo_id)] self._detectionView.clearDetections() update_detectionView(unassigned, "unassigned") @@ -380,18 +463,15 @@ class FixTracks(QWidget): self._data_combo.addItems(self.fileList) self._data_combo.setCurrentIndex(0) - # self._data_combo.currentIndexChanged.connect(self.on_dataSelection) - # self._image_combo.currentIndexChanged.connect(self.on_imageSelection) - def _on_dataOpenend(self, state): logging.info("Finished loading data with state %s", state) self._tasklabel.setText("") self._progress_bar.setRange(0, 100) self._progress_bar.setValue(0) if state and self._reader is not None: - self._data = self._reader.asdict - self._timeline.setDetectionData(self._data) - self.populateTables() + self._data.setData(self._reader.asdict) + self._timeline.setDetectionData(self._data.data) + self.update() def on_save(self): logging.debug("Save fixtracks results") @@ -400,50 +480,41 @@ class FixTracks(QWidget): logging.debug("Back button pressed!") self.back.emit() - def assignTrack(self, rows, trackid): - logging.debug("Assign %i detections to Track One", len(rows)) - ids = np.zeros(len(rows), dtype=int) - for i,r in enumerate(rows): - ids[i] = self._unassignedmodel.headerData(r.row(), Qt.Orientation.Vertical, Qt.ItemDataRole.DisplayRole) + def on_assignOne(self): + logging.debug("Assigning user selection to track One") + self._data.assignUserSelection(self.trackone_id) + self._timeline.setDetectionData(self._data.data) + self.update() - self._data["track"][ids] = np.zeros_like(ids, dtype=int) + trackid - self.populateTables() - self._timeline.setDetectionData(self._data) + def on_assignTwo(self): + logging.debug("Assigning user selection to track Two") + self._data.assignUserSelection(self.tracktwo_id) + self._timeline.setDetectionData(self._data.data) + self.update() - def on_assignLeft(self): - selection = self._unassigned_table.selectionModel() - rows = selection.selectedRows() - logging.debug("Assign %i detections to Track One", len(rows)) - self.assignTrack(rows, 1) - - def on_assignRight(self): - selection = self._unassigned_table.selectionModel() - rows = selection.selectedRows() - logging.debug("Assign %i detections to Track Two", len(rows)) - self.assignTrack(rows, 2) + def on_assignOther(self): + logging.debug("Assigning user selection to track Other") + self._data.assignUserSelection(self.trackother_id) + self._timeline.setDetectionData(self._data.data) + self.update() def on_windowChanged(self, start, stop): logging.info("Timeline reports window change to range %f %f percent of data", start, stop) - self.populateTables() + self.update() def on_windowSizeChanged(self, value): self._timeline.setWindowWidth(value) def on_detectionsSelected(self, detections): logging.debug("Tracks: Detections selected") - selection_model = self._unassigned_table.selectionModel() - selection = selection_model.selection() - logging.debug("Tracks: Get selection model") - for d in detections: - id = d.data(1) - row = self._unassignedmodel.mapIdToRow(id) - if row == -1: - continue - index = self._unassigned_table.model().index(row, 0) - selection.select(index, index) - logging.debug("Tracks: Detections selected") - mode = QItemSelectionModel.Select | QItemSelectionModel.Rows - selection_model.select(selection, mode) + tracks = np.zeros(len(detections)) + ids = np.zeros_like(tracks) + for i, d in enumerate(detections): + tracks[i] = d.data(0) + ids[i] = d.data(1) + self._data.setUserSelection(ids) + self._controls_widget.setSelectedTracks(tracks) + self.update() def main():