diff --git a/fixtracks/widgets/classifier.py b/fixtracks/widgets/classifier.py index 998dfcb..506cd42 100644 --- a/fixtracks/widgets/classifier.py +++ b/fixtracks/widgets/classifier.py @@ -1,23 +1,30 @@ import logging import numpy as np -from PySide6.QtWidgets import QWidget, QVBoxLayout, QSizePolicy, QGraphicsView, QSlider, QPushButton, QLabel -from PySide6.QtWidgets import QGraphicsScene, QGraphicsEllipseItem, QGraphicsRectItem, QGraphicsLineItem -from PySide6.QtCore import Qt -from PySide6.QtGui import QBrush, QColor, QPen, QPainter, QFont +from PySide6.QtWidgets import QWidget, QVBoxLayout, QTabWidget,QPushButton +from PySide6.QtCore import Signal +from PySide6.QtGui import QBrush, QColor import pyqtgraph as pg -from IPython import embed class SizeClassifier(QWidget): + apply = Signal() + def __init__(self, parent=None): super().__init__(parent) self._t1_selection = None self._t2_selection = None - layout = QVBoxLayout() + self._coordinates = None + self._sizes = None + self._plot_widget = self.setupGraph() + self._apply_btn = QPushButton("apply") + self._apply_btn.clicked.connect(lambda: self.apply.emit()) + + layout = QVBoxLayout() layout.addWidget(self._plot_widget) + layout.addWidget(self._apply_btn) self.setLayout(layout) def setupGraph(self): @@ -51,8 +58,9 @@ class SizeClassifier(QWidget): return hist, edges def setCoordinates(self, coordinates): - dists = self.estimate_length(coordinates) - n, e = self.estimate_histogram(dists) + self._coordinates = coordinates + self._sizes = self.estimate_length(coordinates) + n, e = self.estimate_histogram(self._sizes) plot = self._plot_widget.addPlot() bgi = pg.BarGraphItem(x0=e[:-1], x1=e[1:], height=n, pen='w', brush=(0,0,255,150)) plot.addItem(bgi) @@ -67,6 +75,32 @@ class SizeClassifier(QWidget): else: return self._t2_selection.getRegion() + def assignedTracks(self): + tracks = np.ones_like(self._sizes, dtype=int) * -1 + t1lower, t1upper = self.selections() + t2lower, t2upper = self.selections(False) + tracks[(self._sizes >= t1lower) & (self._sizes < t1upper)] = 1 + tracks[(self._sizes >= t2lower) & (self._sizes < t2upper)] = 2 + return tracks + + +class ClassifierWidget(QTabWidget): + apply_sizeclassifier = Signal(np.ndarray) + + def __init__(self, parent=None): + super().__init__(parent) + self._size_classifier = SizeClassifier() + self.addTab(self._size_classifier, "Size classifier") + self._size_classifier.apply.connect(self._on_applySizeClassifier) + + def _on_applySizeClassifier(self): + tracks = self.size_classifier.assignedTracks() + self.apply_sizeclassifier.emit(tracks) + + @property + def size_classifier(self): + return self._size_classifier + def main(): import pickle @@ -78,7 +112,6 @@ def main(): with open(datafile, "rb") as f: df = pickle.load(f) - coords = np.stack(df.keypoints.values,).astype(np.float32)[:,:,:] app = QApplication([]) @@ -98,4 +131,4 @@ def main(): app.exec() if __name__ == "__main__": - main() \ No newline at end of file + main()