From 5a128cf28e2c5bbdd6b66aac128016158a1f9a28 Mon Sep 17 00:00:00 2001 From: Jan Grewe Date: Mon, 10 Feb 2025 11:08:06 +0100 Subject: [PATCH] [tracks, classifier] include neighborhood classifier, not yet doing anything --- fixtracks/widgets/classifier.py | 186 +++++++++++++++++++++++++------- fixtracks/widgets/tracks.py | 4 + 2 files changed, 153 insertions(+), 37 deletions(-) diff --git a/fixtracks/widgets/classifier.py b/fixtracks/widgets/classifier.py index 2bcf5b2..5ece6ca 100644 --- a/fixtracks/widgets/classifier.py +++ b/fixtracks/widgets/classifier.py @@ -1,15 +1,17 @@ import logging import numpy as np +import pyqtgraph as pg -from PySide6.QtWidgets import QWidget, QVBoxLayout, QTabWidget,QPushButton +from PySide6.QtWidgets import QWidget, QVBoxLayout, QTabWidget,QPushButton, QGraphicsView from PySide6.QtCore import Signal from PySide6.QtGui import QBrush, QColor -import pyqtgraph as pg +from fixtracks.utils.trackingdata import TrackingData class SizeClassifier(QWidget): apply = Signal() + name = "SizeClassifier" def __init__(self, parent=None): super().__init__(parent) @@ -29,17 +31,16 @@ class SizeClassifier(QWidget): def setupGraph(self): track1_brush = QBrush(QColor.fromString("orange")) - track1_brush.color().setAlphaF(0.5) track2_brush = QBrush(QColor.fromString("green")) pg.setConfigOptions(antialias=True) plot_widget = pg.GraphicsLayoutWidget(show=False) self._t1_selection = pg.LinearRegionItem([100, 200]) - self._t1_selection.setZValue(-10) # what is that? + self._t1_selection.setZValue(-10) self._t1_selection.setBrush(track1_brush) self._t2_selection = pg.LinearRegionItem([300,400]) - self._t2_selection.setZValue(-10) # what is that? + self._t2_selection.setZValue(-10) self._t2_selection.setBrush(track2_brush) return plot_widget @@ -47,8 +48,8 @@ class SizeClassifier(QWidget): if bodyaxis is None: bodyaxis = [0, 1, 2, 5] bodycoords = coords[:, bodyaxis, :] - dists = np.sum(np.sqrt(np.sum(np.diff(bodycoords, axis=1)**2, axis=2)), axis=1) - return dists + lengths = np.sum(np.sqrt(np.sum(np.diff(bodycoords, axis=1)**2, axis=2)), axis=1) + return lengths def estimate_histogram(self, dists, min_threshold=1., max_threshold=99.): min_length = np.percentile(dists, min_threshold) @@ -83,6 +84,112 @@ class SizeClassifier(QWidget): tracks[(self._sizes >= t2lower) & (self._sizes < t2upper)] = 2 return tracks +class NeighborhoodValidator(QWidget): + apply = Signal() + name = "Neighborhood Validator" + + def __init__(self, parent = None): + super().__init__(parent) + self._threshold = None + self._positions = None + self._distances = None + self._tracks = None + self._plot = None + + self._plot_widget = self.setupGraph() + self._apply_btn = QPushButton("apply") + self._apply_btn.clicked.connect(lambda: self.apply.emit()) + + layout = QVBoxLayout() + print(isinstance(self._plot_widget, QGraphicsView)) + layout.addWidget(self._plot_widget) + layout.addWidget(self._apply_btn) + self.setLayout(layout) + + def setupGraph(self): + pg.setConfigOptions(antialias=True) + plot_widget = pg.GraphicsLayoutWidget(show=False) + self._threshold = pg.LineSegmentROI([[10, 64], [120,64]], pen='r') + self._threshold.setZValue(-10) # what is that? + return plot_widget + + def estimate_histogram(self, dists, min_threshold=1., max_threshold=99., bin_count=100, log=False): + min_dist = np.percentile(dists, min_threshold) + max_dist = np.percentile(dists, max_threshold) + if log: + bins = np.logspace(min_dist, max_dist, bin_count, base=10) + bins = np.linspace(min_dist, max_dist, bin_count) + hist, edges = np.histogram(dists, bins=bins, density=True) + return hist, edges + + def neighborDistances(self, x, frames, n=5, symmetric=True): + logging.debug("classifier:NeighborhoodValidator neighborDistance") + pad_shape = list(x.shape) + pad_shape[0] = n + pad = np.atleast_2d(np.zeros(pad_shape)) + if symmetric: + padded_x = np.vstack((pad, x, pad)) + dists = np.zeros((x.shape[0]-1, 2*n)) + else: + padded_x = np.vstack((pad, x)) + dists = np.zeros((x.shape[0]-1, n)) + + count = 0 + r = range(-n, n+1) if symmetric else range(-n, 0) + for i in r: + if i == 0: + continue + shifted_x = np.roll(padded_x, i, axis=0) + dist = np.sqrt(np.sum((padded_x - shifted_x)**2, axis=1)) + dists[:, count] = dist[n+1:]/np.diff(frames) + count += 1 + return dists + + def setData(self, positions, tracks, frames): + """Set the data, the classifier/should be working on. + + Parameters + ---------- + positions : np.ndarray + The position estimates, e.g. the center of gravity for each detection + tracks : np.ndarray + The current track assignment. + frames : np.ndarray + respective frame. + """ + def mouseClicked(self, event): + print("mouse clicked at", event.pos()) + + track2_brush = QBrush(QColor.fromString("green")) + track1_brush = QBrush(QColor.fromString("orange")) + self._positions = positions + self._tracks = tracks + self._frames = frames + t1_positions = self._positions[self._tracks == 1] + t1_frames = self._frames[self._tracks == 1] + t1_distances = self.neighborDistances(t1_positions, t1_frames, 1, False) + t2_positions = self._positions[self._tracks == 2] + t2_frames = self._frames[self._tracks == 2] + t2_distances = self.neighborDistances(t2_positions, t2_frames, 1, False) + + n, e = self.estimate_histogram(t1_distances[1:], bin_count=100, log=False) + self._plot = self._plot_widget.addPlot() + bgi1 = pg.BarGraphItem(x0=e[:-1], x1=e[1:], height=n, pen='w', brush=track1_brush) + self._plot.addItem(bgi1) + n, e = self.estimate_histogram(t2_distances[1:], bin_count=100, log=False) + bgi2 = pg.BarGraphItem(x0=e[:-1], x1=e[1:], height=n, pen='w', brush=track2_brush) + self._plot.addItem(bgi2) + self._plot.scene().sigMouseClicked.connect(mouseClicked) + self._plot.setLogMode(x=False, y=True) + # plot.setXRange(np.min(t1_distances), np.max(t1_distances)) + self._plot.setLabel('left', "prob. density") + self._plot.setLabel('bottom', "distance", units="px/frame") + # plot.addItem(self._threshold) + vLine = pg.InfiniteLine(pos=10, angle=90, movable=False) + self._plot.addItem(vLine, ignoreBounds=True) + vb = self._plot.vb + + class ClassifierWidget(QTabWidget): apply_sizeclassifier = Signal(np.ndarray) @@ -90,7 +197,9 @@ class ClassifierWidget(QTabWidget): def __init__(self, parent=None): super().__init__(parent) self._size_classifier = SizeClassifier() - self.addTab(self._size_classifier, "Size classifier") + self._neigborhood_validator = NeighborhoodValidator() + self.addTab(self._size_classifier, SizeClassifier.name) + self.addTab(self._neigborhood_validator, NeighborhoodValidator.name) self._size_classifier.apply.connect(self._on_applySizeClassifier) def _on_applySizeClassifier(self): @@ -101,48 +210,51 @@ class ClassifierWidget(QTabWidget): def size_classifier(self): return self._size_classifier + @property + def neighborhood_validator(self): + return self._neigborhood_validator -def test_sizeClassifier(coords): - app = QApplication([]) - window = QWidget() - window.setMinimumSize(200, 200) - layout = QVBoxLayout() - win = SizeClassifier() - win.setCoordinates(coords) +def as_dict(df): + d = {c: df[c].values for c in df.columns} + d["index"] = df.index.values + return d - layout.addWidget(win) - window.setLayout(layout) - window.show() - app.exec() +def main(): + test_size = False + import pickle + from fixtracks.info import PACKAGE_ROOT + + datafile = PACKAGE_ROOT / "data/merged_small.pkl" + + with open(datafile, "rb") as f: + df = pickle.load(f) + data = TrackingData() + data.setData(as_dict(df)) + + positions = data.centerOfGravity() + tracks = data["track"] + frames = data["frame"] + coords = data.coordinates() -def test_neighborhoodClassifier(coords): app = QApplication([]) window = QWidget() window.setMinimumSize(200, 200) + # if test_size: + # win = SizeClassifier() + # win.setCoordinates(coords) + # else: + w = ClassifierWidget() + w.neighborhood_validator.setData(positions, tracks, frames) + layout = QVBoxLayout() - win = SizeClassifier() - win.setCoordinates(coords) - layout.addWidget(win) + layout.addWidget(w) window.setLayout(layout) window.show() app.exec() -def main(): - import pickle - from fixtracks.info import PACKAGE_ROOT - from PySide6.QtWidgets import QApplication - datafile = PACKAGE_ROOT / "data/merged_small.pkl" - print(datafile) - with open(datafile, "rb") as f: - df = pickle.load(f) - - coords = np.stack(df.keypoints.values,).astype(np.float32) - frames = df.frame.values - test_sizeClassifier(coords) - - if __name__ == "__main__": + from PySide6.QtWidgets import QApplication main() diff --git a/fixtracks/widgets/tracks.py b/fixtracks/widgets/tracks.py index 85d85e6..eb94232 100644 --- a/fixtracks/widgets/tracks.py +++ b/fixtracks/widgets/tracks.py @@ -378,7 +378,11 @@ class FixTracks(QWidget): rel_width = self._windowspinner.value() / maxframes self._timeline.setWindowWidth(rel_width) coordinates = self._data.coordinates() + positions = self._data.centerOfGravity() + tracks = self._data["track"] + frames = self._data["frame"] self._classifier.size_classifier.setCoordinates(coordinates) + self._classifier.neighborhood_validator.setData(positions, tracks, frames) self.update() self._saveBtn.setEnabled(True)