560 lines
20 KiB
Python
560 lines
20 KiB
Python
import logging
|
|
import pathlib
|
|
import pickle
|
|
import numpy as np
|
|
import pandas as pd
|
|
|
|
from PySide6.QtCore import Qt, QThreadPool, Signal, QAbstractTableModel, QSortFilterProxyModel, QSize, QObject
|
|
from PySide6.QtGui import QImage, QBrush, QColor, QFont
|
|
from PySide6.QtWidgets import QWidget, QVBoxLayout, QHBoxLayout, QLabel, QPushButton, QSizePolicy, QComboBox
|
|
from PySide6.QtWidgets import QSpinBox, QSpacerItem, QProgressBar, QTableView, QSplitter, QGridLayout, QFileDialog
|
|
|
|
from fixtracks.utils.reader import PickleLoader
|
|
from fixtracks.utils.writer import PickleWriter
|
|
from fixtracks.widgets.detectionview import DetectionView
|
|
from fixtracks.widgets.detectiontimeline import DetectionTimeline
|
|
|
|
class PoseTableModel(QAbstractTableModel):
|
|
column_header = ["frame", "track"]
|
|
columns = ["frame", "track"]
|
|
|
|
def __init__(self, dataframe, parent=None):
|
|
super().__init__(parent)
|
|
self._data = dataframe
|
|
self._frames = self._data.frame.values
|
|
self._tracks = self._data.track.values
|
|
self._indices = self._data.index.values
|
|
self._column_data = [self._frames, self._tracks]
|
|
|
|
def columnCount(self, parent=None):
|
|
return len(self.columns)
|
|
|
|
def rowCount(self, parent=None):
|
|
if self._data is not None:
|
|
return len(self._data)
|
|
else:
|
|
return 0
|
|
|
|
def data(self, index, role = ...):
|
|
value = self._column_data[index.column()][index.row()]
|
|
if role == Qt.ItemDataRole.DisplayRole:
|
|
return str(value)
|
|
elif role == Qt.ItemDataRole.UserRole:
|
|
return value
|
|
return None
|
|
|
|
def headerData(self, section, orientation, role = ...):
|
|
if role == Qt.ItemDataRole.DisplayRole:
|
|
if orientation == Qt.Orientation.Horizontal:
|
|
return self.column_header[section]
|
|
else:
|
|
return str(self._indices[section])
|
|
else:
|
|
return None
|
|
|
|
def mapIdToRow(self, id):
|
|
row = np.where(self._indices == id)[0]
|
|
if len(row) == 0:
|
|
return -1
|
|
return row[0]
|
|
|
|
class FilterProxyModel(QSortFilterProxyModel):
|
|
def __init__(self, parent=None):
|
|
super().__init__(parent)
|
|
self._range = None
|
|
|
|
def setFilterRange(self, start, stop):
|
|
logging.info("FilterProxyModel.setFilterRange set to range %i , %i", start, stop)
|
|
self._range = (start, stop)
|
|
self.invalidateRowsFilter()
|
|
|
|
def all(self):
|
|
self._range = None
|
|
|
|
def filterAcceptsRow(self, source_row, source_parent):
|
|
if self._range is None:
|
|
return True
|
|
else:
|
|
idx = self.sourceModel().index(source_row, 0, source_parent);
|
|
val = self.sourceModel().data(idx, Qt.ItemDataRole.UserRole)
|
|
print("filteracceptrows: ", val, self._range, val >= self._range[0] and val < self._range[1] )
|
|
return val >= self._range[0] and val < self._range[1]
|
|
|
|
def filterAcceptsColumn(self, source_column, source_parent):
|
|
return True
|
|
|
|
|
|
class SelectionControls(QWidget):
|
|
next = Signal()
|
|
previous = Signal()
|
|
assignOne = Signal()
|
|
assignTwo = Signal()
|
|
assignOther = Signal()
|
|
|
|
def __init__(self, parent = None,):
|
|
super().__init__(parent)
|
|
font = QFont()
|
|
font.setBold(True)
|
|
font.setPointSize(10)
|
|
|
|
previousBtn = QPushButton("previous")
|
|
previousBtn.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding)
|
|
previousBtn.setToolTip("Go back to previous window (right-arrow)")
|
|
previousBtn.setEnabled(False)
|
|
previousBtn.setShortcut(Qt.Key.Key_Right)
|
|
previousBtn.clicked.connect(self.on_Previous)
|
|
previousBtn.setFont(font)
|
|
nextBtn = QPushButton("next")
|
|
nextBtn.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding)
|
|
nextBtn.setToolTip("Proceed to next window (left-arrow)")
|
|
nextBtn.setEnabled(False)
|
|
nextBtn.clicked.connect(self.on_Next)
|
|
nextBtn.setShortcut(Qt.Key.Key_Left)
|
|
nextBtn.setFont(font)
|
|
|
|
assignOneBtn = QPushButton("Track One")
|
|
assignOneBtn.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding)
|
|
assignOneBtn.setStyleSheet("QPushButton { background-color: orange; }")
|
|
assignOneBtn.setShortcut("Ctrl+1")
|
|
assignOneBtn.setToolTip("Assign current selection to Track One (Ctrl+1)")
|
|
assignOneBtn.setFont(font)
|
|
assignOneBtn.clicked.connect(self.on_TrackOne)
|
|
|
|
assignTwoBtn = QPushButton("Track Two")
|
|
assignTwoBtn.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding)
|
|
assignTwoBtn.setStyleSheet("QPushButton { background-color: green; }")
|
|
assignTwoBtn.setToolTip("Assign current selection to Track Two (Ctrl+2)")
|
|
assignTwoBtn.setFont(font)
|
|
assignTwoBtn.setShortcut("Ctrl+2")
|
|
assignTwoBtn.clicked.connect(self.on_TrackTwo)
|
|
|
|
assignOtherBtn = QPushButton("Other")
|
|
assignOtherBtn.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding)
|
|
assignOtherBtn.setStyleSheet("QPushButton { background-color: red; }")
|
|
assignOtherBtn.setToolTip("Assign current selection to Unassigned (Ctrl+0)")
|
|
assignOtherBtn.setFont(font)
|
|
assignOtherBtn.setShortcut("Ctrl+0")
|
|
assignOtherBtn.clicked.connect(self.on_TrackOther)
|
|
|
|
self.tone_selection = QLabel("0")
|
|
self.ttwo_selection = QLabel("0")
|
|
self.tother_selection = QLabel("0")
|
|
self._total = 0
|
|
|
|
grid = QGridLayout()
|
|
grid.addWidget(previousBtn, 0, 0, 4, 2)
|
|
grid.addWidget(nextBtn, 0, 6, 4, 2)
|
|
grid.addWidget(QLabel("Current selection:"), 0, 2, 1, 4)
|
|
grid.addWidget(QLabel("Track One:"), 1, 2, 1, 3)
|
|
grid.addWidget(self.tone_selection, 1, 5, 1, 1)
|
|
grid.addWidget(QLabel("Track Two:"), 2, 2, 1, 3)
|
|
grid.addWidget(self.ttwo_selection, 2, 5, 1, 1)
|
|
grid.addWidget(QLabel("Unassigned:"), 3, 2, 1, 3)
|
|
grid.addWidget(self.tother_selection, 3, 5, 1, 1)
|
|
|
|
grid.addWidget(assignOneBtn, 4, 0, 4, 3)
|
|
grid.addWidget(assignOtherBtn, 4, 3, 4, 2)
|
|
grid.addWidget(assignTwoBtn, 4, 5, 4, 3)
|
|
grid.setColumnStretch(0, 1)
|
|
grid.setColumnStretch(7, 1)
|
|
self.setLayout(grid)
|
|
self.setMaximumSize(QSize(400, 200))
|
|
|
|
def _updateNumbers(self, track):
|
|
labels = {1: self.tone_selection, 2: self.ttwo_selection, 3: self.tother_selection}
|
|
for k in labels:
|
|
if k == track:
|
|
labels[k].setText(str(self._total))
|
|
else:
|
|
labels[k].setText("0")
|
|
|
|
def on_Next(self):
|
|
self.next.emit()
|
|
|
|
def on_Previous(self):
|
|
self.previous.emit()
|
|
|
|
def on_TrackOne(self):
|
|
self.assignOne.emit()
|
|
self._updateNumbers(1)
|
|
|
|
def on_TrackTwo(self):
|
|
self.assignTwo.emit()
|
|
self._updateNumbers(2)
|
|
|
|
def on_TrackOther(self):
|
|
self.assignOther.emit()
|
|
self._updateNumbers(3)
|
|
|
|
def setSelectedTracks(self, tracks):
|
|
logging.debug("SelectionControl: setSelectedTracks")
|
|
tone = np.sum(tracks == 1)
|
|
ttwo = np.sum(tracks == 2)
|
|
self.tone_selection.setText(str(tone))
|
|
self.ttwo_selection.setText(str(ttwo))
|
|
self.tother_selection.setText(str(len(tracks) - tone - ttwo))
|
|
self._total = len(tracks)
|
|
|
|
|
|
class DataController(QObject):
|
|
def __init__(self, parent=None):
|
|
super().__init__(parent)
|
|
self._data = None
|
|
self._columns = []
|
|
self._start = 0
|
|
self._stop = 0
|
|
self._indices = None
|
|
self._selection_column = None
|
|
self._user_selections = None
|
|
|
|
def setData(self, datadict):
|
|
assert isinstance(datadict, dict)
|
|
self._data = datadict
|
|
self._columns = [k for k in self._data.keys()]
|
|
|
|
@property
|
|
def data(self):
|
|
return self._data
|
|
|
|
@property
|
|
def columns(self):
|
|
return self._columns
|
|
|
|
def max(self, col):
|
|
if col in self.columns:
|
|
return np.max(self._data[col])
|
|
else:
|
|
logging.error("Column %s not in dictionary", col)
|
|
return np.nan
|
|
|
|
@property
|
|
def selectionRange(self):
|
|
return self._start, self._stop
|
|
|
|
@property
|
|
def selectionRangeColumn(self):
|
|
return self._selection_column
|
|
|
|
@property
|
|
def selectionIndices(self):
|
|
return self._indices
|
|
|
|
def setSelectionRange(self, col, start, stop):
|
|
self._start = start
|
|
self._stop = stop
|
|
self._selection_column = col
|
|
self._indices = np.where((self._data[col] >= self._start) & (self._data[col] < self._stop))[0]
|
|
|
|
def selectedData(self, col):
|
|
return self._data[col][self._indices]
|
|
|
|
def setUserSelection(self, ids):
|
|
self._user_selections = ids.astype(int)
|
|
|
|
def assignUserSelection(self, track_id):
|
|
self._data["track"][self._user_selections] = track_id
|
|
|
|
def save(self, filename):
|
|
export_columns = self._columns.copy()
|
|
export_columns.remove("index")
|
|
dictionary = {c: self._data[c] for c in export_columns}
|
|
df = pd.DataFrame(dictionary, index=self._data["index"])
|
|
with open(filename, 'wb') as f:
|
|
pickle.dump(df, f)
|
|
|
|
|
|
class FixTracks(QWidget):
|
|
back = Signal()
|
|
trackone_id = 1
|
|
tracktwo_id = 2
|
|
trackother_id = -1
|
|
|
|
def __init__(self, parent=None):
|
|
super().__init__(parent)
|
|
self._files = []
|
|
self._threadpool = QThreadPool()
|
|
self._reader = None
|
|
self._image = None
|
|
self._data = DataController()
|
|
self._unassignedmodel = None
|
|
self._leftmodel = None
|
|
self._rightmodel = None
|
|
self._proxymodel = None
|
|
self._brushes = {"assigned_left": QBrush(QColor.fromString("orange")),
|
|
"assigned_right": QBrush(QColor.fromString("green")),
|
|
"unassigned": QBrush(QColor.fromString("red"))
|
|
}
|
|
self._detectionView = DetectionView()
|
|
self._detectionView.signals.itemsSelected.connect(self.on_detectionsSelected)
|
|
self._progress_bar = QProgressBar(self)
|
|
self._progress_bar.setMaximumHeight(20)
|
|
self._progress_bar.setValue(0)
|
|
self._tasklabel = QLabel()
|
|
|
|
self._timeline = DetectionTimeline()
|
|
self._timeline.signals.windowMoved.connect(self.on_windowChanged)
|
|
|
|
self._windowspinner = QSpinBox()
|
|
self._windowspinner.setRange(100, 10000)
|
|
self._windowspinner.setSingleStep(100)
|
|
self._windowspinner.setValue(500)
|
|
self._windowspinner.valueChanged.connect(self.on_windowSizeChanged)
|
|
|
|
timelinebox = QHBoxLayout()
|
|
timelinebox.addWidget(self._timeline)
|
|
timelinebox.addWidget(QLabel("Window"))
|
|
timelinebox.addWidget(self._windowspinner)
|
|
|
|
self._controls_widget = SelectionControls()
|
|
self._controls_widget.assignOne.connect(self.on_assignOne)
|
|
self._controls_widget.assignTwo.connect(self.on_assignTwo)
|
|
self._controls_widget.assignOther.connect(self.on_assignOther)
|
|
|
|
self._trackone_table = QTableView()
|
|
font = QFont()
|
|
font.setBold(True)
|
|
font.setPointSize(8)
|
|
self._trackone_table.setFont(font)
|
|
assign1 = QPushButton("<<")
|
|
assign1.clicked.connect(self.on_assignOne)
|
|
assign2 = QPushButton(">>")
|
|
assign2.clicked.connect(self.on_assignTwo)
|
|
self._unassigned_table = QTableView()
|
|
self._unassigned_table.setFont(font)
|
|
self._unassigned_table.setSelectionMode(QTableView.SelectionMode.ExtendedSelection)
|
|
self._unassigned_table.setSelectionBehavior(QTableView.SelectionBehavior.SelectRows)
|
|
self._tracktwo_table = QTableView()
|
|
self._tracktwo_table.setFont(font)
|
|
|
|
trackone_label = QLabel("Track 1")
|
|
trackone_label.setStyleSheet("QLabel { color : orange; }")
|
|
track1_box = QVBoxLayout()
|
|
track1_box.addWidget(trackone_label)
|
|
track1_box.addWidget(self._trackone_table)
|
|
|
|
tracktwo_label = QLabel("Track 2")
|
|
tracktwo_label.setStyleSheet("QLabel { color : green; }")
|
|
tracktwo_box = QVBoxLayout()
|
|
tracktwo_box.addWidget(tracktwo_label)
|
|
tracktwo_box.addWidget(self._tracktwo_table)
|
|
|
|
trackother_label = QLabel("Unassigned")
|
|
trackother_label.setStyleSheet("QLabel { color : red; }")
|
|
trackother_box = QVBoxLayout()
|
|
trackother_box.addWidget(trackother_label)
|
|
trackother_box.addWidget(self._unassigned_table)
|
|
|
|
# tablebox = QHBoxLayout()
|
|
# tablebox.addLayout(track1_box)
|
|
# tablebox.addWidget(assign1)
|
|
# tablebox.addLayout(trackother_box)
|
|
# tablebox.addWidget(assign2)
|
|
# tablebox.addLayout(tracktwo_box)
|
|
|
|
self._saveBtn = QPushButton("Save")
|
|
self._saveBtn.setShortcut("Ctrl+S")
|
|
self._saveBtn.setEnabled(False)
|
|
self._saveBtn.clicked.connect(self.on_save)
|
|
self._backBtn = QPushButton("Back")
|
|
self._backBtn.setShortcut("ESC")
|
|
self._backBtn.clicked.connect(self.on_back)
|
|
|
|
self._data_combo = QComboBox()
|
|
self._data_combo.addItems(self._files)
|
|
self._data_combo.currentIndexChanged.connect(self.on_dataSelection)
|
|
self._image_combo = QComboBox()
|
|
self._image_combo.addItems(self._files)
|
|
self._image_combo.currentIndexChanged.connect(self.on_imageSelection)
|
|
|
|
data_selection_box = QHBoxLayout()
|
|
data_selection_box.addWidget(QLabel("Select image file"))
|
|
data_selection_box.addWidget(self._image_combo)
|
|
data_selection_box.addWidget(QLabel("Select data file"))
|
|
data_selection_box.addWidget(self._data_combo)
|
|
data_selection_box.addItem(QSpacerItem(100, 10, QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Fixed))
|
|
|
|
btnBox = QHBoxLayout()
|
|
btnBox.setAlignment(Qt.AlignmentFlag.AlignLeft)
|
|
btnBox.addWidget(self._backBtn)
|
|
btnBox.addItem(QSpacerItem(100, 10, QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Fixed))
|
|
btnBox.addWidget(self._tasklabel)
|
|
btnBox.addWidget(self._progress_bar)
|
|
# btnBox.addWidget(self._openBtn)
|
|
btnBox.addWidget(self._saveBtn)
|
|
|
|
vbox = QVBoxLayout()
|
|
vbox.addLayout(timelinebox)
|
|
# vbox.addLayout(tablebox)
|
|
vbox.addWidget(self._controls_widget, stretch=1, alignment=Qt.AlignmentFlag.AlignCenter)
|
|
vbox.addLayout(btnBox)
|
|
container = QWidget()
|
|
container.setLayout(vbox)
|
|
|
|
splitter = QSplitter(Qt.Orientation.Vertical)
|
|
splitter.addWidget(self._detectionView)
|
|
splitter.addWidget(container)
|
|
splitter.setStretchFactor(0, 3)
|
|
splitter.setStretchFactor(1, 1)
|
|
layout = QVBoxLayout()
|
|
layout.addLayout(data_selection_box)
|
|
layout.addWidget(splitter)
|
|
self.setLayout(layout)
|
|
|
|
def on_dataSelection(self):
|
|
filename = self._data_combo.currentText()
|
|
if "please select" in filename.lower():
|
|
return
|
|
self._progress_bar.setRange(0,0)
|
|
self._reader = PickleLoader(filename)
|
|
self._reader.signals.finished.connect(self._on_dataOpenend)
|
|
self._threadpool.start(self._reader)
|
|
|
|
def on_imageSelection(self):
|
|
filename = self._image_combo.currentText()
|
|
if "please select" in filename.lower():
|
|
return
|
|
img = QImage(filename)
|
|
self._detectionView.setImage(img)
|
|
|
|
def update(self):
|
|
def update_detectionView(df, name):
|
|
if len(df) == 0:
|
|
return
|
|
coords = np.stack(df["keypoints"].values).astype(np.float32)[:,0,:]
|
|
tracks = df["track"].values.astype(int)
|
|
ids = df.index.values.astype(int)
|
|
self._detectionView.addDetections(coords, tracks, ids, self._brushes[name])
|
|
|
|
max_frames = self._data.max("frame")
|
|
start = self._timeline.rangeStart
|
|
stop = self._timeline.rangeStop
|
|
start_frame = int(np.floor(start * max_frames))
|
|
stop_frame = int(np.ceil(stop * max_frames))
|
|
logging.debug("Updating View for detection range %i, %i frames", start_frame, stop_frame)
|
|
self._data.setSelectionRange("frame", start_frame, stop_frame)
|
|
frames = self._data.selectedData("frame")
|
|
tracks = self._data.selectedData("track")
|
|
keypoints = self._data.selectedData("keypoints")
|
|
index = self._data.selectedData("index")
|
|
|
|
df = pd.DataFrame({"frame": frames,
|
|
"track": tracks,
|
|
"keypoints": keypoints},
|
|
index=index)
|
|
assigned_left = df[(df.track == self.trackone_id)]
|
|
assigned_right = df[(df.track == self.tracktwo_id)]
|
|
unassigned = df[(df.track != self.trackone_id) & (df.track != self.tracktwo_id)]
|
|
|
|
self._detectionView.clearDetections()
|
|
update_detectionView(unassigned, "unassigned")
|
|
update_detectionView(assigned_left, "assigned_left")
|
|
update_detectionView(assigned_right, "assigned_right")
|
|
|
|
@property
|
|
def fileList(self):
|
|
return self._files
|
|
|
|
@fileList.setter
|
|
def fileList(self, file_list):
|
|
logging.debug("FixTracks.fileList: set new file list")
|
|
|
|
logging.debug("FixTracks.fileList: setting image combo box")
|
|
img_formats = [".jpg", ".png"]
|
|
self._files = [str(f) for f in file_list if f.suffix in img_formats]
|
|
self._image_combo.addItem("Please select")
|
|
self._image_combo.addItems(self.fileList)
|
|
self._image_combo.setCurrentIndex(0)
|
|
|
|
logging.debug("FixTracks.fileList: setting data combo box")
|
|
dataformats = [".pkl"]
|
|
self._files = [str(f) for f in file_list if f.suffix in dataformats]
|
|
self._data_combo.addItem("Please select")
|
|
self._data_combo.addItems(self.fileList)
|
|
self._data_combo.setCurrentIndex(0)
|
|
|
|
def _on_dataOpenend(self, state):
|
|
logging.info("Finished loading data with state %s", state)
|
|
self._tasklabel.setText("")
|
|
self._progress_bar.setRange(0, 100)
|
|
self._progress_bar.setValue(0)
|
|
if state and self._reader is not None:
|
|
self._data.setData(self._reader.asdict)
|
|
self._timeline.setDetectionData(self._data.data)
|
|
self.update()
|
|
self._saveBtn.setEnabled(True)
|
|
|
|
def on_save(self):
|
|
logging.debug("Saving fixtracks results")
|
|
self._tasklabel.setText("Saving results to file...")
|
|
file_dialog = QFileDialog(self)
|
|
file_dialog.setAcceptMode(QFileDialog.AcceptMode.AcceptSave)
|
|
file_dialog.setNameFilter("Pickle Files (*.pkl)")
|
|
if file_dialog.exec():
|
|
file_path = file_dialog.selectedFiles()[0]
|
|
if not file_path.endswith(".pkl"):
|
|
file_path += ".pkl"
|
|
self._progress_bar.setRange(0,0)
|
|
save_task = PickleWriter(self._data, file_path)
|
|
save_task.signals.finished.connect(self.on_dataSaved)
|
|
self._threadpool.start(save_task)
|
|
|
|
def on_dataSaved(self):
|
|
self._tasklabel.setText("")
|
|
self._progress_bar.setRange(0, 100)
|
|
self._progress_bar.setValue(0)
|
|
|
|
def on_back(self):
|
|
logging.debug("Back button pressed!")
|
|
self.back.emit()
|
|
|
|
def on_assignOne(self):
|
|
logging.debug("Assigning user selection to track One")
|
|
self._data.assignUserSelection(self.trackone_id)
|
|
self._timeline.setDetectionData(self._data.data)
|
|
self.update()
|
|
|
|
def on_assignTwo(self):
|
|
logging.debug("Assigning user selection to track Two")
|
|
self._data.assignUserSelection(self.tracktwo_id)
|
|
self._timeline.setDetectionData(self._data.data)
|
|
self.update()
|
|
|
|
def on_assignOther(self):
|
|
logging.debug("Assigning user selection to track Other")
|
|
self._data.assignUserSelection(self.trackother_id)
|
|
self._timeline.setDetectionData(self._data.data)
|
|
self.update()
|
|
|
|
def on_windowChanged(self, start, stop):
|
|
logging.info("Timeline reports window change to range %f %f percent of data", start, stop)
|
|
self.update()
|
|
|
|
def on_windowSizeChanged(self, value):
|
|
self._timeline.setWindowWidth(value)
|
|
|
|
def on_detectionsSelected(self, detections):
|
|
logging.debug("Tracks: Detections selected")
|
|
tracks = np.zeros(len(detections))
|
|
ids = np.zeros_like(tracks)
|
|
for i, d in enumerate(detections):
|
|
tracks[i] = d.data(0)
|
|
ids[i] = d.data(1)
|
|
self._data.setUserSelection(ids)
|
|
self._controls_widget.setSelectedTracks(tracks)
|
|
self.update()
|
|
|
|
|
|
def main():
|
|
from PySide6.QtWidgets import QApplication
|
|
app = QApplication([])
|
|
window = QWidget()
|
|
window.setMinimumSize(200, 200)
|
|
layout = QVBoxLayout()
|
|
controls = SelectionControls()
|
|
layout.addWidget(controls)
|
|
window.setLayout(layout)
|
|
window.show()
|
|
app.exec()
|
|
|
|
if __name__ == "__main__":
|
|
main() |