diff --git a/fixtracks/widgets/tracks.py b/fixtracks/widgets/tracks.py index fcb6e88..7ec07e8 100644 --- a/fixtracks/widgets/tracks.py +++ b/fixtracks/widgets/tracks.py @@ -1,9 +1,10 @@ import logging import pathlib import numpy as np +import pandas as pd -from PySide6.QtCore import Qt, QThreadPool, Signal -from PySide6.QtGui import QImage, QStandardItemModel, QStandardItem +from PySide6.QtCore import Qt, QThreadPool, Signal, QAbstractTableModel, QSortFilterProxyModel, QItemSelectionModel +from PySide6.QtGui import QImage, QBrush, QColor from PySide6.QtWidgets import QWidget, QVBoxLayout, QHBoxLayout, QLabel, QPushButton, QSizePolicy from PySide6.QtWidgets import QSpinBox, QSpacerItem, QFileDialog, QProgressBar, QTableView, QSplitter @@ -11,6 +12,76 @@ from fixtracks.utils.reader import PickleLoader from fixtracks.widgets.detectionview import DetectionView from fixtracks.widgets.timeline import Timeline + +class PoseTableModel(QAbstractTableModel): + column_header = ["frame", "track"] + columns = ["frame", "track"] + + def __init__(self, dataframe, parent=None): + super().__init__(parent) + self._data = dataframe + self._frames = self._data.frame.values + self._tracks = self._data.track.values + self._indices = self._data.index.values + self._column_data = [self._frames, self._tracks] + + def columnCount(self, parent=None): + return len(self.columns) + + def rowCount(self, parent=None): + if self._data is not None: + return len(self._data) + else: + return 0 + + def data(self, index, role = ...): + value = self._column_data[index.column()][index.row()] + if role == Qt.ItemDataRole.DisplayRole: + return str(value) + elif role == Qt.ItemDataRole.UserRole: + return value + return None + + def headerData(self, section, orientation, role = ...): + if role == Qt.ItemDataRole.DisplayRole: + if orientation == Qt.Orientation.Horizontal: + return self.column_header[section] + else: + return str(self._indices[section]) + else: + return None + + def mapIdToRow(self, id): + row = np.where(self._indices == id)[0] + if len(row) == 0: + return -1 + return row[0] + +class FilterProxyModel(QSortFilterProxyModel): + def __init__(self, parent=None): + super().__init__(parent) + self._range = None + + def setFilterRange(self, start, stop): + logging.info("FilterProxyModel.setFilterRange set to range %i , %i", start, stop) + self._range = (start, stop) + self.invalidateRowsFilter() + + def all(self): + self._range = None + + def filterAcceptsRow(self, source_row, source_parent): + if self._range is None: + return True + else: + idx = self.sourceModel().index(source_row, 0, source_parent); + val = self.sourceModel().data(idx, Qt.ItemDataRole.UserRole) + print("filteracceptrows: ", val, self._range, val >= self._range[0] and val < self._range[1] ) + return val >= self._range[0] and val < self._range[1] + + def filterAcceptsColumn(self, source_column, source_parent): + return True + class FixTracks(QWidget): back = Signal() @@ -24,8 +95,13 @@ class FixTracks(QWidget): self._unassignedmodel = None self._leftmodel = None self._rightmodel = None - + self._proxymodel = None + self._brushes = {"assigned_left": QBrush(QColor.fromString("red")), + "assigned_right": QBrush(QColor.fromString("blue")), + "unassigned": QBrush(QColor.fromString("white")) + } self._detectionView = DetectionView() + self._detectionView.signals.itemsSelected.connect(self.on_detectionsSelected) self._progress_bar = QProgressBar(self) self._progress_bar.setMaximumHeight(20) # self._progress_bar.setRange(0, 0) # Set the progress bar to be indeterminate @@ -37,6 +113,7 @@ class FixTracks(QWidget): self._windowspinner = QSpinBox() self._windowspinner.setRange(100, 10000) self._windowspinner.setSingleStep(100) + self._windowspinner.setValue(500) self._windowspinner.valueChanged.connect(self.on_windowSizeChanged) timelinebox = QHBoxLayout() @@ -50,6 +127,8 @@ class FixTracks(QWidget): assign2 = QPushButton(">>") assign2.clicked.connect(self.on_assignRight) self._unassigned_table = QTableView() + self._unassigned_table.setSelectionMode(QTableView.SelectionMode.ExtendedSelection) + self._unassigned_table.setSelectionBehavior(QTableView.SelectionBehavior.SelectRows) self._right_table = QTableView() tablebox = QHBoxLayout() tablebox.addWidget(self._left_table) @@ -83,7 +162,7 @@ class FixTracks(QWidget): container = QWidget() container.setLayout(vbox) - splitter = QSplitter(Qt.Orientation.Horizontal) + splitter = QSplitter(Qt.Orientation.Vertical) splitter.addWidget(self._detectionView) splitter.addWidget(container) splitter.setStretchFactor(0, 3) @@ -118,49 +197,36 @@ class FixTracks(QWidget): self._reader.signals.finished.connect(self._on_dataOpenend) self._threadpool.start(self._reader) - def populateTables(self): + def update_detectionView(df, name): + if len(df) == 0: + return + coords = np.stack(df.keypoints.values).astype(np.float32)[:,0,:] + tracks = df.track.values.astype(int) + ids = df.index.values.astype(int) + self._detectionView.addDetections(coords, tracks, ids, self._brushes[name]) + left_trackid = 1 right_trackid = 2 start_frame = self._timeline.sliderPosition - self._windowspinner.value() // 2 stop_frame = self._timeline.sliderPosition + self._windowspinner.value() // 2 - df = self._dataframe[(self._dataframe.frame >= start_frame) & (self._dataframe.frame < stop_frame)] + df = self._dataframe[(self._dataframe.frame >= start_frame) & (self._dataframe.frame < stop_frame)] assigned_left = df[(df.track == left_trackid)] assigned_right = df[(df.track == right_trackid)] unassigned = df[(df.track != left_trackid) & (df.track != right_trackid)] - print(len(assigned_left), len(assigned_right), len(unassigned)) - columns = ["frame", "track id"] - self._unassignedmodel = QStandardItemModel(len(unassigned), 2) - self._unassignedmodel.setHorizontalHeaderLabels(columns) - self._leftmodel = QStandardItemModel(len(assigned_left), 2) - self._leftmodel.setHorizontalHeaderLabels(columns) - self._rightmodel = QStandardItemModel(len(assigned_right), 2) - self._rightmodel.setHorizontalHeaderLabels(columns) - - # Populate the models with data - for i in range(len(unassigned)): - row = unassigned.iloc[i] - if i == 0: print(row) - for j in range(len(columns)): - item = QStandardItem(f"{i, j}") - self._unassignedmodel.setItem(i, j, item) - self._unassigned_table.setModel(self._unassignedmodel) - for i in range(len(assigned_left)): - row = assigned_left.iloc[i] - for j in range(len(columns)): - item = QStandardItem(f"{i, j}") - self._leftmodel.setItem(i, j, item) + self._unassignedmodel = PoseTableModel(unassigned) + self._unassigned_table.setModel(self._unassignedmodel) + self._leftmodel = PoseTableModel(assigned_left) self._left_table.setModel(self._leftmodel) - - for i in range(len(assigned_right)): - row = assigned_right.iloc[i] - for j in range(len(columns)): - item = QStandardItem(f"{i, j}") - self._rightmodel.setItem(i, j, item) + self._rightmodel = PoseTableModel(assigned_right) self._right_table.setModel(self._rightmodel) + self._detectionView.clearDetections() + update_detectionView(unassigned, "unassigned") + update_detectionView(assigned_left, "assigned_left") + update_detectionView(assigned_right, "assigned_right") def _on_dataOpenend(self, state): logging.info("Finished loading data with state %s", state) @@ -186,7 +252,24 @@ class FixTracks(QWidget): pass def on_windowChanged(self, value): + logging.info("Timeline reports window change") + self.populateTables() def on_windowSizeChanged(self, value): self._timeline.setWindowWidth(value) + + def on_detectionsSelected(self, detections): + selection_model = self._unassigned_table.selectionModel() + selection = selection_model.selection() + + for d in detections: + id = d.data(1) + row = self._unassignedmodel.mapIdToRow(id) + if row == -1: + continue + index = self._unassigned_table.model().index(row, 0) + selection.select(index, index) + + mode = QItemSelectionModel.Select | QItemSelectionModel.Rows + selection_model.select(selection, mode)