fixtracks/fixtracks/widgets/tracks.py

436 lines
17 KiB
Python

import logging
import numpy as np
from PySide6.QtCore import Qt, QThreadPool, Signal
from PySide6.QtGui import QImage, QBrush, QColor
from PySide6.QtWidgets import QWidget, QVBoxLayout, QHBoxLayout, QLabel, QPushButton, QSizePolicy, QComboBox
from PySide6.QtWidgets import QSpinBox, QSpacerItem, QProgressBar, QSplitter, QFileDialog, QMessageBox
from fixtracks.utils.reader import PickleLoader
from fixtracks.utils.writer import PickleWriter
from fixtracks.utils.trackingdata import TrackingData
from fixtracks.widgets.detectionview import DetectionView, DetectionData
from fixtracks.widgets.detectiontimeline import DetectionTimeline
from fixtracks.widgets.skeleton import SkeletonWidget
from fixtracks.widgets.classifier import ClassifierWidget
from fixtracks.widgets.selection_control import SelectionControls
class FixTracks(QWidget):
back = Signal()
trackone_id = 1
tracktwo_id = 2
trackother_id = -1
def __init__(self, parent=None):
super().__init__(parent)
self._files = []
self._threadpool = QThreadPool()
self._reader = None
self._image = None
self._currentWindowPos = 0 # in frames
self._currentWindowWidth = 0 # in frames
self._maxframes = 0
self._manualmove = False
self._data = None
self._detectionView = DetectionView()
self._detectionView.signals.itemsSelected.connect(self.on_detectionsSelected)
self._skeleton = SkeletonWidget()
self._progress_bar = QProgressBar(self)
self._progress_bar.setMaximumHeight(20)
self._progress_bar.setValue(0)
self._tasklabel = QLabel()
self._timeline = DetectionTimeline()
self._timeline.signals.windowMoved.connect(self.on_windowChanged)
self._timeline.signals.moveRequest.connect(self.on_moveRequest)
self._windowspinner = QSpinBox()
self._windowspinner.setRange(10, 10000)
self._windowspinner.setSingleStep(50)
self._windowspinner.setValue(500)
self._windowspinner.valueChanged.connect(self.on_windowSizeChanged)
self._keypointcombo = QComboBox()
self._keypointcombo.currentIndexChanged.connect(self.on_keypointSelected)
self._goto_spinner = QSpinBox()
self._goto_spinner.setSingleStep(1)
self._gotobtn = QPushButton("go!")
self._gotobtn.setToolTip("Jump to a given frame")
self._gotobtn.clicked.connect(self.on_goto)
combo_layout = QHBoxLayout()
combo_layout.addWidget(QLabel("Window width:"))
combo_layout.addWidget(self._windowspinner)
combo_layout.addWidget(QLabel("frames"))
combo_layout.addItem(QSpacerItem(10, 10, QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Fixed))
combo_layout.addWidget(QLabel("Keypoint:"))
combo_layout.addWidget(self._keypointcombo)
combo_layout.addItem(QSpacerItem(10, 10, QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Fixed))
combo_layout.addWidget(QLabel("Jump to frame:"))
combo_layout.addWidget(self._goto_spinner)
combo_layout.addWidget(self._gotobtn)
combo_layout.addItem(QSpacerItem(100, 10, QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Fixed))
combo_layout.setSpacing(1)
timelinebox = QVBoxLayout()
timelinebox.setSpacing(2)
timelinebox.addLayout(combo_layout)
timelinebox.addWidget(self._timeline)
self._controls_widget = SelectionControls()
self._controls_widget.assignOne.connect(self.on_assignOne)
self._controls_widget.assignTwo.connect(self.on_assignTwo)
self._controls_widget.assignOther.connect(self.on_assignOther)
self._controls_widget.fwd.connect(self.on_forward)
self._controls_widget.back.connect(self.on_backward)
self._controls_widget.accept.connect(self.on_setUserFlag)
self._controls_widget.accept_until.connect(self.on_setUserFlagsUntil)
self._controls_widget.unaccept.connect(self.on_unsetUserFlag)
self._controls_widget.delete.connect(self.on_deleteDetection)
self._controls_widget.revertall.connect(self.on_revertUserFlags)
self._saveBtn = QPushButton("Save")
self._saveBtn.setShortcut("Ctrl+S")
self._saveBtn.setEnabled(False)
self._saveBtn.clicked.connect(self.on_save)
self._backBtn = QPushButton("Back")
self._backBtn.setShortcut("ESC")
self._backBtn.clicked.connect(self.on_back)
self._data_combo = QComboBox()
self._data_combo.addItems(self._files)
self._data_combo.currentIndexChanged.connect(self.on_dataSelection)
self._image_combo = QComboBox()
self._image_combo.addItems(self._files)
self._image_combo.currentIndexChanged.connect(self.on_imageSelection)
data_selection_box = QHBoxLayout()
data_selection_box.addWidget(QLabel("Select image file"))
data_selection_box.addWidget(self._image_combo)
data_selection_box.addWidget(QLabel("Select data file"))
data_selection_box.addWidget(self._data_combo)
data_selection_box.addItem(QSpacerItem(100, 10, QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Fixed))
data_selection_box.setSpacing(0)
btnBox = QHBoxLayout()
btnBox.setAlignment(Qt.AlignmentFlag.AlignLeft)
btnBox.addWidget(self._backBtn)
btnBox.addItem(QSpacerItem(100, 10, QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Fixed))
btnBox.addWidget(self._tasklabel)
btnBox.addWidget(self._progress_bar)
btnBox.addWidget(self._saveBtn)
self._classifier = ClassifierWidget()
self._classifier.apply_classifier.connect(self.on_autoClassify)
self._classifier.setMaximumWidth(500)
cntrlBox = QHBoxLayout()
cntrlBox.addWidget(self._classifier)
cntrlBox.addWidget(self._controls_widget, alignment=Qt.AlignmentFlag.AlignCenter)
cntrlBox.addWidget(self._skeleton)
cntrlBox.addItem(QSpacerItem(50, 100, QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Expanding))
cntrlBox.setSpacing(0)
cntrlBox.setContentsMargins(0,0,0,0)
vbox = QVBoxLayout()
vbox.setSpacing(0)
vbox.setContentsMargins(0,0,0,0)
vbox.addLayout(timelinebox)
vbox.addLayout(cntrlBox)
vbox.addLayout(btnBox)
container = QWidget()
container.setLayout(vbox)
splitter = QSplitter(Qt.Orientation.Vertical)
splitter.addWidget(self._detectionView)
splitter.addWidget(container)
splitter.setStretchFactor(0, 3)
splitter.setStretchFactor(1, 1)
layout = QVBoxLayout()
layout.addLayout(data_selection_box)
layout.addWidget(splitter)
layout.setSpacing(0)
layout.setContentsMargins(5,2,2,5)
self.setLayout(layout)
def on_autoClassify(self, tracks):
self._data.setSelectionRange("index", 0, self._data.numDetections)
self._data.assignTracks(tracks)
self._timeline.update()
self.update()
def on_dataSelection(self):
filename = self._data_combo.currentText()
if "please select" in filename.lower() or len(filename.strip()) == 0:
return
self._progress_bar.setRange(0,0)
self._reader = PickleLoader(filename)
self._reader.signals.finished.connect(self._on_dataOpenend)
self._threadpool.start(self._reader)
def on_imageSelection(self):
filename = self._image_combo.currentText()
if "please select" in filename.lower() or len(filename.strip()) == 0:
return
img = QImage(filename)
self._detectionView.setImage(img)
def update(self):
kp = self._keypointcombo.currentText().lower()
if len(kp) == 0:
return
kpi = -1 if "center" in kp else int(kp)
start_frame = self._currentWindowPos
stop_frame = start_frame + self._currentWindowWidth
self._timeline.setWindow(start_frame / self._maxframes,
self._currentWindowWidth/self._maxframes)
logging.debug("Tracks:update: Updating View for detection range %i, %i frames", start_frame, stop_frame)
self._data.setSelectionRange("frame", start_frame, stop_frame)
self._controls_widget.setWindow(start_frame, stop_frame)
self._detectionView.updateDetections(kpi)
@property
def fileList(self):
return self._files
@fileList.setter
def fileList(self, file_list):
logging.debug("FixTracks.fileList: set new file list")
self._files = []
self._image_combo.clear()
self._data_combo.clear()
img_formats = [".jpg", ".png"]
self._files = [str(f) for f in file_list if f.suffix in img_formats]
self._image_combo.addItem("Please select")
self._image_combo.addItems(self.fileList)
self._image_combo.setCurrentIndex(0)
dataformats = [".pkl"]
self._files = [str(f) for f in file_list if f.suffix in dataformats]
self._data_combo.addItem("Please select")
self._data_combo.addItems(self.fileList)
self._data_combo.setCurrentIndex(0)
def populateKeypointCombo(self, num_keypoints):
self._keypointcombo.clear()
self._keypointcombo.addItem("Center")
for i in range(num_keypoints):
self._keypointcombo.addItem(str(i))
self._keypointcombo.setCurrentIndex(0)
def _on_dataOpenend(self, state):
self._tasklabel.setText("")
self._progress_bar.setRange(0, 100)
self._progress_bar.setValue(0)
if state and self._reader is not None:
self._data = TrackingData(self._reader.asdict)
self._saveBtn.setEnabled(True)
self._currentWindowPos = 0
self._currentWindowWidth = self._windowspinner.value()
self._maxframes = np.max(self._data["frame"])
self._goto_spinner.setMaximum(self._maxframes)
self.populateKeypointCombo(self._data.numKeypoints())
self._timeline.setData(self._data)
# self._timeline.setWindow(self._currentWindowPos / self._maxframes,
# self._currentWindowWidth / self._maxframes)
self._detectionView.setData(self._data)
self._classifier.setData(self._data)
self.update()
logging.info("Finished loading data: %i frames", self._maxframes)
def on_keypointSelected(self):
self.update()
def on_save(self):
logging.debug("Saving fixtracks results")
self._tasklabel.setText("Saving results to file...")
file_dialog = QFileDialog(self)
file_dialog.setAcceptMode(QFileDialog.AcceptMode.AcceptSave)
file_dialog.setNameFilter("Pickle Files (*.pkl)")
if file_dialog.exec():
file_path = file_dialog.selectedFiles()[0]
if not file_path.endswith(".pkl"):
file_path += ".pkl"
self._progress_bar.setRange(0,0)
save_task = PickleWriter(self._data, file_path)
save_task.signals.finished.connect(self.on_dataSaved)
self._threadpool.start(save_task)
def on_dataSaved(self):
self._tasklabel.setText("")
self._progress_bar.setRange(0, 100)
self._progress_bar.setValue(0)
def on_back(self):
logging.debug("Back button pressed!")
self.back.emit()
def on_assignOne(self):
logging.debug("Assigning user selection to track One")
self._data.setTrack(self.trackone_id)
self._timeline.update()
self.update()
def on_assignTwo(self):
logging.debug("Assigning user selection to track Two")
self._data.setTrack(self.tracktwo_id)
self._timeline.update()
self.update()
def on_assignOther(self):
logging.debug("Assigning user selection to track Other")
self._data.setTrack(self.trackother_id, False)
self._timeline.update()
self.update()
def on_setUserFlag(self):
self._data.setUserLabeledStatus(True)
self._timeline.update()
self.update()
def on_setUserFlagsUntil(self):
self._data.setSelectionRange("frame", 0, self._currentWindowPos + self._currentWindowWidth)
self._data.setUserLabeledStatus(True)
self._timeline.update()
self.update()
def on_unsetUserFlag(self):
logging.debug("Tracks:unsetUserFlag")
self._data.setUserLabeledStatus(False)
self._timeline.update()
self.update()
def on_revertUserFlags(self):
logging.debug("Tracks:revert ALL UserFlags and track assignments")
msg_box = QMessageBox()
msg_box.setIcon(QMessageBox.Icon.Warning)
msg_box.setText(f"Are you sure you want to revert ALL track assignments?")
msg_box.setStandardButtons(QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No)
msg_box.setDefaultButton(QMessageBox.StandardButton.No)
ret = msg_box.exec()
if ret == QMessageBox.StandardButton.Yes:
self._data.revertUserLabeledStatus()
self._data.revertTrackAssignments()
self._timeline.update()
self.update()
def on_deleteDetection(self):
logging.info("Tracks:deleting detections!")
msg_box = QMessageBox()
msg_box.setIcon(QMessageBox.Icon.Warning)
msg_box.setText(f"Are you sure you want to delete the selected ({len(self._data.selectionIndices)})detections?")
msg_box.setStandardButtons(QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No)
msg_box.setDefaultButton(QMessageBox.StandardButton.No)
ret = msg_box.exec()
if ret == QMessageBox.StandardButton.Yes:
self._data.deleteDetections()
self._timeline.update()
self.update()
def on_windowChanged(self):
logging.debug("Tracks:Timeline reports window change ")
if not self._manualmove:
self._currentWindowPos = np.round(self._timeline.rangeStart * self._maxframes)
self.update()
self._manualmove = False
def on_moveRequest(self, pos):
new_pos = int(np.round(pos * self._maxframes))
self._currentWindowPos = new_pos
self.update()
def on_windowSizeChanged(self, value):
"""Reacts on the user window-width selection. Selection is done in the unit of frames.
Parameters
----------
value : int
The width of the observation window in frames.
"""
self._currentWindowWidth = value
logging.debug("Tracks:OnWindowSizeChanged %i franes", value)
# if self._maxframes == 0:
# self._timeline.setWindowWidth(self._currentWindowWidth / 2000)
# else:
# self._timeline.setWindowWidth(self._currentWindowWidth / self._maxframes)
# self._controls_widget.setSelectedTracks(None)
self.update()
def on_goto(self):
target = self._goto_spinner.value()
if target > self._maxframes - self._currentWindowWidth:
target = self._maxframes - self._currentWindowWidth
logging.info("Jump to frame %i", target)
self._currentWindowPos = target
self._timeline.setWindow(self._currentWindowPos / self._maxframes,
self._currentWindowWidth / self._maxframes)
self.update()
def on_detectionsSelected(self, detections):
logging.debug("Tracks: %i Detections selected", len(detections))
tracks = np.zeros(len(detections), dtype=int)
ids = np.zeros_like(tracks)
frames = np.zeros_like(tracks)
scores = np.zeros(tracks.shape, dtype=float)
coordinates = None
if len(detections) > 0:
c = detections[0].data(DetectionData.COORDINATES.value)
coordinates = np.zeros((len(detections), c.shape[0], c.shape[1]))
for i, d in enumerate(detections):
tracks[i] = d.data(DetectionData.TRACK_ID.value)
ids[i] = d.data(DetectionData.ID.value)
frames[i] = d.data(DetectionData.FRAME.value)
coordinates[i, :, :] = d.data(DetectionData.COORDINATES.value)
scores[i] = d.data(DetectionData.SCORE.value)
self._data.setSelection(ids)
self._controls_widget.setSelectedTracks(tracks)
self._skeleton.clear()
self._skeleton.addSkeletons(coordinates, ids, frames, tracks, scores, QBrush(QColor(10, 255, 65, 255)))
def moveWindow(self, stepsize):
logging.info("Tracks.moveWindow: move window with stepsize %.2f", stepsize)
self._manualmove = True
new_start_frame = self._currentWindowPos + np.round(stepsize * self._currentWindowWidth)
if new_start_frame < 0:
new_start_frame = 0
elif new_start_frame + self._currentWindowWidth > self._maxframes:
new_start_frame = self._maxframes - self._currentWindowWidth
self._currentWindowPos = new_start_frame
self._controls_widget.setSelectedTracks(None)
self.update()
def on_forward(self, stepsize):
logging.debug("Tracks: receive forward command with step-size: %.2f", stepsize)
self.moveWindow(stepsize)
def on_backward(self, stepsize):
logging.debug("Tracks: receive backward command with step-size: %.2f", stepsize)
self.moveWindow(-stepsize)
def main():
from PySide6.QtWidgets import QApplication
app = QApplication([])
window = QWidget()
window.setMinimumSize(200, 200)
layout = QVBoxLayout()
controls = SelectionControls()
layout.addWidget(controls)
window.setLayout(layout)
window.show()
app.exec()
if __name__ == "__main__":
main()