[tracks, classifier] include neighborhood classifier, not yet doing

anything
This commit is contained in:
Jan Grewe 2025-02-10 11:08:06 +01:00
parent 96e4b0b2c5
commit 5a128cf28e
2 changed files with 153 additions and 37 deletions

View File

@ -1,15 +1,17 @@
import logging import logging
import numpy as np import numpy as np
import pyqtgraph as pg
from PySide6.QtWidgets import QWidget, QVBoxLayout, QTabWidget,QPushButton from PySide6.QtWidgets import QWidget, QVBoxLayout, QTabWidget,QPushButton, QGraphicsView
from PySide6.QtCore import Signal from PySide6.QtCore import Signal
from PySide6.QtGui import QBrush, QColor from PySide6.QtGui import QBrush, QColor
import pyqtgraph as pg from fixtracks.utils.trackingdata import TrackingData
class SizeClassifier(QWidget): class SizeClassifier(QWidget):
apply = Signal() apply = Signal()
name = "SizeClassifier"
def __init__(self, parent=None): def __init__(self, parent=None):
super().__init__(parent) super().__init__(parent)
@ -29,17 +31,16 @@ class SizeClassifier(QWidget):
def setupGraph(self): def setupGraph(self):
track1_brush = QBrush(QColor.fromString("orange")) track1_brush = QBrush(QColor.fromString("orange"))
track1_brush.color().setAlphaF(0.5)
track2_brush = QBrush(QColor.fromString("green")) track2_brush = QBrush(QColor.fromString("green"))
pg.setConfigOptions(antialias=True) pg.setConfigOptions(antialias=True)
plot_widget = pg.GraphicsLayoutWidget(show=False) plot_widget = pg.GraphicsLayoutWidget(show=False)
self._t1_selection = pg.LinearRegionItem([100, 200]) self._t1_selection = pg.LinearRegionItem([100, 200])
self._t1_selection.setZValue(-10) # what is that? self._t1_selection.setZValue(-10)
self._t1_selection.setBrush(track1_brush) self._t1_selection.setBrush(track1_brush)
self._t2_selection = pg.LinearRegionItem([300,400]) self._t2_selection = pg.LinearRegionItem([300,400])
self._t2_selection.setZValue(-10) # what is that? self._t2_selection.setZValue(-10)
self._t2_selection.setBrush(track2_brush) self._t2_selection.setBrush(track2_brush)
return plot_widget return plot_widget
@ -47,8 +48,8 @@ class SizeClassifier(QWidget):
if bodyaxis is None: if bodyaxis is None:
bodyaxis = [0, 1, 2, 5] bodyaxis = [0, 1, 2, 5]
bodycoords = coords[:, bodyaxis, :] bodycoords = coords[:, bodyaxis, :]
dists = np.sum(np.sqrt(np.sum(np.diff(bodycoords, axis=1)**2, axis=2)), axis=1) lengths = np.sum(np.sqrt(np.sum(np.diff(bodycoords, axis=1)**2, axis=2)), axis=1)
return dists return lengths
def estimate_histogram(self, dists, min_threshold=1., max_threshold=99.): def estimate_histogram(self, dists, min_threshold=1., max_threshold=99.):
min_length = np.percentile(dists, min_threshold) min_length = np.percentile(dists, min_threshold)
@ -83,6 +84,112 @@ class SizeClassifier(QWidget):
tracks[(self._sizes >= t2lower) & (self._sizes < t2upper)] = 2 tracks[(self._sizes >= t2lower) & (self._sizes < t2upper)] = 2
return tracks return tracks
class NeighborhoodValidator(QWidget):
apply = Signal()
name = "Neighborhood Validator"
def __init__(self, parent = None):
super().__init__(parent)
self._threshold = None
self._positions = None
self._distances = None
self._tracks = None
self._plot = None
self._plot_widget = self.setupGraph()
self._apply_btn = QPushButton("apply")
self._apply_btn.clicked.connect(lambda: self.apply.emit())
layout = QVBoxLayout()
print(isinstance(self._plot_widget, QGraphicsView))
layout.addWidget(self._plot_widget)
layout.addWidget(self._apply_btn)
self.setLayout(layout)
def setupGraph(self):
pg.setConfigOptions(antialias=True)
plot_widget = pg.GraphicsLayoutWidget(show=False)
self._threshold = pg.LineSegmentROI([[10, 64], [120,64]], pen='r')
self._threshold.setZValue(-10) # what is that?
return plot_widget
def estimate_histogram(self, dists, min_threshold=1., max_threshold=99., bin_count=100, log=False):
min_dist = np.percentile(dists, min_threshold)
max_dist = np.percentile(dists, max_threshold)
if log:
bins = np.logspace(min_dist, max_dist, bin_count, base=10)
bins = np.linspace(min_dist, max_dist, bin_count)
hist, edges = np.histogram(dists, bins=bins, density=True)
return hist, edges
def neighborDistances(self, x, frames, n=5, symmetric=True):
logging.debug("classifier:NeighborhoodValidator neighborDistance")
pad_shape = list(x.shape)
pad_shape[0] = n
pad = np.atleast_2d(np.zeros(pad_shape))
if symmetric:
padded_x = np.vstack((pad, x, pad))
dists = np.zeros((x.shape[0]-1, 2*n))
else:
padded_x = np.vstack((pad, x))
dists = np.zeros((x.shape[0]-1, n))
count = 0
r = range(-n, n+1) if symmetric else range(-n, 0)
for i in r:
if i == 0:
continue
shifted_x = np.roll(padded_x, i, axis=0)
dist = np.sqrt(np.sum((padded_x - shifted_x)**2, axis=1))
dists[:, count] = dist[n+1:]/np.diff(frames)
count += 1
return dists
def setData(self, positions, tracks, frames):
"""Set the data, the classifier/should be working on.
Parameters
----------
positions : np.ndarray
The position estimates, e.g. the center of gravity for each detection
tracks : np.ndarray
The current track assignment.
frames : np.ndarray
respective frame.
"""
def mouseClicked(self, event):
print("mouse clicked at", event.pos())
track2_brush = QBrush(QColor.fromString("green"))
track1_brush = QBrush(QColor.fromString("orange"))
self._positions = positions
self._tracks = tracks
self._frames = frames
t1_positions = self._positions[self._tracks == 1]
t1_frames = self._frames[self._tracks == 1]
t1_distances = self.neighborDistances(t1_positions, t1_frames, 1, False)
t2_positions = self._positions[self._tracks == 2]
t2_frames = self._frames[self._tracks == 2]
t2_distances = self.neighborDistances(t2_positions, t2_frames, 1, False)
n, e = self.estimate_histogram(t1_distances[1:], bin_count=100, log=False)
self._plot = self._plot_widget.addPlot()
bgi1 = pg.BarGraphItem(x0=e[:-1], x1=e[1:], height=n, pen='w', brush=track1_brush)
self._plot.addItem(bgi1)
n, e = self.estimate_histogram(t2_distances[1:], bin_count=100, log=False)
bgi2 = pg.BarGraphItem(x0=e[:-1], x1=e[1:], height=n, pen='w', brush=track2_brush)
self._plot.addItem(bgi2)
self._plot.scene().sigMouseClicked.connect(mouseClicked)
self._plot.setLogMode(x=False, y=True)
# plot.setXRange(np.min(t1_distances), np.max(t1_distances))
self._plot.setLabel('left', "prob. density")
self._plot.setLabel('bottom', "distance", units="px/frame")
# plot.addItem(self._threshold)
vLine = pg.InfiniteLine(pos=10, angle=90, movable=False)
self._plot.addItem(vLine, ignoreBounds=True)
vb = self._plot.vb
class ClassifierWidget(QTabWidget): class ClassifierWidget(QTabWidget):
apply_sizeclassifier = Signal(np.ndarray) apply_sizeclassifier = Signal(np.ndarray)
@ -90,7 +197,9 @@ class ClassifierWidget(QTabWidget):
def __init__(self, parent=None): def __init__(self, parent=None):
super().__init__(parent) super().__init__(parent)
self._size_classifier = SizeClassifier() self._size_classifier = SizeClassifier()
self.addTab(self._size_classifier, "Size classifier") self._neigborhood_validator = NeighborhoodValidator()
self.addTab(self._size_classifier, SizeClassifier.name)
self.addTab(self._neigborhood_validator, NeighborhoodValidator.name)
self._size_classifier.apply.connect(self._on_applySizeClassifier) self._size_classifier.apply.connect(self._on_applySizeClassifier)
def _on_applySizeClassifier(self): def _on_applySizeClassifier(self):
@ -101,48 +210,51 @@ class ClassifierWidget(QTabWidget):
def size_classifier(self): def size_classifier(self):
return self._size_classifier return self._size_classifier
@property
def neighborhood_validator(self):
return self._neigborhood_validator
def test_sizeClassifier(coords): def as_dict(df):
app = QApplication([]) d = {c: df[c].values for c in df.columns}
window = QWidget() d["index"] = df.index.values
window.setMinimumSize(200, 200) return d
layout = QVBoxLayout()
win = SizeClassifier()
win.setCoordinates(coords)
layout.addWidget(win)
window.setLayout(layout)
window.show()
app.exec()
def main():
test_size = False
import pickle
from fixtracks.info import PACKAGE_ROOT
datafile = PACKAGE_ROOT / "data/merged_small.pkl"
with open(datafile, "rb") as f:
df = pickle.load(f)
data = TrackingData()
data.setData(as_dict(df))
positions = data.centerOfGravity()
tracks = data["track"]
frames = data["frame"]
coords = data.coordinates()
def test_neighborhoodClassifier(coords):
app = QApplication([]) app = QApplication([])
window = QWidget() window = QWidget()
window.setMinimumSize(200, 200) window.setMinimumSize(200, 200)
# if test_size:
# win = SizeClassifier()
# win.setCoordinates(coords)
# else:
w = ClassifierWidget()
w.neighborhood_validator.setData(positions, tracks, frames)
layout = QVBoxLayout() layout = QVBoxLayout()
win = SizeClassifier() layout.addWidget(w)
win.setCoordinates(coords)
layout.addWidget(win)
window.setLayout(layout) window.setLayout(layout)
window.show() window.show()
app.exec() app.exec()
def main():
import pickle
from fixtracks.info import PACKAGE_ROOT
from PySide6.QtWidgets import QApplication
datafile = PACKAGE_ROOT / "data/merged_small.pkl"
print(datafile)
with open(datafile, "rb") as f:
df = pickle.load(f)
coords = np.stack(df.keypoints.values,).astype(np.float32)
frames = df.frame.values
test_sizeClassifier(coords)
if __name__ == "__main__": if __name__ == "__main__":
from PySide6.QtWidgets import QApplication
main() main()

View File

@ -378,7 +378,11 @@ class FixTracks(QWidget):
rel_width = self._windowspinner.value() / maxframes rel_width = self._windowspinner.value() / maxframes
self._timeline.setWindowWidth(rel_width) self._timeline.setWindowWidth(rel_width)
coordinates = self._data.coordinates() coordinates = self._data.coordinates()
positions = self._data.centerOfGravity()
tracks = self._data["track"]
frames = self._data["frame"]
self._classifier.size_classifier.setCoordinates(coordinates) self._classifier.size_classifier.setCoordinates(coordinates)
self._classifier.neighborhood_validator.setData(positions, tracks, frames)
self.update() self.update()
self._saveBtn.setEnabled(True) self._saveBtn.setEnabled(True)