135 lines
4.3 KiB
Python
135 lines
4.3 KiB
Python
import logging
|
|
import numpy as np
|
|
|
|
from PySide6.QtWidgets import QWidget, QVBoxLayout, QTabWidget,QPushButton
|
|
from PySide6.QtCore import Signal
|
|
from PySide6.QtGui import QBrush, QColor
|
|
|
|
import pyqtgraph as pg
|
|
|
|
|
|
class SizeClassifier(QWidget):
|
|
apply = Signal()
|
|
|
|
def __init__(self, parent=None):
|
|
super().__init__(parent)
|
|
self._t1_selection = None
|
|
self._t2_selection = None
|
|
self._coordinates = None
|
|
self._sizes = None
|
|
|
|
self._plot_widget = self.setupGraph()
|
|
self._apply_btn = QPushButton("apply")
|
|
self._apply_btn.clicked.connect(lambda: self.apply.emit())
|
|
|
|
layout = QVBoxLayout()
|
|
layout.addWidget(self._plot_widget)
|
|
layout.addWidget(self._apply_btn)
|
|
self.setLayout(layout)
|
|
|
|
def setupGraph(self):
|
|
track1_brush = QBrush(QColor.fromString("orange"))
|
|
track1_brush.color().setAlphaF(0.5)
|
|
track2_brush = QBrush(QColor.fromString("green"))
|
|
|
|
pg.setConfigOptions(antialias=True)
|
|
plot_widget = pg.GraphicsLayoutWidget(show=False)
|
|
|
|
self._t1_selection = pg.LinearRegionItem([100, 200])
|
|
self._t1_selection.setZValue(-10) # what is that?
|
|
self._t1_selection.setBrush(track1_brush)
|
|
self._t2_selection = pg.LinearRegionItem([300,400])
|
|
self._t2_selection.setZValue(-10) # what is that?
|
|
self._t2_selection.setBrush(track2_brush)
|
|
return plot_widget
|
|
|
|
def estimate_length(self, coords, bodyaxis =None):
|
|
if bodyaxis is None:
|
|
bodyaxis = [0, 1, 2, 5]
|
|
bodycoords = coords[:, bodyaxis, :]
|
|
dists = np.sum(np.sqrt(np.sum(np.diff(bodycoords, axis=1)**2, axis=2)), axis=1)
|
|
return dists
|
|
|
|
def estimate_histogram(self, dists, min_threshold=1., max_threshold=99.):
|
|
min_length = np.percentile(dists, min_threshold)
|
|
max_length = np.percentile(dists, max_threshold)
|
|
bins = np.linspace(0.5 * min_length, 1.5 * max_length, 100)
|
|
hist, edges = np.histogram(dists, bins=bins, density=True)
|
|
return hist, edges
|
|
|
|
def setCoordinates(self, coordinates):
|
|
self._coordinates = coordinates
|
|
self._sizes = self.estimate_length(coordinates)
|
|
n, e = self.estimate_histogram(self._sizes)
|
|
plot = self._plot_widget.addPlot()
|
|
bgi = pg.BarGraphItem(x0=e[:-1], x1=e[1:], height=n, pen='w', brush=(0,0,255,150))
|
|
plot.addItem(bgi)
|
|
plot.setLabel('left', "prob. density")
|
|
plot.setLabel('bottom', "bodylength", units="px")
|
|
plot.addItem(self._t1_selection)
|
|
plot.addItem(self._t2_selection)
|
|
|
|
def selections(self, track1=True):
|
|
if track1:
|
|
return self._t1_selection.getRegion()
|
|
else:
|
|
return self._t2_selection.getRegion()
|
|
|
|
def assignedTracks(self):
|
|
tracks = np.ones_like(self._sizes, dtype=int) * -1
|
|
t1lower, t1upper = self.selections()
|
|
t2lower, t2upper = self.selections(False)
|
|
tracks[(self._sizes >= t1lower) & (self._sizes < t1upper)] = 1
|
|
tracks[(self._sizes >= t2lower) & (self._sizes < t2upper)] = 2
|
|
return tracks
|
|
|
|
|
|
class ClassifierWidget(QTabWidget):
|
|
apply_sizeclassifier = Signal(np.ndarray)
|
|
|
|
def __init__(self, parent=None):
|
|
super().__init__(parent)
|
|
self._size_classifier = SizeClassifier()
|
|
self.addTab(self._size_classifier, "Size classifier")
|
|
self._size_classifier.apply.connect(self._on_applySizeClassifier)
|
|
|
|
def _on_applySizeClassifier(self):
|
|
tracks = self.size_classifier.assignedTracks()
|
|
self.apply_sizeclassifier.emit(tracks)
|
|
|
|
@property
|
|
def size_classifier(self):
|
|
return self._size_classifier
|
|
|
|
|
|
def main():
|
|
import pickle
|
|
from fixtracks.info import PACKAGE_ROOT
|
|
from PySide6.QtWidgets import QApplication
|
|
|
|
datafile = PACKAGE_ROOT / "data/merged_small.pkl"
|
|
print(datafile)
|
|
with open(datafile, "rb") as f:
|
|
df = pickle.load(f)
|
|
|
|
coords = np.stack(df.keypoints.values,).astype(np.float32)[:,:,:]
|
|
|
|
app = QApplication([])
|
|
window = QWidget()
|
|
window.setMinimumSize(200, 200)
|
|
layout = QVBoxLayout()
|
|
win = SizeClassifier()
|
|
win.setCoordinates(coords)
|
|
|
|
btn = QPushButton("get bounds")
|
|
btn.clicked.connect(lambda: win.selections())
|
|
|
|
layout.addWidget(win)
|
|
layout.addWidget(btn)
|
|
window.setLayout(layout)
|
|
window.show()
|
|
app.exec()
|
|
|
|
if __name__ == "__main__":
|
|
main()
|