fixtracks/fixtracks/widgets/tracks.py

360 lines
14 KiB
Python

import logging
import numpy as np
import pandas as pd
from PySide6.QtCore import Qt, QThreadPool, Signal
from PySide6.QtGui import QImage, QBrush, QColor
from PySide6.QtWidgets import QWidget, QVBoxLayout, QHBoxLayout, QLabel, QPushButton, QSizePolicy, QComboBox
from PySide6.QtWidgets import QSpinBox, QSpacerItem, QProgressBar, QSplitter, QGridLayout, QFileDialog, QGridLayout
from fixtracks.utils.reader import PickleLoader
from fixtracks.utils.writer import PickleWriter
from fixtracks.utils.trackingdata import TrackingData
from fixtracks.widgets.detectionview import DetectionView, DetectionData
from fixtracks.widgets.detectiontimeline import DetectionTimeline
from fixtracks.widgets.skeleton import SkeletonWidget
from fixtracks.widgets.classifier import ClassifierWidget
from fixtracks.widgets.selection_control import SelectionControls
class FixTracks(QWidget):
back = Signal()
trackone_id = 1
tracktwo_id = 2
trackother_id = -1
def __init__(self, parent=None):
super().__init__(parent)
self._files = []
self._threadpool = QThreadPool()
self._reader = None
self._image = None
self._currentWindowPos = 0 # in frames
self._currentWindowWidth = 0 # in frames
self._maxframes = 0
self._data = TrackingData()
self._detectionView = DetectionView()
self._detectionView.signals.itemsSelected.connect(self.on_detectionsSelected)
self._skeleton = SkeletonWidget()
# self._skeleton.setMaximumSize(QSize(400, 400))
top_splitter = QSplitter(Qt.Orientation.Horizontal)
top_splitter.addWidget(self._detectionView)
top_splitter.addWidget(self._skeleton)
top_splitter.setStretchFactor(0, 2)
top_splitter.setStretchFactor(1, 1)
self._progress_bar = QProgressBar(self)
self._progress_bar.setMaximumHeight(20)
self._progress_bar.setValue(0)
self._tasklabel = QLabel()
self._timeline = DetectionTimeline()
self._timeline.signals.windowMoved.connect(self.on_windowChanged)
self._windowspinner = QSpinBox()
self._windowspinner.setRange(10, 10000)
self._windowspinner.setSingleStep(50)
self._windowspinner.setValue(500)
self._windowspinner.valueChanged.connect(self.on_windowSizeChanged)
self._keypointcombo = QComboBox()
self._keypointcombo.currentIndexChanged.connect(self.on_keypointSelected)
combo_layout = QGridLayout()
combo_layout.addWidget(QLabel("Window:"), 0, 0)
combo_layout.addWidget(self._windowspinner, 0, 1)
combo_layout.addWidget(QLabel("Keypoint:"), 1, 0)
combo_layout.addWidget(self._keypointcombo, 1, 1)
timelinebox = QHBoxLayout()
timelinebox.addWidget(self._timeline)
timelinebox.addLayout(combo_layout)
self._controls_widget = SelectionControls()
self._controls_widget.assignOne.connect(self.on_assignOne)
self._controls_widget.assignTwo.connect(self.on_assignTwo)
self._controls_widget.assignOther.connect(self.on_assignOther)
self._controls_widget.fwd.connect(self.on_forward)
self._controls_widget.back.connect(self.on_backward)
self._controls_widget.accept.connect(self.on_setUserFlag)
self._controls_widget.unaccept.connect(self.on_unsetUserFlag)
self._controls_widget.delete.connect(self.on_deleteDetection)
self._controls_widget.revertall.connect(self.on_revertUserFlags)
self._saveBtn = QPushButton("Save")
self._saveBtn.setShortcut("Ctrl+S")
self._saveBtn.setEnabled(False)
self._saveBtn.clicked.connect(self.on_save)
self._backBtn = QPushButton("Back")
self._backBtn.setShortcut("ESC")
self._backBtn.clicked.connect(self.on_back)
self._data_combo = QComboBox()
self._data_combo.addItems(self._files)
self._data_combo.currentIndexChanged.connect(self.on_dataSelection)
self._image_combo = QComboBox()
self._image_combo.addItems(self._files)
self._image_combo.currentIndexChanged.connect(self.on_imageSelection)
data_selection_box = QHBoxLayout()
data_selection_box.addWidget(QLabel("Select image file"))
data_selection_box.addWidget(self._image_combo)
data_selection_box.addWidget(QLabel("Select data file"))
data_selection_box.addWidget(self._data_combo)
data_selection_box.addItem(QSpacerItem(100, 10, QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Fixed))
btnBox = QHBoxLayout()
btnBox.setAlignment(Qt.AlignmentFlag.AlignLeft)
btnBox.addWidget(self._backBtn)
btnBox.addItem(QSpacerItem(100, 10, QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Fixed))
btnBox.addWidget(self._tasklabel)
btnBox.addWidget(self._progress_bar)
btnBox.addWidget(self._saveBtn)
self._classifier = ClassifierWidget()
self._classifier.apply_classifier.connect(self.on_autoClassify)
self._classifier.setMaximumWidth(500)
cntrlBox = QHBoxLayout()
cntrlBox.addWidget(self._classifier)
cntrlBox.addWidget(self._controls_widget, alignment=Qt.AlignmentFlag.AlignCenter)
cntrlBox.addItem(QSpacerItem(300, 100, QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Expanding))
vbox = QVBoxLayout()
vbox.addLayout(timelinebox)
vbox.addLayout(cntrlBox)
vbox.addLayout(btnBox)
container = QWidget()
container.setLayout(vbox)
splitter = QSplitter(Qt.Orientation.Vertical)
splitter.addWidget(top_splitter)
splitter.addWidget(container)
splitter.setStretchFactor(0, 3)
splitter.setStretchFactor(1, 1)
layout = QVBoxLayout()
layout.addLayout(data_selection_box)
layout.addWidget(splitter)
self.setLayout(layout)
def on_autoClassify(self, tracks):
self._data.setSelectionRange("index", 0, self._data.numDetections)
self._data.assignTracks(tracks)
self._timeline.update()
self.update()
def on_dataSelection(self):
filename = self._data_combo.currentText()
if "please select" in filename.lower() or len(filename.strip()) == 0:
return
self._progress_bar.setRange(0,0)
self._reader = PickleLoader(filename)
self._reader.signals.finished.connect(self._on_dataOpenend)
self._threadpool.start(self._reader)
def on_imageSelection(self):
filename = self._image_combo.currentText()
if "please select" in filename.lower() or len(filename.strip()) == 0:
return
img = QImage(filename)
self._detectionView.setImage(img)
def update(self):
start_frame = self._currentWindowPos
stop_frame = start_frame + self._currentWindowWidth
logging.debug("Tracks:update: Updating View for detection range %i, %i frames", start_frame, stop_frame)
self._data.setSelectionRange("frame", start_frame, stop_frame)
self._controls_widget.setWindow(start_frame, stop_frame)
kp = self._keypointcombo.currentText().lower()
kpi = -1 if "center" in kp else int(kp)
self._detectionView.updateDetections(kpi)
@property
def fileList(self):
return self._files
@fileList.setter
def fileList(self, file_list):
logging.debug("FixTracks.fileList: set new file list")
self._files = []
self._image_combo.clear()
self._data_combo.clear()
img_formats = [".jpg", ".png"]
self._files = [str(f) for f in file_list if f.suffix in img_formats]
self._image_combo.addItem("Please select")
self._image_combo.addItems(self.fileList)
self._image_combo.setCurrentIndex(0)
dataformats = [".pkl"]
self._files = [str(f) for f in file_list if f.suffix in dataformats]
self._data_combo.addItem("Please select")
self._data_combo.addItems(self.fileList)
self._data_combo.setCurrentIndex(0)
def populateKeypointCombo(self, num_keypoints):
self._keypointcombo.clear()
self._keypointcombo.addItem("Center")
for i in range(num_keypoints):
self._keypointcombo.addItem(str(i))
self._keypointcombo.setCurrentIndex(0)
def _on_dataOpenend(self, state):
self._tasklabel.setText("")
self._progress_bar.setRange(0, 100)
self._progress_bar.setValue(0)
if state and self._reader is not None:
self._data.setData(self._reader.asdict)
self._saveBtn.setEnabled(True)
self._currentWindowPos = 0
self._currentWindowWidth = self._windowspinner.value()
self._maxframes = self._data.max("frame")
self.populateKeypointCombo(self._data.numKeypoints())
self._timeline.setData(self._data)
self._timeline.setWindow(self._currentWindowPos / self._maxframes,
self._currentWindowWidth / self._maxframes)
self._detectionView.setData(self._data)
self._classifier.setData(self._data)
self.update()
logging.info("Finished loading data: %i frames", self._maxframes)
def on_keypointSelected(self):
self.update()
def on_save(self):
logging.debug("Saving fixtracks results")
self._tasklabel.setText("Saving results to file...")
file_dialog = QFileDialog(self)
file_dialog.setAcceptMode(QFileDialog.AcceptMode.AcceptSave)
file_dialog.setNameFilter("Pickle Files (*.pkl)")
if file_dialog.exec():
file_path = file_dialog.selectedFiles()[0]
if not file_path.endswith(".pkl"):
file_path += ".pkl"
self._progress_bar.setRange(0,0)
save_task = PickleWriter(self._data, file_path)
save_task.signals.finished.connect(self.on_dataSaved)
self._threadpool.start(save_task)
def on_dataSaved(self):
self._tasklabel.setText("")
self._progress_bar.setRange(0, 100)
self._progress_bar.setValue(0)
def on_back(self):
logging.debug("Back button pressed!")
self.back.emit()
def on_assignOne(self):
logging.debug("Assigning user selection to track One")
self._data.assignUserSelection(self.trackone_id)
self._timeline.update()
self.update()
def on_assignTwo(self):
logging.debug("Assigning user selection to track Two")
self._data.assignUserSelection(self.tracktwo_id)
self._timeline.update()
self.update()
def on_assignOther(self):
logging.debug("Assigning user selection to track Other")
self._data.assignUserSelection(self.trackother_id, False)
self._timeline.update()
self.update()
def on_setUserFlag(self):
self._data.setAssignmentStatus(True)
self._timeline.update()
self.update()
def on_unsetUserFlag(self):
logging.debug("Tracks:unsetUserFlag")
self._data.setAssignmentStatus(False)
self._timeline.update()
self.update()
def on_revertUserFlags(self):
logging.debug("Tracks:revert ALL UserFlags and track assignments")
self._data.revertAssignmentStatus()
self._data.revertTrackAssignments()
self._timeline.update()
self.update()
def on_deleteDetection(self):
logging.warning("Tracks:delete detections is currently not supported!")
# self._data.deleteDetections()
self._timeline.update()
self.update()
def on_windowChanged(self):
logging.debug("Tracks:Timeline reports window change ")
self._currentWindowPos = np.round(self._timeline.rangeStart * self._maxframes)
self.update()
def on_windowSizeChanged(self, value):
"""Reacts on the user window-width selection. Selection is done in the unit of frames.
Parameters
----------
value : int
The width of the observation window in frames.
"""
self._currentWindowWidth = value
logging.debug("Tracks:OnWindowSizeChanged %i franes", value)
self._timeline.setWindowWidth(self._currentWindowWidth / self._maxframes)
self._controls_widget.setSelectedTracks(None)
def on_detectionsSelected(self, detections):
logging.debug("Tracks: Detections selected")
tracks = np.zeros(len(detections), dtype=int)
ids = np.zeros_like(tracks)
frames = np.zeros_like(tracks)
coordinates = None
if len(detections) > 0:
c = detections[0].data(DetectionData.COORDINATES.value)
coordinates = np.zeros((len(detections), c.shape[0], c.shape[1]))
for i, d in enumerate(detections):
tracks[i] = d.data(DetectionData.TRACK_ID.value)
ids[i] = d.data(DetectionData.ID.value)
frames[i] = d.data(DetectionData.FRAME.value)
coordinates[i, :, :] = d.data(DetectionData.COORDINATES.value)
self._data.setUserSelection(ids)
self._controls_widget.setSelectedTracks(tracks)
self._skeleton.clear()
self._skeleton.addSkeletons(coordinates, ids, frames, tracks, QBrush(QColor(10, 255, 65, 255)))
self.update()
def moveWindow(self, stepsize):
step = np.round(stepsize * (self._currentWindowWidth))
new_start_frame = self._currentWindowPos + step
self._timeline.setWindowPos(new_start_frame / self._maxframes)
self._currentWindowPos = new_start_frame
self._controls_widget.setSelectedTracks(None)
self.update()
def on_forward(self, stepsize):
logging.debug("Tracks: receive forward command with step-size: %.2f", stepsize)
self.moveWindow(stepsize)
def on_backward(self, stepsize):
logging.debug("Tracks: receive backward command with step-size: %.2f", stepsize)
self.moveWindow(-stepsize)
def main():
from PySide6.QtWidgets import QApplication
app = QApplication([])
window = QWidget()
window.setMinimumSize(200, 200)
layout = QVBoxLayout()
controls = SelectionControls()
layout.addWidget(controls)
window.setLayout(layout)
window.show()
app.exec()
if __name__ == "__main__":
main()