Compare commits

...

6 Commits

4 changed files with 94 additions and 72 deletions

View File

@ -32,7 +32,7 @@ def set_logging(loglevel):
logging.basicConfig(level=loglevel, force=True) logging.basicConfig(level=loglevel, force=True)
def main(args): def main(args):
set_logging(logging.DEBUG) set_logging(args.loglevel)
if platform.system() == "Windows": if platform.system() == "Windows":
# from PySide6.QtWinExtras import QtWin # from PySide6.QtWinExtras import QtWin
myappid = f"{info.organization_name}.{info.application_version}" myappid = f"{info.organization_name}.{info.application_version}"
@ -75,6 +75,6 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser(description="FixTracks. Tools for fixing animal tracking") parser = argparse.ArgumentParser(description="FixTracks. Tools for fixing animal tracking")
parser.add_argument("-ll", "--loglevel", type=str, default="INFO", help=f"The log level that should be used. Valid levels are {[str(k) for k in levels.keys()]}") parser.add_argument("-ll", "--loglevel", type=str, default="INFO", help=f"The log level that should be used. Valid levels are {[str(k) for k in levels.keys()]}")
args = parser.parse_args() args = parser.parse_args()
args.loglevel = levels[args.loglevel if args.loglevel.lower() in levels else "info"] args.loglevel = levels[args.loglevel.lower() if args.loglevel.lower() in levels else "info"]
main(args) main(args)

View File

@ -7,21 +7,16 @@ from PySide6.QtCore import QObject
class TrackingData(QObject): class TrackingData(QObject):
def __init__(self, parent=None): def __init__(self, datadict, parent=None):
super().__init__(parent) super().__init__(parent)
self._data = None
self._columns = []
self._start = 0
self._stop = 0
self._indices = None
self._selection_column = None
self._user_selections = None
def setData(self, datadict):
assert isinstance(datadict, dict) assert isinstance(datadict, dict)
self._data = datadict self._data = datadict
if "userlabeled" not in self._data.keys():
self._data["userlabeled"] = np.zeros_like(self["frame"], dtype=bool) self._data["userlabeled"] = np.zeros_like(self["frame"], dtype=bool)
self._columns = [k for k in self._data.keys()] self._columns = [k for k in self._data.keys()]
self._indices = self["index"]
self._selection = np.asarray([])
self._selected_ids = None
@property @property
def data(self): def data(self):
@ -42,60 +37,77 @@ class TrackingData(QObject):
def numDetections(self): def numDetections(self):
return self._data["track"].shape[0] return self._data["track"].shape[0]
@property def _find(self, ids):
def selectionRange(self): ids = np.sort(ids)
return self._start, self._stop indexes = np.ones_like(ids, dtype=int) * -1
j = 0
@property for idx, i in enumerate(self._indices):
def selectionRangeColumn(self): if i == ids[j]:
return self._selection_column indexes[j] = idx
j += 1
if j == len(indexes):
break
indexes = indexes[indexes >= 0]
return indexes
@property @property
def selectionIndices(self): def selectionIndices(self):
return self._indices return self._selection
def setSelectionRange(self, col, start, stop): def setSelectionRange(self, col, start, stop):
logging.debug("Trackingdata: set selection range based on column %s to %.2f - %.2f", col, start, stop) logging.debug("Trackingdata: set selection range based on column %s to %.2f - %.2f", col, start, stop)
self._start = start col_indices = np.where((self[col] >= start) & (self[col] < stop))[0]
self._stop = stop self._selection = self._indices[col_indices]
self._selection_column = col
col_indices = np.where((self._data[col] >= self._start) & (self._data[col] < self._stop))[0]
self._indices = self["index"][col_indices]
if len(col_indices) < 1: if len(col_indices) < 1:
logging.warning("TrackingData: Selection range is empty!") logging.warning("TrackingData: Selection range is empty!")
def selectedData(self, col:str): def selectedData(self, col:str):
if col not in self.columns: if col not in self.columns:
logging.error("TrackingData:selectedData: Invalid column name! %s", col) logging.error("TrackingData:selectedData: Invalid column name! %s", col)
return self[col][self._indices] return self[col][self._selection]
def setUserSelection(self, ids): def setSelection(self, ids):
""" """
Set the user selections. That is, e.g. when the user selected a number of detection ids (aka the index of the original data frame entries). Set the selection based on the detection IDs.
Parameters Parameters
---------- ----------
ids : array-like ids : array-like
An array-like object containing the IDs to be set as user selections. An array-like object containing the IDs to be set as user selections.
The IDs will be converted to integers.
""" """
self._user_selections = ids.astype(int) self._selection = self._find(ids)
self._selected_ids = ids
def assignUserSelection(self, track_id:int, userFlag:bool=True)-> None: def setTrack(self, track_id:int, setUserLabeled:bool=True)-> None:
"""Assign a new track_id to the user-selected detections """Assign a new track_id to the user-selected detections
Parameters Parameters
---------- ----------
track_id : int track_id : int
The new track id for the user-selected detections The new track id for the user-selected detections
userFlag : bool setUserLabeled : bool
Should the "userlabeled" state of the detections be set to True or False? Should the "userlabeled" state of the detections be set to True? Otherwise they will be left untouched.
""" """
self["track"][self._user_selections] = track_id self["track"][self._user_selections] = track_id
self.setAssignmentStatus(userFlag) if setUserLabeled:
self.setUserLabeledStatus(True, True)
def setAssignmentStatus(self, isTrue: bool): def setUserLabeledStatus(self, new_status: bool, selection=True):
logging.debug("TrackingData:Re-setting assignment status of user selected data to %s", str(isTrue)) """Sets the status of the "userlabeled" column to a given value (True|False). This can done for ALL data in one go, or only for the UserSelection.
self["userlabeled"][self._user_selections] = isTrue
Parameters
----------
new_status : bool
The new status, TRUE, if the detections are confirmed by the user (human observer) and can be treated as correct
selection : bool, optional
Whether the new status should be set for the selection only (True, default) ore not (False)
"""
logging.debug("TrackingData: (Re-)setting assignment status of %s to %s",
"user selected data" if selection else " ALL", str(new_status))
if selection:
self["userlabeled"][self._selection] = new_status
else:
self["userlabeled"][:] = new_status
def revertAssignmentStatus(self): def revertAssignmentStatus(self):
logging.debug("TrackingData:Un-setting assignment status of all data!") logging.debug("TrackingData:Un-setting assignment status of all data!")
@ -105,14 +117,15 @@ class TrackingData(QObject):
logging.debug("TrackingData: Reverting all track assignments!") logging.debug("TrackingData: Reverting all track assignments!")
self["track"][:] = -1 self["track"][:] = -1
def deleteDetections(self): def deleteDetections(self, ids=None):
# from IPython import embed if ids is not None:
# if self._user_selections is not None: del_indices = self._find(ids)
# ids = self._user_selections else:
# for c in self.columns: del_indices = self._indices
# pass for c in self._columns:
# embed() self._data[c] = np.delete(self._data[c], del_indices, axis=0)
pass self._indices = self["index"]
self._selected_ids = np.setdiff1d(self._selected_ids, del_indices)
def assignTracks(self, tracks:np.ndarray): def assignTracks(self, tracks:np.ndarray):
"""assigns the given tracks to the user-selected detections. If the sizes of """assigns the given tracks to the user-selected detections. If the sizes of
@ -154,11 +167,11 @@ class TrackingData(QObject):
and M is number of keypoints and M is number of keypoints
""" """
if selection: if selection:
if len(self._indices) < 1: if len(self._selection) < 1:
logging.info("TrackingData.coordinates returns empty array, not detections in range!") logging.info("TrackingData.coordinates returns empty array, not detections in range!")
return np.ndarray([]) return np.ndarray([])
return np.stack(self._data["keypoints"][self._indices]).astype(np.float32) return np.stack(self["keypoints"][self._selection]).astype(np.float32)
return np.stack(self._data["keypoints"]).astype(np.float32) return np.stack(self["keypoints"]).astype(np.float32)
def keypointScores(self, selection=False): def keypointScores(self, selection=False):
""" """
@ -171,11 +184,11 @@ class TrackingData(QObject):
with N the number of detections and M the number of keypoints. with N the number of detections and M the number of keypoints.
""" """
if selection: if selection:
if len(self._indices) < 1: if len(self._selection) < 1:
logging.info("TrackingData.scores returns empty array, not detections in range!") logging.info("TrackingData.scores returns empty array, not detections in range!")
return np.ndarray([]) return None
return np.stack(self._data["keypoint_score"][self._indices]).astype(np.float32) return np.stack(self["keypoint_score"][self._selection]).astype(np.float32)
return np.stack(self._data["keypoint_score"]).astype(np.float32) return np.stack(self["keypoint_score"]).astype(np.float32)
def centerOfGravity(self, selection=False, threshold=0.8, nodes=[0,1,2]): def centerOfGravity(self, selection=False, threshold=0.8, nodes=[0,1,2]):
""" """
@ -195,12 +208,17 @@ class TrackingData(QObject):
A NumPy array of shape (N, 2) containing the center of gravity for each detection. A NumPy array of shape (N, 2) containing the center of gravity for each detection.
""" """
scores = self.keypointScores(selection) scores = self.keypointScores(selection)
if scores is None:
return None
scores[scores < threshold] = 0.0 scores[scores < threshold] = 0.0
scores[:, np.setdiff1d(np.arange(scores.shape[1]), nodes)] = 0.0 scores[:, np.setdiff1d(np.arange(scores.shape[1]), nodes)] = 0.0
weighted_coords = self.coordinates(selection=selection) * scores[:, :, np.newaxis] weighted_coords = self.coordinates(selection) * scores[:, :, np.newaxis]
sum_scores = np.sum(scores, axis=1, keepdims=True) sum_scores = np.sum(scores, axis=1, keepdims=True)
center_of_gravity = np.sum(weighted_coords, axis=1) / sum_scores
return center_of_gravity cogs = np.zeros((weighted_coords.shape[0], 2))
val_ids = np.where(sum_scores > 0.0)[0]
cogs[val_ids] = np.sum(weighted_coords[val_ids], axis=1) / sum_scores[val_ids]
return cogs
def animalLength(self, bodyaxis=None): def animalLength(self, bodyaxis=None):
if bodyaxis is None: if bodyaxis is None:
@ -288,22 +306,26 @@ def main():
plt.plot([positions[si, 0], positions[ei, 0]], plt.plot([positions[si, 0], positions[ei, 0]],
[positions[si, 1], positions[ei, 1]], color="tab:green") [positions[si, 1], positions[ei, 1]], color="tab:green")
datafile = PACKAGE_ROOT / "data/merged_small.pkl" datafile = PACKAGE_ROOT / "data/merged_small_tracked.pkl"
with open(datafile, "rb") as f: with open(datafile, "rb") as f:
df = pickle.load(f) df = pickle.load(f)
data = TrackingData(as_dict(df))
test_indices = [32, 88, 99, 2593]
data.deleteDetections(test_indices)
data = TrackingData() embed()
data.setData(as_dict(df)) data.deleteDetections(test_indices)
data.setSelection(test_indices)
all_cogs = data.centerOfGravity() all_cogs = data.centerOfGravity()
orientations = data.orientation() orientations = data.orientation()
lengths = data.animalLength() lengths = data.animalLength()
frames = data["frame"] frames = data["frame"]
tracks = data["track"] tracks = data["track"]
bendedness = data.bendedness() bendedness = data.bendedness()
indices = data._indices
# positions = data.coordinates()[[160388, 160389]] # positions = data.coordinates()[[160388, 160389]]
embed()
tracks = data["track"] tracks = data["track"]
cogs = all_cogs[tracks==1] cogs = all_cogs[tracks==1]
all_dists = neighborDistances(cogs, 2, False) all_dists = neighborDistances(cogs, 2, False)

View File

@ -311,8 +311,8 @@ def main():
df = pickle.load(f) df = pickle.load(f)
data = TrackingData() data = TrackingData()
data.setData(as_dict(df)) data.setData(as_dict(df))
data.setUserSelection(np.arange(0,100, 1)) data.setSelection(np.arange(0,100, 1))
data.setAssignmentStatus(True) data.setUserLabeledStatus(True)
start_x = 0.1 start_x = 0.1
app = QApplication([]) app = QApplication([])
window = QWidget() window = QWidget()

View File

@ -31,7 +31,7 @@ class FixTracks(QWidget):
self._currentWindowPos = 0 # in frames self._currentWindowPos = 0 # in frames
self._currentWindowWidth = 0 # in frames self._currentWindowWidth = 0 # in frames
self._maxframes = 0 self._maxframes = 0
self._data = TrackingData() self._data = None
self._detectionView = DetectionView() self._detectionView = DetectionView()
self._detectionView.signals.itemsSelected.connect(self.on_detectionsSelected) self._detectionView.signals.itemsSelected.connect(self.on_detectionsSelected)
@ -204,7 +204,7 @@ class FixTracks(QWidget):
self._progress_bar.setRange(0, 100) self._progress_bar.setRange(0, 100)
self._progress_bar.setValue(0) self._progress_bar.setValue(0)
if state and self._reader is not None: if state and self._reader is not None:
self._data.setData(self._reader.asdict) self._data = TrackingData(self._reader.asdict)
self._saveBtn.setEnabled(True) self._saveBtn.setEnabled(True)
self._currentWindowPos = 0 self._currentWindowPos = 0
self._currentWindowWidth = self._windowspinner.value() self._currentWindowWidth = self._windowspinner.value()
@ -247,30 +247,30 @@ class FixTracks(QWidget):
def on_assignOne(self): def on_assignOne(self):
logging.debug("Assigning user selection to track One") logging.debug("Assigning user selection to track One")
self._data.assignUserSelection(self.trackone_id) self._data.setTrack(self.trackone_id)
self._timeline.update() self._timeline.update()
self.update() self.update()
def on_assignTwo(self): def on_assignTwo(self):
logging.debug("Assigning user selection to track Two") logging.debug("Assigning user selection to track Two")
self._data.assignUserSelection(self.tracktwo_id) self._data.setTrack(self.tracktwo_id)
self._timeline.update() self._timeline.update()
self.update() self.update()
def on_assignOther(self): def on_assignOther(self):
logging.debug("Assigning user selection to track Other") logging.debug("Assigning user selection to track Other")
self._data.assignUserSelection(self.trackother_id, False) self._data.setTrack(self.trackother_id, False)
self._timeline.update() self._timeline.update()
self.update() self.update()
def on_setUserFlag(self): def on_setUserFlag(self):
self._data.setAssignmentStatus(True) self._data.setUserLabeledStatus(True)
self._timeline.update() self._timeline.update()
self.update() self.update()
def on_unsetUserFlag(self): def on_unsetUserFlag(self):
logging.debug("Tracks:unsetUserFlag") logging.debug("Tracks:unsetUserFlag")
self._data.setAssignmentStatus(False) self._data.setUserLabeledStatus(False)
self._timeline.update() self._timeline.update()
self.update() self.update()
@ -320,7 +320,7 @@ class FixTracks(QWidget):
ids[i] = d.data(DetectionData.ID.value) ids[i] = d.data(DetectionData.ID.value)
frames[i] = d.data(DetectionData.FRAME.value) frames[i] = d.data(DetectionData.FRAME.value)
coordinates[i, :, :] = d.data(DetectionData.COORDINATES.value) coordinates[i, :, :] = d.data(DetectionData.COORDINATES.value)
self._data.setUserSelection(ids) self._data.setSelection(ids)
self._controls_widget.setSelectedTracks(tracks) self._controls_widget.setSelectedTracks(tracks)
self._skeleton.clear() self._skeleton.clear()
self._skeleton.addSkeletons(coordinates, ids, frames, tracks, QBrush(QColor(10, 255, 65, 255))) self._skeleton.addSkeletons(coordinates, ids, frames, tracks, QBrush(QColor(10, 255, 65, 255)))