[trackingdata] change selections, constructor ...

renaming of some functions
This commit is contained in:
Jan Grewe 2025-02-24 16:03:05 +01:00
parent 35be41282a
commit d176925796
2 changed files with 44 additions and 38 deletions

View File

@ -7,21 +7,16 @@ from PySide6.QtCore import QObject
class TrackingData(QObject): class TrackingData(QObject):
def __init__(self, parent=None): def __init__(self, datadict, parent=None):
super().__init__(parent) super().__init__(parent)
self._data = None
self._columns = []
self._start = 0
self._stop = 0
self._indices = None
self._selection_column = None
self._user_selections = None
def setData(self, datadict):
assert isinstance(datadict, dict) assert isinstance(datadict, dict)
self._data = datadict self._data = datadict
self._data["userlabeled"] = np.zeros_like(self["frame"], dtype=bool) if "userlabeled" not in self._data.keys():
self._data["userlabeled"] = np.zeros_like(self["frame"], dtype=bool)
self._columns = [k for k in self._data.keys()] self._columns = [k for k in self._data.keys()]
self._indices = self["index"]
self._selection = np.asarray([])
self._selected_ids = None
@property @property
def data(self): def data(self):
@ -56,11 +51,8 @@ class TrackingData(QObject):
def setSelectionRange(self, col, start, stop): def setSelectionRange(self, col, start, stop):
logging.debug("Trackingdata: set selection range based on column %s to %.2f - %.2f", col, start, stop) logging.debug("Trackingdata: set selection range based on column %s to %.2f - %.2f", col, start, stop)
self._start = start col_indices = np.where((self[col] >= start) & (self[col] < stop))[0]
self._stop = stop self._selection = self._indices[col_indices]
self._selection_column = col
col_indices = np.where((self._data[col] >= self._start) & (self._data[col] < self._stop))[0]
self._indices = self["index"][col_indices]
if len(col_indices) < 1: if len(col_indices) < 1:
logging.warning("TrackingData: Selection range is empty!") logging.warning("TrackingData: Selection range is empty!")
@ -80,22 +72,36 @@ class TrackingData(QObject):
""" """
self._user_selections = ids.astype(int) self._user_selections = ids.astype(int)
def assignUserSelection(self, track_id:int, userFlag:bool=True)-> None: def setTrack(self, track_id:int, setUserLabeled:bool=True)-> None:
"""Assign a new track_id to the user-selected detections """Assign a new track_id to the user-selected detections
Parameters Parameters
---------- ----------
track_id : int track_id : int
The new track id for the user-selected detections The new track id for the user-selected detections
userFlag : bool setUserLabeled : bool
Should the "userlabeled" state of the detections be set to True or False? Should the "userlabeled" state of the detections be set to True? Otherwise they will be left untouched.
""" """
self["track"][self._user_selections] = track_id self["track"][self._user_selections] = track_id
self.setAssignmentStatus(userFlag) if setUserLabeled:
self.setUserLabeledStatus(True, True)
def setAssignmentStatus(self, isTrue: bool): def setUserLabeledStatus(self, new_status: bool, selection=True):
logging.debug("TrackingData:Re-setting assignment status of user selected data to %s", str(isTrue)) """Sets the status of the "userlabeled" column to a given value (True|False). This can done for ALL data in one go, or only for the UserSelection.
self["userlabeled"][self._user_selections] = isTrue
Parameters
----------
new_status : bool
The new status, TRUE, if the detections are confirmed by the user (human observer) and can be treated as correct
selection : bool, optional
Whether the new status should be set for the selection only (True, default) ore not (False)
"""
logging.debug("TrackingData: (Re-)setting assignment status of %s to %s",
"user selected data" if selection else " ALL", str(new_status))
if selection:
self["userlabeled"][self._selection] = new_status
else:
self["userlabeled"][:] = new_status
def revertAssignmentStatus(self): def revertAssignmentStatus(self):
logging.debug("TrackingData:Un-setting assignment status of all data!") logging.debug("TrackingData:Un-setting assignment status of all data!")
@ -154,11 +160,11 @@ class TrackingData(QObject):
and M is number of keypoints and M is number of keypoints
""" """
if selection: if selection:
if len(self._indices) < 1: if len(self._selection) < 1:
logging.info("TrackingData.coordinates returns empty array, not detections in range!") logging.info("TrackingData.coordinates returns empty array, not detections in range!")
return np.ndarray([]) return np.ndarray([])
return np.stack(self._data["keypoints"][self._indices]).astype(np.float32) return np.stack(self["keypoints"][self._selection]).astype(np.float32)
return np.stack(self._data["keypoints"]).astype(np.float32) return np.stack(self["keypoints"]).astype(np.float32)
def keypointScores(self, selection=False): def keypointScores(self, selection=False):
""" """
@ -171,11 +177,11 @@ class TrackingData(QObject):
with N the number of detections and M the number of keypoints. with N the number of detections and M the number of keypoints.
""" """
if selection: if selection:
if len(self._indices) < 1: if len(self._selection) < 1:
logging.info("TrackingData.scores returns empty array, not detections in range!") logging.info("TrackingData.scores returns empty array, not detections in range!")
return None return None
return np.stack(self._data["keypoint_score"][self._indices]).astype(np.float32) return np.stack(self["keypoint_score"][self._selection]).astype(np.float32)
return np.stack(self._data["keypoint_score"]).astype(np.float32) return np.stack(self["keypoint_score"]).astype(np.float32)
def centerOfGravity(self, selection=False, threshold=0.8, nodes=[0,1,2]): def centerOfGravity(self, selection=False, threshold=0.8, nodes=[0,1,2]):
""" """
@ -199,7 +205,7 @@ class TrackingData(QObject):
return None return None
scores[scores < threshold] = 0.0 scores[scores < threshold] = 0.0
scores[:, np.setdiff1d(np.arange(scores.shape[1]), nodes)] = 0.0 scores[:, np.setdiff1d(np.arange(scores.shape[1]), nodes)] = 0.0
weighted_coords = self.coordinates(selection=selection) * scores[:, :, np.newaxis] weighted_coords = self.coordinates(selection) * scores[:, :, np.newaxis]
sum_scores = np.sum(scores, axis=1, keepdims=True) sum_scores = np.sum(scores, axis=1, keepdims=True)
cogs = np.zeros((weighted_coords.shape[0], 2)) cogs = np.zeros((weighted_coords.shape[0], 2))

View File

@ -31,7 +31,7 @@ class FixTracks(QWidget):
self._currentWindowPos = 0 # in frames self._currentWindowPos = 0 # in frames
self._currentWindowWidth = 0 # in frames self._currentWindowWidth = 0 # in frames
self._maxframes = 0 self._maxframes = 0
self._data = TrackingData() self._data = None
self._detectionView = DetectionView() self._detectionView = DetectionView()
self._detectionView.signals.itemsSelected.connect(self.on_detectionsSelected) self._detectionView.signals.itemsSelected.connect(self.on_detectionsSelected)
@ -204,7 +204,7 @@ class FixTracks(QWidget):
self._progress_bar.setRange(0, 100) self._progress_bar.setRange(0, 100)
self._progress_bar.setValue(0) self._progress_bar.setValue(0)
if state and self._reader is not None: if state and self._reader is not None:
self._data.setData(self._reader.asdict) self._data = TrackingData(self._reader.asdict)
self._saveBtn.setEnabled(True) self._saveBtn.setEnabled(True)
self._currentWindowPos = 0 self._currentWindowPos = 0
self._currentWindowWidth = self._windowspinner.value() self._currentWindowWidth = self._windowspinner.value()
@ -247,30 +247,30 @@ class FixTracks(QWidget):
def on_assignOne(self): def on_assignOne(self):
logging.debug("Assigning user selection to track One") logging.debug("Assigning user selection to track One")
self._data.assignUserSelection(self.trackone_id) self._data.setTrack(self.trackone_id)
self._timeline.update() self._timeline.update()
self.update() self.update()
def on_assignTwo(self): def on_assignTwo(self):
logging.debug("Assigning user selection to track Two") logging.debug("Assigning user selection to track Two")
self._data.assignUserSelection(self.tracktwo_id) self._data.setTrack(self.tracktwo_id)
self._timeline.update() self._timeline.update()
self.update() self.update()
def on_assignOther(self): def on_assignOther(self):
logging.debug("Assigning user selection to track Other") logging.debug("Assigning user selection to track Other")
self._data.assignUserSelection(self.trackother_id, False) self._data.setTrack(self.trackother_id, False)
self._timeline.update() self._timeline.update()
self.update() self.update()
def on_setUserFlag(self): def on_setUserFlag(self):
self._data.setAssignmentStatus(True) self._data.setUserLabeledStatus(True)
self._timeline.update() self._timeline.update()
self.update() self.update()
def on_unsetUserFlag(self): def on_unsetUserFlag(self):
logging.debug("Tracks:unsetUserFlag") logging.debug("Tracks:unsetUserFlag")
self._data.setAssignmentStatus(False) self._data.setUserLabeledStatus(False)
self._timeline.update() self._timeline.update()
self.update() self.update()
@ -320,7 +320,7 @@ class FixTracks(QWidget):
ids[i] = d.data(DetectionData.ID.value) ids[i] = d.data(DetectionData.ID.value)
frames[i] = d.data(DetectionData.FRAME.value) frames[i] = d.data(DetectionData.FRAME.value)
coordinates[i, :, :] = d.data(DetectionData.COORDINATES.value) coordinates[i, :, :] = d.data(DetectionData.COORDINATES.value)
self._data.setUserSelection(ids) self._data.setSelection(ids)
self._controls_widget.setSelectedTracks(tracks) self._controls_widget.setSelectedTracks(tracks)
self._skeleton.clear() self._skeleton.clear()
self._skeleton.addSkeletons(coordinates, ids, frames, tracks, QBrush(QColor(10, 255, 65, 255))) self._skeleton.addSkeletons(coordinates, ids, frames, tracks, QBrush(QColor(10, 255, 65, 255)))