[classifier] cleanup, fixes, working version

This commit is contained in:
Jan Grewe 2025-02-07 15:57:31 +01:00
parent 801acc9547
commit 796e03ae7e

View File

@ -1,23 +1,30 @@
import logging import logging
import numpy as np import numpy as np
from PySide6.QtWidgets import QWidget, QVBoxLayout, QSizePolicy, QGraphicsView, QSlider, QPushButton, QLabel from PySide6.QtWidgets import QWidget, QVBoxLayout, QTabWidget,QPushButton
from PySide6.QtWidgets import QGraphicsScene, QGraphicsEllipseItem, QGraphicsRectItem, QGraphicsLineItem from PySide6.QtCore import Signal
from PySide6.QtCore import Qt from PySide6.QtGui import QBrush, QColor
from PySide6.QtGui import QBrush, QColor, QPen, QPainter, QFont
import pyqtgraph as pg import pyqtgraph as pg
from IPython import embed
class SizeClassifier(QWidget): class SizeClassifier(QWidget):
apply = Signal()
def __init__(self, parent=None): def __init__(self, parent=None):
super().__init__(parent) super().__init__(parent)
self._t1_selection = None self._t1_selection = None
self._t2_selection = None self._t2_selection = None
layout = QVBoxLayout() self._coordinates = None
self._sizes = None
self._plot_widget = self.setupGraph() self._plot_widget = self.setupGraph()
self._apply_btn = QPushButton("apply")
self._apply_btn.clicked.connect(lambda: self.apply.emit())
layout = QVBoxLayout()
layout.addWidget(self._plot_widget) layout.addWidget(self._plot_widget)
layout.addWidget(self._apply_btn)
self.setLayout(layout) self.setLayout(layout)
def setupGraph(self): def setupGraph(self):
@ -51,8 +58,9 @@ class SizeClassifier(QWidget):
return hist, edges return hist, edges
def setCoordinates(self, coordinates): def setCoordinates(self, coordinates):
dists = self.estimate_length(coordinates) self._coordinates = coordinates
n, e = self.estimate_histogram(dists) self._sizes = self.estimate_length(coordinates)
n, e = self.estimate_histogram(self._sizes)
plot = self._plot_widget.addPlot() plot = self._plot_widget.addPlot()
bgi = pg.BarGraphItem(x0=e[:-1], x1=e[1:], height=n, pen='w', brush=(0,0,255,150)) bgi = pg.BarGraphItem(x0=e[:-1], x1=e[1:], height=n, pen='w', brush=(0,0,255,150))
plot.addItem(bgi) plot.addItem(bgi)
@ -67,6 +75,32 @@ class SizeClassifier(QWidget):
else: else:
return self._t2_selection.getRegion() return self._t2_selection.getRegion()
def assignedTracks(self):
tracks = np.ones_like(self._sizes, dtype=int) * -1
t1lower, t1upper = self.selections()
t2lower, t2upper = self.selections(False)
tracks[(self._sizes >= t1lower) & (self._sizes < t1upper)] = 1
tracks[(self._sizes >= t2lower) & (self._sizes < t2upper)] = 2
return tracks
class ClassifierWidget(QTabWidget):
apply_sizeclassifier = Signal(np.ndarray)
def __init__(self, parent=None):
super().__init__(parent)
self._size_classifier = SizeClassifier()
self.addTab(self._size_classifier, "Size classifier")
self._size_classifier.apply.connect(self._on_applySizeClassifier)
def _on_applySizeClassifier(self):
tracks = self.size_classifier.assignedTracks()
self.apply_sizeclassifier.emit(tracks)
@property
def size_classifier(self):
return self._size_classifier
def main(): def main():
import pickle import pickle
@ -78,7 +112,6 @@ def main():
with open(datafile, "rb") as f: with open(datafile, "rb") as f:
df = pickle.load(f) df = pickle.load(f)
coords = np.stack(df.keypoints.values,).astype(np.float32)[:,:,:] coords = np.stack(df.keypoints.values,).astype(np.float32)[:,:,:]
app = QApplication([]) app = QApplication([])