[classifier] cleanup, fixes, working version
This commit is contained in:
parent
801acc9547
commit
796e03ae7e
@ -1,23 +1,30 @@
|
||||
import logging
|
||||
import numpy as np
|
||||
|
||||
from PySide6.QtWidgets import QWidget, QVBoxLayout, QSizePolicy, QGraphicsView, QSlider, QPushButton, QLabel
|
||||
from PySide6.QtWidgets import QGraphicsScene, QGraphicsEllipseItem, QGraphicsRectItem, QGraphicsLineItem
|
||||
from PySide6.QtCore import Qt
|
||||
from PySide6.QtGui import QBrush, QColor, QPen, QPainter, QFont
|
||||
from PySide6.QtWidgets import QWidget, QVBoxLayout, QTabWidget,QPushButton
|
||||
from PySide6.QtCore import Signal
|
||||
from PySide6.QtGui import QBrush, QColor
|
||||
|
||||
import pyqtgraph as pg
|
||||
|
||||
from IPython import embed
|
||||
|
||||
class SizeClassifier(QWidget):
|
||||
apply = Signal()
|
||||
|
||||
def __init__(self, parent=None):
|
||||
super().__init__(parent)
|
||||
self._t1_selection = None
|
||||
self._t2_selection = None
|
||||
layout = QVBoxLayout()
|
||||
self._coordinates = None
|
||||
self._sizes = None
|
||||
|
||||
self._plot_widget = self.setupGraph()
|
||||
self._apply_btn = QPushButton("apply")
|
||||
self._apply_btn.clicked.connect(lambda: self.apply.emit())
|
||||
|
||||
layout = QVBoxLayout()
|
||||
layout.addWidget(self._plot_widget)
|
||||
layout.addWidget(self._apply_btn)
|
||||
self.setLayout(layout)
|
||||
|
||||
def setupGraph(self):
|
||||
@ -51,8 +58,9 @@ class SizeClassifier(QWidget):
|
||||
return hist, edges
|
||||
|
||||
def setCoordinates(self, coordinates):
|
||||
dists = self.estimate_length(coordinates)
|
||||
n, e = self.estimate_histogram(dists)
|
||||
self._coordinates = coordinates
|
||||
self._sizes = self.estimate_length(coordinates)
|
||||
n, e = self.estimate_histogram(self._sizes)
|
||||
plot = self._plot_widget.addPlot()
|
||||
bgi = pg.BarGraphItem(x0=e[:-1], x1=e[1:], height=n, pen='w', brush=(0,0,255,150))
|
||||
plot.addItem(bgi)
|
||||
@ -67,6 +75,32 @@ class SizeClassifier(QWidget):
|
||||
else:
|
||||
return self._t2_selection.getRegion()
|
||||
|
||||
def assignedTracks(self):
|
||||
tracks = np.ones_like(self._sizes, dtype=int) * -1
|
||||
t1lower, t1upper = self.selections()
|
||||
t2lower, t2upper = self.selections(False)
|
||||
tracks[(self._sizes >= t1lower) & (self._sizes < t1upper)] = 1
|
||||
tracks[(self._sizes >= t2lower) & (self._sizes < t2upper)] = 2
|
||||
return tracks
|
||||
|
||||
|
||||
class ClassifierWidget(QTabWidget):
|
||||
apply_sizeclassifier = Signal(np.ndarray)
|
||||
|
||||
def __init__(self, parent=None):
|
||||
super().__init__(parent)
|
||||
self._size_classifier = SizeClassifier()
|
||||
self.addTab(self._size_classifier, "Size classifier")
|
||||
self._size_classifier.apply.connect(self._on_applySizeClassifier)
|
||||
|
||||
def _on_applySizeClassifier(self):
|
||||
tracks = self.size_classifier.assignedTracks()
|
||||
self.apply_sizeclassifier.emit(tracks)
|
||||
|
||||
@property
|
||||
def size_classifier(self):
|
||||
return self._size_classifier
|
||||
|
||||
|
||||
def main():
|
||||
import pickle
|
||||
@ -78,7 +112,6 @@ def main():
|
||||
with open(datafile, "rb") as f:
|
||||
df = pickle.load(f)
|
||||
|
||||
|
||||
coords = np.stack(df.keypoints.values,).astype(np.float32)[:,:,:]
|
||||
|
||||
app = QApplication([])
|
||||
@ -98,4 +131,4 @@ def main():
|
||||
app.exec()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
|
Loading…
Reference in New Issue
Block a user