[trackingdata] extract to util subpackage

This commit is contained in:
Jan Grewe 2025-02-09 11:32:03 +01:00
parent b7d4097e73
commit 1e86a74549
2 changed files with 122 additions and 86 deletions

View File

@ -0,0 +1,121 @@
import pickle
import logging
import numpy as np
from PySide6.QtCore import QObject
class TrackingData(QObject):
def __init__(self, parent=None):
super().__init__(parent)
self._data = None
self._columns = []
self._start = 0
self._stop = 0
self._indices = None
self._selection_column = None
self._user_selections = None
def setData(self, datadict):
assert isinstance(datadict, dict)
self._data = datadict
self._columns = [k for k in self._data.keys()]
@property
def data(self):
return self._data
@property
def columns(self):
return self._columns
def max(self, col):
if col in self.columns:
return np.max(self._data[col])
else:
logging.error("Column %s not in dictionary", col)
return np.nan
@property
def numDetections(self):
return self._data["track"].shape[0]
@property
def selectionRange(self):
return self._start, self._stop
@property
def selectionRangeColumn(self):
return self._selection_column
@property
def selectionIndices(self):
return self._indices
def setSelectionRange(self, col, start, stop):
self._start = start
self._stop = stop
self._selection_column = col
self._indices = np.where((self._data[col] >= self._start) & (self._data[col] < self._stop))[0]
def selectedData(self, col):
return self._data[col][self._indices]
def setUserSelection(self, ids):
self._user_selections = ids.astype(int)
def assignUserSelection(self, track_id):
self._data["track"][self._user_selections] = track_id
def assignTracks(self, tracks):
if len(tracks) != self.numDetections:
logging.error("DataController: Size of passed tracks does not match data!")
return
self._data["track"] = tracks
def save(self, filename):
export_columns = self._columns.copy()
export_columns.remove("index")
dictionary = {c: self._data[c] for c in export_columns}
df = pd.DataFrame(dictionary, index=self._data["index"])
with open(filename, 'wb') as f:
pickle.dump(df, f)
def numKeypoints(self):
if len(self._data["keypoints"]) == 0:
return 0
return self._data["keypoints"][0].shape[0]
def coordinates(self):
return np.stack(self._data["keypoints"]).astype(np.float32)
"""
self._data.setSelectionRange("index", 0, self._data.numDetections)
self._data.assignTracks(tracks)
self._timeline.setDetectionData(self._data.data)
self.update()
"""
def main():
import pandas as pd
from IPython import embed
from fixtracks.info import PACKAGE_ROOT
def as_dict(df:pd.DataFrame):
d = {c: df[c].values for c in df.columns}
d["index"] = df.index.values
return d
datafile = PACKAGE_ROOT / "data/merged_small.pkl"
with open(datafile, "rb") as f:
df = pickle.load(f)
data = TrackingData()
data.setData(as_dict(df))
embed()
pass
if __name__ == "__main__":
main()

View File

@ -1,5 +1,4 @@
import logging
import pickle
import numpy as np
import pandas as pd
@ -10,6 +9,7 @@ from PySide6.QtWidgets import QSpinBox, QSpacerItem, QProgressBar, QSplitter, QG
from fixtracks.utils.reader import PickleLoader
from fixtracks.utils.writer import PickleWriter
from fixtracks.utils.trackingdata import TrackingData
from fixtracks.widgets.detectionview import DetectionView, DetectionData
from fixtracks.widgets.detectiontimeline import DetectionTimeline
from fixtracks.widgets.skeleton import SkeletonWidget
@ -158,91 +158,6 @@ class SelectionControls(QWidget):
self._total = len(tracks)
class TrackingData(QObject):
def __init__(self, parent=None):
super().__init__(parent)
self._data = None
self._columns = []
self._start = 0
self._stop = 0
self._indices = None
self._selection_column = None
self._user_selections = None
def setData(self, datadict):
assert isinstance(datadict, dict)
self._data = datadict
self._columns = [k for k in self._data.keys()]
@property
def data(self):
return self._data
@property
def columns(self):
return self._columns
def max(self, col):
if col in self.columns:
return np.max(self._data[col])
else:
logging.error("Column %s not in dictionary", col)
return np.nan
@property
def numDetections(self):
return self._data["track"].shape[0]
@property
def selectionRange(self):
return self._start, self._stop
@property
def selectionRangeColumn(self):
return self._selection_column
@property
def selectionIndices(self):
return self._indices
def setSelectionRange(self, col, start, stop):
self._start = start
self._stop = stop
self._selection_column = col
self._indices = np.where((self._data[col] >= self._start) & (self._data[col] < self._stop))[0]
def selectedData(self, col):
return self._data[col][self._indices]
def setUserSelection(self, ids):
self._user_selections = ids.astype(int)
def assignUserSelection(self, track_id):
self._data["track"][self._user_selections] = track_id
def assignTracks(self, tracks):
if len(tracks) != self.numDetections:
logging.error("DataController: Size of passed tracks does not match data!")
return
self._data["track"] = tracks
def save(self, filename):
export_columns = self._columns.copy()
export_columns.remove("index")
dictionary = {c: self._data[c] for c in export_columns}
df = pd.DataFrame(dictionary, index=self._data["index"])
with open(filename, 'wb') as f:
pickle.dump(df, f)
def numKeypoints(self):
if len(self._data["keypoints"]) == 0:
return 0
return self._data["keypoints"][0].shape[0]
def coordinates(self):
return np.stack(self._data["keypoints"]).astype(np.float32)
class FixTracks(QWidget):
back = Signal()
trackone_id = 1