[trackingdata] extract to util subpackage
This commit is contained in:
parent
b7d4097e73
commit
1e86a74549
121
fixtracks/utils/trackingdata.py
Normal file
121
fixtracks/utils/trackingdata.py
Normal file
@ -0,0 +1,121 @@
|
||||
import pickle
|
||||
import logging
|
||||
import numpy as np
|
||||
|
||||
from PySide6.QtCore import QObject
|
||||
|
||||
|
||||
class TrackingData(QObject):
|
||||
def __init__(self, parent=None):
|
||||
super().__init__(parent)
|
||||
self._data = None
|
||||
self._columns = []
|
||||
self._start = 0
|
||||
self._stop = 0
|
||||
self._indices = None
|
||||
self._selection_column = None
|
||||
self._user_selections = None
|
||||
|
||||
def setData(self, datadict):
|
||||
assert isinstance(datadict, dict)
|
||||
self._data = datadict
|
||||
self._columns = [k for k in self._data.keys()]
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
return self._data
|
||||
|
||||
@property
|
||||
def columns(self):
|
||||
return self._columns
|
||||
|
||||
def max(self, col):
|
||||
if col in self.columns:
|
||||
return np.max(self._data[col])
|
||||
else:
|
||||
logging.error("Column %s not in dictionary", col)
|
||||
return np.nan
|
||||
|
||||
@property
|
||||
def numDetections(self):
|
||||
return self._data["track"].shape[0]
|
||||
|
||||
@property
|
||||
def selectionRange(self):
|
||||
return self._start, self._stop
|
||||
|
||||
@property
|
||||
def selectionRangeColumn(self):
|
||||
return self._selection_column
|
||||
|
||||
@property
|
||||
def selectionIndices(self):
|
||||
return self._indices
|
||||
|
||||
def setSelectionRange(self, col, start, stop):
|
||||
self._start = start
|
||||
self._stop = stop
|
||||
self._selection_column = col
|
||||
self._indices = np.where((self._data[col] >= self._start) & (self._data[col] < self._stop))[0]
|
||||
|
||||
def selectedData(self, col):
|
||||
return self._data[col][self._indices]
|
||||
|
||||
def setUserSelection(self, ids):
|
||||
self._user_selections = ids.astype(int)
|
||||
|
||||
def assignUserSelection(self, track_id):
|
||||
self._data["track"][self._user_selections] = track_id
|
||||
|
||||
def assignTracks(self, tracks):
|
||||
if len(tracks) != self.numDetections:
|
||||
logging.error("DataController: Size of passed tracks does not match data!")
|
||||
return
|
||||
self._data["track"] = tracks
|
||||
|
||||
def save(self, filename):
|
||||
export_columns = self._columns.copy()
|
||||
export_columns.remove("index")
|
||||
dictionary = {c: self._data[c] for c in export_columns}
|
||||
df = pd.DataFrame(dictionary, index=self._data["index"])
|
||||
with open(filename, 'wb') as f:
|
||||
pickle.dump(df, f)
|
||||
|
||||
def numKeypoints(self):
|
||||
if len(self._data["keypoints"]) == 0:
|
||||
return 0
|
||||
return self._data["keypoints"][0].shape[0]
|
||||
|
||||
def coordinates(self):
|
||||
return np.stack(self._data["keypoints"]).astype(np.float32)
|
||||
|
||||
"""
|
||||
self._data.setSelectionRange("index", 0, self._data.numDetections)
|
||||
self._data.assignTracks(tracks)
|
||||
self._timeline.setDetectionData(self._data.data)
|
||||
self.update()
|
||||
"""
|
||||
|
||||
def main():
|
||||
import pandas as pd
|
||||
from IPython import embed
|
||||
from fixtracks.info import PACKAGE_ROOT
|
||||
|
||||
def as_dict(df:pd.DataFrame):
|
||||
d = {c: df[c].values for c in df.columns}
|
||||
d["index"] = df.index.values
|
||||
return d
|
||||
|
||||
datafile = PACKAGE_ROOT / "data/merged_small.pkl"
|
||||
with open(datafile, "rb") as f:
|
||||
df = pickle.load(f)
|
||||
|
||||
data = TrackingData()
|
||||
data.setData(as_dict(df))
|
||||
embed()
|
||||
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,5 +1,4 @@
|
||||
import logging
|
||||
import pickle
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
@ -10,6 +9,7 @@ from PySide6.QtWidgets import QSpinBox, QSpacerItem, QProgressBar, QSplitter, QG
|
||||
|
||||
from fixtracks.utils.reader import PickleLoader
|
||||
from fixtracks.utils.writer import PickleWriter
|
||||
from fixtracks.utils.trackingdata import TrackingData
|
||||
from fixtracks.widgets.detectionview import DetectionView, DetectionData
|
||||
from fixtracks.widgets.detectiontimeline import DetectionTimeline
|
||||
from fixtracks.widgets.skeleton import SkeletonWidget
|
||||
@ -158,91 +158,6 @@ class SelectionControls(QWidget):
|
||||
self._total = len(tracks)
|
||||
|
||||
|
||||
class TrackingData(QObject):
|
||||
def __init__(self, parent=None):
|
||||
super().__init__(parent)
|
||||
self._data = None
|
||||
self._columns = []
|
||||
self._start = 0
|
||||
self._stop = 0
|
||||
self._indices = None
|
||||
self._selection_column = None
|
||||
self._user_selections = None
|
||||
|
||||
def setData(self, datadict):
|
||||
assert isinstance(datadict, dict)
|
||||
self._data = datadict
|
||||
self._columns = [k for k in self._data.keys()]
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
return self._data
|
||||
|
||||
@property
|
||||
def columns(self):
|
||||
return self._columns
|
||||
|
||||
def max(self, col):
|
||||
if col in self.columns:
|
||||
return np.max(self._data[col])
|
||||
else:
|
||||
logging.error("Column %s not in dictionary", col)
|
||||
return np.nan
|
||||
|
||||
@property
|
||||
def numDetections(self):
|
||||
return self._data["track"].shape[0]
|
||||
|
||||
@property
|
||||
def selectionRange(self):
|
||||
return self._start, self._stop
|
||||
|
||||
@property
|
||||
def selectionRangeColumn(self):
|
||||
return self._selection_column
|
||||
|
||||
@property
|
||||
def selectionIndices(self):
|
||||
return self._indices
|
||||
|
||||
def setSelectionRange(self, col, start, stop):
|
||||
self._start = start
|
||||
self._stop = stop
|
||||
self._selection_column = col
|
||||
self._indices = np.where((self._data[col] >= self._start) & (self._data[col] < self._stop))[0]
|
||||
|
||||
def selectedData(self, col):
|
||||
return self._data[col][self._indices]
|
||||
|
||||
def setUserSelection(self, ids):
|
||||
self._user_selections = ids.astype(int)
|
||||
|
||||
def assignUserSelection(self, track_id):
|
||||
self._data["track"][self._user_selections] = track_id
|
||||
|
||||
def assignTracks(self, tracks):
|
||||
if len(tracks) != self.numDetections:
|
||||
logging.error("DataController: Size of passed tracks does not match data!")
|
||||
return
|
||||
self._data["track"] = tracks
|
||||
|
||||
def save(self, filename):
|
||||
export_columns = self._columns.copy()
|
||||
export_columns.remove("index")
|
||||
dictionary = {c: self._data[c] for c in export_columns}
|
||||
df = pd.DataFrame(dictionary, index=self._data["index"])
|
||||
with open(filename, 'wb') as f:
|
||||
pickle.dump(df, f)
|
||||
|
||||
def numKeypoints(self):
|
||||
if len(self._data["keypoints"]) == 0:
|
||||
return 0
|
||||
return self._data["keypoints"][0].shape[0]
|
||||
|
||||
def coordinates(self):
|
||||
return np.stack(self._data["keypoints"]).astype(np.float32)
|
||||
|
||||
|
||||
class FixTracks(QWidget):
|
||||
back = Signal()
|
||||
trackone_id = 1
|
||||
|
Loading…
Reference in New Issue
Block a user