Compare commits

..

8 Commits

6 changed files with 110 additions and 58 deletions

View File

@ -7,6 +7,7 @@ class DetectionData(Enum):
FRAME = 1
COORDINATES = 2
TRACK_ID = 3
USERLABELED = 4
class Tracks(Enum):
TRACKONE = 1

View File

@ -14,8 +14,8 @@ class TrackingData(QObject):
if "userlabeled" not in self._data.keys():
self._data["userlabeled"] = np.zeros_like(self["frame"], dtype=bool)
self._columns = [k for k in self._data.keys()]
self._indices = self["index"]
self._selection = np.asarray([])
self._indices = np.arange(len(self["index"]), dtype=int)
self._selection_indices = np.asarray([])
self._selected_ids = None
@property
@ -43,8 +43,8 @@ class TrackingData(QObject):
ids = np.sort(ids)
indexes = np.ones_like(ids, dtype=int) * -1
j = 0
for idx, i in enumerate(self._indices):
if i == ids[j]:
for idx in self._indices:
if self["index"][idx] == ids[j]:
indexes[j] = idx
j += 1
if j == len(indexes):
@ -54,19 +54,23 @@ class TrackingData(QObject):
@property
def selectionIndices(self):
return self._selection
return self._selection_indices
@property
def selectionIDs(self):
return self._selected_ids
def setSelectionRange(self, col, start, stop):
logging.info("Trackingdata: set selection range based on column %s to %.2f - %.2f", col, start, stop)
col_indices = np.where((self[col] >= start) & (self[col] < stop))[0]
self._selection = self._indices[col_indices]
self._selection_indices = self._indices[col_indices]
if len(col_indices) < 1:
logging.warning("TrackingData: Selection range is empty!")
def selectedData(self, col:str):
if col not in self.columns:
logging.error("TrackingData:selectedData: Invalid column name! %s", col)
return self[col][self._selection]
return self[col][self._selection_indices]
def setSelection(self, ids):
"""
@ -78,8 +82,9 @@ class TrackingData(QObject):
An array-like object containing the IDs to be set as user selections.
"""
logging.debug("TrackingData.setSelection: %i number of ids", len(ids))
self._selection = self._find(ids)
self._selection_indices = self._find(ids)
self._selected_ids = ids
print(self._selected_ids, self._selection_indices)
def setTrack(self, track_id:int, setUserLabeled:bool=True)-> None:
"""Assign a new track_id to the user-selected detections
@ -92,9 +97,12 @@ class TrackingData(QObject):
Should the "userlabeled" state of the detections be set to True? Otherwise they will be left untouched.
"""
logging.info("TrackingData: set track id %i for selection, set user-labeled status %s", track_id, str(setUserLabeled))
self["track"][self._selection] = track_id
print(self._selected_ids, self._selection_indices)
print("before: ", self["track"][self._selection_indices], self["frame"][self._selection_indices])
self["track"][self._selection_indices] = track_id
if setUserLabeled:
self.setUserLabeledStatus(True, True)
print("after: ", self["track"][self._selection_indices], self["frame"][self._selection_indices])
def setUserLabeledStatus(self, new_status: bool, selection=True):
"""Sets the status of the "userlabeled" column to a given value (True|False). This can done for ALL data in one go, or only for the UserSelection.
@ -109,7 +117,7 @@ class TrackingData(QObject):
logging.debug("TrackingData: (Re-)setting assignment status of %s to %s",
"user selected data" if selection else " ALL", str(new_status))
if selection:
self["userlabeled"][self._selection] = new_status
self["userlabeled"][self._selection_indices] = new_status
else:
self["userlabeled"][:] = new_status
@ -123,12 +131,14 @@ class TrackingData(QObject):
def deleteDetections(self, ids=None):
if ids is not None:
logging.debug("TrackingData.deleteDetections of %i detections", len(ids))
del_indices = self._find(ids)
else:
del_indices = self._indices
logging.debug("TrackingData.deleteDetections of all selected detections (%i)", len(self._selected_ids))
del_indices = self._selected_ids
for c in self._columns:
self._data[c] = np.delete(self._data[c], del_indices, axis=0)
self._indices = self["index"]
self._indices = self._indices[:-len(del_indices)]
self._selected_ids = np.setdiff1d(self._selected_ids, del_indices)
def assignTracks(self, tracks:np.ndarray):
@ -171,10 +181,10 @@ class TrackingData(QObject):
and M is number of keypoints
"""
if selection:
if len(self._selection) < 1:
if len(self._selection_indices) < 1:
logging.info("TrackingData.coordinates returns empty array, not detections in range!")
return np.ndarray([])
return np.stack(self["keypoints"][self._selection]).astype(np.float32)
return np.stack(self["keypoints"][self._selection_indices]).astype(np.float32)
return np.stack(self["keypoints"]).astype(np.float32)
def keypointScores(self, selection=False):
@ -188,10 +198,10 @@ class TrackingData(QObject):
with N the number of detections and M the number of keypoints.
"""
if selection:
if len(self._selection) < 1:
if len(self._selection_indices) < 1:
logging.info("TrackingData.scores returns empty array, not detections in range!")
return None
return np.stack(self["keypoint_score"][self._selection]).astype(np.float32)
return np.stack(self["keypoint_score"][self._selection_indices]).astype(np.float32)
return np.stack(self["keypoint_score"]).astype(np.float32)
def centerOfGravity(self, selection=False, threshold=0.8, nodes=[0,1,2]):

View File

@ -247,7 +247,7 @@ class ConsistencyWorker(QRunnable):
if steps > 0 and f % steps == 0:
progress += 1
self.signals.progress.emit(progress, processed, errors)
self.signals.message.emit("Tracking stopped at frame %i.", f)
self.signals.message.emit(f"Tracking stopped at frame {f}.")
self.signals.stopped.emit(f)
@ -548,11 +548,17 @@ class ConsistencyClassifier(QWidget):
self._tracks = self._dataworker.tracks
self._dataworker = None
if np.sum(self._userlabeled) < 1:
logging.error("ConsistencyTracker: I need at least 1 user-labeled frame to start with!")
msg = "ConsistencyTracker: I need at least 1 user-labeled frame to start with!"
logging.error(msg)
self._messagebox.append(msg)
self.setEnabled(False)
else:
t1_userlabeled = self._frames[self._userlabeled & (self._tracks == 1)]
t2_userlabeled = self._frames[self._userlabeled & (self._tracks == 2)]
if any([len(t1_userlabeled) == 0, len(t2_userlabeled)== 0]):
self._messagebox.append("Error preparing data! Make sure that the first user-labeled frames contain both tracks!")
self.setEnabled(False)
return
max_startframe = np.min([t1_userlabeled[-1], t2_userlabeled[-1]])
min_startframe = np.max([t1_userlabeled[0], t2_userlabeled[0]])
self._maxframes = np.max(self._frames)
@ -574,10 +580,7 @@ class ConsistencyClassifier(QWidget):
def stop(self):
if self._worker is not None:
self._worker.stop()
self._startbtn.setEnabled(True)
self._proceedbtn.setEnabled(True)
self._stopbtn.setEnabled(False)
self._refreshbtn.setEnabled(True)
self._messagebox.append("Stopping tracking.")
def start(self):
self._startbtn.setEnabled(False)
@ -590,6 +593,7 @@ class ConsistencyClassifier(QWidget):
self._worker.signals.progress.connect(self.worker_progress)
self._worker.signals.message.connect(self.worker_error)
self._worker.signals.currentframe.connect(self.worker_frame)
self._messagebox.append("Tracking in progress ...")
self.threadpool.start(self._worker)
def worker_frame(self, frame):
@ -602,8 +606,11 @@ class ConsistencyClassifier(QWidget):
self.start()
def refresh(self):
self.setEnabled(False)
self._dataworker = ConsitencyDataLoader(self._data)
self._dataworker.signals.stopped.connect(self.data_processed)
self._messagebox.clear()
self._messagebox.append("Refreshing...")
self.threadpool.start(self._dataworker)
def worker_progress(self, progress, processed, errors):
@ -612,13 +619,15 @@ class ConsistencyClassifier(QWidget):
self._assignedlabel.setText(str(processed))
def worker_stopped(self, frame):
self._apply_btn.setEnabled(True)
self._startbtn.setEnabled(True)
self._proceedbtn.setEnabled(True)
self._stopbtn.setEnabled(False)
self._apply_btn.setEnabled(True)
self._refreshbtn.setEnabled(True)
self._startframe_spinner.setValue(frame-1)
self._proceedbtn.setEnabled(bool(frame < self._maxframes-1))
self._refreshbtn.setEnabled(True)
self._processed_frames = frame
self._messagebox.append("... done.")
def assignedTracks(self):
return self._tracks
@ -683,7 +692,7 @@ def main():
import pickle
from fixtracks.info import PACKAGE_ROOT
datafile = PACKAGE_ROOT / "data/merged_small_starter.pkl"
datafile = PACKAGE_ROOT / "data/merged_small_beginning.pkl"
with open(datafile, "rb") as f:
df = pickle.load(f)

View File

@ -128,20 +128,20 @@ class DetectionView(QWidget):
del it
def updateDetections(self, keypoint=-1):
logging.info("DetectionView.updateDetections!")
self.clearDetections()
if self._data is None:
return
frames = self._data.selectedData("frame")
tracks = self._data.selectedData("track")
ids = self._data.selectedData("index")
coordinates = self._data.coordinates(selection=True)
centercoordinates = self._data.centerOfGravity(selection=True)
userlabeled = self._data.selectedData("userlabeled")
indices = self._data.selectionIndices
image_rect = self._pixmapitem.boundingRect() if self._pixmapitem is not None else QRectF(0,0,0,0)
for i, idx in enumerate(indices):
t = tracks[i]
for i, (id, f, t, l) in enumerate(zip(ids, frames, tracks, userlabeled)):
c = Tracks.fromValue(t).toColor()
if keypoint >= 0:
x = coordinates[i, keypoint, 0]
@ -151,10 +151,11 @@ class DetectionView(QWidget):
y = centercoordinates[i, 1]
item = Detection(image_rect.left() + x, image_rect.top() + y, 20, 20, brush=QBrush(c))
item.setData(DetectionData.TRACK_ID.value, tracks[i])
item.setData(DetectionData.ID.value, idx)
item.setData(DetectionData.TRACK_ID.value, t)
item.setData(DetectionData.ID.value, id)
item.setData(DetectionData.COORDINATES.value, coordinates[i, :, :])
item.setData(DetectionData.FRAME.value, frames[i])
item.setData(DetectionData.FRAME.value, f)
item.setData(DetectionData.USERLABELED.value, l)
item = self._scene.addItem(item)
def fit_image_to_view(self):

View File

@ -4,7 +4,7 @@ import numpy as np
from PySide6.QtCore import Qt, Signal, QSize
from PySide6.QtGui import QFont
from PySide6.QtWidgets import QWidget, QLabel, QPushButton, QSizePolicy
from PySide6.QtWidgets import QGridLayout, QVBoxLayout
from PySide6.QtWidgets import QGridLayout, QVBoxLayout, QApplication
from fixtracks.utils.styles import pushBtnStyle
@ -15,6 +15,7 @@ class SelectionControls(QWidget):
assignTwo = Signal()
assignOther = Signal()
accept = Signal()
accept_until = Signal()
unaccept = Signal()
delete = Signal()
revertall = Signal()
@ -102,7 +103,7 @@ class SelectionControls(QWidget):
acceptBtn.setFont(font)
acceptBtn.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding)
acceptBtn.setStyleSheet(pushBtnStyle("darkgray"))
acceptBtn.setToolTip(f"Accept assignments of current selection as TRUE")
acceptBtn.setToolTip(f"Accept assignments of current selection as TRUE, Hold shift while clicking to accept all until here.")
acceptBtn.clicked.connect(self.on_Accept)
unacceptBtn = QPushButton("un-accept")
@ -117,8 +118,9 @@ class SelectionControls(QWidget):
deleteBtn.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding)
deleteBtn.setStyleSheet(pushBtnStyle("red"))
deleteBtn.setToolTip(f"DANGERZONE! Delete current selection of detections!")
deleteBtn.setEnabled(False)
deleteBtn.setShortcut("Ctrl+D")
deleteBtn.clicked.connect(self.on_Delete)
deleteBtn.setEnabled(False)
revertBtn = QPushButton("revert assignments")
revertBtn.setFont(font)
@ -210,7 +212,12 @@ class SelectionControls(QWidget):
def on_Accept(self):
logging.debug("SelectionControl: accept AssignmentBtn")
self.accept.emit()
modifiers = QApplication.keyboardModifiers()
if modifiers == Qt.KeyboardModifier.ShiftModifier:
logging.debug("Shift key was pressed during accept")
self.accept_until.emit()
else:
self.accept.emit()
def on_Unaccept(self):
logging.debug("SelectionControl: revoke user assignmentBtn")

View File

@ -1,11 +1,11 @@
import logging
import numpy as np
import pandas as pd
from PySide6.QtCore import Qt, QThreadPool, Signal
from PySide6.QtGui import QImage, QBrush, QColor
from PySide6.QtWidgets import QWidget, QVBoxLayout, QHBoxLayout, QLabel, QPushButton, QSizePolicy, QComboBox
from PySide6.QtWidgets import QSpinBox, QSpacerItem, QProgressBar, QSplitter, QGridLayout, QFileDialog, QGridLayout
from PySide6.QtWidgets import QSpinBox, QSpacerItem, QProgressBar, QSplitter, QFileDialog, QMessageBox
from fixtracks.utils.reader import PickleLoader
from fixtracks.utils.writer import PickleWriter
@ -37,12 +37,6 @@ class FixTracks(QWidget):
self._detectionView = DetectionView()
self._detectionView.signals.itemsSelected.connect(self.on_detectionsSelected)
self._skeleton = SkeletonWidget()
# self._skeleton.setMaximumSize(QSize(400, 400))
top_splitter = QSplitter(Qt.Orientation.Horizontal)
top_splitter.addWidget(self._detectionView)
top_splitter.addWidget(self._skeleton)
top_splitter.setStretchFactor(0, 2)
top_splitter.setStretchFactor(1, 1)
self._progress_bar = QProgressBar(self)
self._progress_bar.setMaximumHeight(20)
@ -61,15 +55,18 @@ class FixTracks(QWidget):
self._keypointcombo = QComboBox()
self._keypointcombo.currentIndexChanged.connect(self.on_keypointSelected)
combo_layout = QGridLayout()
combo_layout.addWidget(QLabel("Window:"), 0, 0)
combo_layout.addWidget(self._windowspinner, 0, 1)
combo_layout.addWidget(QLabel("Keypoint:"), 1, 0)
combo_layout.addWidget(self._keypointcombo, 1, 1)
combo_layout = QHBoxLayout()
combo_layout.addWidget(QLabel("Window width:"))
combo_layout.addWidget(self._windowspinner)
combo_layout.addWidget(QLabel("frames"))
combo_layout.addWidget(QLabel("Keypoint:"))
combo_layout.addWidget(self._keypointcombo)
combo_layout.addItem(QSpacerItem(100, 10, QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Fixed))
timelinebox = QHBoxLayout()
timelinebox.addWidget(self._timeline)
timelinebox = QVBoxLayout()
timelinebox.setSpacing(2)
timelinebox.addLayout(combo_layout)
timelinebox.addWidget(self._timeline)
self._controls_widget = SelectionControls()
self._controls_widget.assignOne.connect(self.on_assignOne)
@ -78,6 +75,7 @@ class FixTracks(QWidget):
self._controls_widget.fwd.connect(self.on_forward)
self._controls_widget.back.connect(self.on_backward)
self._controls_widget.accept.connect(self.on_setUserFlag)
self._controls_widget.accept_until.connect(self.on_setUserFlagsUntil)
self._controls_widget.unaccept.connect(self.on_unsetUserFlag)
self._controls_widget.delete.connect(self.on_deleteDetection)
self._controls_widget.revertall.connect(self.on_revertUserFlags)
@ -118,7 +116,8 @@ class FixTracks(QWidget):
cntrlBox = QHBoxLayout()
cntrlBox.addWidget(self._classifier)
cntrlBox.addWidget(self._controls_widget, alignment=Qt.AlignmentFlag.AlignCenter)
cntrlBox.addItem(QSpacerItem(300, 100, QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Expanding))
cntrlBox.addWidget(self._skeleton)
cntrlBox.addItem(QSpacerItem(50, 100, QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Expanding))
vbox = QVBoxLayout()
vbox.addLayout(timelinebox)
@ -128,7 +127,7 @@ class FixTracks(QWidget):
container.setLayout(vbox)
splitter = QSplitter(Qt.Orientation.Vertical)
splitter.addWidget(top_splitter)
splitter.addWidget(self._detectionView)
splitter.addWidget(container)
splitter.setStretchFactor(0, 3)
splitter.setStretchFactor(1, 1)
@ -270,6 +269,12 @@ class FixTracks(QWidget):
self._timeline.update()
self.update()
def on_setUserFlagsUntil(self):
self._data.setSelectionRange("frame", 0, self._currentWindowPos + self._currentWindowWidth)
self._data.setUserLabeledStatus(True)
self._timeline.update()
self.update()
def on_unsetUserFlag(self):
logging.debug("Tracks:unsetUserFlag")
self._data.setUserLabeledStatus(False)
@ -278,14 +283,30 @@ class FixTracks(QWidget):
def on_revertUserFlags(self):
logging.debug("Tracks:revert ALL UserFlags and track assignments")
self._data.revertUserLabeledStatus()
self._data.revertTrackAssignments()
msg_box = QMessageBox()
msg_box.setIcon(QMessageBox.Icon.Warning)
msg_box.setText(f"Are you sure you want to revert ALL track assignments?")
msg_box.setStandardButtons(QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No)
msg_box.setDefaultButton(QMessageBox.StandardButton.No)
ret = msg_box.exec()
if ret == QMessageBox.StandardButton.Yes:
self._data.revertUserLabeledStatus()
self._data.revertTrackAssignments()
self._timeline.update()
self.update()
def on_deleteDetection(self):
logging.warning("Tracks:delete detections is currently not supported!")
# self._data.deleteDetections()
logging.info("Tracks:deleting detections!")
msg_box = QMessageBox()
msg_box.setIcon(QMessageBox.Icon.Warning)
msg_box.setText(f"Are you sure you want to delete the selected ({len(self._data.selectionIndices)})detections?")
msg_box.setStandardButtons(QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No)
msg_box.setDefaultButton(QMessageBox.StandardButton.No)
ret = msg_box.exec()
if ret == QMessageBox.StandardButton.Yes:
self._data.deleteDetections()
self._timeline.update()
self.update()
@ -306,7 +327,10 @@ class FixTracks(QWidget):
"""
self._currentWindowWidth = value
logging.debug("Tracks:OnWindowSizeChanged %i franes", value)
self._timeline.setWindowWidth(self._currentWindowWidth / self._maxframes)
if self._maxframes == 0:
self._timeline.setWindowWidth(self._currentWindowWidth / 2000)
else:
self._timeline.setWindowWidth(self._currentWindowWidth / self._maxframes)
self._controls_widget.setSelectedTracks(None)
def on_detectionsSelected(self, detections):