fixtracks/fixtracks/widgets/tracks.py

560 lines
20 KiB
Python

import logging
import pathlib
import pickle
import numpy as np
import pandas as pd
from PySide6.QtCore import Qt, QThreadPool, Signal, QAbstractTableModel, QSortFilterProxyModel, QSize, QObject
from PySide6.QtGui import QImage, QBrush, QColor, QFont
from PySide6.QtWidgets import QWidget, QVBoxLayout, QHBoxLayout, QLabel, QPushButton, QSizePolicy, QComboBox
from PySide6.QtWidgets import QSpinBox, QSpacerItem, QProgressBar, QTableView, QSplitter, QGridLayout, QFileDialog
from fixtracks.utils.reader import PickleLoader
from fixtracks.utils.writer import PickleWriter
from fixtracks.widgets.detectionview import DetectionView
from fixtracks.widgets.detectiontimeline import DetectionTimeline
class PoseTableModel(QAbstractTableModel):
column_header = ["frame", "track"]
columns = ["frame", "track"]
def __init__(self, dataframe, parent=None):
super().__init__(parent)
self._data = dataframe
self._frames = self._data.frame.values
self._tracks = self._data.track.values
self._indices = self._data.index.values
self._column_data = [self._frames, self._tracks]
def columnCount(self, parent=None):
return len(self.columns)
def rowCount(self, parent=None):
if self._data is not None:
return len(self._data)
else:
return 0
def data(self, index, role = ...):
value = self._column_data[index.column()][index.row()]
if role == Qt.ItemDataRole.DisplayRole:
return str(value)
elif role == Qt.ItemDataRole.UserRole:
return value
return None
def headerData(self, section, orientation, role = ...):
if role == Qt.ItemDataRole.DisplayRole:
if orientation == Qt.Orientation.Horizontal:
return self.column_header[section]
else:
return str(self._indices[section])
else:
return None
def mapIdToRow(self, id):
row = np.where(self._indices == id)[0]
if len(row) == 0:
return -1
return row[0]
class FilterProxyModel(QSortFilterProxyModel):
def __init__(self, parent=None):
super().__init__(parent)
self._range = None
def setFilterRange(self, start, stop):
logging.info("FilterProxyModel.setFilterRange set to range %i , %i", start, stop)
self._range = (start, stop)
self.invalidateRowsFilter()
def all(self):
self._range = None
def filterAcceptsRow(self, source_row, source_parent):
if self._range is None:
return True
else:
idx = self.sourceModel().index(source_row, 0, source_parent);
val = self.sourceModel().data(idx, Qt.ItemDataRole.UserRole)
print("filteracceptrows: ", val, self._range, val >= self._range[0] and val < self._range[1] )
return val >= self._range[0] and val < self._range[1]
def filterAcceptsColumn(self, source_column, source_parent):
return True
class SelectionControls(QWidget):
next = Signal()
previous = Signal()
assignOne = Signal()
assignTwo = Signal()
assignOther = Signal()
def __init__(self, parent = None,):
super().__init__(parent)
font = QFont()
font.setBold(True)
font.setPointSize(10)
previousBtn = QPushButton("previous")
previousBtn.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding)
previousBtn.setToolTip("Go back to previous window (right-arrow)")
previousBtn.setEnabled(False)
previousBtn.setShortcut(Qt.Key.Key_Right)
previousBtn.clicked.connect(self.on_Previous)
previousBtn.setFont(font)
nextBtn = QPushButton("next")
nextBtn.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding)
nextBtn.setToolTip("Proceed to next window (left-arrow)")
nextBtn.setEnabled(False)
nextBtn.clicked.connect(self.on_Next)
nextBtn.setShortcut(Qt.Key.Key_Left)
nextBtn.setFont(font)
assignOneBtn = QPushButton("Track One")
assignOneBtn.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding)
assignOneBtn.setStyleSheet("QPushButton { background-color: orange; }")
assignOneBtn.setShortcut("Ctrl+1")
assignOneBtn.setToolTip("Assign current selection to Track One (Ctrl+1)")
assignOneBtn.setFont(font)
assignOneBtn.clicked.connect(self.on_TrackOne)
assignTwoBtn = QPushButton("Track Two")
assignTwoBtn.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding)
assignTwoBtn.setStyleSheet("QPushButton { background-color: green; }")
assignTwoBtn.setToolTip("Assign current selection to Track Two (Ctrl+2)")
assignTwoBtn.setFont(font)
assignTwoBtn.setShortcut("Ctrl+2")
assignTwoBtn.clicked.connect(self.on_TrackTwo)
assignOtherBtn = QPushButton("Other")
assignOtherBtn.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding)
assignOtherBtn.setStyleSheet("QPushButton { background-color: red; }")
assignOtherBtn.setToolTip("Assign current selection to Unassigned (Ctrl+0)")
assignOtherBtn.setFont(font)
assignOtherBtn.setShortcut("Ctrl+0")
assignOtherBtn.clicked.connect(self.on_TrackOther)
self.tone_selection = QLabel("0")
self.ttwo_selection = QLabel("0")
self.tother_selection = QLabel("0")
self._total = 0
grid = QGridLayout()
grid.addWidget(previousBtn, 0, 0, 4, 2)
grid.addWidget(nextBtn, 0, 6, 4, 2)
grid.addWidget(QLabel("Current selection:"), 0, 2, 1, 4)
grid.addWidget(QLabel("Track One:"), 1, 2, 1, 3)
grid.addWidget(self.tone_selection, 1, 5, 1, 1)
grid.addWidget(QLabel("Track Two:"), 2, 2, 1, 3)
grid.addWidget(self.ttwo_selection, 2, 5, 1, 1)
grid.addWidget(QLabel("Unassigned:"), 3, 2, 1, 3)
grid.addWidget(self.tother_selection, 3, 5, 1, 1)
grid.addWidget(assignOneBtn, 4, 0, 4, 3)
grid.addWidget(assignOtherBtn, 4, 3, 4, 2)
grid.addWidget(assignTwoBtn, 4, 5, 4, 3)
grid.setColumnStretch(0, 1)
grid.setColumnStretch(7, 1)
self.setLayout(grid)
self.setMaximumSize(QSize(400, 200))
def _updateNumbers(self, track):
labels = {1: self.tone_selection, 2: self.ttwo_selection, 3: self.tother_selection}
for k in labels:
if k == track:
labels[k].setText(str(self._total))
else:
labels[k].setText("0")
def on_Next(self):
self.next.emit()
def on_Previous(self):
self.previous.emit()
def on_TrackOne(self):
self.assignOne.emit()
self._updateNumbers(1)
def on_TrackTwo(self):
self.assignTwo.emit()
self._updateNumbers(2)
def on_TrackOther(self):
self.assignOther.emit()
self._updateNumbers(3)
def setSelectedTracks(self, tracks):
logging.debug("SelectionControl: setSelectedTracks")
tone = np.sum(tracks == 1)
ttwo = np.sum(tracks == 2)
self.tone_selection.setText(str(tone))
self.ttwo_selection.setText(str(ttwo))
self.tother_selection.setText(str(len(tracks) - tone - ttwo))
self._total = len(tracks)
class DataController(QObject):
def __init__(self, parent=None):
super().__init__(parent)
self._data = None
self._columns = []
self._start = 0
self._stop = 0
self._indices = None
self._selection_column = None
self._user_selections = None
def setData(self, datadict):
assert isinstance(datadict, dict)
self._data = datadict
self._columns = [k for k in self._data.keys()]
@property
def data(self):
return self._data
@property
def columns(self):
return self._columns
def max(self, col):
if col in self.columns:
return np.max(self._data[col])
else:
logging.error("Column %s not in dictionary", col)
return np.nan
@property
def selectionRange(self):
return self._start, self._stop
@property
def selectionRangeColumn(self):
return self._selection_column
@property
def selectionIndices(self):
return self._indices
def setSelectionRange(self, col, start, stop):
self._start = start
self._stop = stop
self._selection_column = col
self._indices = np.where((self._data[col] >= self._start) & (self._data[col] < self._stop))[0]
def selectedData(self, col):
return self._data[col][self._indices]
def setUserSelection(self, ids):
self._user_selections = ids.astype(int)
def assignUserSelection(self, track_id):
self._data["track"][self._user_selections] = track_id
def save(self, filename):
export_columns = self._columns.copy()
export_columns.remove("index")
dictionary = {c: self._data[c] for c in export_columns}
df = pd.DataFrame(dictionary, index=self._data["index"])
with open(filename, 'wb') as f:
pickle.dump(df, f)
class FixTracks(QWidget):
back = Signal()
trackone_id = 1
tracktwo_id = 2
trackother_id = -1
def __init__(self, parent=None):
super().__init__(parent)
self._files = []
self._threadpool = QThreadPool()
self._reader = None
self._image = None
self._data = DataController()
self._unassignedmodel = None
self._leftmodel = None
self._rightmodel = None
self._proxymodel = None
self._brushes = {"assigned_left": QBrush(QColor.fromString("orange")),
"assigned_right": QBrush(QColor.fromString("green")),
"unassigned": QBrush(QColor.fromString("red"))
}
self._detectionView = DetectionView()
self._detectionView.signals.itemsSelected.connect(self.on_detectionsSelected)
self._progress_bar = QProgressBar(self)
self._progress_bar.setMaximumHeight(20)
self._progress_bar.setValue(0)
self._tasklabel = QLabel()
self._timeline = DetectionTimeline()
self._timeline.signals.windowMoved.connect(self.on_windowChanged)
self._windowspinner = QSpinBox()
self._windowspinner.setRange(100, 10000)
self._windowspinner.setSingleStep(100)
self._windowspinner.setValue(500)
self._windowspinner.valueChanged.connect(self.on_windowSizeChanged)
timelinebox = QHBoxLayout()
timelinebox.addWidget(self._timeline)
timelinebox.addWidget(QLabel("Window"))
timelinebox.addWidget(self._windowspinner)
self._controls_widget = SelectionControls()
self._controls_widget.assignOne.connect(self.on_assignOne)
self._controls_widget.assignTwo.connect(self.on_assignTwo)
self._controls_widget.assignOther.connect(self.on_assignOther)
self._trackone_table = QTableView()
font = QFont()
font.setBold(True)
font.setPointSize(8)
self._trackone_table.setFont(font)
assign1 = QPushButton("<<")
assign1.clicked.connect(self.on_assignOne)
assign2 = QPushButton(">>")
assign2.clicked.connect(self.on_assignTwo)
self._unassigned_table = QTableView()
self._unassigned_table.setFont(font)
self._unassigned_table.setSelectionMode(QTableView.SelectionMode.ExtendedSelection)
self._unassigned_table.setSelectionBehavior(QTableView.SelectionBehavior.SelectRows)
self._tracktwo_table = QTableView()
self._tracktwo_table.setFont(font)
trackone_label = QLabel("Track 1")
trackone_label.setStyleSheet("QLabel { color : orange; }")
track1_box = QVBoxLayout()
track1_box.addWidget(trackone_label)
track1_box.addWidget(self._trackone_table)
tracktwo_label = QLabel("Track 2")
tracktwo_label.setStyleSheet("QLabel { color : green; }")
tracktwo_box = QVBoxLayout()
tracktwo_box.addWidget(tracktwo_label)
tracktwo_box.addWidget(self._tracktwo_table)
trackother_label = QLabel("Unassigned")
trackother_label.setStyleSheet("QLabel { color : red; }")
trackother_box = QVBoxLayout()
trackother_box.addWidget(trackother_label)
trackother_box.addWidget(self._unassigned_table)
# tablebox = QHBoxLayout()
# tablebox.addLayout(track1_box)
# tablebox.addWidget(assign1)
# tablebox.addLayout(trackother_box)
# tablebox.addWidget(assign2)
# tablebox.addLayout(tracktwo_box)
self._saveBtn = QPushButton("Save")
self._saveBtn.setShortcut("Ctrl+S")
self._saveBtn.setEnabled(False)
self._saveBtn.clicked.connect(self.on_save)
self._backBtn = QPushButton("Back")
self._backBtn.setShortcut("ESC")
self._backBtn.clicked.connect(self.on_back)
self._data_combo = QComboBox()
self._data_combo.addItems(self._files)
self._data_combo.currentIndexChanged.connect(self.on_dataSelection)
self._image_combo = QComboBox()
self._image_combo.addItems(self._files)
self._image_combo.currentIndexChanged.connect(self.on_imageSelection)
data_selection_box = QHBoxLayout()
data_selection_box.addWidget(QLabel("Select image file"))
data_selection_box.addWidget(self._image_combo)
data_selection_box.addWidget(QLabel("Select data file"))
data_selection_box.addWidget(self._data_combo)
data_selection_box.addItem(QSpacerItem(100, 10, QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Fixed))
btnBox = QHBoxLayout()
btnBox.setAlignment(Qt.AlignmentFlag.AlignLeft)
btnBox.addWidget(self._backBtn)
btnBox.addItem(QSpacerItem(100, 10, QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Fixed))
btnBox.addWidget(self._tasklabel)
btnBox.addWidget(self._progress_bar)
# btnBox.addWidget(self._openBtn)
btnBox.addWidget(self._saveBtn)
vbox = QVBoxLayout()
vbox.addLayout(timelinebox)
# vbox.addLayout(tablebox)
vbox.addWidget(self._controls_widget, stretch=1, alignment=Qt.AlignmentFlag.AlignCenter)
vbox.addLayout(btnBox)
container = QWidget()
container.setLayout(vbox)
splitter = QSplitter(Qt.Orientation.Vertical)
splitter.addWidget(self._detectionView)
splitter.addWidget(container)
splitter.setStretchFactor(0, 3)
splitter.setStretchFactor(1, 1)
layout = QVBoxLayout()
layout.addLayout(data_selection_box)
layout.addWidget(splitter)
self.setLayout(layout)
def on_dataSelection(self):
filename = self._data_combo.currentText()
if "please select" in filename.lower():
return
self._progress_bar.setRange(0,0)
self._reader = PickleLoader(filename)
self._reader.signals.finished.connect(self._on_dataOpenend)
self._threadpool.start(self._reader)
def on_imageSelection(self):
filename = self._image_combo.currentText()
if "please select" in filename.lower():
return
img = QImage(filename)
self._detectionView.setImage(img)
def update(self):
def update_detectionView(df, name):
if len(df) == 0:
return
coords = np.stack(df["keypoints"].values).astype(np.float32)[:,0,:]
tracks = df["track"].values.astype(int)
ids = df.index.values.astype(int)
self._detectionView.addDetections(coords, tracks, ids, self._brushes[name])
max_frames = self._data.max("frame")
start = self._timeline.rangeStart
stop = self._timeline.rangeStop
start_frame = int(np.floor(start * max_frames))
stop_frame = int(np.ceil(stop * max_frames))
logging.debug("Updating View for detection range %i, %i frames", start_frame, stop_frame)
self._data.setSelectionRange("frame", start_frame, stop_frame)
frames = self._data.selectedData("frame")
tracks = self._data.selectedData("track")
keypoints = self._data.selectedData("keypoints")
index = self._data.selectedData("index")
df = pd.DataFrame({"frame": frames,
"track": tracks,
"keypoints": keypoints},
index=index)
assigned_left = df[(df.track == self.trackone_id)]
assigned_right = df[(df.track == self.tracktwo_id)]
unassigned = df[(df.track != self.trackone_id) & (df.track != self.tracktwo_id)]
self._detectionView.clearDetections()
update_detectionView(unassigned, "unassigned")
update_detectionView(assigned_left, "assigned_left")
update_detectionView(assigned_right, "assigned_right")
@property
def fileList(self):
return self._files
@fileList.setter
def fileList(self, file_list):
logging.debug("FixTracks.fileList: set new file list")
logging.debug("FixTracks.fileList: setting image combo box")
img_formats = [".jpg", ".png"]
self._files = [str(f) for f in file_list if f.suffix in img_formats]
self._image_combo.addItem("Please select")
self._image_combo.addItems(self.fileList)
self._image_combo.setCurrentIndex(0)
logging.debug("FixTracks.fileList: setting data combo box")
dataformats = [".pkl"]
self._files = [str(f) for f in file_list if f.suffix in dataformats]
self._data_combo.addItem("Please select")
self._data_combo.addItems(self.fileList)
self._data_combo.setCurrentIndex(0)
def _on_dataOpenend(self, state):
logging.info("Finished loading data with state %s", state)
self._tasklabel.setText("")
self._progress_bar.setRange(0, 100)
self._progress_bar.setValue(0)
if state and self._reader is not None:
self._data.setData(self._reader.asdict)
self._timeline.setDetectionData(self._data.data)
self.update()
self._saveBtn.setEnabled(True)
def on_save(self):
logging.debug("Saving fixtracks results")
self._tasklabel.setText("Saving results to file...")
file_dialog = QFileDialog(self)
file_dialog.setAcceptMode(QFileDialog.AcceptMode.AcceptSave)
file_dialog.setNameFilter("Pickle Files (*.pkl)")
if file_dialog.exec():
file_path = file_dialog.selectedFiles()[0]
if not file_path.endswith(".pkl"):
file_path += ".pkl"
self._progress_bar.setRange(0,0)
save_task = PickleWriter(self._data, file_path)
save_task.signals.finished.connect(self.on_dataSaved)
self._threadpool.start(save_task)
def on_dataSaved(self):
self._tasklabel.setText("")
self._progress_bar.setRange(0, 100)
self._progress_bar.setValue(0)
def on_back(self):
logging.debug("Back button pressed!")
self.back.emit()
def on_assignOne(self):
logging.debug("Assigning user selection to track One")
self._data.assignUserSelection(self.trackone_id)
self._timeline.setDetectionData(self._data.data)
self.update()
def on_assignTwo(self):
logging.debug("Assigning user selection to track Two")
self._data.assignUserSelection(self.tracktwo_id)
self._timeline.setDetectionData(self._data.data)
self.update()
def on_assignOther(self):
logging.debug("Assigning user selection to track Other")
self._data.assignUserSelection(self.trackother_id)
self._timeline.setDetectionData(self._data.data)
self.update()
def on_windowChanged(self, start, stop):
logging.info("Timeline reports window change to range %f %f percent of data", start, stop)
self.update()
def on_windowSizeChanged(self, value):
self._timeline.setWindowWidth(value)
def on_detectionsSelected(self, detections):
logging.debug("Tracks: Detections selected")
tracks = np.zeros(len(detections))
ids = np.zeros_like(tracks)
for i, d in enumerate(detections):
tracks[i] = d.data(0)
ids[i] = d.data(1)
self._data.setUserSelection(ids)
self._controls_widget.setSelectedTracks(tracks)
self.update()
def main():
from PySide6.QtWidgets import QApplication
app = QApplication([])
window = QWidget()
window.setMinimumSize(200, 200)
layout = QVBoxLayout()
controls = SelectionControls()
layout.addWidget(controls)
window.setLayout(layout)
window.show()
app.exec()
if __name__ == "__main__":
main()