diff --git a/fixtracks/utils/trackingdata.py b/fixtracks/utils/trackingdata.py index a41c120..e3d8625 100644 --- a/fixtracks/utils/trackingdata.py +++ b/fixtracks/utils/trackingdata.py @@ -5,6 +5,7 @@ import pandas as pd from PySide6.QtCore import QObject + class TrackingData(QObject): def __init__(self, parent=None): super().__init__(parent) @@ -58,9 +59,14 @@ class TrackingData(QObject): self._start = start self._stop = stop self._selection_column = col - self._indices = np.where((self._data[col] >= self._start) & (self._data[col] < self._stop))[0] - - def selectedData(self, col): + col_indices = np.where((self._data[col] >= self._start) & (self._data[col] < self._stop))[0] + self._indices = self["index"][col_indices] + if len(col_indices) < 1: + logging.warning("TrackingData: Selection range is empty!") + + def selectedData(self, col:str): + if col not in self.columns: + logging.error("TrackingData:selectedData: Invalid column name! %s", col) return self[col][self._indices] def setUserSelection(self, ids): @@ -148,11 +154,13 @@ class TrackingData(QObject): and M is number of keypoints """ if selection: - return np.stack(self._data["keypoints"][self._start:self._stop, :, :]).astype(np.float32) - else: - return np.stack(self._data["keypoints"]).astype(np.float32) + if len(self._indices) < 1: + logging.info("TrackingData.coordinates returns empty array, not detections in range!") + return np.ndarray([]) + return np.stack(self._data["keypoints"][self._indices]).astype(np.float32) + return np.stack(self._data["keypoints"]).astype(np.float32) - def keypointScores(self): + def keypointScores(self, selection=False): """ Returns the keypoint scores as a NumPy array of type float32. @@ -161,10 +169,15 @@ class TrackingData(QObject): numpy.ndarray A NumPy array of type float32 containing the keypoint scores of the shape (N, M) with N the number of detections and M the number of keypoints. - """ + """ + if selection: + if len(self._indices) < 1: + logging.info("TrackingData.scores returns empty array, not detections in range!") + return np.ndarray([]) + return np.stack(self._data["keypoint_score"][self._indices]).astype(np.float32) return np.stack(self._data["keypoint_score"]).astype(np.float32) - def centerOfGravity(self, threshold=0.8): + def centerOfGravity(self, selection=False, threshold=0.8, nodes=[0,1,2]): """ Calculate the center of gravity of keypoints weighted by their scores. Ignores keypoints that have a score less than threshold. @@ -172,16 +185,19 @@ class TrackingData(QObject): Parameters: ----------- threshold: float - keypoints with a score less than threshold are ignored + nodes with a score less than threshold are ignored + nodes: list + nodes/keypoints to consider for estimation. Defaults to [0,1,2] Returns: -------- np.ndarray: A NumPy array of shape (N, 2) containing the center of gravity for each detection. """ - scores = self.keypointScores() + scores = self.keypointScores(selection) scores[scores < threshold] = 0.0 - weighted_coords = self.coordinates() * scores[:, :, np.newaxis] + scores[:, np.setdiff1d(np.arange(scores.shape[1]), nodes)] = 0.0 + weighted_coords = self.coordinates(selection=selection) * scores[:, :, np.newaxis] sum_scores = np.sum(scores, axis=1, keepdims=True) center_of_gravity = np.sum(weighted_coords, axis=1) / sum_scores return center_of_gravity @@ -285,7 +301,7 @@ def main(): frames = data["frame"] tracks = data["track"] bendedness = data.bendedness() - positions = data.coordinates()[[160388, 160389]] + # positions = data.coordinates()[[160388, 160389]] embed() tracks = data["track"] @@ -309,7 +325,6 @@ def main(): # return distances # print("estimating neighorhood distances") # neighbor_distances = compute_neighbor_distances(cogs) - embed() if __name__ == "__main__": main() \ No newline at end of file