fixtracks/fixtracks/widgets/classifier.py

135 lines
4.3 KiB
Python

import logging
import numpy as np
from PySide6.QtWidgets import QWidget, QVBoxLayout, QTabWidget,QPushButton
from PySide6.QtCore import Signal
from PySide6.QtGui import QBrush, QColor
import pyqtgraph as pg
class SizeClassifier(QWidget):
apply = Signal()
def __init__(self, parent=None):
super().__init__(parent)
self._t1_selection = None
self._t2_selection = None
self._coordinates = None
self._sizes = None
self._plot_widget = self.setupGraph()
self._apply_btn = QPushButton("apply")
self._apply_btn.clicked.connect(lambda: self.apply.emit())
layout = QVBoxLayout()
layout.addWidget(self._plot_widget)
layout.addWidget(self._apply_btn)
self.setLayout(layout)
def setupGraph(self):
track1_brush = QBrush(QColor.fromString("orange"))
track1_brush.color().setAlphaF(0.5)
track2_brush = QBrush(QColor.fromString("green"))
pg.setConfigOptions(antialias=True)
plot_widget = pg.GraphicsLayoutWidget(show=False)
self._t1_selection = pg.LinearRegionItem([100, 200])
self._t1_selection.setZValue(-10) # what is that?
self._t1_selection.setBrush(track1_brush)
self._t2_selection = pg.LinearRegionItem([300,400])
self._t2_selection.setZValue(-10) # what is that?
self._t2_selection.setBrush(track2_brush)
return plot_widget
def estimate_length(self, coords, bodyaxis =None):
if bodyaxis is None:
bodyaxis = [0, 1, 2, 5]
bodycoords = coords[:, bodyaxis, :]
dists = np.sum(np.sqrt(np.sum(np.diff(bodycoords, axis=1)**2, axis=2)), axis=1)
return dists
def estimate_histogram(self, dists, min_threshold=1., max_threshold=99.):
min_length = np.percentile(dists, min_threshold)
max_length = np.percentile(dists, max_threshold)
bins = np.linspace(0.5 * min_length, 1.5 * max_length, 100)
hist, edges = np.histogram(dists, bins=bins, density=True)
return hist, edges
def setCoordinates(self, coordinates):
self._coordinates = coordinates
self._sizes = self.estimate_length(coordinates)
n, e = self.estimate_histogram(self._sizes)
plot = self._plot_widget.addPlot()
bgi = pg.BarGraphItem(x0=e[:-1], x1=e[1:], height=n, pen='w', brush=(0,0,255,150))
plot.addItem(bgi)
plot.setLabel('left', "prob. density")
plot.setLabel('bottom', "bodylength", units="px")
plot.addItem(self._t1_selection)
plot.addItem(self._t2_selection)
def selections(self, track1=True):
if track1:
return self._t1_selection.getRegion()
else:
return self._t2_selection.getRegion()
def assignedTracks(self):
tracks = np.ones_like(self._sizes, dtype=int) * -1
t1lower, t1upper = self.selections()
t2lower, t2upper = self.selections(False)
tracks[(self._sizes >= t1lower) & (self._sizes < t1upper)] = 1
tracks[(self._sizes >= t2lower) & (self._sizes < t2upper)] = 2
return tracks
class ClassifierWidget(QTabWidget):
apply_sizeclassifier = Signal(np.ndarray)
def __init__(self, parent=None):
super().__init__(parent)
self._size_classifier = SizeClassifier()
self.addTab(self._size_classifier, "Size classifier")
self._size_classifier.apply.connect(self._on_applySizeClassifier)
def _on_applySizeClassifier(self):
tracks = self.size_classifier.assignedTracks()
self.apply_sizeclassifier.emit(tracks)
@property
def size_classifier(self):
return self._size_classifier
def main():
import pickle
from fixtracks.info import PACKAGE_ROOT
from PySide6.QtWidgets import QApplication
datafile = PACKAGE_ROOT / "data/merged_small.pkl"
print(datafile)
with open(datafile, "rb") as f:
df = pickle.load(f)
coords = np.stack(df.keypoints.values,).astype(np.float32)[:,:,:]
app = QApplication([])
window = QWidget()
window.setMinimumSize(200, 200)
layout = QVBoxLayout()
win = SizeClassifier()
win.setCoordinates(coords)
btn = QPushButton("get bounds")
btn.clicked.connect(lambda: win.selections())
layout.addWidget(win)
layout.addWidget(btn)
window.setLayout(layout)
window.show()
app.exec()
if __name__ == "__main__":
main()