From 801acc9547f9eae253346c301c0240fa4c4102f2 Mon Sep 17 00:00:00 2001 From: Jan Grewe Date: Fri, 7 Feb 2025 15:56:20 +0100 Subject: [PATCH] [classifier] new classfier widgets for automatic classification by size --- fixtracks/widgets/classifier.py | 101 ++++++++++++++++++++++++++++++++ 1 file changed, 101 insertions(+) create mode 100644 fixtracks/widgets/classifier.py diff --git a/fixtracks/widgets/classifier.py b/fixtracks/widgets/classifier.py new file mode 100644 index 0000000..998dfcb --- /dev/null +++ b/fixtracks/widgets/classifier.py @@ -0,0 +1,101 @@ +import logging +import numpy as np + +from PySide6.QtWidgets import QWidget, QVBoxLayout, QSizePolicy, QGraphicsView, QSlider, QPushButton, QLabel +from PySide6.QtWidgets import QGraphicsScene, QGraphicsEllipseItem, QGraphicsRectItem, QGraphicsLineItem +from PySide6.QtCore import Qt +from PySide6.QtGui import QBrush, QColor, QPen, QPainter, QFont + +import pyqtgraph as pg + +from IPython import embed + +class SizeClassifier(QWidget): + def __init__(self, parent=None): + super().__init__(parent) + self._t1_selection = None + self._t2_selection = None + layout = QVBoxLayout() + self._plot_widget = self.setupGraph() + layout.addWidget(self._plot_widget) + self.setLayout(layout) + + def setupGraph(self): + track1_brush = QBrush(QColor.fromString("orange")) + track1_brush.color().setAlphaF(0.5) + track2_brush = QBrush(QColor.fromString("green")) + + pg.setConfigOptions(antialias=True) + plot_widget = pg.GraphicsLayoutWidget(show=False) + + self._t1_selection = pg.LinearRegionItem([100, 200]) + self._t1_selection.setZValue(-10) # what is that? + self._t1_selection.setBrush(track1_brush) + self._t2_selection = pg.LinearRegionItem([300,400]) + self._t2_selection.setZValue(-10) # what is that? + self._t2_selection.setBrush(track2_brush) + return plot_widget + + def estimate_length(self, coords, bodyaxis =None): + if bodyaxis is None: + bodyaxis = [0, 1, 2, 5] + bodycoords = coords[:, bodyaxis, :] + dists = np.sum(np.sqrt(np.sum(np.diff(bodycoords, axis=1)**2, axis=2)), axis=1) + return dists + + def estimate_histogram(self, dists, min_threshold=1., max_threshold=99.): + min_length = np.percentile(dists, min_threshold) + max_length = np.percentile(dists, max_threshold) + bins = np.linspace(0.5 * min_length, 1.5 * max_length, 100) + hist, edges = np.histogram(dists, bins=bins, density=True) + return hist, edges + + def setCoordinates(self, coordinates): + dists = self.estimate_length(coordinates) + n, e = self.estimate_histogram(dists) + plot = self._plot_widget.addPlot() + bgi = pg.BarGraphItem(x0=e[:-1], x1=e[1:], height=n, pen='w', brush=(0,0,255,150)) + plot.addItem(bgi) + plot.setLabel('left', "prob. density") + plot.setLabel('bottom', "bodylength", units="px") + plot.addItem(self._t1_selection) + plot.addItem(self._t2_selection) + + def selections(self, track1=True): + if track1: + return self._t1_selection.getRegion() + else: + return self._t2_selection.getRegion() + + +def main(): + import pickle + from fixtracks.info import PACKAGE_ROOT + from PySide6.QtWidgets import QApplication + + datafile = PACKAGE_ROOT / "data/merged_small.pkl" + print(datafile) + with open(datafile, "rb") as f: + df = pickle.load(f) + + + coords = np.stack(df.keypoints.values,).astype(np.float32)[:,:,:] + + app = QApplication([]) + window = QWidget() + window.setMinimumSize(200, 200) + layout = QVBoxLayout() + win = SizeClassifier() + win.setCoordinates(coords) + + btn = QPushButton("get bounds") + btn.clicked.connect(lambda: win.selections()) + + layout.addWidget(win) + layout.addWidget(btn) + window.setLayout(layout) + window.show() + app.exec() + +if __name__ == "__main__": + main() \ No newline at end of file