589 lines
22 KiB
Python
589 lines
22 KiB
Python
import logging
|
|
import numpy as np
|
|
|
|
from PySide6.QtWidgets import QWidget, QVBoxLayout, QTabWidget, QPushButton, QGraphicsView
|
|
from PySide6.QtWidgets import QSpinBox, QProgressBar, QGridLayout, QLabel, QCheckBox, QProgressDialog
|
|
from PySide6.QtCore import Qt, Signal, Slot, QRunnable, QObject, QThreadPool
|
|
from PySide6.QtGui import QBrush, QColor
|
|
|
|
import pyqtgraph as pg # needs to be imported after pyside to not import pyqt
|
|
|
|
from fixtracks.utils.trackingdata import TrackingData
|
|
|
|
from IPython import embed
|
|
|
|
class WorkerSignals(QObject):
|
|
error = Signal(str)
|
|
running = Signal(bool)
|
|
progress = Signal(int, int, int)
|
|
stopped = Signal(int)
|
|
|
|
class ConsitencyDataLoader(QRunnable):
|
|
def __init__(self, data):
|
|
super().__init__()
|
|
self.signals = WorkerSignals()
|
|
self.data = data
|
|
self.bendedness = self.positions = None
|
|
self.lengths = None
|
|
self.orientations = None
|
|
self.userlabeled = None
|
|
self.scores = None
|
|
self.frames = None
|
|
self.tracks = None
|
|
|
|
@Slot()
|
|
def run(self):
|
|
self.positions = self.data.centerOfGravity()
|
|
self.orientations = self.data.orientation()
|
|
self.lengths = self.data.animalLength()
|
|
self.bendedness = self.data.bendedness()
|
|
self.userlabeled = self.data["userlabeled"]
|
|
self.scores = self.data["confidence"] # ignore for now, let's see how far this carries.
|
|
self.frames = self.data["frame"]
|
|
self.tracks = self.data["track"]
|
|
self.signals.stopped.emit(0)
|
|
|
|
class ConsistencyWorker(QRunnable):
|
|
|
|
def __init__(self, positions, orientations, lengths, bendedness, frames, tracks,
|
|
userlabeled, startframe=0, stoponerror=False) -> None:
|
|
super().__init__()
|
|
self.signals = WorkerSignals()
|
|
self.positions = positions
|
|
self.orientations = orientations
|
|
self.lengths = lengths
|
|
self.bendedness = bendedness
|
|
self.userlabeled = userlabeled
|
|
self.frames = frames
|
|
self.tracks = tracks
|
|
self._startframe = startframe
|
|
self._stoprequest = False
|
|
self._stoponerror = stoponerror
|
|
|
|
@Slot()
|
|
def stop(self):
|
|
self._stoprequest = True
|
|
|
|
@Slot()
|
|
def run(self):
|
|
def needs_checking(original, new):
|
|
res = False
|
|
for n, o in zip(new, original):
|
|
res = (o == 1 or o == 2) and n != o
|
|
if not res:
|
|
res = len(new) > 1 and (np.all(new == 1) or np.all(new == 2))
|
|
return res
|
|
|
|
def assign_by_distance(f, p):
|
|
t1_step = f - last_frame[0]
|
|
t2_step = f - last_frame[1]
|
|
if t1_step == 0 or t2_step == 0:
|
|
print(f"framecount is zero! current frame {f}, last frame {last_frame[0]} and {last_frame[1]}")
|
|
|
|
distance_to_trackone = np.linalg.norm(p - last_pos[0])/t1_step
|
|
distance_to_tracktwo = np.linalg.norm(p - last_pos[1])/t2_step
|
|
most_likely_track = np.argmin([distance_to_trackone, distance_to_tracktwo]) + 1
|
|
distances = np.zeros(2)
|
|
distances[0] = distance_to_trackone
|
|
distances[1] = distance_to_tracktwo
|
|
return most_likely_track, distances
|
|
|
|
def assign_by_orientation(f, o):
|
|
t1_step = f - last_frame[0]
|
|
t2_step = f - last_frame[1]
|
|
orientationchange = np.unwrap((last_angle - o)/np.array([t1_step, t2_step]))
|
|
most_likely_track = np.argmin(orientationchange) + 1
|
|
return most_likely_track, orientationchange
|
|
|
|
last_pos = [self.positions[(self.tracks == 1) & (self.frames <= self._startframe)][-1],
|
|
self.positions[(self.tracks == 2) & (self.frames <= self._startframe)][-1]]
|
|
last_frame = [self.frames[(self.tracks == 1) & (self.frames <= self._startframe)][-1],
|
|
self.frames[(self.tracks == 2) & (self.frames <= self._startframe)][-1]]
|
|
last_angle = [self.orientations[(self.tracks == 1) & (self.frames <= self._startframe)][-1],
|
|
self.orientations[(self.tracks == 2) & (self.frames <= self._startframe)][-1]]
|
|
errors = 0
|
|
processed = 1
|
|
progress = 0
|
|
assignment_error = False
|
|
self._stoprequest = False
|
|
maxframes = np.max(self.frames)
|
|
startframe = np.max(last_frame)
|
|
steps = int((maxframes - startframe) // 200)
|
|
|
|
for f in np.unique(self.frames[self.frames > startframe]):
|
|
if self._stoprequest:
|
|
break
|
|
indices = np.where(self.frames == f)[0]
|
|
pp = self.positions[indices]
|
|
originaltracks = self.tracks[indices]
|
|
dist_assignments = np.zeros_like(originaltracks)
|
|
angle_assignments = np.zeros_like(originaltracks)
|
|
# userlabeld = np.zeros_like(originaltracks)
|
|
distances = np.zeros((len(originaltracks), 2))
|
|
orientations = np.zeros((len(originaltracks), 2))
|
|
|
|
for i, (idx, p) in enumerate(zip(indices, pp)):
|
|
if self.userlabeled[idx]:
|
|
print("user")
|
|
processed += 1
|
|
last_pos[originaltracks[i]-1] = pp[i]
|
|
last_frame[originaltracks[i]-1] = f
|
|
last_angle[originaltracks[i]-1] = self.orientations[idx]
|
|
continue
|
|
dist_assignments[i], distances[i, :] = assign_by_distance(f, p)
|
|
angle_assignments[i], orientations[i,:] = assign_by_orientation(f, self.orientations[idx])
|
|
|
|
# check (re) assignment update and proceed
|
|
print("dist", distances)
|
|
print("angle", orientations)
|
|
if needs_checking(originaltracks, dist_assignments):
|
|
logging.info("frame %i: Issues assigning based on distances %s", f, str(distances))
|
|
assignment_error = True
|
|
errors += 1
|
|
if self._stoponerror:
|
|
embed()
|
|
break
|
|
else:
|
|
processed += 1
|
|
for i, idx in enumerate(indices):
|
|
if assignment_error:
|
|
self.tracks[idx] = -1
|
|
else:
|
|
self.tracks[idx] = dist_assignments[i]
|
|
last_pos[dist_assignments[i]-1] = pp[i]
|
|
last_frame[dist_assignments[i]-1] = f
|
|
last_angle[dist_assignments[i]-1] = self.orientations[idx]
|
|
assignment_error = False
|
|
if steps > 0 and f % steps == 0:
|
|
progress += 1
|
|
self.signals.progress.emit(progress, processed, errors)
|
|
|
|
self.signals.stopped.emit(f)
|
|
|
|
|
|
class SizeClassifier(QWidget):
|
|
apply = Signal()
|
|
name = "Size classifier"
|
|
|
|
def __init__(self, parent=None):
|
|
super().__init__(parent)
|
|
self._t1_selection = None
|
|
self._t2_selection = None
|
|
self._coordinates = None
|
|
self._sizes = None
|
|
|
|
self._plot_widget = self.setupGraph()
|
|
self._apply_btn = QPushButton("apply")
|
|
self._apply_btn.clicked.connect(lambda: self.apply.emit())
|
|
|
|
layout = QVBoxLayout()
|
|
layout.addWidget(self._plot_widget)
|
|
layout.addWidget(self._apply_btn)
|
|
self.setLayout(layout)
|
|
|
|
def setupGraph(self):
|
|
track1_brush = QBrush(QColor.fromString("orange"))
|
|
track2_brush = QBrush(QColor.fromString("green"))
|
|
|
|
pg.setConfigOptions(antialias=True)
|
|
plot_widget = pg.GraphicsLayoutWidget(show=False)
|
|
|
|
self._t1_selection = pg.LinearRegionItem([100, 200])
|
|
self._t1_selection.setZValue(-10)
|
|
self._t1_selection.setBrush("orange")
|
|
self._t2_selection = pg.LinearRegionItem([300,400])
|
|
self._t2_selection.setZValue(-10)
|
|
self._t2_selection.setBrush("green")
|
|
return plot_widget
|
|
|
|
def estimate_length(self, coords, bodyaxis =None):
|
|
if bodyaxis is None:
|
|
bodyaxis = [0, 1, 2, 5]
|
|
bodycoords = coords[:, bodyaxis, :]
|
|
lengths = np.sum(np.sqrt(np.sum(np.diff(bodycoords, axis=1)**2, axis=2)), axis=1)
|
|
return lengths
|
|
|
|
def estimate_histogram(self, dists, min_threshold=1., max_threshold=99.):
|
|
min_length = np.percentile(dists, min_threshold)
|
|
max_length = np.percentile(dists, max_threshold)
|
|
bins = np.linspace(0.5 * min_length, 1.5 * max_length, 100)
|
|
hist, edges = np.histogram(dists, bins=bins, density=True)
|
|
return hist, edges
|
|
|
|
def setCoordinates(self, coordinates):
|
|
self._coordinates = coordinates
|
|
self._sizes = self.estimate_length(coordinates)
|
|
n, e = self.estimate_histogram(self._sizes)
|
|
plot = self._plot_widget.addPlot()
|
|
bgi = pg.BarGraphItem(x0=e[:-1], x1=e[1:], height=n, pen='w', brush=(0,0,255,150))
|
|
plot.addItem(bgi)
|
|
plot.setLabel('left', "prob. density")
|
|
plot.setLabel('bottom', "bodylength", units="px")
|
|
plot.addItem(self._t1_selection)
|
|
plot.addItem(self._t2_selection)
|
|
|
|
def selections(self, track1=True):
|
|
if track1:
|
|
return self._t1_selection.getRegion()
|
|
else:
|
|
return self._t2_selection.getRegion()
|
|
|
|
def assignedTracks(self):
|
|
tracks = np.ones_like(self._sizes, dtype=int) * -1
|
|
t1lower, t1upper = self.selections()
|
|
t2lower, t2upper = self.selections(False)
|
|
tracks[(self._sizes >= t1lower) & (self._sizes < t1upper)] = 1
|
|
tracks[(self._sizes >= t2lower) & (self._sizes < t2upper)] = 2
|
|
return tracks
|
|
|
|
class NeighborhoodValidator(QWidget):
|
|
apply = Signal()
|
|
name = "Neighborhood Validator"
|
|
|
|
def __init__(self, parent = None):
|
|
super().__init__(parent)
|
|
self._threshold = None
|
|
self._positions = None
|
|
self._distances = None
|
|
self._tracks = None
|
|
self._frames = None
|
|
self._plot = None
|
|
|
|
self._plot_widget = self.setupGraph()
|
|
self._apply_btn = QPushButton("apply")
|
|
self._apply_btn.clicked.connect(lambda: self.apply.emit())
|
|
|
|
layout = QVBoxLayout()
|
|
print(isinstance(self._plot_widget, QGraphicsView))
|
|
layout.addWidget(self._plot_widget)
|
|
layout.addWidget(self._apply_btn)
|
|
self.setLayout(layout)
|
|
|
|
def setupGraph(self):
|
|
pg.setConfigOptions(antialias=True)
|
|
plot_widget = pg.GraphicsLayoutWidget(show=False)
|
|
self._threshold = pg.LineSegmentROI([[10, 64], [120,64]], pen='r')
|
|
self._threshold.setZValue(-10) # what is that?
|
|
return plot_widget
|
|
|
|
def estimate_histogram(self, dists, min_threshold=1., max_threshold=99., bin_count=100, log=False):
|
|
min_dist = np.percentile(dists, min_threshold)
|
|
max_dist = np.percentile(dists, max_threshold)
|
|
print(min_dist, max_dist)
|
|
if log:
|
|
bins = np.logspace(min_dist, max_dist, bin_count, base=10)
|
|
else:
|
|
bins = np.linspace(min_dist, max_dist, bin_count)
|
|
hist, edges = np.histogram(dists, bins=bins, density=True)
|
|
return hist, edges
|
|
|
|
def neighborDistances(self, x, frames, n=5, symmetric=True):
|
|
logging.debug("classifier:NeighborhoodValidator neighborDistance")
|
|
pad_shape = list(x.shape)
|
|
pad_shape[0] = n
|
|
pad = np.atleast_2d(np.zeros(pad_shape))
|
|
if symmetric:
|
|
padded_x = np.vstack((pad, x, pad))
|
|
dists = np.zeros((x.shape[0]-1, 2*n))
|
|
else:
|
|
padded_x = np.vstack((pad, x))
|
|
dists = np.zeros((x.shape[0]-1, n))
|
|
|
|
count = 0
|
|
r = range(-n, n+1) if symmetric else range(-n, 0)
|
|
for i in r:
|
|
if i == 0:
|
|
continue
|
|
shifted_x = np.roll(padded_x, i, axis=0)
|
|
dist = np.sqrt(np.sum((padded_x - shifted_x)**2, axis=1))
|
|
dists[:, count] = dist[n+1:]/np.diff(frames)
|
|
count += 1
|
|
return dists
|
|
|
|
def setData(self, positions, tracks, frames):
|
|
"""Set the data, the classifier/should be working on.
|
|
|
|
Parameters
|
|
----------
|
|
positions : np.ndarray
|
|
The position estimates, e.g. the center of gravity for each detection
|
|
tracks : np.ndarray
|
|
The current track assignment.
|
|
frames : np.ndarray
|
|
respective frame.
|
|
"""
|
|
def mouseClicked(event):
|
|
pos = event.pos()
|
|
if self._plot.sceneBoundingRect().contains(pos):
|
|
mousePoint = vb.mapSceneToView(pos)
|
|
print("mouse clicked at", mousePoint)
|
|
vLine.setPos(mousePoint.x())
|
|
track2_brush = QBrush(QColor.fromString("green"))
|
|
track1_brush = QBrush(QColor.fromString("orange"))
|
|
self._positions = positions
|
|
self._tracks = tracks
|
|
self._frames = frames
|
|
t1_positions = self._positions[self._tracks == 1]
|
|
t1_frames = self._frames[self._tracks == 1]
|
|
t1_distances = self.neighborDistances(t1_positions, t1_frames, 1, False)
|
|
t2_positions = self._positions[self._tracks == 2]
|
|
t2_frames = self._frames[self._tracks == 2]
|
|
t2_distances = self.neighborDistances(t2_positions, t2_frames, 1, False)
|
|
|
|
self._plot = self._plot_widget.addPlot()
|
|
vb = self._plot.vb
|
|
n, e = self.estimate_histogram(t1_distances[1:], 1, 95, bin_count=100, log=False)
|
|
bgi1 = pg.BarGraphItem(x0=e[:-1], x1=e[1:], height=n, pen='w', brush=track1_brush)
|
|
self._plot.addItem(bgi1)
|
|
n, e = self.estimate_histogram(t2_distances[1:], 1, 95, bin_count=100, log=False)
|
|
bgi2 = pg.BarGraphItem(x0=e[:-1], x1=e[1:], height=n, pen='w', brush=track2_brush)
|
|
self._plot.addItem(bgi2)
|
|
self._plot.scene().sigMouseClicked.connect(mouseClicked)
|
|
self._plot.setLogMode(x=False, y=True)
|
|
# plot.setXRange(np.min(t1_distances), np.max(t1_distances))
|
|
self._plot.setLabel('left', "prob. density")
|
|
self._plot.setLabel('bottom', "distance", units="px/frame")
|
|
# plot.addItem(self._threshold)
|
|
vLine = pg.InfiniteLine(pos=10, angle=90, movable=False)
|
|
self._plot.addItem(vLine, ignoreBounds=True)
|
|
|
|
|
|
class ConsistencyClassifier(QWidget):
|
|
apply = Signal()
|
|
name = "Consistency tracker"
|
|
|
|
def __init__(self, parent=None):
|
|
super().__init__(parent)
|
|
self._data = None
|
|
self._all_pos = None
|
|
self._all_orientations = None
|
|
self._all_lengths = None
|
|
self._all_bendedness = None
|
|
self._all_scores = None
|
|
self._userlabeled = None
|
|
self._maxframes = 0
|
|
self._frames = None
|
|
self._tracks = None
|
|
self._worker = None
|
|
self._dataworker = None
|
|
self._processed_frames = 0
|
|
|
|
self._errorlabel = QLabel()
|
|
self._errorlabel.setStyleSheet("QLabel { color : red; }")
|
|
self._assignedlabel = QLabel()
|
|
self._maxframeslabel = QLabel()
|
|
self._startframe_spinner = QSpinBox()
|
|
|
|
self._startbtn = QPushButton("start")
|
|
self._startbtn.clicked.connect(self.start)
|
|
self._startbtn.setEnabled(False)
|
|
|
|
self._stopbtn = QPushButton("stop")
|
|
self._stopbtn.clicked.connect(self.stop)
|
|
self._stopbtn.setEnabled(False)
|
|
|
|
self._proceedbtn = QPushButton("proceed")
|
|
self._proceedbtn.clicked.connect(self.proceed)
|
|
self._proceedbtn.setEnabled(False)
|
|
|
|
self._refreshbtn = QPushButton("refresh")
|
|
self._refreshbtn.clicked.connect(self.refresh)
|
|
self._refreshbtn.setEnabled(True)
|
|
|
|
self._apply_btn = QPushButton("apply")
|
|
self._apply_btn.clicked.connect(lambda: self.apply.emit())
|
|
self._apply_btn.setEnabled(False)
|
|
|
|
self._progressbar = QProgressBar()
|
|
self._progressbar.setMinimum(0)
|
|
self._progressbar.setMaximum(100)
|
|
|
|
self._stoponerror = QCheckBox("Stop processing whenever an error is encountered")
|
|
self._stoponerror.setToolTip("Stop process upon errors")
|
|
self._stoponerror.setCheckable(True)
|
|
self._stoponerror.setChecked(True)
|
|
self.threadpool = QThreadPool()
|
|
|
|
lyt = QGridLayout()
|
|
lyt.addWidget(QLabel("Start frame:"), 0, 0 )
|
|
lyt.addWidget(self._startframe_spinner, 0, 1, 1, 2)
|
|
lyt.addWidget(QLabel("of"), 1, 1, 1, 1)
|
|
lyt.addWidget(self._maxframeslabel, 1, 2, 1, 1)
|
|
lyt.addWidget(self._stoponerror, 2, 0, 1, 3)
|
|
lyt.addWidget(QLabel("assigned"), 3, 0)
|
|
lyt.addWidget(self._assignedlabel, 3, 1)
|
|
lyt.addWidget(QLabel("errors/issues"), 4, 0)
|
|
lyt.addWidget(self._errorlabel, 4, 1)
|
|
|
|
lyt.addWidget(self._startbtn, 5, 0)
|
|
lyt.addWidget(self._stopbtn, 5, 1)
|
|
lyt.addWidget(self._proceedbtn, 5, 2)
|
|
lyt.addWidget(self._apply_btn, 6, 0, 1, 2)
|
|
lyt.addWidget(self._refreshbtn, 6, 2, 1, 1)
|
|
lyt.addWidget(self._progressbar, 7, 0, 1, 3)
|
|
self.setLayout(lyt)
|
|
|
|
def setData(self, data:TrackingData):
|
|
"""Set the data, the classifier/should be working on.
|
|
|
|
Parameters
|
|
----------
|
|
data : Trackingdata
|
|
The tracking data.
|
|
"""
|
|
self.setEnabled(False)
|
|
self._progressbar.setRange(0,0)
|
|
self._data = data
|
|
self._dataworker = ConsitencyDataLoader(self._data)
|
|
self._dataworker.signals.stopped.connect(self.data_processed)
|
|
self.threadpool.start(self._dataworker)
|
|
|
|
@Slot()
|
|
def data_processed(self):
|
|
if self._dataworker is not None:
|
|
self._progressbar.setRange(0,100)
|
|
self._progressbar.setValue(0)
|
|
self._all_pos = self._dataworker.positions
|
|
self._all_orientations = self._dataworker.orientations
|
|
self._all_lengths = self._dataworker.lengths
|
|
self._all_bendedness = self._dataworker.bendedness
|
|
self._userlabeled = self._dataworker.userlabeled
|
|
self._all_scores = self._dataworker.scores
|
|
self._frames = self._dataworker.frames
|
|
self._tracks = self._dataworker.tracks
|
|
self._maxframes = np.max(self._frames)
|
|
min_frame = max([self._frames[self._tracks == 1][0], self._frames[self._tracks == 2][0]]) + 1
|
|
self._maxframeslabel.setText(str(self._maxframes))
|
|
self._startframe_spinner.setMinimum(min_frame)
|
|
self._startframe_spinner.setMaximum(self._frames[-1])
|
|
self._startframe_spinner.setValue(self._frames[0] + 1)
|
|
self._startbtn.setEnabled(True)
|
|
self._assignedlabel.setText("0")
|
|
self._errorlabel.setText("0")
|
|
self._dataworker = None
|
|
self.setEnabled(True)
|
|
|
|
@Slot(float)
|
|
def on_progress(self, value):
|
|
if self._progressbar is not None:
|
|
self._progressDialog.setValue(int(value * 100))
|
|
|
|
def stop(self):
|
|
if self._worker is not None:
|
|
self._worker.stop()
|
|
self._startbtn.setEnabled(True)
|
|
self._proceedbtn.setEnabled(True)
|
|
self._stopbtn.setEnabled(False)
|
|
self._refreshbtn.setEnabled(True)
|
|
|
|
def start(self):
|
|
self._startbtn.setEnabled(False)
|
|
self._refreshbtn.setEnabled(False)
|
|
self._stopbtn.setEnabled(True)
|
|
self._worker = ConsistencyWorker(self._all_pos, self._all_orientations, self._all_lengths,
|
|
self._all_bendedness, self._frames, self._tracks, self._userlabeled,
|
|
self._startframe_spinner.value(), self._stoponerror.isChecked())
|
|
self._worker.signals.stopped.connect(self.worker_stopped)
|
|
self._worker.signals.progress.connect(self.worker_progress)
|
|
self.threadpool.start(self._worker)
|
|
|
|
def proceed(self):
|
|
self.start()
|
|
|
|
def refresh(self):
|
|
self.setData(self._data)
|
|
|
|
def worker_progress(self, progress, processed, errors):
|
|
self._progressbar.setValue(progress)
|
|
self._errorlabel.setText(str(errors))
|
|
self._assignedlabel.setText(str(processed))
|
|
|
|
def worker_stopped(self, frame):
|
|
self._apply_btn.setEnabled(True)
|
|
self._startbtn.setEnabled(True)
|
|
self._stopbtn.setEnabled(False)
|
|
self._startframe_spinner.setValue(frame-1)
|
|
self._proceedbtn.setEnabled(bool(frame < self._maxframes-1))
|
|
self._refreshbtn.setEnabled(True)
|
|
self._processed_frames = frame
|
|
|
|
def assignedTracks(self):
|
|
return self._tracks
|
|
|
|
class ClassifierWidget(QTabWidget):
|
|
apply_classifier = Signal(np.ndarray)
|
|
|
|
def __init__(self, parent=None):
|
|
super().__init__(parent)
|
|
self._data = None
|
|
self._size_classifier = SizeClassifier()
|
|
# self._neigborhood_validator = NeighborhoodValidator()
|
|
self._consistency_tracker = ConsistencyClassifier()
|
|
self.addTab(self._size_classifier, SizeClassifier.name)
|
|
self.addTab(self._consistency_tracker, ConsistencyClassifier.name)
|
|
self.tabBarClicked.connect(self.update)
|
|
self._size_classifier.apply.connect(self._on_applySizeClassifier)
|
|
self._consistency_tracker.apply.connect(self._on_applyConsistencyTracker)
|
|
|
|
def _on_applySizeClassifier(self):
|
|
tracks = self.size_classifier.assignedTracks()
|
|
self.apply_classifier.emit(tracks)
|
|
|
|
def _on_applyConsistencyTracker(self):
|
|
tracks = self._consistency_tracker.assignedTracks()
|
|
self.apply_classifier.emit(tracks)
|
|
|
|
@property
|
|
def size_classifier(self):
|
|
return self._size_classifier
|
|
|
|
@property
|
|
def consistency_tracker(self):
|
|
return self._consistency_tracker
|
|
|
|
@Slot()
|
|
def update(self):
|
|
self.consistency_tracker.setData(self._data)
|
|
|
|
def setData(self, data:TrackingData):
|
|
self._data = data
|
|
|
|
def as_dict(df):
|
|
d = {c: df[c].values for c in df.columns}
|
|
d["index"] = df.index.values
|
|
return d
|
|
|
|
|
|
def main():
|
|
test_size = False
|
|
import pickle
|
|
from fixtracks.info import PACKAGE_ROOT
|
|
|
|
datafile = PACKAGE_ROOT / "data/merged.pkl"
|
|
|
|
with open(datafile, "rb") as f:
|
|
df = pickle.load(f)
|
|
data = TrackingData()
|
|
data.setData(as_dict(df))
|
|
|
|
app = QApplication([])
|
|
window = QWidget()
|
|
window.setMinimumSize(200, 200)
|
|
# if test_size:
|
|
# win = SizeClassifier()
|
|
# win.setCoordinates(coords)
|
|
# else:
|
|
w = ClassifierWidget()
|
|
w.setData(data)
|
|
|
|
layout = QVBoxLayout()
|
|
layout.addWidget(w)
|
|
window.setLayout(layout)
|
|
window.show()
|
|
app.exec()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from PySide6.QtWidgets import QApplication
|
|
main()
|