[trackingdata] add more functionality

This commit is contained in:
Jan Grewe 2025-02-20 17:59:52 +01:00
parent 9361069b74
commit dbd5b380ba

View File

@ -54,6 +54,7 @@ class TrackingData(QObject):
return self._indices
def setSelectionRange(self, col, start, stop):
logging.debug("Trackingdata: set selection range based on column %s to %.2f - %.2f", col, start, stop)
self._start = start
self._stop = stop
self._selection_column = col
@ -64,7 +65,7 @@ class TrackingData(QObject):
def setUserSelection(self, ids):
"""
Set the user selections. That is, e.g. when the user selected a number of ids.
Set the user selections. That is, e.g. when the user selected a number of detection ids (aka the index of the original data frame entries).
Parameters
----------
ids : array-like
@ -73,16 +74,33 @@ class TrackingData(QObject):
"""
self._user_selections = ids.astype(int)
def assignUserSelection(self, track_id:int)-> None:
def assignUserSelection(self, track_id:int, userFlag:bool=True)-> None:
"""Assign a new track_id to the user-selected detections
Parameters
----------
track_id : int
The new track id for the user-selected detections
userFlag : bool
Should the "userlabeled" state of the detections be set to True or False?
"""
self._data["track"][self._user_selections] = track_id
self._data["userlabeled"][self._user_selections] = True
self.setAssignmentStatus(userFlag)
def setAssignmentStatus(self, isTrue: bool):
self._data["userlabeled"][self._user_selections] = isTrue
def revertAssignmentStatus(self):
self._data["userlabeled"][:] = False
def deleteDetections(self):
# from IPython import embed
# if self._user_selections is not None:
# ids = self._user_selections
# for c in self.columns:
# pass
# embed()
pass
def assignTracks(self, tracks):
"""assignTracks _summary_
@ -115,7 +133,7 @@ class TrackingData(QObject):
return 0
return self._data["keypoints"][0].shape[0]
def coordinates(self):
def coordinates(self, selection=False):
"""
Returns the coordinates of all keypoints as a NumPy array.
@ -123,7 +141,10 @@ class TrackingData(QObject):
np.ndarray: A NumPy array of shape (N, M, 2) where N is the number of detections,
and M is number of keypoints
"""
return np.stack(self._data["keypoints"]).astype(np.float32)
if selection:
return np.stack(self._data["keypoints"][self._start:self._stop, :, :]).astype(np.float32)
else:
return np.stack(self._data["keypoints"]).astype(np.float32)
def keypointScores(self):
"""
@ -202,14 +223,6 @@ class TrackingData(QObject):
def __getitem__(self, key):
return self._data[key]
# def __setitem__(self, key, value):
# self._data[key] = value
"""
self._data.setSelectionRange("index", 0, self._data.numDetections)
self._data.assignTracks(tracks)
self._timeline.setDetectionData(self._data.data)
self.update()
"""
def main():
import pandas as pd
@ -217,6 +230,8 @@ def main():
import matplotlib.pyplot as plt
from fixtracks.info import PACKAGE_ROOT
logging.basicConfig(level=logging.DEBUG, force=True)
def as_dict(df:pd.DataFrame):
d = {c: df[c].values for c in df.columns}
d["index"] = df.index.values
@ -251,7 +266,7 @@ def main():
plt.plot([positions[si, 0], positions[ei, 0]],
[positions[si, 1], positions[ei, 1]], color="tab:green")
datafile = PACKAGE_ROOT / "data/merged2.pkl"
datafile = PACKAGE_ROOT / "data/merged_small.pkl"
with open(datafile, "rb") as f:
df = pickle.load(f)