[classifier] cleanup, fixes, working version
This commit is contained in:
parent
801acc9547
commit
796e03ae7e
@ -1,23 +1,30 @@
|
|||||||
import logging
|
import logging
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from PySide6.QtWidgets import QWidget, QVBoxLayout, QSizePolicy, QGraphicsView, QSlider, QPushButton, QLabel
|
from PySide6.QtWidgets import QWidget, QVBoxLayout, QTabWidget,QPushButton
|
||||||
from PySide6.QtWidgets import QGraphicsScene, QGraphicsEllipseItem, QGraphicsRectItem, QGraphicsLineItem
|
from PySide6.QtCore import Signal
|
||||||
from PySide6.QtCore import Qt
|
from PySide6.QtGui import QBrush, QColor
|
||||||
from PySide6.QtGui import QBrush, QColor, QPen, QPainter, QFont
|
|
||||||
|
|
||||||
import pyqtgraph as pg
|
import pyqtgraph as pg
|
||||||
|
|
||||||
from IPython import embed
|
|
||||||
|
|
||||||
class SizeClassifier(QWidget):
|
class SizeClassifier(QWidget):
|
||||||
|
apply = Signal()
|
||||||
|
|
||||||
def __init__(self, parent=None):
|
def __init__(self, parent=None):
|
||||||
super().__init__(parent)
|
super().__init__(parent)
|
||||||
self._t1_selection = None
|
self._t1_selection = None
|
||||||
self._t2_selection = None
|
self._t2_selection = None
|
||||||
layout = QVBoxLayout()
|
self._coordinates = None
|
||||||
|
self._sizes = None
|
||||||
|
|
||||||
self._plot_widget = self.setupGraph()
|
self._plot_widget = self.setupGraph()
|
||||||
|
self._apply_btn = QPushButton("apply")
|
||||||
|
self._apply_btn.clicked.connect(lambda: self.apply.emit())
|
||||||
|
|
||||||
|
layout = QVBoxLayout()
|
||||||
layout.addWidget(self._plot_widget)
|
layout.addWidget(self._plot_widget)
|
||||||
|
layout.addWidget(self._apply_btn)
|
||||||
self.setLayout(layout)
|
self.setLayout(layout)
|
||||||
|
|
||||||
def setupGraph(self):
|
def setupGraph(self):
|
||||||
@ -51,8 +58,9 @@ class SizeClassifier(QWidget):
|
|||||||
return hist, edges
|
return hist, edges
|
||||||
|
|
||||||
def setCoordinates(self, coordinates):
|
def setCoordinates(self, coordinates):
|
||||||
dists = self.estimate_length(coordinates)
|
self._coordinates = coordinates
|
||||||
n, e = self.estimate_histogram(dists)
|
self._sizes = self.estimate_length(coordinates)
|
||||||
|
n, e = self.estimate_histogram(self._sizes)
|
||||||
plot = self._plot_widget.addPlot()
|
plot = self._plot_widget.addPlot()
|
||||||
bgi = pg.BarGraphItem(x0=e[:-1], x1=e[1:], height=n, pen='w', brush=(0,0,255,150))
|
bgi = pg.BarGraphItem(x0=e[:-1], x1=e[1:], height=n, pen='w', brush=(0,0,255,150))
|
||||||
plot.addItem(bgi)
|
plot.addItem(bgi)
|
||||||
@ -67,6 +75,32 @@ class SizeClassifier(QWidget):
|
|||||||
else:
|
else:
|
||||||
return self._t2_selection.getRegion()
|
return self._t2_selection.getRegion()
|
||||||
|
|
||||||
|
def assignedTracks(self):
|
||||||
|
tracks = np.ones_like(self._sizes, dtype=int) * -1
|
||||||
|
t1lower, t1upper = self.selections()
|
||||||
|
t2lower, t2upper = self.selections(False)
|
||||||
|
tracks[(self._sizes >= t1lower) & (self._sizes < t1upper)] = 1
|
||||||
|
tracks[(self._sizes >= t2lower) & (self._sizes < t2upper)] = 2
|
||||||
|
return tracks
|
||||||
|
|
||||||
|
|
||||||
|
class ClassifierWidget(QTabWidget):
|
||||||
|
apply_sizeclassifier = Signal(np.ndarray)
|
||||||
|
|
||||||
|
def __init__(self, parent=None):
|
||||||
|
super().__init__(parent)
|
||||||
|
self._size_classifier = SizeClassifier()
|
||||||
|
self.addTab(self._size_classifier, "Size classifier")
|
||||||
|
self._size_classifier.apply.connect(self._on_applySizeClassifier)
|
||||||
|
|
||||||
|
def _on_applySizeClassifier(self):
|
||||||
|
tracks = self.size_classifier.assignedTracks()
|
||||||
|
self.apply_sizeclassifier.emit(tracks)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def size_classifier(self):
|
||||||
|
return self._size_classifier
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
import pickle
|
import pickle
|
||||||
@ -78,7 +112,6 @@ def main():
|
|||||||
with open(datafile, "rb") as f:
|
with open(datafile, "rb") as f:
|
||||||
df = pickle.load(f)
|
df = pickle.load(f)
|
||||||
|
|
||||||
|
|
||||||
coords = np.stack(df.keypoints.values,).astype(np.float32)[:,:,:]
|
coords = np.stack(df.keypoints.values,).astype(np.float32)[:,:,:]
|
||||||
|
|
||||||
app = QApplication([])
|
app = QApplication([])
|
||||||
|
Loading…
Reference in New Issue
Block a user