[tracks] no more tables, much more speed

This commit is contained in:
Jan Grewe 2025-01-28 14:50:24 +01:00
parent 17b907619c
commit 46d9e3d15b

View File

@ -3,7 +3,7 @@ import pathlib
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from PySide6.QtCore import Qt, QThreadPool, Signal, QAbstractTableModel, QSortFilterProxyModel, QItemSelectionModel, QSize from PySide6.QtCore import Qt, QThreadPool, Signal, QAbstractTableModel, QSortFilterProxyModel, QItemSelectionModel, QSize, QObject
from PySide6.QtGui import QImage, QBrush, QColor, QFont from PySide6.QtGui import QImage, QBrush, QColor, QFont
from PySide6.QtWidgets import QWidget, QVBoxLayout, QHBoxLayout, QLabel, QPushButton, QSizePolicy, QComboBox from PySide6.QtWidgets import QWidget, QVBoxLayout, QHBoxLayout, QLabel, QPushButton, QSizePolicy, QComboBox
from PySide6.QtWidgets import QSpinBox, QSpacerItem, QProgressBar, QTableView, QSplitter, QGridLayout from PySide6.QtWidgets import QSpinBox, QSpacerItem, QProgressBar, QTableView, QSplitter, QGridLayout
@ -134,21 +134,38 @@ class SelectionControls(QWidget):
assignOtherBtn.setShortcut("Ctrl+0") assignOtherBtn.setShortcut("Ctrl+0")
assignOtherBtn.clicked.connect(self.on_TrackOther) assignOtherBtn.clicked.connect(self.on_TrackOther)
self.tone_selection = QLabel("0")
self.ttwo_selection = QLabel("0")
self.tother_selection = QLabel("0")
self._total = 0
grid = QGridLayout() grid = QGridLayout()
grid.addWidget(previousBtn, 0, 0, 4, 2) grid.addWidget(previousBtn, 0, 0, 4, 2)
grid.addWidget(nextBtn, 0, 6, 4, 2)
grid.addWidget(QLabel("Current selection:"), 0, 2, 1, 4) grid.addWidget(QLabel("Current selection:"), 0, 2, 1, 4)
grid.addWidget(QLabel("Track One:"), 1, 2, 1, 4) grid.addWidget(QLabel("Track One:"), 1, 2, 1, 3)
grid.addWidget(QLabel("Track Two:"), 2, 2, 1, 4) grid.addWidget(self.tone_selection, 1, 5, 1, 1)
grid.addWidget(QLabel("Unassigned:"), 3, 2, 1, 4) grid.addWidget(QLabel("Track Two:"), 2, 2, 1, 3)
grid.addWidget(nextBtn, 0, 5, 4, 2) grid.addWidget(self.ttwo_selection, 2, 5, 1, 1)
grid.addWidget(QLabel("Unassigned:"), 3, 2, 1, 3)
grid.addWidget(self.tother_selection, 3, 5, 1, 1)
grid.addWidget(assignOneBtn, 4, 0, 4, 3) grid.addWidget(assignOneBtn, 4, 0, 4, 3)
grid.addWidget(assignOtherBtn, 4, 3, 4, 2) grid.addWidget(assignOtherBtn, 4, 3, 4, 2)
grid.addWidget(assignTwoBtn, 4, 5, 4, 3) grid.addWidget(assignTwoBtn, 4, 5, 4, 3)
grid.setColumnStretch(0, 1)
grid.setColumnStretch(7, 1)
self.setLayout(grid) self.setLayout(grid)
self.setMaximumSize(QSize(400, 200)) self.setMaximumSize(QSize(400, 200))
def _updateNumbers(self, track):
labels = {1: self.tone_selection, 2: self.ttwo_selection, 3: self.tother_selection}
for k in labels:
if k == track:
labels[k].setText(str(self._total))
else:
labels[k].setText("0")
def on_Next(self): def on_Next(self):
self.next.emit() self.next.emit()
@ -157,19 +174,93 @@ class SelectionControls(QWidget):
def on_TrackOne(self): def on_TrackOne(self):
self.assignOne.emit() self.assignOne.emit()
self._updateNumbers(1)
def on_TrackTwo(self): def on_TrackTwo(self):
self.assignTwo.emit() self.assignTwo.emit()
self._updateNumbers(2)
def on_TrackOther(self): def on_TrackOther(self):
self.assignOther.emit() self.assignOther.emit()
self._updateNumbers(3)
def setSelectedTracks(self, tracks):
logging.debug("SelectionControl: setSelectedTracks")
tone = np.sum(tracks == 1)
ttwo = np.sum(tracks == 2)
self.tone_selection.setText(str(tone))
self.ttwo_selection.setText(str(ttwo))
self.tother_selection.setText(str(len(tracks) - tone - ttwo))
self._total = len(tracks)
class DataController(QObject):
def __init__(self, parent=None):
super().__init__(parent)
self._data = None
self._columns = []
self._start = 0
self._stop = 0
self._indices = None
self._selection_column = None
self._user_selections = None
def setData(self, datadict):
assert isinstance(datadict, dict)
self._data = datadict
self._columns = [k for k in self._data.keys()]
@property
def data(self):
return self._data
@property
def columns(self):
return self._columns
def max(self, col):
if col in self.columns:
return np.max(self._data[col])
else:
logging.error("Column %s not in dictionary", col)
return np.nan
def setSelection(self, selection): @property
def selectionRange(self):
return self._start, self._stop
@property
def selectionRangeColumn(self):
return self._selection_column
@property
def selectionIndices(self):
return self._indices
def setSelectionRange(self, col, start, stop):
self._start = start
self._stop = stop
self._selection_column = col
self._indices = np.where((self._data[col] >= self._start) & (self._data[col] < self._stop))[0]
def selectedData(self, col):
return self._data[col][self._indices]
def setUserSelection(self, ids):
self._user_selections = ids.astype(int)
def assignUserSelection(self, track_id):
self._data["track"][self._user_selections] = track_id
def save(self, filename):
pass pass
class FixTracks(QWidget): class FixTracks(QWidget):
back = Signal() back = Signal()
trackone_id = 1
tracktwo_id = 2
trackother_id = -1
def __init__(self, parent=None): def __init__(self, parent=None):
super().__init__(parent) super().__init__(parent)
@ -177,7 +268,7 @@ class FixTracks(QWidget):
self._threadpool = QThreadPool() self._threadpool = QThreadPool()
self._reader = None self._reader = None
self._image = None self._image = None
self._data = None self._data = DataController()
self._unassignedmodel = None self._unassignedmodel = None
self._leftmodel = None self._leftmodel = None
self._rightmodel = None self._rightmodel = None
@ -207,7 +298,10 @@ class FixTracks(QWidget):
timelinebox.addWidget(QLabel("Window")) timelinebox.addWidget(QLabel("Window"))
timelinebox.addWidget(self._windowspinner) timelinebox.addWidget(self._windowspinner)
# self._controls_widget = SelectionControls() self._controls_widget = SelectionControls()
self._controls_widget.assignOne.connect(self.on_assignOne)
self._controls_widget.assignTwo.connect(self.on_assignTwo)
self._controls_widget.assignOther.connect(self.on_assignOther)
self._trackone_table = QTableView() self._trackone_table = QTableView()
font = QFont() font = QFont()
@ -215,9 +309,9 @@ class FixTracks(QWidget):
font.setPointSize(8) font.setPointSize(8)
self._trackone_table.setFont(font) self._trackone_table.setFont(font)
assign1 = QPushButton("<<") assign1 = QPushButton("<<")
assign1.clicked.connect(self.on_assignLeft) assign1.clicked.connect(self.on_assignOne)
assign2 = QPushButton(">>") assign2 = QPushButton(">>")
assign2.clicked.connect(self.on_assignRight) assign2.clicked.connect(self.on_assignTwo)
self._unassigned_table = QTableView() self._unassigned_table = QTableView()
self._unassigned_table.setFont(font) self._unassigned_table.setFont(font)
self._unassigned_table.setSelectionMode(QTableView.SelectionMode.ExtendedSelection) self._unassigned_table.setSelectionMode(QTableView.SelectionMode.ExtendedSelection)
@ -243,12 +337,12 @@ class FixTracks(QWidget):
trackother_box.addWidget(trackother_label) trackother_box.addWidget(trackother_label)
trackother_box.addWidget(self._unassigned_table) trackother_box.addWidget(self._unassigned_table)
tablebox = QHBoxLayout() # tablebox = QHBoxLayout()
tablebox.addLayout(track1_box) # tablebox.addLayout(track1_box)
tablebox.addWidget(assign1) # tablebox.addWidget(assign1)
tablebox.addLayout(trackother_box) # tablebox.addLayout(trackother_box)
tablebox.addWidget(assign2) # tablebox.addWidget(assign2)
tablebox.addLayout(tracktwo_box) # tablebox.addLayout(tracktwo_box)
self._saveBtn = QPushButton("Save") self._saveBtn = QPushButton("Save")
self._saveBtn.setEnabled(False) self._saveBtn.setEnabled(False)
@ -281,7 +375,8 @@ class FixTracks(QWidget):
vbox = QVBoxLayout() vbox = QVBoxLayout()
vbox.addLayout(timelinebox) vbox.addLayout(timelinebox)
vbox.addLayout(tablebox) # vbox.addLayout(tablebox)
vbox.addWidget(self._controls_widget, stretch=1, alignment=Qt.AlignmentFlag.AlignCenter)
vbox.addLayout(btnBox) vbox.addLayout(btnBox)
container = QWidget() container = QWidget()
container.setLayout(vbox) container.setLayout(vbox)
@ -296,7 +391,6 @@ class FixTracks(QWidget):
layout.addWidget(splitter) layout.addWidget(splitter)
self.setLayout(layout) self.setLayout(layout)
def on_dataSelection(self): def on_dataSelection(self):
filename = self._data_combo.currentText() filename = self._data_combo.currentText()
if "please select" in filename.lower(): if "please select" in filename.lower():
@ -313,7 +407,7 @@ class FixTracks(QWidget):
img = QImage(filename) img = QImage(filename)
self._detectionView.setImage(img) self._detectionView.setImage(img)
def populateTables(self): def update(self):
def update_detectionView(df, name): def update_detectionView(df, name):
if len(df) == 0: if len(df) == 0:
return return
@ -322,36 +416,25 @@ class FixTracks(QWidget):
ids = df.index.values.astype(int) ids = df.index.values.astype(int)
self._detectionView.addDetections(coords, tracks, ids, self._brushes[name]) self._detectionView.addDetections(coords, tracks, ids, self._brushes[name])
trackone_id = 1 max_frames = self._data.max("frame")
tracktwo_id = 2
max_frames = np.max(self._data["frame"])
start = self._timeline.rangeStart start = self._timeline.rangeStart
stop = self._timeline.rangeStop stop = self._timeline.rangeStop
print(start, stop, max_frames)
start_frame = int(np.floor(start * max_frames)) start_frame = int(np.floor(start * max_frames))
stop_frame = int(np.ceil(stop * max_frames)) stop_frame = int(np.ceil(stop * max_frames))
logging.debug("Updating TableModel for range %i, %i", start_frame, stop_frame) logging.debug("Updating View for detection range %i, %i frames", start_frame, stop_frame)
indices = np.where((self._data["frame"] >= start_frame) & (self._data["frame"] < stop_frame))[0] self._data.setSelectionRange("frame", start_frame, stop_frame)
# from IPython import embed frames = self._data.selectedData("frame")
# embed() tracks = self._data.selectedData("track")
# exit() keypoints = self._data.selectedData("keypoints")
df = pd.DataFrame({"frame": self._data["index"][indices], index = self._data.selectedData("index")
"track": self._data["track"][indices],
"keypoints": self._data["keypoints"][indices]}, df = pd.DataFrame({"frame": frames,
index= self._data["index"][indices]) "track": tracks,
assigned_left = df[(df.track == trackone_id)] "keypoints": keypoints},
assigned_right = df[(df.track == tracktwo_id)] index=index)
unassigned = df[(df.track != trackone_id) & (df.track != tracktwo_id)] assigned_left = df[(df.track == self.trackone_id)]
logging.debug("Updating TableModel: %i track one, %i unassigned, %i tracktwo", len(assigned_left), assigned_right = df[(df.track == self.tracktwo_id)]
len(unassigned), len(assigned_right)) unassigned = df[(df.track != self.trackone_id) & (df.track != self.tracktwo_id)]
self._unassignedmodel = PoseTableModel(unassigned)
self._unassigned_table.setModel(self._unassignedmodel)
self._leftmodel = PoseTableModel(assigned_left)
self._trackone_table.setModel(self._leftmodel)
self._rightmodel = PoseTableModel(assigned_right)
self._tracktwo_table.setModel(self._rightmodel)
self._detectionView.clearDetections() self._detectionView.clearDetections()
update_detectionView(unassigned, "unassigned") update_detectionView(unassigned, "unassigned")
@ -380,18 +463,15 @@ class FixTracks(QWidget):
self._data_combo.addItems(self.fileList) self._data_combo.addItems(self.fileList)
self._data_combo.setCurrentIndex(0) self._data_combo.setCurrentIndex(0)
# self._data_combo.currentIndexChanged.connect(self.on_dataSelection)
# self._image_combo.currentIndexChanged.connect(self.on_imageSelection)
def _on_dataOpenend(self, state): def _on_dataOpenend(self, state):
logging.info("Finished loading data with state %s", state) logging.info("Finished loading data with state %s", state)
self._tasklabel.setText("") self._tasklabel.setText("")
self._progress_bar.setRange(0, 100) self._progress_bar.setRange(0, 100)
self._progress_bar.setValue(0) self._progress_bar.setValue(0)
if state and self._reader is not None: if state and self._reader is not None:
self._data = self._reader.asdict self._data.setData(self._reader.asdict)
self._timeline.setDetectionData(self._data) self._timeline.setDetectionData(self._data.data)
self.populateTables() self.update()
def on_save(self): def on_save(self):
logging.debug("Save fixtracks results") logging.debug("Save fixtracks results")
@ -400,50 +480,41 @@ class FixTracks(QWidget):
logging.debug("Back button pressed!") logging.debug("Back button pressed!")
self.back.emit() self.back.emit()
def assignTrack(self, rows, trackid): def on_assignOne(self):
logging.debug("Assign %i detections to Track One", len(rows)) logging.debug("Assigning user selection to track One")
ids = np.zeros(len(rows), dtype=int) self._data.assignUserSelection(self.trackone_id)
for i,r in enumerate(rows): self._timeline.setDetectionData(self._data.data)
ids[i] = self._unassignedmodel.headerData(r.row(), Qt.Orientation.Vertical, Qt.ItemDataRole.DisplayRole) self.update()
self._data["track"][ids] = np.zeros_like(ids, dtype=int) + trackid def on_assignTwo(self):
self.populateTables() logging.debug("Assigning user selection to track Two")
self._timeline.setDetectionData(self._data) self._data.assignUserSelection(self.tracktwo_id)
self._timeline.setDetectionData(self._data.data)
self.update()
def on_assignLeft(self): def on_assignOther(self):
selection = self._unassigned_table.selectionModel() logging.debug("Assigning user selection to track Other")
rows = selection.selectedRows() self._data.assignUserSelection(self.trackother_id)
logging.debug("Assign %i detections to Track One", len(rows)) self._timeline.setDetectionData(self._data.data)
self.assignTrack(rows, 1) self.update()
def on_assignRight(self):
selection = self._unassigned_table.selectionModel()
rows = selection.selectedRows()
logging.debug("Assign %i detections to Track Two", len(rows))
self.assignTrack(rows, 2)
def on_windowChanged(self, start, stop): def on_windowChanged(self, start, stop):
logging.info("Timeline reports window change to range %f %f percent of data", start, stop) logging.info("Timeline reports window change to range %f %f percent of data", start, stop)
self.populateTables() self.update()
def on_windowSizeChanged(self, value): def on_windowSizeChanged(self, value):
self._timeline.setWindowWidth(value) self._timeline.setWindowWidth(value)
def on_detectionsSelected(self, detections): def on_detectionsSelected(self, detections):
logging.debug("Tracks: Detections selected") logging.debug("Tracks: Detections selected")
selection_model = self._unassigned_table.selectionModel() tracks = np.zeros(len(detections))
selection = selection_model.selection() ids = np.zeros_like(tracks)
logging.debug("Tracks: Get selection model") for i, d in enumerate(detections):
for d in detections: tracks[i] = d.data(0)
id = d.data(1) ids[i] = d.data(1)
row = self._unassignedmodel.mapIdToRow(id) self._data.setUserSelection(ids)
if row == -1: self._controls_widget.setSelectedTracks(tracks)
continue self.update()
index = self._unassigned_table.model().index(row, 0)
selection.select(index, index)
logging.debug("Tracks: Detections selected")
mode = QItemSelectionModel.Select | QItemSelectionModel.Rows
selection_model.select(selection, mode)
def main(): def main():