diff --git a/fixtracks/utils/trackingdata.py b/fixtracks/utils/trackingdata.py new file mode 100644 index 0000000..9853938 --- /dev/null +++ b/fixtracks/utils/trackingdata.py @@ -0,0 +1,121 @@ +import pickle +import logging +import numpy as np + +from PySide6.QtCore import QObject + + +class TrackingData(QObject): + def __init__(self, parent=None): + super().__init__(parent) + self._data = None + self._columns = [] + self._start = 0 + self._stop = 0 + self._indices = None + self._selection_column = None + self._user_selections = None + + def setData(self, datadict): + assert isinstance(datadict, dict) + self._data = datadict + self._columns = [k for k in self._data.keys()] + + @property + def data(self): + return self._data + + @property + def columns(self): + return self._columns + + def max(self, col): + if col in self.columns: + return np.max(self._data[col]) + else: + logging.error("Column %s not in dictionary", col) + return np.nan + + @property + def numDetections(self): + return self._data["track"].shape[0] + + @property + def selectionRange(self): + return self._start, self._stop + + @property + def selectionRangeColumn(self): + return self._selection_column + + @property + def selectionIndices(self): + return self._indices + + def setSelectionRange(self, col, start, stop): + self._start = start + self._stop = stop + self._selection_column = col + self._indices = np.where((self._data[col] >= self._start) & (self._data[col] < self._stop))[0] + + def selectedData(self, col): + return self._data[col][self._indices] + + def setUserSelection(self, ids): + self._user_selections = ids.astype(int) + + def assignUserSelection(self, track_id): + self._data["track"][self._user_selections] = track_id + + def assignTracks(self, tracks): + if len(tracks) != self.numDetections: + logging.error("DataController: Size of passed tracks does not match data!") + return + self._data["track"] = tracks + + def save(self, filename): + export_columns = self._columns.copy() + export_columns.remove("index") + dictionary = {c: self._data[c] for c in export_columns} + df = pd.DataFrame(dictionary, index=self._data["index"]) + with open(filename, 'wb') as f: + pickle.dump(df, f) + + def numKeypoints(self): + if len(self._data["keypoints"]) == 0: + return 0 + return self._data["keypoints"][0].shape[0] + + def coordinates(self): + return np.stack(self._data["keypoints"]).astype(np.float32) + + """ + self._data.setSelectionRange("index", 0, self._data.numDetections) + self._data.assignTracks(tracks) + self._timeline.setDetectionData(self._data.data) + self.update() + """ + +def main(): + import pandas as pd + from IPython import embed + from fixtracks.info import PACKAGE_ROOT + + def as_dict(df:pd.DataFrame): + d = {c: df[c].values for c in df.columns} + d["index"] = df.index.values + return d + + datafile = PACKAGE_ROOT / "data/merged_small.pkl" + with open(datafile, "rb") as f: + df = pickle.load(f) + + data = TrackingData() + data.setData(as_dict(df)) + embed() + + pass + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/fixtracks/widgets/tracks.py b/fixtracks/widgets/tracks.py index 9eb6f75..85d85e6 100644 --- a/fixtracks/widgets/tracks.py +++ b/fixtracks/widgets/tracks.py @@ -1,5 +1,4 @@ import logging -import pickle import numpy as np import pandas as pd @@ -10,6 +9,7 @@ from PySide6.QtWidgets import QSpinBox, QSpacerItem, QProgressBar, QSplitter, QG from fixtracks.utils.reader import PickleLoader from fixtracks.utils.writer import PickleWriter +from fixtracks.utils.trackingdata import TrackingData from fixtracks.widgets.detectionview import DetectionView, DetectionData from fixtracks.widgets.detectiontimeline import DetectionTimeline from fixtracks.widgets.skeleton import SkeletonWidget @@ -158,91 +158,6 @@ class SelectionControls(QWidget): self._total = len(tracks) -class TrackingData(QObject): - def __init__(self, parent=None): - super().__init__(parent) - self._data = None - self._columns = [] - self._start = 0 - self._stop = 0 - self._indices = None - self._selection_column = None - self._user_selections = None - - def setData(self, datadict): - assert isinstance(datadict, dict) - self._data = datadict - self._columns = [k for k in self._data.keys()] - - @property - def data(self): - return self._data - - @property - def columns(self): - return self._columns - - def max(self, col): - if col in self.columns: - return np.max(self._data[col]) - else: - logging.error("Column %s not in dictionary", col) - return np.nan - - @property - def numDetections(self): - return self._data["track"].shape[0] - - @property - def selectionRange(self): - return self._start, self._stop - - @property - def selectionRangeColumn(self): - return self._selection_column - - @property - def selectionIndices(self): - return self._indices - - def setSelectionRange(self, col, start, stop): - self._start = start - self._stop = stop - self._selection_column = col - self._indices = np.where((self._data[col] >= self._start) & (self._data[col] < self._stop))[0] - - def selectedData(self, col): - return self._data[col][self._indices] - - def setUserSelection(self, ids): - self._user_selections = ids.astype(int) - - def assignUserSelection(self, track_id): - self._data["track"][self._user_selections] = track_id - - def assignTracks(self, tracks): - if len(tracks) != self.numDetections: - logging.error("DataController: Size of passed tracks does not match data!") - return - self._data["track"] = tracks - - def save(self, filename): - export_columns = self._columns.copy() - export_columns.remove("index") - dictionary = {c: self._data[c] for c in export_columns} - df = pd.DataFrame(dictionary, index=self._data["index"]) - with open(filename, 'wb') as f: - pickle.dump(df, f) - - def numKeypoints(self): - if len(self._data["keypoints"]) == 0: - return 0 - return self._data["keypoints"][0].shape[0] - - def coordinates(self): - return np.stack(self._data["keypoints"]).astype(np.float32) - - class FixTracks(QWidget): back = Signal() trackone_id = 1