[trackingdata] extract to util subpackage
This commit is contained in:
parent
b7d4097e73
commit
1e86a74549
121
fixtracks/utils/trackingdata.py
Normal file
121
fixtracks/utils/trackingdata.py
Normal file
@ -0,0 +1,121 @@
|
|||||||
|
import pickle
|
||||||
|
import logging
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from PySide6.QtCore import QObject
|
||||||
|
|
||||||
|
|
||||||
|
class TrackingData(QObject):
|
||||||
|
def __init__(self, parent=None):
|
||||||
|
super().__init__(parent)
|
||||||
|
self._data = None
|
||||||
|
self._columns = []
|
||||||
|
self._start = 0
|
||||||
|
self._stop = 0
|
||||||
|
self._indices = None
|
||||||
|
self._selection_column = None
|
||||||
|
self._user_selections = None
|
||||||
|
|
||||||
|
def setData(self, datadict):
|
||||||
|
assert isinstance(datadict, dict)
|
||||||
|
self._data = datadict
|
||||||
|
self._columns = [k for k in self._data.keys()]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def data(self):
|
||||||
|
return self._data
|
||||||
|
|
||||||
|
@property
|
||||||
|
def columns(self):
|
||||||
|
return self._columns
|
||||||
|
|
||||||
|
def max(self, col):
|
||||||
|
if col in self.columns:
|
||||||
|
return np.max(self._data[col])
|
||||||
|
else:
|
||||||
|
logging.error("Column %s not in dictionary", col)
|
||||||
|
return np.nan
|
||||||
|
|
||||||
|
@property
|
||||||
|
def numDetections(self):
|
||||||
|
return self._data["track"].shape[0]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def selectionRange(self):
|
||||||
|
return self._start, self._stop
|
||||||
|
|
||||||
|
@property
|
||||||
|
def selectionRangeColumn(self):
|
||||||
|
return self._selection_column
|
||||||
|
|
||||||
|
@property
|
||||||
|
def selectionIndices(self):
|
||||||
|
return self._indices
|
||||||
|
|
||||||
|
def setSelectionRange(self, col, start, stop):
|
||||||
|
self._start = start
|
||||||
|
self._stop = stop
|
||||||
|
self._selection_column = col
|
||||||
|
self._indices = np.where((self._data[col] >= self._start) & (self._data[col] < self._stop))[0]
|
||||||
|
|
||||||
|
def selectedData(self, col):
|
||||||
|
return self._data[col][self._indices]
|
||||||
|
|
||||||
|
def setUserSelection(self, ids):
|
||||||
|
self._user_selections = ids.astype(int)
|
||||||
|
|
||||||
|
def assignUserSelection(self, track_id):
|
||||||
|
self._data["track"][self._user_selections] = track_id
|
||||||
|
|
||||||
|
def assignTracks(self, tracks):
|
||||||
|
if len(tracks) != self.numDetections:
|
||||||
|
logging.error("DataController: Size of passed tracks does not match data!")
|
||||||
|
return
|
||||||
|
self._data["track"] = tracks
|
||||||
|
|
||||||
|
def save(self, filename):
|
||||||
|
export_columns = self._columns.copy()
|
||||||
|
export_columns.remove("index")
|
||||||
|
dictionary = {c: self._data[c] for c in export_columns}
|
||||||
|
df = pd.DataFrame(dictionary, index=self._data["index"])
|
||||||
|
with open(filename, 'wb') as f:
|
||||||
|
pickle.dump(df, f)
|
||||||
|
|
||||||
|
def numKeypoints(self):
|
||||||
|
if len(self._data["keypoints"]) == 0:
|
||||||
|
return 0
|
||||||
|
return self._data["keypoints"][0].shape[0]
|
||||||
|
|
||||||
|
def coordinates(self):
|
||||||
|
return np.stack(self._data["keypoints"]).astype(np.float32)
|
||||||
|
|
||||||
|
"""
|
||||||
|
self._data.setSelectionRange("index", 0, self._data.numDetections)
|
||||||
|
self._data.assignTracks(tracks)
|
||||||
|
self._timeline.setDetectionData(self._data.data)
|
||||||
|
self.update()
|
||||||
|
"""
|
||||||
|
|
||||||
|
def main():
|
||||||
|
import pandas as pd
|
||||||
|
from IPython import embed
|
||||||
|
from fixtracks.info import PACKAGE_ROOT
|
||||||
|
|
||||||
|
def as_dict(df:pd.DataFrame):
|
||||||
|
d = {c: df[c].values for c in df.columns}
|
||||||
|
d["index"] = df.index.values
|
||||||
|
return d
|
||||||
|
|
||||||
|
datafile = PACKAGE_ROOT / "data/merged_small.pkl"
|
||||||
|
with open(datafile, "rb") as f:
|
||||||
|
df = pickle.load(f)
|
||||||
|
|
||||||
|
data = TrackingData()
|
||||||
|
data.setData(as_dict(df))
|
||||||
|
embed()
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -1,5 +1,4 @@
|
|||||||
import logging
|
import logging
|
||||||
import pickle
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
@ -10,6 +9,7 @@ from PySide6.QtWidgets import QSpinBox, QSpacerItem, QProgressBar, QSplitter, QG
|
|||||||
|
|
||||||
from fixtracks.utils.reader import PickleLoader
|
from fixtracks.utils.reader import PickleLoader
|
||||||
from fixtracks.utils.writer import PickleWriter
|
from fixtracks.utils.writer import PickleWriter
|
||||||
|
from fixtracks.utils.trackingdata import TrackingData
|
||||||
from fixtracks.widgets.detectionview import DetectionView, DetectionData
|
from fixtracks.widgets.detectionview import DetectionView, DetectionData
|
||||||
from fixtracks.widgets.detectiontimeline import DetectionTimeline
|
from fixtracks.widgets.detectiontimeline import DetectionTimeline
|
||||||
from fixtracks.widgets.skeleton import SkeletonWidget
|
from fixtracks.widgets.skeleton import SkeletonWidget
|
||||||
@ -158,91 +158,6 @@ class SelectionControls(QWidget):
|
|||||||
self._total = len(tracks)
|
self._total = len(tracks)
|
||||||
|
|
||||||
|
|
||||||
class TrackingData(QObject):
|
|
||||||
def __init__(self, parent=None):
|
|
||||||
super().__init__(parent)
|
|
||||||
self._data = None
|
|
||||||
self._columns = []
|
|
||||||
self._start = 0
|
|
||||||
self._stop = 0
|
|
||||||
self._indices = None
|
|
||||||
self._selection_column = None
|
|
||||||
self._user_selections = None
|
|
||||||
|
|
||||||
def setData(self, datadict):
|
|
||||||
assert isinstance(datadict, dict)
|
|
||||||
self._data = datadict
|
|
||||||
self._columns = [k for k in self._data.keys()]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def data(self):
|
|
||||||
return self._data
|
|
||||||
|
|
||||||
@property
|
|
||||||
def columns(self):
|
|
||||||
return self._columns
|
|
||||||
|
|
||||||
def max(self, col):
|
|
||||||
if col in self.columns:
|
|
||||||
return np.max(self._data[col])
|
|
||||||
else:
|
|
||||||
logging.error("Column %s not in dictionary", col)
|
|
||||||
return np.nan
|
|
||||||
|
|
||||||
@property
|
|
||||||
def numDetections(self):
|
|
||||||
return self._data["track"].shape[0]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def selectionRange(self):
|
|
||||||
return self._start, self._stop
|
|
||||||
|
|
||||||
@property
|
|
||||||
def selectionRangeColumn(self):
|
|
||||||
return self._selection_column
|
|
||||||
|
|
||||||
@property
|
|
||||||
def selectionIndices(self):
|
|
||||||
return self._indices
|
|
||||||
|
|
||||||
def setSelectionRange(self, col, start, stop):
|
|
||||||
self._start = start
|
|
||||||
self._stop = stop
|
|
||||||
self._selection_column = col
|
|
||||||
self._indices = np.where((self._data[col] >= self._start) & (self._data[col] < self._stop))[0]
|
|
||||||
|
|
||||||
def selectedData(self, col):
|
|
||||||
return self._data[col][self._indices]
|
|
||||||
|
|
||||||
def setUserSelection(self, ids):
|
|
||||||
self._user_selections = ids.astype(int)
|
|
||||||
|
|
||||||
def assignUserSelection(self, track_id):
|
|
||||||
self._data["track"][self._user_selections] = track_id
|
|
||||||
|
|
||||||
def assignTracks(self, tracks):
|
|
||||||
if len(tracks) != self.numDetections:
|
|
||||||
logging.error("DataController: Size of passed tracks does not match data!")
|
|
||||||
return
|
|
||||||
self._data["track"] = tracks
|
|
||||||
|
|
||||||
def save(self, filename):
|
|
||||||
export_columns = self._columns.copy()
|
|
||||||
export_columns.remove("index")
|
|
||||||
dictionary = {c: self._data[c] for c in export_columns}
|
|
||||||
df = pd.DataFrame(dictionary, index=self._data["index"])
|
|
||||||
with open(filename, 'wb') as f:
|
|
||||||
pickle.dump(df, f)
|
|
||||||
|
|
||||||
def numKeypoints(self):
|
|
||||||
if len(self._data["keypoints"]) == 0:
|
|
||||||
return 0
|
|
||||||
return self._data["keypoints"][0].shape[0]
|
|
||||||
|
|
||||||
def coordinates(self):
|
|
||||||
return np.stack(self._data["keypoints"]).astype(np.float32)
|
|
||||||
|
|
||||||
|
|
||||||
class FixTracks(QWidget):
|
class FixTracks(QWidget):
|
||||||
back = Signal()
|
back = Signal()
|
||||||
trackone_id = 1
|
trackone_id = 1
|
||||||
|
Loading…
Reference in New Issue
Block a user