[trackingdata] implement better selection handling ...

allow deletion of entries
This commit is contained in:
Jan Grewe 2025-02-24 16:03:42 +01:00
parent d176925796
commit 6fbbb52370

View File

@ -37,17 +37,22 @@ class TrackingData(QObject):
def numDetections(self): def numDetections(self):
return self._data["track"].shape[0] return self._data["track"].shape[0]
@property def _find(self, ids):
def selectionRange(self): ids = np.sort(ids)
return self._start, self._stop indexes = np.ones_like(ids, dtype=int) * -1
j = 0
@property for idx, i in enumerate(self._indices):
def selectionRangeColumn(self): if i == ids[j]:
return self._selection_column indexes[j] = idx
j += 1
if j == len(indexes):
break
indexes = indexes[indexes >= 0]
return indexes
@property @property
def selectionIndices(self): def selectionIndices(self):
return self._indices return self._selection
def setSelectionRange(self, col, start, stop): def setSelectionRange(self, col, start, stop):
logging.debug("Trackingdata: set selection range based on column %s to %.2f - %.2f", col, start, stop) logging.debug("Trackingdata: set selection range based on column %s to %.2f - %.2f", col, start, stop)
@ -59,18 +64,19 @@ class TrackingData(QObject):
def selectedData(self, col:str): def selectedData(self, col:str):
if col not in self.columns: if col not in self.columns:
logging.error("TrackingData:selectedData: Invalid column name! %s", col) logging.error("TrackingData:selectedData: Invalid column name! %s", col)
return self[col][self._indices] return self[col][self._selection]
def setUserSelection(self, ids): def setSelection(self, ids):
""" """
Set the user selections. That is, e.g. when the user selected a number of detection ids (aka the index of the original data frame entries). Set the selection based on the detection IDs.
Parameters Parameters
---------- ----------
ids : array-like ids : array-like
An array-like object containing the IDs to be set as user selections. An array-like object containing the IDs to be set as user selections.
The IDs will be converted to integers.
""" """
self._user_selections = ids.astype(int) self._selection = self._find(ids)
self._selected_ids = ids
def setTrack(self, track_id:int, setUserLabeled:bool=True)-> None: def setTrack(self, track_id:int, setUserLabeled:bool=True)-> None:
"""Assign a new track_id to the user-selected detections """Assign a new track_id to the user-selected detections
@ -111,14 +117,15 @@ class TrackingData(QObject):
logging.debug("TrackingData: Reverting all track assignments!") logging.debug("TrackingData: Reverting all track assignments!")
self["track"][:] = -1 self["track"][:] = -1
def deleteDetections(self): def deleteDetections(self, ids=None):
# from IPython import embed if ids is not None:
# if self._user_selections is not None: del_indices = self._find(ids)
# ids = self._user_selections else:
# for c in self.columns: del_indices = self._indices
# pass for c in self._columns:
# embed() self._data[c] = np.delete(self._data[c], del_indices, axis=0)
pass self._indices = self["index"]
self._selected_ids = np.setdiff1d(self._selected_ids, del_indices)
def assignTracks(self, tracks:np.ndarray): def assignTracks(self, tracks:np.ndarray):
"""assigns the given tracks to the user-selected detections. If the sizes of """assigns the given tracks to the user-selected detections. If the sizes of
@ -299,22 +306,26 @@ def main():
plt.plot([positions[si, 0], positions[ei, 0]], plt.plot([positions[si, 0], positions[ei, 0]],
[positions[si, 1], positions[ei, 1]], color="tab:green") [positions[si, 1], positions[ei, 1]], color="tab:green")
datafile = PACKAGE_ROOT / "data/merged_small.pkl" datafile = PACKAGE_ROOT / "data/merged_small_tracked.pkl"
with open(datafile, "rb") as f: with open(datafile, "rb") as f:
df = pickle.load(f) df = pickle.load(f)
data = TrackingData(as_dict(df))
test_indices = [32, 88, 99, 2593]
data.deleteDetections(test_indices)
data = TrackingData() embed()
data.setData(as_dict(df)) data.deleteDetections(test_indices)
data.setSelection(test_indices)
all_cogs = data.centerOfGravity() all_cogs = data.centerOfGravity()
orientations = data.orientation() orientations = data.orientation()
lengths = data.animalLength() lengths = data.animalLength()
frames = data["frame"] frames = data["frame"]
tracks = data["track"] tracks = data["track"]
bendedness = data.bendedness() bendedness = data.bendedness()
indices = data._indices
# positions = data.coordinates()[[160388, 160389]] # positions = data.coordinates()[[160388, 160389]]
embed()
tracks = data["track"] tracks = data["track"]
cogs = all_cogs[tracks==1] cogs = all_cogs[tracks==1]
all_dists = neighborDistances(cogs, 2, False) all_dists = neighborDistances(cogs, 2, False)