fixtracks/fixtracks/widgets/classifier.py

101 lines
3.3 KiB
Python

import logging
import numpy as np
from PySide6.QtWidgets import QWidget, QVBoxLayout, QSizePolicy, QGraphicsView, QSlider, QPushButton, QLabel
from PySide6.QtWidgets import QGraphicsScene, QGraphicsEllipseItem, QGraphicsRectItem, QGraphicsLineItem
from PySide6.QtCore import Qt
from PySide6.QtGui import QBrush, QColor, QPen, QPainter, QFont
import pyqtgraph as pg
from IPython import embed
class SizeClassifier(QWidget):
def __init__(self, parent=None):
super().__init__(parent)
self._t1_selection = None
self._t2_selection = None
layout = QVBoxLayout()
self._plot_widget = self.setupGraph()
layout.addWidget(self._plot_widget)
self.setLayout(layout)
def setupGraph(self):
track1_brush = QBrush(QColor.fromString("orange"))
track1_brush.color().setAlphaF(0.5)
track2_brush = QBrush(QColor.fromString("green"))
pg.setConfigOptions(antialias=True)
plot_widget = pg.GraphicsLayoutWidget(show=False)
self._t1_selection = pg.LinearRegionItem([100, 200])
self._t1_selection.setZValue(-10) # what is that?
self._t1_selection.setBrush(track1_brush)
self._t2_selection = pg.LinearRegionItem([300,400])
self._t2_selection.setZValue(-10) # what is that?
self._t2_selection.setBrush(track2_brush)
return plot_widget
def estimate_length(self, coords, bodyaxis =None):
if bodyaxis is None:
bodyaxis = [0, 1, 2, 5]
bodycoords = coords[:, bodyaxis, :]
dists = np.sum(np.sqrt(np.sum(np.diff(bodycoords, axis=1)**2, axis=2)), axis=1)
return dists
def estimate_histogram(self, dists, min_threshold=1., max_threshold=99.):
min_length = np.percentile(dists, min_threshold)
max_length = np.percentile(dists, max_threshold)
bins = np.linspace(0.5 * min_length, 1.5 * max_length, 100)
hist, edges = np.histogram(dists, bins=bins, density=True)
return hist, edges
def setCoordinates(self, coordinates):
dists = self.estimate_length(coordinates)
n, e = self.estimate_histogram(dists)
plot = self._plot_widget.addPlot()
bgi = pg.BarGraphItem(x0=e[:-1], x1=e[1:], height=n, pen='w', brush=(0,0,255,150))
plot.addItem(bgi)
plot.setLabel('left', "prob. density")
plot.setLabel('bottom', "bodylength", units="px")
plot.addItem(self._t1_selection)
plot.addItem(self._t2_selection)
def selections(self, track1=True):
if track1:
return self._t1_selection.getRegion()
else:
return self._t2_selection.getRegion()
def main():
import pickle
from fixtracks.info import PACKAGE_ROOT
from PySide6.QtWidgets import QApplication
datafile = PACKAGE_ROOT / "data/merged_small.pkl"
print(datafile)
with open(datafile, "rb") as f:
df = pickle.load(f)
coords = np.stack(df.keypoints.values,).astype(np.float32)[:,:,:]
app = QApplication([])
window = QWidget()
window.setMinimumSize(200, 200)
layout = QVBoxLayout()
win = SizeClassifier()
win.setCoordinates(coords)
btn = QPushButton("get bounds")
btn.clicked.connect(lambda: win.selections())
layout.addWidget(win)
layout.addWidget(btn)
window.setLayout(layout)
window.show()
app.exec()
if __name__ == "__main__":
main()