436 lines
17 KiB
Python
436 lines
17 KiB
Python
import logging
|
|
import numpy as np
|
|
|
|
from PySide6.QtCore import Qt, QThreadPool, Signal
|
|
from PySide6.QtGui import QImage, QBrush, QColor
|
|
from PySide6.QtWidgets import QWidget, QVBoxLayout, QHBoxLayout, QLabel, QPushButton, QSizePolicy, QComboBox
|
|
from PySide6.QtWidgets import QSpinBox, QSpacerItem, QProgressBar, QSplitter, QFileDialog, QMessageBox
|
|
|
|
|
|
from fixtracks.utils.reader import PickleLoader
|
|
from fixtracks.utils.writer import PickleWriter
|
|
from fixtracks.utils.trackingdata import TrackingData
|
|
from fixtracks.widgets.detectionview import DetectionView, DetectionData
|
|
from fixtracks.widgets.detectiontimeline import DetectionTimeline
|
|
from fixtracks.widgets.skeleton import SkeletonWidget
|
|
from fixtracks.widgets.classifier import ClassifierWidget
|
|
from fixtracks.widgets.selection_control import SelectionControls
|
|
|
|
class FixTracks(QWidget):
|
|
back = Signal()
|
|
trackone_id = 1
|
|
tracktwo_id = 2
|
|
trackother_id = -1
|
|
|
|
def __init__(self, parent=None):
|
|
super().__init__(parent)
|
|
self._files = []
|
|
self._threadpool = QThreadPool()
|
|
self._reader = None
|
|
self._image = None
|
|
self._currentWindowPos = 0 # in frames
|
|
self._currentWindowWidth = 0 # in frames
|
|
self._maxframes = 0
|
|
self._manualmove = False
|
|
self._data = None
|
|
|
|
self._detectionView = DetectionView()
|
|
self._detectionView.signals.itemsSelected.connect(self.on_detectionsSelected)
|
|
self._skeleton = SkeletonWidget()
|
|
|
|
self._progress_bar = QProgressBar(self)
|
|
self._progress_bar.setMaximumHeight(20)
|
|
self._progress_bar.setValue(0)
|
|
self._tasklabel = QLabel()
|
|
|
|
self._timeline = DetectionTimeline()
|
|
self._timeline.signals.windowMoved.connect(self.on_windowChanged)
|
|
self._timeline.signals.moveRequest.connect(self.on_moveRequest)
|
|
|
|
self._windowspinner = QSpinBox()
|
|
self._windowspinner.setRange(10, 10000)
|
|
self._windowspinner.setSingleStep(50)
|
|
self._windowspinner.setValue(500)
|
|
self._windowspinner.valueChanged.connect(self.on_windowSizeChanged)
|
|
|
|
self._keypointcombo = QComboBox()
|
|
self._keypointcombo.currentIndexChanged.connect(self.on_keypointSelected)
|
|
|
|
self._goto_spinner = QSpinBox()
|
|
self._goto_spinner.setSingleStep(1)
|
|
|
|
self._gotobtn = QPushButton("go!")
|
|
self._gotobtn.setToolTip("Jump to a given frame")
|
|
self._gotobtn.clicked.connect(self.on_goto)
|
|
|
|
combo_layout = QHBoxLayout()
|
|
combo_layout.addWidget(QLabel("Window width:"))
|
|
combo_layout.addWidget(self._windowspinner)
|
|
combo_layout.addWidget(QLabel("frames"))
|
|
combo_layout.addItem(QSpacerItem(10, 10, QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Fixed))
|
|
combo_layout.addWidget(QLabel("Keypoint:"))
|
|
combo_layout.addWidget(self._keypointcombo)
|
|
combo_layout.addItem(QSpacerItem(10, 10, QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Fixed))
|
|
combo_layout.addWidget(QLabel("Jump to frame:"))
|
|
combo_layout.addWidget(self._goto_spinner)
|
|
combo_layout.addWidget(self._gotobtn)
|
|
combo_layout.addItem(QSpacerItem(100, 10, QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Fixed))
|
|
combo_layout.setSpacing(1)
|
|
|
|
timelinebox = QVBoxLayout()
|
|
timelinebox.setSpacing(2)
|
|
timelinebox.addLayout(combo_layout)
|
|
timelinebox.addWidget(self._timeline)
|
|
|
|
self._controls_widget = SelectionControls()
|
|
self._controls_widget.assignOne.connect(self.on_assignOne)
|
|
self._controls_widget.assignTwo.connect(self.on_assignTwo)
|
|
self._controls_widget.assignOther.connect(self.on_assignOther)
|
|
self._controls_widget.fwd.connect(self.on_forward)
|
|
self._controls_widget.back.connect(self.on_backward)
|
|
self._controls_widget.accept.connect(self.on_setUserFlag)
|
|
self._controls_widget.accept_until.connect(self.on_setUserFlagsUntil)
|
|
self._controls_widget.unaccept.connect(self.on_unsetUserFlag)
|
|
self._controls_widget.delete.connect(self.on_deleteDetection)
|
|
self._controls_widget.revertall.connect(self.on_revertUserFlags)
|
|
|
|
self._saveBtn = QPushButton("Save")
|
|
self._saveBtn.setShortcut("Ctrl+S")
|
|
self._saveBtn.setEnabled(False)
|
|
self._saveBtn.clicked.connect(self.on_save)
|
|
self._backBtn = QPushButton("Back")
|
|
self._backBtn.setShortcut("ESC")
|
|
self._backBtn.clicked.connect(self.on_back)
|
|
|
|
self._data_combo = QComboBox()
|
|
self._data_combo.addItems(self._files)
|
|
self._data_combo.currentIndexChanged.connect(self.on_dataSelection)
|
|
self._image_combo = QComboBox()
|
|
self._image_combo.addItems(self._files)
|
|
self._image_combo.currentIndexChanged.connect(self.on_imageSelection)
|
|
|
|
data_selection_box = QHBoxLayout()
|
|
data_selection_box.addWidget(QLabel("Select image file"))
|
|
data_selection_box.addWidget(self._image_combo)
|
|
data_selection_box.addWidget(QLabel("Select data file"))
|
|
data_selection_box.addWidget(self._data_combo)
|
|
data_selection_box.addItem(QSpacerItem(100, 10, QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Fixed))
|
|
data_selection_box.setSpacing(0)
|
|
|
|
btnBox = QHBoxLayout()
|
|
btnBox.setAlignment(Qt.AlignmentFlag.AlignLeft)
|
|
btnBox.addWidget(self._backBtn)
|
|
btnBox.addItem(QSpacerItem(100, 10, QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Fixed))
|
|
btnBox.addWidget(self._tasklabel)
|
|
btnBox.addWidget(self._progress_bar)
|
|
btnBox.addWidget(self._saveBtn)
|
|
|
|
self._classifier = ClassifierWidget()
|
|
self._classifier.apply_classifier.connect(self.on_autoClassify)
|
|
self._classifier.setMaximumWidth(500)
|
|
cntrlBox = QHBoxLayout()
|
|
cntrlBox.addWidget(self._classifier)
|
|
cntrlBox.addWidget(self._controls_widget, alignment=Qt.AlignmentFlag.AlignCenter)
|
|
cntrlBox.addWidget(self._skeleton)
|
|
cntrlBox.addItem(QSpacerItem(50, 100, QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Expanding))
|
|
cntrlBox.setSpacing(0)
|
|
cntrlBox.setContentsMargins(0,0,0,0)
|
|
|
|
vbox = QVBoxLayout()
|
|
vbox.setSpacing(0)
|
|
vbox.setContentsMargins(0,0,0,0)
|
|
vbox.addLayout(timelinebox)
|
|
vbox.addLayout(cntrlBox)
|
|
vbox.addLayout(btnBox)
|
|
container = QWidget()
|
|
container.setLayout(vbox)
|
|
|
|
splitter = QSplitter(Qt.Orientation.Vertical)
|
|
splitter.addWidget(self._detectionView)
|
|
splitter.addWidget(container)
|
|
splitter.setStretchFactor(0, 3)
|
|
splitter.setStretchFactor(1, 1)
|
|
|
|
layout = QVBoxLayout()
|
|
layout.addLayout(data_selection_box)
|
|
layout.addWidget(splitter)
|
|
layout.setSpacing(0)
|
|
layout.setContentsMargins(5,2,2,5)
|
|
self.setLayout(layout)
|
|
|
|
def on_autoClassify(self, tracks):
|
|
self._data.setSelectionRange("index", 0, self._data.numDetections)
|
|
self._data.assignTracks(tracks)
|
|
self._timeline.update()
|
|
self.update()
|
|
|
|
def on_dataSelection(self):
|
|
filename = self._data_combo.currentText()
|
|
if "please select" in filename.lower() or len(filename.strip()) == 0:
|
|
return
|
|
self._progress_bar.setRange(0,0)
|
|
self._reader = PickleLoader(filename)
|
|
self._reader.signals.finished.connect(self._on_dataOpenend)
|
|
self._threadpool.start(self._reader)
|
|
|
|
def on_imageSelection(self):
|
|
filename = self._image_combo.currentText()
|
|
if "please select" in filename.lower() or len(filename.strip()) == 0:
|
|
return
|
|
img = QImage(filename)
|
|
self._detectionView.setImage(img)
|
|
|
|
def update(self):
|
|
kp = self._keypointcombo.currentText().lower()
|
|
if len(kp) == 0:
|
|
return
|
|
kpi = -1 if "center" in kp else int(kp)
|
|
|
|
start_frame = self._currentWindowPos
|
|
stop_frame = start_frame + self._currentWindowWidth
|
|
|
|
self._timeline.setWindow(start_frame / self._maxframes,
|
|
self._currentWindowWidth/self._maxframes)
|
|
logging.debug("Tracks:update: Updating View for detection range %i, %i frames", start_frame, stop_frame)
|
|
self._data.setSelectionRange("frame", start_frame, stop_frame)
|
|
self._controls_widget.setWindow(start_frame, stop_frame)
|
|
self._detectionView.updateDetections(kpi)
|
|
|
|
@property
|
|
def fileList(self):
|
|
return self._files
|
|
|
|
@fileList.setter
|
|
def fileList(self, file_list):
|
|
logging.debug("FixTracks.fileList: set new file list")
|
|
self._files = []
|
|
self._image_combo.clear()
|
|
self._data_combo.clear()
|
|
|
|
img_formats = [".jpg", ".png"]
|
|
self._files = [str(f) for f in file_list if f.suffix in img_formats]
|
|
self._image_combo.addItem("Please select")
|
|
self._image_combo.addItems(self.fileList)
|
|
self._image_combo.setCurrentIndex(0)
|
|
|
|
dataformats = [".pkl"]
|
|
self._files = [str(f) for f in file_list if f.suffix in dataformats]
|
|
self._data_combo.addItem("Please select")
|
|
self._data_combo.addItems(self.fileList)
|
|
self._data_combo.setCurrentIndex(0)
|
|
|
|
def populateKeypointCombo(self, num_keypoints):
|
|
self._keypointcombo.clear()
|
|
self._keypointcombo.addItem("Center")
|
|
for i in range(num_keypoints):
|
|
self._keypointcombo.addItem(str(i))
|
|
self._keypointcombo.setCurrentIndex(0)
|
|
|
|
def _on_dataOpenend(self, state):
|
|
self._tasklabel.setText("")
|
|
self._progress_bar.setRange(0, 100)
|
|
self._progress_bar.setValue(0)
|
|
if state and self._reader is not None:
|
|
self._data = TrackingData(self._reader.asdict)
|
|
self._saveBtn.setEnabled(True)
|
|
self._currentWindowPos = 0
|
|
self._currentWindowWidth = self._windowspinner.value()
|
|
self._maxframes = np.max(self._data["frame"])
|
|
self._goto_spinner.setMaximum(self._maxframes)
|
|
self.populateKeypointCombo(self._data.numKeypoints())
|
|
self._timeline.setData(self._data)
|
|
# self._timeline.setWindow(self._currentWindowPos / self._maxframes,
|
|
# self._currentWindowWidth / self._maxframes)
|
|
self._detectionView.setData(self._data)
|
|
self._classifier.setData(self._data)
|
|
self.update()
|
|
logging.info("Finished loading data: %i frames", self._maxframes)
|
|
|
|
def on_keypointSelected(self):
|
|
self.update()
|
|
|
|
def on_save(self):
|
|
logging.debug("Saving fixtracks results")
|
|
self._tasklabel.setText("Saving results to file...")
|
|
file_dialog = QFileDialog(self)
|
|
file_dialog.setAcceptMode(QFileDialog.AcceptMode.AcceptSave)
|
|
file_dialog.setNameFilter("Pickle Files (*.pkl)")
|
|
if file_dialog.exec():
|
|
file_path = file_dialog.selectedFiles()[0]
|
|
if not file_path.endswith(".pkl"):
|
|
file_path += ".pkl"
|
|
self._progress_bar.setRange(0,0)
|
|
save_task = PickleWriter(self._data, file_path)
|
|
save_task.signals.finished.connect(self.on_dataSaved)
|
|
self._threadpool.start(save_task)
|
|
|
|
def on_dataSaved(self):
|
|
self._tasklabel.setText("")
|
|
self._progress_bar.setRange(0, 100)
|
|
self._progress_bar.setValue(0)
|
|
|
|
def on_back(self):
|
|
logging.debug("Back button pressed!")
|
|
self.back.emit()
|
|
|
|
def on_assignOne(self):
|
|
logging.debug("Assigning user selection to track One")
|
|
self._data.setTrack(self.trackone_id)
|
|
self._timeline.update()
|
|
self.update()
|
|
|
|
def on_assignTwo(self):
|
|
logging.debug("Assigning user selection to track Two")
|
|
self._data.setTrack(self.tracktwo_id)
|
|
self._timeline.update()
|
|
self.update()
|
|
|
|
def on_assignOther(self):
|
|
logging.debug("Assigning user selection to track Other")
|
|
self._data.setTrack(self.trackother_id, False)
|
|
self._timeline.update()
|
|
self.update()
|
|
|
|
def on_setUserFlag(self):
|
|
self._data.setUserLabeledStatus(True)
|
|
self._timeline.update()
|
|
self.update()
|
|
|
|
def on_setUserFlagsUntil(self):
|
|
self._data.setSelectionRange("frame", 0, self._currentWindowPos + self._currentWindowWidth)
|
|
self._data.setUserLabeledStatus(True)
|
|
self._timeline.update()
|
|
self.update()
|
|
|
|
def on_unsetUserFlag(self):
|
|
logging.debug("Tracks:unsetUserFlag")
|
|
self._data.setUserLabeledStatus(False)
|
|
self._timeline.update()
|
|
self.update()
|
|
|
|
def on_revertUserFlags(self):
|
|
logging.debug("Tracks:revert ALL UserFlags and track assignments")
|
|
msg_box = QMessageBox()
|
|
msg_box.setIcon(QMessageBox.Icon.Warning)
|
|
msg_box.setText(f"Are you sure you want to revert ALL track assignments?")
|
|
msg_box.setStandardButtons(QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No)
|
|
msg_box.setDefaultButton(QMessageBox.StandardButton.No)
|
|
ret = msg_box.exec()
|
|
|
|
if ret == QMessageBox.StandardButton.Yes:
|
|
self._data.revertUserLabeledStatus()
|
|
self._data.revertTrackAssignments()
|
|
self._timeline.update()
|
|
self.update()
|
|
|
|
def on_deleteDetection(self):
|
|
logging.info("Tracks:deleting detections!")
|
|
msg_box = QMessageBox()
|
|
msg_box.setIcon(QMessageBox.Icon.Warning)
|
|
msg_box.setText(f"Are you sure you want to delete the selected ({len(self._data.selectionIndices)})detections?")
|
|
msg_box.setStandardButtons(QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No)
|
|
msg_box.setDefaultButton(QMessageBox.StandardButton.No)
|
|
ret = msg_box.exec()
|
|
|
|
if ret == QMessageBox.StandardButton.Yes:
|
|
self._data.deleteDetections()
|
|
self._timeline.update()
|
|
self.update()
|
|
|
|
def on_windowChanged(self):
|
|
logging.debug("Tracks:Timeline reports window change ")
|
|
if not self._manualmove:
|
|
self._currentWindowPos = np.round(self._timeline.rangeStart * self._maxframes)
|
|
self.update()
|
|
self._manualmove = False
|
|
|
|
def on_moveRequest(self, pos):
|
|
new_pos = int(np.round(pos * self._maxframes))
|
|
self._currentWindowPos = new_pos
|
|
self.update()
|
|
|
|
def on_windowSizeChanged(self, value):
|
|
"""Reacts on the user window-width selection. Selection is done in the unit of frames.
|
|
|
|
Parameters
|
|
----------
|
|
value : int
|
|
The width of the observation window in frames.
|
|
"""
|
|
self._currentWindowWidth = value
|
|
logging.debug("Tracks:OnWindowSizeChanged %i franes", value)
|
|
# if self._maxframes == 0:
|
|
# self._timeline.setWindowWidth(self._currentWindowWidth / 2000)
|
|
# else:
|
|
# self._timeline.setWindowWidth(self._currentWindowWidth / self._maxframes)
|
|
# self._controls_widget.setSelectedTracks(None)
|
|
self.update()
|
|
|
|
def on_goto(self):
|
|
target = self._goto_spinner.value()
|
|
if target > self._maxframes - self._currentWindowWidth:
|
|
target = self._maxframes - self._currentWindowWidth
|
|
logging.info("Jump to frame %i", target)
|
|
self._currentWindowPos = target
|
|
self._timeline.setWindow(self._currentWindowPos / self._maxframes,
|
|
self._currentWindowWidth / self._maxframes)
|
|
self.update()
|
|
|
|
def on_detectionsSelected(self, detections):
|
|
logging.debug("Tracks: %i Detections selected", len(detections))
|
|
tracks = np.zeros(len(detections), dtype=int)
|
|
ids = np.zeros_like(tracks)
|
|
frames = np.zeros_like(tracks)
|
|
scores = np.zeros(tracks.shape, dtype=float)
|
|
coordinates = None
|
|
if len(detections) > 0:
|
|
c = detections[0].data(DetectionData.COORDINATES.value)
|
|
coordinates = np.zeros((len(detections), c.shape[0], c.shape[1]))
|
|
|
|
for i, d in enumerate(detections):
|
|
tracks[i] = d.data(DetectionData.TRACK_ID.value)
|
|
ids[i] = d.data(DetectionData.ID.value)
|
|
frames[i] = d.data(DetectionData.FRAME.value)
|
|
coordinates[i, :, :] = d.data(DetectionData.COORDINATES.value)
|
|
scores[i] = d.data(DetectionData.SCORE.value)
|
|
self._data.setSelection(ids)
|
|
self._controls_widget.setSelectedTracks(tracks)
|
|
self._skeleton.clear()
|
|
self._skeleton.addSkeletons(coordinates, ids, frames, tracks, scores, QBrush(QColor(10, 255, 65, 255)))
|
|
|
|
def moveWindow(self, stepsize):
|
|
logging.info("Tracks.moveWindow: move window with stepsize %.2f", stepsize)
|
|
self._manualmove = True
|
|
new_start_frame = self._currentWindowPos + np.round(stepsize * self._currentWindowWidth)
|
|
if new_start_frame < 0:
|
|
new_start_frame = 0
|
|
elif new_start_frame + self._currentWindowWidth > self._maxframes:
|
|
new_start_frame = self._maxframes - self._currentWindowWidth
|
|
self._currentWindowPos = new_start_frame
|
|
self._controls_widget.setSelectedTracks(None)
|
|
self.update()
|
|
|
|
def on_forward(self, stepsize):
|
|
logging.debug("Tracks: receive forward command with step-size: %.2f", stepsize)
|
|
self.moveWindow(stepsize)
|
|
|
|
def on_backward(self, stepsize):
|
|
logging.debug("Tracks: receive backward command with step-size: %.2f", stepsize)
|
|
self.moveWindow(-stepsize)
|
|
|
|
|
|
def main():
|
|
from PySide6.QtWidgets import QApplication
|
|
app = QApplication([])
|
|
window = QWidget()
|
|
window.setMinimumSize(200, 200)
|
|
layout = QVBoxLayout()
|
|
controls = SelectionControls()
|
|
layout.addWidget(controls)
|
|
window.setLayout(layout)
|
|
window.show()
|
|
app.exec()
|
|
|
|
if __name__ == "__main__":
|
|
main()
|