[classifier] cleanup, fixes, working version

This commit is contained in:
Jan Grewe 2025-02-07 15:57:31 +01:00
parent 801acc9547
commit 796e03ae7e

View File

@ -1,23 +1,30 @@
import logging
import numpy as np
from PySide6.QtWidgets import QWidget, QVBoxLayout, QSizePolicy, QGraphicsView, QSlider, QPushButton, QLabel
from PySide6.QtWidgets import QGraphicsScene, QGraphicsEllipseItem, QGraphicsRectItem, QGraphicsLineItem
from PySide6.QtCore import Qt
from PySide6.QtGui import QBrush, QColor, QPen, QPainter, QFont
from PySide6.QtWidgets import QWidget, QVBoxLayout, QTabWidget,QPushButton
from PySide6.QtCore import Signal
from PySide6.QtGui import QBrush, QColor
import pyqtgraph as pg
from IPython import embed
class SizeClassifier(QWidget):
apply = Signal()
def __init__(self, parent=None):
super().__init__(parent)
self._t1_selection = None
self._t2_selection = None
layout = QVBoxLayout()
self._coordinates = None
self._sizes = None
self._plot_widget = self.setupGraph()
self._apply_btn = QPushButton("apply")
self._apply_btn.clicked.connect(lambda: self.apply.emit())
layout = QVBoxLayout()
layout.addWidget(self._plot_widget)
layout.addWidget(self._apply_btn)
self.setLayout(layout)
def setupGraph(self):
@ -51,8 +58,9 @@ class SizeClassifier(QWidget):
return hist, edges
def setCoordinates(self, coordinates):
dists = self.estimate_length(coordinates)
n, e = self.estimate_histogram(dists)
self._coordinates = coordinates
self._sizes = self.estimate_length(coordinates)
n, e = self.estimate_histogram(self._sizes)
plot = self._plot_widget.addPlot()
bgi = pg.BarGraphItem(x0=e[:-1], x1=e[1:], height=n, pen='w', brush=(0,0,255,150))
plot.addItem(bgi)
@ -67,6 +75,32 @@ class SizeClassifier(QWidget):
else:
return self._t2_selection.getRegion()
def assignedTracks(self):
tracks = np.ones_like(self._sizes, dtype=int) * -1
t1lower, t1upper = self.selections()
t2lower, t2upper = self.selections(False)
tracks[(self._sizes >= t1lower) & (self._sizes < t1upper)] = 1
tracks[(self._sizes >= t2lower) & (self._sizes < t2upper)] = 2
return tracks
class ClassifierWidget(QTabWidget):
apply_sizeclassifier = Signal(np.ndarray)
def __init__(self, parent=None):
super().__init__(parent)
self._size_classifier = SizeClassifier()
self.addTab(self._size_classifier, "Size classifier")
self._size_classifier.apply.connect(self._on_applySizeClassifier)
def _on_applySizeClassifier(self):
tracks = self.size_classifier.assignedTracks()
self.apply_sizeclassifier.emit(tracks)
@property
def size_classifier(self):
return self._size_classifier
def main():
import pickle
@ -78,7 +112,6 @@ def main():
with open(datafile, "rb") as f:
df = pickle.load(f)
coords = np.stack(df.keypoints.values,).astype(np.float32)[:,:,:]
app = QApplication([])
@ -98,4 +131,4 @@ def main():
app.exec()
if __name__ == "__main__":
main()
main()