[trackingdata] extract to util subpackage
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
import logging
|
||||
import pickle
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
@@ -10,6 +9,7 @@ from PySide6.QtWidgets import QSpinBox, QSpacerItem, QProgressBar, QSplitter, QG
|
||||
|
||||
from fixtracks.utils.reader import PickleLoader
|
||||
from fixtracks.utils.writer import PickleWriter
|
||||
from fixtracks.utils.trackingdata import TrackingData
|
||||
from fixtracks.widgets.detectionview import DetectionView, DetectionData
|
||||
from fixtracks.widgets.detectiontimeline import DetectionTimeline
|
||||
from fixtracks.widgets.skeleton import SkeletonWidget
|
||||
@@ -158,91 +158,6 @@ class SelectionControls(QWidget):
|
||||
self._total = len(tracks)
|
||||
|
||||
|
||||
class TrackingData(QObject):
|
||||
def __init__(self, parent=None):
|
||||
super().__init__(parent)
|
||||
self._data = None
|
||||
self._columns = []
|
||||
self._start = 0
|
||||
self._stop = 0
|
||||
self._indices = None
|
||||
self._selection_column = None
|
||||
self._user_selections = None
|
||||
|
||||
def setData(self, datadict):
|
||||
assert isinstance(datadict, dict)
|
||||
self._data = datadict
|
||||
self._columns = [k for k in self._data.keys()]
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
return self._data
|
||||
|
||||
@property
|
||||
def columns(self):
|
||||
return self._columns
|
||||
|
||||
def max(self, col):
|
||||
if col in self.columns:
|
||||
return np.max(self._data[col])
|
||||
else:
|
||||
logging.error("Column %s not in dictionary", col)
|
||||
return np.nan
|
||||
|
||||
@property
|
||||
def numDetections(self):
|
||||
return self._data["track"].shape[0]
|
||||
|
||||
@property
|
||||
def selectionRange(self):
|
||||
return self._start, self._stop
|
||||
|
||||
@property
|
||||
def selectionRangeColumn(self):
|
||||
return self._selection_column
|
||||
|
||||
@property
|
||||
def selectionIndices(self):
|
||||
return self._indices
|
||||
|
||||
def setSelectionRange(self, col, start, stop):
|
||||
self._start = start
|
||||
self._stop = stop
|
||||
self._selection_column = col
|
||||
self._indices = np.where((self._data[col] >= self._start) & (self._data[col] < self._stop))[0]
|
||||
|
||||
def selectedData(self, col):
|
||||
return self._data[col][self._indices]
|
||||
|
||||
def setUserSelection(self, ids):
|
||||
self._user_selections = ids.astype(int)
|
||||
|
||||
def assignUserSelection(self, track_id):
|
||||
self._data["track"][self._user_selections] = track_id
|
||||
|
||||
def assignTracks(self, tracks):
|
||||
if len(tracks) != self.numDetections:
|
||||
logging.error("DataController: Size of passed tracks does not match data!")
|
||||
return
|
||||
self._data["track"] = tracks
|
||||
|
||||
def save(self, filename):
|
||||
export_columns = self._columns.copy()
|
||||
export_columns.remove("index")
|
||||
dictionary = {c: self._data[c] for c in export_columns}
|
||||
df = pd.DataFrame(dictionary, index=self._data["index"])
|
||||
with open(filename, 'wb') as f:
|
||||
pickle.dump(df, f)
|
||||
|
||||
def numKeypoints(self):
|
||||
if len(self._data["keypoints"]) == 0:
|
||||
return 0
|
||||
return self._data["keypoints"][0].shape[0]
|
||||
|
||||
def coordinates(self):
|
||||
return np.stack(self._data["keypoints"]).astype(np.float32)
|
||||
|
||||
|
||||
class FixTracks(QWidget):
|
||||
back = Signal()
|
||||
trackone_id = 1
|
||||
|
||||
Reference in New Issue
Block a user