[trackingdata] fixes and support for selections when getting data of columns
This commit is contained in:
parent
f09c78adb5
commit
af5dbc7dfc
@ -5,6 +5,7 @@ import pandas as pd
|
||||
|
||||
from PySide6.QtCore import QObject
|
||||
|
||||
|
||||
class TrackingData(QObject):
|
||||
def __init__(self, parent=None):
|
||||
super().__init__(parent)
|
||||
@ -58,9 +59,14 @@ class TrackingData(QObject):
|
||||
self._start = start
|
||||
self._stop = stop
|
||||
self._selection_column = col
|
||||
self._indices = np.where((self._data[col] >= self._start) & (self._data[col] < self._stop))[0]
|
||||
|
||||
def selectedData(self, col):
|
||||
col_indices = np.where((self._data[col] >= self._start) & (self._data[col] < self._stop))[0]
|
||||
self._indices = self["index"][col_indices]
|
||||
if len(col_indices) < 1:
|
||||
logging.warning("TrackingData: Selection range is empty!")
|
||||
|
||||
def selectedData(self, col:str):
|
||||
if col not in self.columns:
|
||||
logging.error("TrackingData:selectedData: Invalid column name! %s", col)
|
||||
return self[col][self._indices]
|
||||
|
||||
def setUserSelection(self, ids):
|
||||
@ -148,11 +154,13 @@ class TrackingData(QObject):
|
||||
and M is number of keypoints
|
||||
"""
|
||||
if selection:
|
||||
return np.stack(self._data["keypoints"][self._start:self._stop, :, :]).astype(np.float32)
|
||||
else:
|
||||
return np.stack(self._data["keypoints"]).astype(np.float32)
|
||||
if len(self._indices) < 1:
|
||||
logging.info("TrackingData.coordinates returns empty array, not detections in range!")
|
||||
return np.ndarray([])
|
||||
return np.stack(self._data["keypoints"][self._indices]).astype(np.float32)
|
||||
return np.stack(self._data["keypoints"]).astype(np.float32)
|
||||
|
||||
def keypointScores(self):
|
||||
def keypointScores(self, selection=False):
|
||||
"""
|
||||
Returns the keypoint scores as a NumPy array of type float32.
|
||||
|
||||
@ -161,10 +169,15 @@ class TrackingData(QObject):
|
||||
numpy.ndarray
|
||||
A NumPy array of type float32 containing the keypoint scores of the shape (N, M)
|
||||
with N the number of detections and M the number of keypoints.
|
||||
"""
|
||||
"""
|
||||
if selection:
|
||||
if len(self._indices) < 1:
|
||||
logging.info("TrackingData.scores returns empty array, not detections in range!")
|
||||
return np.ndarray([])
|
||||
return np.stack(self._data["keypoint_score"][self._indices]).astype(np.float32)
|
||||
return np.stack(self._data["keypoint_score"]).astype(np.float32)
|
||||
|
||||
def centerOfGravity(self, threshold=0.8):
|
||||
def centerOfGravity(self, selection=False, threshold=0.8, nodes=[0,1,2]):
|
||||
"""
|
||||
Calculate the center of gravity of keypoints weighted by their scores. Ignores keypoints that have a score
|
||||
less than threshold.
|
||||
@ -172,16 +185,19 @@ class TrackingData(QObject):
|
||||
Parameters:
|
||||
-----------
|
||||
threshold: float
|
||||
keypoints with a score less than threshold are ignored
|
||||
nodes with a score less than threshold are ignored
|
||||
nodes: list
|
||||
nodes/keypoints to consider for estimation. Defaults to [0,1,2]
|
||||
|
||||
Returns:
|
||||
--------
|
||||
np.ndarray:
|
||||
A NumPy array of shape (N, 2) containing the center of gravity for each detection.
|
||||
"""
|
||||
scores = self.keypointScores()
|
||||
scores = self.keypointScores(selection)
|
||||
scores[scores < threshold] = 0.0
|
||||
weighted_coords = self.coordinates() * scores[:, :, np.newaxis]
|
||||
scores[:, np.setdiff1d(np.arange(scores.shape[1]), nodes)] = 0.0
|
||||
weighted_coords = self.coordinates(selection=selection) * scores[:, :, np.newaxis]
|
||||
sum_scores = np.sum(scores, axis=1, keepdims=True)
|
||||
center_of_gravity = np.sum(weighted_coords, axis=1) / sum_scores
|
||||
return center_of_gravity
|
||||
@ -285,7 +301,7 @@ def main():
|
||||
frames = data["frame"]
|
||||
tracks = data["track"]
|
||||
bendedness = data.bendedness()
|
||||
positions = data.coordinates()[[160388, 160389]]
|
||||
# positions = data.coordinates()[[160388, 160389]]
|
||||
|
||||
embed()
|
||||
tracks = data["track"]
|
||||
@ -309,7 +325,6 @@ def main():
|
||||
# return distances
|
||||
# print("estimating neighorhood distances")
|
||||
# neighbor_distances = compute_neighbor_distances(cogs)
|
||||
embed()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
Reference in New Issue
Block a user