[tracks, classifier] include neighborhood classifier, not yet doing
anything
This commit is contained in:
		
							parent
							
								
									96e4b0b2c5
								
							
						
					
					
						commit
						5a128cf28e
					
				@ -1,15 +1,17 @@
 | 
			
		||||
import logging
 | 
			
		||||
import numpy as np
 | 
			
		||||
import pyqtgraph as pg
 | 
			
		||||
 | 
			
		||||
from PySide6.QtWidgets import QWidget, QVBoxLayout, QTabWidget,QPushButton
 | 
			
		||||
from PySide6.QtWidgets import QWidget, QVBoxLayout, QTabWidget,QPushButton, QGraphicsView
 | 
			
		||||
from PySide6.QtCore import Signal
 | 
			
		||||
from PySide6.QtGui import QBrush, QColor
 | 
			
		||||
 | 
			
		||||
import pyqtgraph as pg
 | 
			
		||||
from fixtracks.utils.trackingdata import TrackingData
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SizeClassifier(QWidget):
 | 
			
		||||
    apply = Signal()
 | 
			
		||||
    name = "SizeClassifier"
 | 
			
		||||
 | 
			
		||||
    def __init__(self, parent=None):
 | 
			
		||||
        super().__init__(parent)
 | 
			
		||||
@ -29,17 +31,16 @@ class SizeClassifier(QWidget):
 | 
			
		||||
 | 
			
		||||
    def setupGraph(self):
 | 
			
		||||
        track1_brush = QBrush(QColor.fromString("orange"))
 | 
			
		||||
        track1_brush.color().setAlphaF(0.5)
 | 
			
		||||
        track2_brush = QBrush(QColor.fromString("green"))
 | 
			
		||||
 | 
			
		||||
        pg.setConfigOptions(antialias=True)
 | 
			
		||||
        plot_widget = pg.GraphicsLayoutWidget(show=False)
 | 
			
		||||
 | 
			
		||||
        self._t1_selection = pg.LinearRegionItem([100, 200])
 | 
			
		||||
        self._t1_selection.setZValue(-10)  # what is that?
 | 
			
		||||
        self._t1_selection.setZValue(-10)
 | 
			
		||||
        self._t1_selection.setBrush(track1_brush)
 | 
			
		||||
        self._t2_selection = pg.LinearRegionItem([300,400])
 | 
			
		||||
        self._t2_selection.setZValue(-10)  # what is that?
 | 
			
		||||
        self._t2_selection.setZValue(-10)
 | 
			
		||||
        self._t2_selection.setBrush(track2_brush)
 | 
			
		||||
        return plot_widget
 | 
			
		||||
 | 
			
		||||
@ -47,8 +48,8 @@ class SizeClassifier(QWidget):
 | 
			
		||||
        if bodyaxis is None:
 | 
			
		||||
            bodyaxis = [0, 1, 2, 5]
 | 
			
		||||
        bodycoords = coords[:, bodyaxis, :]
 | 
			
		||||
        dists = np.sum(np.sqrt(np.sum(np.diff(bodycoords, axis=1)**2, axis=2)), axis=1)
 | 
			
		||||
        return dists
 | 
			
		||||
        lengths = np.sum(np.sqrt(np.sum(np.diff(bodycoords, axis=1)**2, axis=2)), axis=1)
 | 
			
		||||
        return lengths
 | 
			
		||||
 | 
			
		||||
    def estimate_histogram(self, dists, min_threshold=1., max_threshold=99.):
 | 
			
		||||
        min_length = np.percentile(dists, min_threshold)
 | 
			
		||||
@ -83,6 +84,112 @@ class SizeClassifier(QWidget):
 | 
			
		||||
        tracks[(self._sizes >= t2lower) & (self._sizes < t2upper)] = 2
 | 
			
		||||
        return tracks
 | 
			
		||||
 | 
			
		||||
class NeighborhoodValidator(QWidget):
 | 
			
		||||
    apply = Signal()
 | 
			
		||||
    name = "Neighborhood Validator"
 | 
			
		||||
 | 
			
		||||
    def __init__(self, parent = None):
 | 
			
		||||
        super().__init__(parent)
 | 
			
		||||
        self._threshold = None
 | 
			
		||||
        self._positions = None
 | 
			
		||||
        self._distances = None
 | 
			
		||||
        self._tracks = None
 | 
			
		||||
        self._plot = None
 | 
			
		||||
 | 
			
		||||
        self._plot_widget = self.setupGraph()
 | 
			
		||||
        self._apply_btn = QPushButton("apply")
 | 
			
		||||
        self._apply_btn.clicked.connect(lambda: self.apply.emit())
 | 
			
		||||
 | 
			
		||||
        layout = QVBoxLayout()
 | 
			
		||||
        print(isinstance(self._plot_widget, QGraphicsView))
 | 
			
		||||
        layout.addWidget(self._plot_widget)
 | 
			
		||||
        layout.addWidget(self._apply_btn)
 | 
			
		||||
        self.setLayout(layout)
 | 
			
		||||
 | 
			
		||||
    def setupGraph(self):
 | 
			
		||||
        pg.setConfigOptions(antialias=True)
 | 
			
		||||
        plot_widget = pg.GraphicsLayoutWidget(show=False)
 | 
			
		||||
        self._threshold = pg.LineSegmentROI([[10, 64], [120,64]], pen='r')
 | 
			
		||||
        self._threshold.setZValue(-10)  # what is that?
 | 
			
		||||
        return plot_widget
 | 
			
		||||
 | 
			
		||||
    def estimate_histogram(self, dists, min_threshold=1., max_threshold=99., bin_count=100, log=False):
 | 
			
		||||
        min_dist = np.percentile(dists, min_threshold)
 | 
			
		||||
        max_dist = np.percentile(dists, max_threshold)
 | 
			
		||||
        if log:
 | 
			
		||||
            bins = np.logspace(min_dist, max_dist, bin_count, base=10)
 | 
			
		||||
        bins = np.linspace(min_dist, max_dist, bin_count)
 | 
			
		||||
        hist, edges = np.histogram(dists, bins=bins, density=True)
 | 
			
		||||
        return hist, edges
 | 
			
		||||
 | 
			
		||||
    def neighborDistances(self, x, frames, n=5, symmetric=True):
 | 
			
		||||
        logging.debug("classifier:NeighborhoodValidator neighborDistance")
 | 
			
		||||
        pad_shape = list(x.shape)
 | 
			
		||||
        pad_shape[0] = n
 | 
			
		||||
        pad = np.atleast_2d(np.zeros(pad_shape))
 | 
			
		||||
        if symmetric:
 | 
			
		||||
            padded_x = np.vstack((pad, x, pad))
 | 
			
		||||
            dists = np.zeros((x.shape[0]-1, 2*n))
 | 
			
		||||
        else:
 | 
			
		||||
            padded_x = np.vstack((pad, x))
 | 
			
		||||
            dists = np.zeros((x.shape[0]-1, n))
 | 
			
		||||
 | 
			
		||||
        count = 0
 | 
			
		||||
        r = range(-n, n+1) if symmetric else range(-n, 0)
 | 
			
		||||
        for i in r:
 | 
			
		||||
            if i == 0:
 | 
			
		||||
                continue
 | 
			
		||||
            shifted_x = np.roll(padded_x, i, axis=0)
 | 
			
		||||
            dist = np.sqrt(np.sum((padded_x - shifted_x)**2, axis=1))
 | 
			
		||||
            dists[:, count] = dist[n+1:]/np.diff(frames)
 | 
			
		||||
            count += 1
 | 
			
		||||
        return dists
 | 
			
		||||
 | 
			
		||||
    def setData(self, positions, tracks, frames):
 | 
			
		||||
        """Set the data, the classifier/should be working on.
 | 
			
		||||
 | 
			
		||||
        Parameters
 | 
			
		||||
        ----------
 | 
			
		||||
        positions : np.ndarray
 | 
			
		||||
            The position estimates, e.g. the center of gravity for each detection
 | 
			
		||||
        tracks : np.ndarray
 | 
			
		||||
            The current track assignment.
 | 
			
		||||
        frames : np.ndarray
 | 
			
		||||
            respective frame.
 | 
			
		||||
        """
 | 
			
		||||
        def mouseClicked(self, event):
 | 
			
		||||
            print("mouse clicked at", event.pos())
 | 
			
		||||
 | 
			
		||||
        track2_brush = QBrush(QColor.fromString("green"))
 | 
			
		||||
        track1_brush = QBrush(QColor.fromString("orange"))
 | 
			
		||||
        self._positions = positions
 | 
			
		||||
        self._tracks = tracks
 | 
			
		||||
        self._frames = frames
 | 
			
		||||
        t1_positions = self._positions[self._tracks == 1]
 | 
			
		||||
        t1_frames = self._frames[self._tracks == 1]
 | 
			
		||||
        t1_distances = self.neighborDistances(t1_positions, t1_frames, 1, False)
 | 
			
		||||
        t2_positions = self._positions[self._tracks == 2]
 | 
			
		||||
        t2_frames = self._frames[self._tracks == 2]
 | 
			
		||||
        t2_distances = self.neighborDistances(t2_positions, t2_frames, 1, False)
 | 
			
		||||
 | 
			
		||||
        n, e = self.estimate_histogram(t1_distances[1:], bin_count=100, log=False)
 | 
			
		||||
        self._plot = self._plot_widget.addPlot()
 | 
			
		||||
        bgi1 = pg.BarGraphItem(x0=e[:-1], x1=e[1:], height=n, pen='w', brush=track1_brush)
 | 
			
		||||
        self._plot.addItem(bgi1)
 | 
			
		||||
        n, e = self.estimate_histogram(t2_distances[1:], bin_count=100, log=False)
 | 
			
		||||
        bgi2 = pg.BarGraphItem(x0=e[:-1], x1=e[1:], height=n, pen='w', brush=track2_brush)
 | 
			
		||||
        self._plot.addItem(bgi2)
 | 
			
		||||
        self._plot.scene().sigMouseClicked.connect(mouseClicked)
 | 
			
		||||
        self._plot.setLogMode(x=False, y=True)
 | 
			
		||||
        # plot.setXRange(np.min(t1_distances), np.max(t1_distances))
 | 
			
		||||
        self._plot.setLabel('left', "prob. density")
 | 
			
		||||
        self._plot.setLabel('bottom', "distance", units="px/frame")
 | 
			
		||||
        # plot.addItem(self._threshold)
 | 
			
		||||
        vLine = pg.InfiniteLine(pos=10, angle=90, movable=False)
 | 
			
		||||
        self._plot.addItem(vLine, ignoreBounds=True)
 | 
			
		||||
        vb = self._plot.vb
 | 
			
		||||
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
class ClassifierWidget(QTabWidget):
 | 
			
		||||
    apply_sizeclassifier = Signal(np.ndarray)
 | 
			
		||||
@ -90,7 +197,9 @@ class ClassifierWidget(QTabWidget):
 | 
			
		||||
    def __init__(self, parent=None):
 | 
			
		||||
        super().__init__(parent)
 | 
			
		||||
        self._size_classifier = SizeClassifier()
 | 
			
		||||
        self.addTab(self._size_classifier, "Size classifier")
 | 
			
		||||
        self._neigborhood_validator = NeighborhoodValidator()
 | 
			
		||||
        self.addTab(self._size_classifier, SizeClassifier.name)
 | 
			
		||||
        self.addTab(self._neigborhood_validator, NeighborhoodValidator.name)
 | 
			
		||||
        self._size_classifier.apply.connect(self._on_applySizeClassifier)
 | 
			
		||||
 | 
			
		||||
    def _on_applySizeClassifier(self):
 | 
			
		||||
@ -101,48 +210,51 @@ class ClassifierWidget(QTabWidget):
 | 
			
		||||
    def size_classifier(self):
 | 
			
		||||
        return self._size_classifier
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def neighborhood_validator(self):
 | 
			
		||||
        return self._neigborhood_validator
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_sizeClassifier(coords):
 | 
			
		||||
    app = QApplication([])
 | 
			
		||||
    window = QWidget()
 | 
			
		||||
    window.setMinimumSize(200, 200)
 | 
			
		||||
    layout = QVBoxLayout()
 | 
			
		||||
    win = SizeClassifier()
 | 
			
		||||
    win.setCoordinates(coords)
 | 
			
		||||
 | 
			
		||||
    layout.addWidget(win)
 | 
			
		||||
    window.setLayout(layout)
 | 
			
		||||
    window.show()
 | 
			
		||||
    app.exec()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_neighborhoodClassifier(coords):
 | 
			
		||||
    app = QApplication([])
 | 
			
		||||
    window = QWidget()
 | 
			
		||||
    window.setMinimumSize(200, 200)
 | 
			
		||||
    layout = QVBoxLayout()
 | 
			
		||||
    win = SizeClassifier()
 | 
			
		||||
    win.setCoordinates(coords)
 | 
			
		||||
    layout.addWidget(win)
 | 
			
		||||
    window.setLayout(layout)
 | 
			
		||||
    window.show()
 | 
			
		||||
    app.exec()
 | 
			
		||||
def as_dict(df):
 | 
			
		||||
    d = {c: df[c].values for c in df.columns}
 | 
			
		||||
    d["index"] = df.index.values
 | 
			
		||||
    return d
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def main():
 | 
			
		||||
    test_size = False
 | 
			
		||||
    import pickle
 | 
			
		||||
    from fixtracks.info import PACKAGE_ROOT
 | 
			
		||||
    from PySide6.QtWidgets import QApplication
 | 
			
		||||
    
 | 
			
		||||
    datafile = PACKAGE_ROOT / "data/merged_small.pkl"
 | 
			
		||||
    print(datafile)
 | 
			
		||||
 | 
			
		||||
    with open(datafile, "rb") as f:
 | 
			
		||||
        df = pickle.load(f)
 | 
			
		||||
    data = TrackingData()
 | 
			
		||||
    data.setData(as_dict(df))
 | 
			
		||||
 | 
			
		||||
    coords = np.stack(df.keypoints.values,).astype(np.float32)
 | 
			
		||||
    frames = df.frame.values
 | 
			
		||||
    test_sizeClassifier(coords)
 | 
			
		||||
    positions = data.centerOfGravity()
 | 
			
		||||
    tracks = data["track"]
 | 
			
		||||
    frames = data["frame"]
 | 
			
		||||
    coords = data.coordinates()
 | 
			
		||||
 | 
			
		||||
    app = QApplication([])
 | 
			
		||||
    window = QWidget()
 | 
			
		||||
    window.setMinimumSize(200, 200)
 | 
			
		||||
    # if test_size:
 | 
			
		||||
    #     win = SizeClassifier()
 | 
			
		||||
    #     win.setCoordinates(coords)
 | 
			
		||||
    # else:
 | 
			
		||||
    w = ClassifierWidget()
 | 
			
		||||
    w.neighborhood_validator.setData(positions, tracks, frames)
 | 
			
		||||
 | 
			
		||||
    layout = QVBoxLayout()
 | 
			
		||||
    layout.addWidget(w)
 | 
			
		||||
    window.setLayout(layout)
 | 
			
		||||
    window.show()
 | 
			
		||||
    app.exec()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    from PySide6.QtWidgets import QApplication
 | 
			
		||||
    main()
 | 
			
		||||
 | 
			
		||||
@ -378,7 +378,11 @@ class FixTracks(QWidget):
 | 
			
		||||
            rel_width = self._windowspinner.value() / maxframes
 | 
			
		||||
            self._timeline.setWindowWidth(rel_width)
 | 
			
		||||
            coordinates = self._data.coordinates()
 | 
			
		||||
            positions = self._data.centerOfGravity()
 | 
			
		||||
            tracks = self._data["track"]
 | 
			
		||||
            frames = self._data["frame"]
 | 
			
		||||
            self._classifier.size_classifier.setCoordinates(coordinates)
 | 
			
		||||
            self._classifier.neighborhood_validator.setData(positions, tracks, frames)
 | 
			
		||||
            self.update()
 | 
			
		||||
            self._saveBtn.setEnabled(True)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user