[tracks] specializes tablemodel, automatic selection of tablerows and more

This commit is contained in:
Jan Grewe 2025-01-25 14:49:04 +01:00
parent a7f6d65e62
commit 723ff8d3b9

View File

@ -1,9 +1,10 @@
import logging
import pathlib
import numpy as np
import pandas as pd
from PySide6.QtCore import Qt, QThreadPool, Signal
from PySide6.QtGui import QImage, QStandardItemModel, QStandardItem
from PySide6.QtCore import Qt, QThreadPool, Signal, QAbstractTableModel, QSortFilterProxyModel, QItemSelectionModel
from PySide6.QtGui import QImage, QBrush, QColor
from PySide6.QtWidgets import QWidget, QVBoxLayout, QHBoxLayout, QLabel, QPushButton, QSizePolicy
from PySide6.QtWidgets import QSpinBox, QSpacerItem, QFileDialog, QProgressBar, QTableView, QSplitter
@ -11,6 +12,76 @@ from fixtracks.utils.reader import PickleLoader
from fixtracks.widgets.detectionview import DetectionView
from fixtracks.widgets.timeline import Timeline
class PoseTableModel(QAbstractTableModel):
column_header = ["frame", "track"]
columns = ["frame", "track"]
def __init__(self, dataframe, parent=None):
super().__init__(parent)
self._data = dataframe
self._frames = self._data.frame.values
self._tracks = self._data.track.values
self._indices = self._data.index.values
self._column_data = [self._frames, self._tracks]
def columnCount(self, parent=None):
return len(self.columns)
def rowCount(self, parent=None):
if self._data is not None:
return len(self._data)
else:
return 0
def data(self, index, role = ...):
value = self._column_data[index.column()][index.row()]
if role == Qt.ItemDataRole.DisplayRole:
return str(value)
elif role == Qt.ItemDataRole.UserRole:
return value
return None
def headerData(self, section, orientation, role = ...):
if role == Qt.ItemDataRole.DisplayRole:
if orientation == Qt.Orientation.Horizontal:
return self.column_header[section]
else:
return str(self._indices[section])
else:
return None
def mapIdToRow(self, id):
row = np.where(self._indices == id)[0]
if len(row) == 0:
return -1
return row[0]
class FilterProxyModel(QSortFilterProxyModel):
def __init__(self, parent=None):
super().__init__(parent)
self._range = None
def setFilterRange(self, start, stop):
logging.info("FilterProxyModel.setFilterRange set to range %i , %i", start, stop)
self._range = (start, stop)
self.invalidateRowsFilter()
def all(self):
self._range = None
def filterAcceptsRow(self, source_row, source_parent):
if self._range is None:
return True
else:
idx = self.sourceModel().index(source_row, 0, source_parent);
val = self.sourceModel().data(idx, Qt.ItemDataRole.UserRole)
print("filteracceptrows: ", val, self._range, val >= self._range[0] and val < self._range[1] )
return val >= self._range[0] and val < self._range[1]
def filterAcceptsColumn(self, source_column, source_parent):
return True
class FixTracks(QWidget):
back = Signal()
@ -24,8 +95,13 @@ class FixTracks(QWidget):
self._unassignedmodel = None
self._leftmodel = None
self._rightmodel = None
self._proxymodel = None
self._brushes = {"assigned_left": QBrush(QColor.fromString("red")),
"assigned_right": QBrush(QColor.fromString("blue")),
"unassigned": QBrush(QColor.fromString("white"))
}
self._detectionView = DetectionView()
self._detectionView.signals.itemsSelected.connect(self.on_detectionsSelected)
self._progress_bar = QProgressBar(self)
self._progress_bar.setMaximumHeight(20)
# self._progress_bar.setRange(0, 0) # Set the progress bar to be indeterminate
@ -37,6 +113,7 @@ class FixTracks(QWidget):
self._windowspinner = QSpinBox()
self._windowspinner.setRange(100, 10000)
self._windowspinner.setSingleStep(100)
self._windowspinner.setValue(500)
self._windowspinner.valueChanged.connect(self.on_windowSizeChanged)
timelinebox = QHBoxLayout()
@ -50,6 +127,8 @@ class FixTracks(QWidget):
assign2 = QPushButton(">>")
assign2.clicked.connect(self.on_assignRight)
self._unassigned_table = QTableView()
self._unassigned_table.setSelectionMode(QTableView.SelectionMode.ExtendedSelection)
self._unassigned_table.setSelectionBehavior(QTableView.SelectionBehavior.SelectRows)
self._right_table = QTableView()
tablebox = QHBoxLayout()
tablebox.addWidget(self._left_table)
@ -83,7 +162,7 @@ class FixTracks(QWidget):
container = QWidget()
container.setLayout(vbox)
splitter = QSplitter(Qt.Orientation.Horizontal)
splitter = QSplitter(Qt.Orientation.Vertical)
splitter.addWidget(self._detectionView)
splitter.addWidget(container)
splitter.setStretchFactor(0, 3)
@ -118,49 +197,36 @@ class FixTracks(QWidget):
self._reader.signals.finished.connect(self._on_dataOpenend)
self._threadpool.start(self._reader)
def populateTables(self):
def update_detectionView(df, name):
if len(df) == 0:
return
coords = np.stack(df.keypoints.values).astype(np.float32)[:,0,:]
tracks = df.track.values.astype(int)
ids = df.index.values.astype(int)
self._detectionView.addDetections(coords, tracks, ids, self._brushes[name])
left_trackid = 1
right_trackid = 2
start_frame = self._timeline.sliderPosition - self._windowspinner.value() // 2
stop_frame = self._timeline.sliderPosition + self._windowspinner.value() // 2
df = self._dataframe[(self._dataframe.frame >= start_frame) & (self._dataframe.frame < stop_frame)]
df = self._dataframe[(self._dataframe.frame >= start_frame) & (self._dataframe.frame < stop_frame)]
assigned_left = df[(df.track == left_trackid)]
assigned_right = df[(df.track == right_trackid)]
unassigned = df[(df.track != left_trackid) & (df.track != right_trackid)]
print(len(assigned_left), len(assigned_right), len(unassigned))
columns = ["frame", "track id"]
self._unassignedmodel = QStandardItemModel(len(unassigned), 2)
self._unassignedmodel.setHorizontalHeaderLabels(columns)
self._leftmodel = QStandardItemModel(len(assigned_left), 2)
self._leftmodel.setHorizontalHeaderLabels(columns)
self._rightmodel = QStandardItemModel(len(assigned_right), 2)
self._rightmodel.setHorizontalHeaderLabels(columns)
# Populate the models with data
for i in range(len(unassigned)):
row = unassigned.iloc[i]
if i == 0: print(row)
for j in range(len(columns)):
item = QStandardItem(f"{i, j}")
self._unassignedmodel.setItem(i, j, item)
self._unassigned_table.setModel(self._unassignedmodel)
for i in range(len(assigned_left)):
row = assigned_left.iloc[i]
for j in range(len(columns)):
item = QStandardItem(f"{i, j}")
self._leftmodel.setItem(i, j, item)
self._unassignedmodel = PoseTableModel(unassigned)
self._unassigned_table.setModel(self._unassignedmodel)
self._leftmodel = PoseTableModel(assigned_left)
self._left_table.setModel(self._leftmodel)
for i in range(len(assigned_right)):
row = assigned_right.iloc[i]
for j in range(len(columns)):
item = QStandardItem(f"{i, j}")
self._rightmodel.setItem(i, j, item)
self._rightmodel = PoseTableModel(assigned_right)
self._right_table.setModel(self._rightmodel)
self._detectionView.clearDetections()
update_detectionView(unassigned, "unassigned")
update_detectionView(assigned_left, "assigned_left")
update_detectionView(assigned_right, "assigned_right")
def _on_dataOpenend(self, state):
logging.info("Finished loading data with state %s", state)
@ -186,7 +252,24 @@ class FixTracks(QWidget):
pass
def on_windowChanged(self, value):
logging.info("Timeline reports window change")
self.populateTables()
def on_windowSizeChanged(self, value):
self._timeline.setWindowWidth(value)
def on_detectionsSelected(self, detections):
selection_model = self._unassigned_table.selectionModel()
selection = selection_model.selection()
for d in detections:
id = d.data(1)
row = self._unassignedmodel.mapIdToRow(id)
if row == -1:
continue
index = self._unassigned_table.model().index(row, 0)
selection.select(index, index)
mode = QItemSelectionModel.Select | QItemSelectionModel.Rows
selection_model.select(selection, mode)