From d176925796a5c8dea370fee33b6284752d38f673 Mon Sep 17 00:00:00 2001 From: Jan Grewe Date: Mon, 24 Feb 2025 16:03:05 +0100 Subject: [PATCH] [trackingdata] change selections, constructor ... renaming of some functions --- fixtracks/utils/trackingdata.py | 66 ++++++++++++++++++--------------- fixtracks/widgets/tracks.py | 16 ++++---- 2 files changed, 44 insertions(+), 38 deletions(-) diff --git a/fixtracks/utils/trackingdata.py b/fixtracks/utils/trackingdata.py index 0ae80f6..21769a7 100644 --- a/fixtracks/utils/trackingdata.py +++ b/fixtracks/utils/trackingdata.py @@ -7,21 +7,16 @@ from PySide6.QtCore import QObject class TrackingData(QObject): - def __init__(self, parent=None): + def __init__(self, datadict, parent=None): super().__init__(parent) - self._data = None - self._columns = [] - self._start = 0 - self._stop = 0 - self._indices = None - self._selection_column = None - self._user_selections = None - - def setData(self, datadict): assert isinstance(datadict, dict) self._data = datadict - self._data["userlabeled"] = np.zeros_like(self["frame"], dtype=bool) + if "userlabeled" not in self._data.keys(): + self._data["userlabeled"] = np.zeros_like(self["frame"], dtype=bool) self._columns = [k for k in self._data.keys()] + self._indices = self["index"] + self._selection = np.asarray([]) + self._selected_ids = None @property def data(self): @@ -56,11 +51,8 @@ class TrackingData(QObject): def setSelectionRange(self, col, start, stop): logging.debug("Trackingdata: set selection range based on column %s to %.2f - %.2f", col, start, stop) - self._start = start - self._stop = stop - self._selection_column = col - col_indices = np.where((self._data[col] >= self._start) & (self._data[col] < self._stop))[0] - self._indices = self["index"][col_indices] + col_indices = np.where((self[col] >= start) & (self[col] < stop))[0] + self._selection = self._indices[col_indices] if len(col_indices) < 1: logging.warning("TrackingData: Selection range is empty!") @@ -80,22 +72,36 @@ class TrackingData(QObject): """ self._user_selections = ids.astype(int) - def assignUserSelection(self, track_id:int, userFlag:bool=True)-> None: + def setTrack(self, track_id:int, setUserLabeled:bool=True)-> None: """Assign a new track_id to the user-selected detections Parameters ---------- track_id : int The new track id for the user-selected detections - userFlag : bool - Should the "userlabeled" state of the detections be set to True or False? + setUserLabeled : bool + Should the "userlabeled" state of the detections be set to True? Otherwise they will be left untouched. """ self["track"][self._user_selections] = track_id - self.setAssignmentStatus(userFlag) + if setUserLabeled: + self.setUserLabeledStatus(True, True) - def setAssignmentStatus(self, isTrue: bool): - logging.debug("TrackingData:Re-setting assignment status of user selected data to %s", str(isTrue)) - self["userlabeled"][self._user_selections] = isTrue + def setUserLabeledStatus(self, new_status: bool, selection=True): + """Sets the status of the "userlabeled" column to a given value (True|False). This can done for ALL data in one go, or only for the UserSelection. + + Parameters + ---------- + new_status : bool + The new status, TRUE, if the detections are confirmed by the user (human observer) and can be treated as correct + selection : bool, optional + Whether the new status should be set for the selection only (True, default) ore not (False) + """ + logging.debug("TrackingData: (Re-)setting assignment status of %s to %s", + "user selected data" if selection else " ALL", str(new_status)) + if selection: + self["userlabeled"][self._selection] = new_status + else: + self["userlabeled"][:] = new_status def revertAssignmentStatus(self): logging.debug("TrackingData:Un-setting assignment status of all data!") @@ -154,11 +160,11 @@ class TrackingData(QObject): and M is number of keypoints """ if selection: - if len(self._indices) < 1: + if len(self._selection) < 1: logging.info("TrackingData.coordinates returns empty array, not detections in range!") return np.ndarray([]) - return np.stack(self._data["keypoints"][self._indices]).astype(np.float32) - return np.stack(self._data["keypoints"]).astype(np.float32) + return np.stack(self["keypoints"][self._selection]).astype(np.float32) + return np.stack(self["keypoints"]).astype(np.float32) def keypointScores(self, selection=False): """ @@ -171,11 +177,11 @@ class TrackingData(QObject): with N the number of detections and M the number of keypoints. """ if selection: - if len(self._indices) < 1: + if len(self._selection) < 1: logging.info("TrackingData.scores returns empty array, not detections in range!") return None - return np.stack(self._data["keypoint_score"][self._indices]).astype(np.float32) - return np.stack(self._data["keypoint_score"]).astype(np.float32) + return np.stack(self["keypoint_score"][self._selection]).astype(np.float32) + return np.stack(self["keypoint_score"]).astype(np.float32) def centerOfGravity(self, selection=False, threshold=0.8, nodes=[0,1,2]): """ @@ -199,7 +205,7 @@ class TrackingData(QObject): return None scores[scores < threshold] = 0.0 scores[:, np.setdiff1d(np.arange(scores.shape[1]), nodes)] = 0.0 - weighted_coords = self.coordinates(selection=selection) * scores[:, :, np.newaxis] + weighted_coords = self.coordinates(selection) * scores[:, :, np.newaxis] sum_scores = np.sum(scores, axis=1, keepdims=True) cogs = np.zeros((weighted_coords.shape[0], 2)) diff --git a/fixtracks/widgets/tracks.py b/fixtracks/widgets/tracks.py index 4c4a524..8944959 100644 --- a/fixtracks/widgets/tracks.py +++ b/fixtracks/widgets/tracks.py @@ -31,7 +31,7 @@ class FixTracks(QWidget): self._currentWindowPos = 0 # in frames self._currentWindowWidth = 0 # in frames self._maxframes = 0 - self._data = TrackingData() + self._data = None self._detectionView = DetectionView() self._detectionView.signals.itemsSelected.connect(self.on_detectionsSelected) @@ -204,7 +204,7 @@ class FixTracks(QWidget): self._progress_bar.setRange(0, 100) self._progress_bar.setValue(0) if state and self._reader is not None: - self._data.setData(self._reader.asdict) + self._data = TrackingData(self._reader.asdict) self._saveBtn.setEnabled(True) self._currentWindowPos = 0 self._currentWindowWidth = self._windowspinner.value() @@ -247,30 +247,30 @@ class FixTracks(QWidget): def on_assignOne(self): logging.debug("Assigning user selection to track One") - self._data.assignUserSelection(self.trackone_id) + self._data.setTrack(self.trackone_id) self._timeline.update() self.update() def on_assignTwo(self): logging.debug("Assigning user selection to track Two") - self._data.assignUserSelection(self.tracktwo_id) + self._data.setTrack(self.tracktwo_id) self._timeline.update() self.update() def on_assignOther(self): logging.debug("Assigning user selection to track Other") - self._data.assignUserSelection(self.trackother_id, False) + self._data.setTrack(self.trackother_id, False) self._timeline.update() self.update() def on_setUserFlag(self): - self._data.setAssignmentStatus(True) + self._data.setUserLabeledStatus(True) self._timeline.update() self.update() def on_unsetUserFlag(self): logging.debug("Tracks:unsetUserFlag") - self._data.setAssignmentStatus(False) + self._data.setUserLabeledStatus(False) self._timeline.update() self.update() @@ -320,7 +320,7 @@ class FixTracks(QWidget): ids[i] = d.data(DetectionData.ID.value) frames[i] = d.data(DetectionData.FRAME.value) coordinates[i, :, :] = d.data(DetectionData.COORDINATES.value) - self._data.setUserSelection(ids) + self._data.setSelection(ids) self._controls_widget.setSelectedTracks(tracks) self._skeleton.clear() self._skeleton.addSkeletons(coordinates, ids, frames, tracks, QBrush(QColor(10, 255, 65, 255)))