[tracks] no more tables, much more speed

This commit is contained in:
Jan Grewe 2025-01-28 14:50:24 +01:00
parent 17b907619c
commit 46d9e3d15b

View File

@ -3,7 +3,7 @@ import pathlib
import numpy as np
import pandas as pd
from PySide6.QtCore import Qt, QThreadPool, Signal, QAbstractTableModel, QSortFilterProxyModel, QItemSelectionModel, QSize
from PySide6.QtCore import Qt, QThreadPool, Signal, QAbstractTableModel, QSortFilterProxyModel, QItemSelectionModel, QSize, QObject
from PySide6.QtGui import QImage, QBrush, QColor, QFont
from PySide6.QtWidgets import QWidget, QVBoxLayout, QHBoxLayout, QLabel, QPushButton, QSizePolicy, QComboBox
from PySide6.QtWidgets import QSpinBox, QSpacerItem, QProgressBar, QTableView, QSplitter, QGridLayout
@ -134,21 +134,38 @@ class SelectionControls(QWidget):
assignOtherBtn.setShortcut("Ctrl+0")
assignOtherBtn.clicked.connect(self.on_TrackOther)
self.tone_selection = QLabel("0")
self.ttwo_selection = QLabel("0")
self.tother_selection = QLabel("0")
self._total = 0
grid = QGridLayout()
grid.addWidget(previousBtn, 0, 0, 4, 2)
grid.addWidget(nextBtn, 0, 6, 4, 2)
grid.addWidget(QLabel("Current selection:"), 0, 2, 1, 4)
grid.addWidget(QLabel("Track One:"), 1, 2, 1, 4)
grid.addWidget(QLabel("Track Two:"), 2, 2, 1, 4)
grid.addWidget(QLabel("Unassigned:"), 3, 2, 1, 4)
grid.addWidget(nextBtn, 0, 5, 4, 2)
grid.addWidget(QLabel("Track One:"), 1, 2, 1, 3)
grid.addWidget(self.tone_selection, 1, 5, 1, 1)
grid.addWidget(QLabel("Track Two:"), 2, 2, 1, 3)
grid.addWidget(self.ttwo_selection, 2, 5, 1, 1)
grid.addWidget(QLabel("Unassigned:"), 3, 2, 1, 3)
grid.addWidget(self.tother_selection, 3, 5, 1, 1)
grid.addWidget(assignOneBtn, 4, 0, 4, 3)
grid.addWidget(assignOtherBtn, 4, 3, 4, 2)
grid.addWidget(assignTwoBtn, 4, 5, 4, 3)
grid.setColumnStretch(0, 1)
grid.setColumnStretch(7, 1)
self.setLayout(grid)
self.setMaximumSize(QSize(400, 200))
def _updateNumbers(self, track):
labels = {1: self.tone_selection, 2: self.ttwo_selection, 3: self.tother_selection}
for k in labels:
if k == track:
labels[k].setText(str(self._total))
else:
labels[k].setText("0")
def on_Next(self):
self.next.emit()
@ -157,19 +174,93 @@ class SelectionControls(QWidget):
def on_TrackOne(self):
self.assignOne.emit()
self._updateNumbers(1)
def on_TrackTwo(self):
self.assignTwo.emit()
self._updateNumbers(2)
def on_TrackOther(self):
self.assignOther.emit()
self._updateNumbers(3)
def setSelectedTracks(self, tracks):
logging.debug("SelectionControl: setSelectedTracks")
tone = np.sum(tracks == 1)
ttwo = np.sum(tracks == 2)
self.tone_selection.setText(str(tone))
self.ttwo_selection.setText(str(ttwo))
self.tother_selection.setText(str(len(tracks) - tone - ttwo))
self._total = len(tracks)
class DataController(QObject):
def __init__(self, parent=None):
super().__init__(parent)
self._data = None
self._columns = []
self._start = 0
self._stop = 0
self._indices = None
self._selection_column = None
self._user_selections = None
def setData(self, datadict):
assert isinstance(datadict, dict)
self._data = datadict
self._columns = [k for k in self._data.keys()]
@property
def data(self):
return self._data
@property
def columns(self):
return self._columns
def max(self, col):
if col in self.columns:
return np.max(self._data[col])
else:
logging.error("Column %s not in dictionary", col)
return np.nan
def setSelection(self, selection):
@property
def selectionRange(self):
return self._start, self._stop
@property
def selectionRangeColumn(self):
return self._selection_column
@property
def selectionIndices(self):
return self._indices
def setSelectionRange(self, col, start, stop):
self._start = start
self._stop = stop
self._selection_column = col
self._indices = np.where((self._data[col] >= self._start) & (self._data[col] < self._stop))[0]
def selectedData(self, col):
return self._data[col][self._indices]
def setUserSelection(self, ids):
self._user_selections = ids.astype(int)
def assignUserSelection(self, track_id):
self._data["track"][self._user_selections] = track_id
def save(self, filename):
pass
class FixTracks(QWidget):
back = Signal()
trackone_id = 1
tracktwo_id = 2
trackother_id = -1
def __init__(self, parent=None):
super().__init__(parent)
@ -177,7 +268,7 @@ class FixTracks(QWidget):
self._threadpool = QThreadPool()
self._reader = None
self._image = None
self._data = None
self._data = DataController()
self._unassignedmodel = None
self._leftmodel = None
self._rightmodel = None
@ -207,7 +298,10 @@ class FixTracks(QWidget):
timelinebox.addWidget(QLabel("Window"))
timelinebox.addWidget(self._windowspinner)
# self._controls_widget = SelectionControls()
self._controls_widget = SelectionControls()
self._controls_widget.assignOne.connect(self.on_assignOne)
self._controls_widget.assignTwo.connect(self.on_assignTwo)
self._controls_widget.assignOther.connect(self.on_assignOther)
self._trackone_table = QTableView()
font = QFont()
@ -215,9 +309,9 @@ class FixTracks(QWidget):
font.setPointSize(8)
self._trackone_table.setFont(font)
assign1 = QPushButton("<<")
assign1.clicked.connect(self.on_assignLeft)
assign1.clicked.connect(self.on_assignOne)
assign2 = QPushButton(">>")
assign2.clicked.connect(self.on_assignRight)
assign2.clicked.connect(self.on_assignTwo)
self._unassigned_table = QTableView()
self._unassigned_table.setFont(font)
self._unassigned_table.setSelectionMode(QTableView.SelectionMode.ExtendedSelection)
@ -243,12 +337,12 @@ class FixTracks(QWidget):
trackother_box.addWidget(trackother_label)
trackother_box.addWidget(self._unassigned_table)
tablebox = QHBoxLayout()
tablebox.addLayout(track1_box)
tablebox.addWidget(assign1)
tablebox.addLayout(trackother_box)
tablebox.addWidget(assign2)
tablebox.addLayout(tracktwo_box)
# tablebox = QHBoxLayout()
# tablebox.addLayout(track1_box)
# tablebox.addWidget(assign1)
# tablebox.addLayout(trackother_box)
# tablebox.addWidget(assign2)
# tablebox.addLayout(tracktwo_box)
self._saveBtn = QPushButton("Save")
self._saveBtn.setEnabled(False)
@ -281,7 +375,8 @@ class FixTracks(QWidget):
vbox = QVBoxLayout()
vbox.addLayout(timelinebox)
vbox.addLayout(tablebox)
# vbox.addLayout(tablebox)
vbox.addWidget(self._controls_widget, stretch=1, alignment=Qt.AlignmentFlag.AlignCenter)
vbox.addLayout(btnBox)
container = QWidget()
container.setLayout(vbox)
@ -296,7 +391,6 @@ class FixTracks(QWidget):
layout.addWidget(splitter)
self.setLayout(layout)
def on_dataSelection(self):
filename = self._data_combo.currentText()
if "please select" in filename.lower():
@ -313,7 +407,7 @@ class FixTracks(QWidget):
img = QImage(filename)
self._detectionView.setImage(img)
def populateTables(self):
def update(self):
def update_detectionView(df, name):
if len(df) == 0:
return
@ -322,36 +416,25 @@ class FixTracks(QWidget):
ids = df.index.values.astype(int)
self._detectionView.addDetections(coords, tracks, ids, self._brushes[name])
trackone_id = 1
tracktwo_id = 2
max_frames = np.max(self._data["frame"])
max_frames = self._data.max("frame")
start = self._timeline.rangeStart
stop = self._timeline.rangeStop
print(start, stop, max_frames)
start_frame = int(np.floor(start * max_frames))
stop_frame = int(np.ceil(stop * max_frames))
logging.debug("Updating TableModel for range %i, %i", start_frame, stop_frame)
indices = np.where((self._data["frame"] >= start_frame) & (self._data["frame"] < stop_frame))[0]
# from IPython import embed
# embed()
# exit()
df = pd.DataFrame({"frame": self._data["index"][indices],
"track": self._data["track"][indices],
"keypoints": self._data["keypoints"][indices]},
index= self._data["index"][indices])
assigned_left = df[(df.track == trackone_id)]
assigned_right = df[(df.track == tracktwo_id)]
unassigned = df[(df.track != trackone_id) & (df.track != tracktwo_id)]
logging.debug("Updating TableModel: %i track one, %i unassigned, %i tracktwo", len(assigned_left),
len(unassigned), len(assigned_right))
self._unassignedmodel = PoseTableModel(unassigned)
self._unassigned_table.setModel(self._unassignedmodel)
self._leftmodel = PoseTableModel(assigned_left)
self._trackone_table.setModel(self._leftmodel)
self._rightmodel = PoseTableModel(assigned_right)
self._tracktwo_table.setModel(self._rightmodel)
logging.debug("Updating View for detection range %i, %i frames", start_frame, stop_frame)
self._data.setSelectionRange("frame", start_frame, stop_frame)
frames = self._data.selectedData("frame")
tracks = self._data.selectedData("track")
keypoints = self._data.selectedData("keypoints")
index = self._data.selectedData("index")
df = pd.DataFrame({"frame": frames,
"track": tracks,
"keypoints": keypoints},
index=index)
assigned_left = df[(df.track == self.trackone_id)]
assigned_right = df[(df.track == self.tracktwo_id)]
unassigned = df[(df.track != self.trackone_id) & (df.track != self.tracktwo_id)]
self._detectionView.clearDetections()
update_detectionView(unassigned, "unassigned")
@ -380,18 +463,15 @@ class FixTracks(QWidget):
self._data_combo.addItems(self.fileList)
self._data_combo.setCurrentIndex(0)
# self._data_combo.currentIndexChanged.connect(self.on_dataSelection)
# self._image_combo.currentIndexChanged.connect(self.on_imageSelection)
def _on_dataOpenend(self, state):
logging.info("Finished loading data with state %s", state)
self._tasklabel.setText("")
self._progress_bar.setRange(0, 100)
self._progress_bar.setValue(0)
if state and self._reader is not None:
self._data = self._reader.asdict
self._timeline.setDetectionData(self._data)
self.populateTables()
self._data.setData(self._reader.asdict)
self._timeline.setDetectionData(self._data.data)
self.update()
def on_save(self):
logging.debug("Save fixtracks results")
@ -400,50 +480,41 @@ class FixTracks(QWidget):
logging.debug("Back button pressed!")
self.back.emit()
def assignTrack(self, rows, trackid):
logging.debug("Assign %i detections to Track One", len(rows))
ids = np.zeros(len(rows), dtype=int)
for i,r in enumerate(rows):
ids[i] = self._unassignedmodel.headerData(r.row(), Qt.Orientation.Vertical, Qt.ItemDataRole.DisplayRole)
def on_assignOne(self):
logging.debug("Assigning user selection to track One")
self._data.assignUserSelection(self.trackone_id)
self._timeline.setDetectionData(self._data.data)
self.update()
self._data["track"][ids] = np.zeros_like(ids, dtype=int) + trackid
self.populateTables()
self._timeline.setDetectionData(self._data)
def on_assignTwo(self):
logging.debug("Assigning user selection to track Two")
self._data.assignUserSelection(self.tracktwo_id)
self._timeline.setDetectionData(self._data.data)
self.update()
def on_assignLeft(self):
selection = self._unassigned_table.selectionModel()
rows = selection.selectedRows()
logging.debug("Assign %i detections to Track One", len(rows))
self.assignTrack(rows, 1)
def on_assignRight(self):
selection = self._unassigned_table.selectionModel()
rows = selection.selectedRows()
logging.debug("Assign %i detections to Track Two", len(rows))
self.assignTrack(rows, 2)
def on_assignOther(self):
logging.debug("Assigning user selection to track Other")
self._data.assignUserSelection(self.trackother_id)
self._timeline.setDetectionData(self._data.data)
self.update()
def on_windowChanged(self, start, stop):
logging.info("Timeline reports window change to range %f %f percent of data", start, stop)
self.populateTables()
self.update()
def on_windowSizeChanged(self, value):
self._timeline.setWindowWidth(value)
def on_detectionsSelected(self, detections):
logging.debug("Tracks: Detections selected")
selection_model = self._unassigned_table.selectionModel()
selection = selection_model.selection()
logging.debug("Tracks: Get selection model")
for d in detections:
id = d.data(1)
row = self._unassignedmodel.mapIdToRow(id)
if row == -1:
continue
index = self._unassigned_table.model().index(row, 0)
selection.select(index, index)
logging.debug("Tracks: Detections selected")
mode = QItemSelectionModel.Select | QItemSelectionModel.Rows
selection_model.select(selection, mode)
tracks = np.zeros(len(detections))
ids = np.zeros_like(tracks)
for i, d in enumerate(detections):
tracks[i] = d.data(0)
ids[i] = d.data(1)
self._data.setUserSelection(ids)
self._controls_widget.setSelectedTracks(tracks)
self.update()
def main():