[tracks, classifier] include neighborhood classifier, not yet doing
anything
This commit is contained in:
		
							parent
							
								
									96e4b0b2c5
								
							
						
					
					
						commit
						5a128cf28e
					
				@ -1,15 +1,17 @@
 | 
				
			|||||||
import logging
 | 
					import logging
 | 
				
			||||||
import numpy as np
 | 
					import numpy as np
 | 
				
			||||||
 | 
					import pyqtgraph as pg
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from PySide6.QtWidgets import QWidget, QVBoxLayout, QTabWidget,QPushButton
 | 
					from PySide6.QtWidgets import QWidget, QVBoxLayout, QTabWidget,QPushButton, QGraphicsView
 | 
				
			||||||
from PySide6.QtCore import Signal
 | 
					from PySide6.QtCore import Signal
 | 
				
			||||||
from PySide6.QtGui import QBrush, QColor
 | 
					from PySide6.QtGui import QBrush, QColor
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import pyqtgraph as pg
 | 
					from fixtracks.utils.trackingdata import TrackingData
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class SizeClassifier(QWidget):
 | 
					class SizeClassifier(QWidget):
 | 
				
			||||||
    apply = Signal()
 | 
					    apply = Signal()
 | 
				
			||||||
 | 
					    name = "SizeClassifier"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __init__(self, parent=None):
 | 
					    def __init__(self, parent=None):
 | 
				
			||||||
        super().__init__(parent)
 | 
					        super().__init__(parent)
 | 
				
			||||||
@ -29,17 +31,16 @@ class SizeClassifier(QWidget):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    def setupGraph(self):
 | 
					    def setupGraph(self):
 | 
				
			||||||
        track1_brush = QBrush(QColor.fromString("orange"))
 | 
					        track1_brush = QBrush(QColor.fromString("orange"))
 | 
				
			||||||
        track1_brush.color().setAlphaF(0.5)
 | 
					 | 
				
			||||||
        track2_brush = QBrush(QColor.fromString("green"))
 | 
					        track2_brush = QBrush(QColor.fromString("green"))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        pg.setConfigOptions(antialias=True)
 | 
					        pg.setConfigOptions(antialias=True)
 | 
				
			||||||
        plot_widget = pg.GraphicsLayoutWidget(show=False)
 | 
					        plot_widget = pg.GraphicsLayoutWidget(show=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self._t1_selection = pg.LinearRegionItem([100, 200])
 | 
					        self._t1_selection = pg.LinearRegionItem([100, 200])
 | 
				
			||||||
        self._t1_selection.setZValue(-10)  # what is that?
 | 
					        self._t1_selection.setZValue(-10)
 | 
				
			||||||
        self._t1_selection.setBrush(track1_brush)
 | 
					        self._t1_selection.setBrush(track1_brush)
 | 
				
			||||||
        self._t2_selection = pg.LinearRegionItem([300,400])
 | 
					        self._t2_selection = pg.LinearRegionItem([300,400])
 | 
				
			||||||
        self._t2_selection.setZValue(-10)  # what is that?
 | 
					        self._t2_selection.setZValue(-10)
 | 
				
			||||||
        self._t2_selection.setBrush(track2_brush)
 | 
					        self._t2_selection.setBrush(track2_brush)
 | 
				
			||||||
        return plot_widget
 | 
					        return plot_widget
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -47,8 +48,8 @@ class SizeClassifier(QWidget):
 | 
				
			|||||||
        if bodyaxis is None:
 | 
					        if bodyaxis is None:
 | 
				
			||||||
            bodyaxis = [0, 1, 2, 5]
 | 
					            bodyaxis = [0, 1, 2, 5]
 | 
				
			||||||
        bodycoords = coords[:, bodyaxis, :]
 | 
					        bodycoords = coords[:, bodyaxis, :]
 | 
				
			||||||
        dists = np.sum(np.sqrt(np.sum(np.diff(bodycoords, axis=1)**2, axis=2)), axis=1)
 | 
					        lengths = np.sum(np.sqrt(np.sum(np.diff(bodycoords, axis=1)**2, axis=2)), axis=1)
 | 
				
			||||||
        return dists
 | 
					        return lengths
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def estimate_histogram(self, dists, min_threshold=1., max_threshold=99.):
 | 
					    def estimate_histogram(self, dists, min_threshold=1., max_threshold=99.):
 | 
				
			||||||
        min_length = np.percentile(dists, min_threshold)
 | 
					        min_length = np.percentile(dists, min_threshold)
 | 
				
			||||||
@ -83,6 +84,112 @@ class SizeClassifier(QWidget):
 | 
				
			|||||||
        tracks[(self._sizes >= t2lower) & (self._sizes < t2upper)] = 2
 | 
					        tracks[(self._sizes >= t2lower) & (self._sizes < t2upper)] = 2
 | 
				
			||||||
        return tracks
 | 
					        return tracks
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class NeighborhoodValidator(QWidget):
 | 
				
			||||||
 | 
					    apply = Signal()
 | 
				
			||||||
 | 
					    name = "Neighborhood Validator"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, parent = None):
 | 
				
			||||||
 | 
					        super().__init__(parent)
 | 
				
			||||||
 | 
					        self._threshold = None
 | 
				
			||||||
 | 
					        self._positions = None
 | 
				
			||||||
 | 
					        self._distances = None
 | 
				
			||||||
 | 
					        self._tracks = None
 | 
				
			||||||
 | 
					        self._plot = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self._plot_widget = self.setupGraph()
 | 
				
			||||||
 | 
					        self._apply_btn = QPushButton("apply")
 | 
				
			||||||
 | 
					        self._apply_btn.clicked.connect(lambda: self.apply.emit())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        layout = QVBoxLayout()
 | 
				
			||||||
 | 
					        print(isinstance(self._plot_widget, QGraphicsView))
 | 
				
			||||||
 | 
					        layout.addWidget(self._plot_widget)
 | 
				
			||||||
 | 
					        layout.addWidget(self._apply_btn)
 | 
				
			||||||
 | 
					        self.setLayout(layout)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def setupGraph(self):
 | 
				
			||||||
 | 
					        pg.setConfigOptions(antialias=True)
 | 
				
			||||||
 | 
					        plot_widget = pg.GraphicsLayoutWidget(show=False)
 | 
				
			||||||
 | 
					        self._threshold = pg.LineSegmentROI([[10, 64], [120,64]], pen='r')
 | 
				
			||||||
 | 
					        self._threshold.setZValue(-10)  # what is that?
 | 
				
			||||||
 | 
					        return plot_widget
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def estimate_histogram(self, dists, min_threshold=1., max_threshold=99., bin_count=100, log=False):
 | 
				
			||||||
 | 
					        min_dist = np.percentile(dists, min_threshold)
 | 
				
			||||||
 | 
					        max_dist = np.percentile(dists, max_threshold)
 | 
				
			||||||
 | 
					        if log:
 | 
				
			||||||
 | 
					            bins = np.logspace(min_dist, max_dist, bin_count, base=10)
 | 
				
			||||||
 | 
					        bins = np.linspace(min_dist, max_dist, bin_count)
 | 
				
			||||||
 | 
					        hist, edges = np.histogram(dists, bins=bins, density=True)
 | 
				
			||||||
 | 
					        return hist, edges
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def neighborDistances(self, x, frames, n=5, symmetric=True):
 | 
				
			||||||
 | 
					        logging.debug("classifier:NeighborhoodValidator neighborDistance")
 | 
				
			||||||
 | 
					        pad_shape = list(x.shape)
 | 
				
			||||||
 | 
					        pad_shape[0] = n
 | 
				
			||||||
 | 
					        pad = np.atleast_2d(np.zeros(pad_shape))
 | 
				
			||||||
 | 
					        if symmetric:
 | 
				
			||||||
 | 
					            padded_x = np.vstack((pad, x, pad))
 | 
				
			||||||
 | 
					            dists = np.zeros((x.shape[0]-1, 2*n))
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            padded_x = np.vstack((pad, x))
 | 
				
			||||||
 | 
					            dists = np.zeros((x.shape[0]-1, n))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        count = 0
 | 
				
			||||||
 | 
					        r = range(-n, n+1) if symmetric else range(-n, 0)
 | 
				
			||||||
 | 
					        for i in r:
 | 
				
			||||||
 | 
					            if i == 0:
 | 
				
			||||||
 | 
					                continue
 | 
				
			||||||
 | 
					            shifted_x = np.roll(padded_x, i, axis=0)
 | 
				
			||||||
 | 
					            dist = np.sqrt(np.sum((padded_x - shifted_x)**2, axis=1))
 | 
				
			||||||
 | 
					            dists[:, count] = dist[n+1:]/np.diff(frames)
 | 
				
			||||||
 | 
					            count += 1
 | 
				
			||||||
 | 
					        return dists
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def setData(self, positions, tracks, frames):
 | 
				
			||||||
 | 
					        """Set the data, the classifier/should be working on.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Parameters
 | 
				
			||||||
 | 
					        ----------
 | 
				
			||||||
 | 
					        positions : np.ndarray
 | 
				
			||||||
 | 
					            The position estimates, e.g. the center of gravity for each detection
 | 
				
			||||||
 | 
					        tracks : np.ndarray
 | 
				
			||||||
 | 
					            The current track assignment.
 | 
				
			||||||
 | 
					        frames : np.ndarray
 | 
				
			||||||
 | 
					            respective frame.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        def mouseClicked(self, event):
 | 
				
			||||||
 | 
					            print("mouse clicked at", event.pos())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        track2_brush = QBrush(QColor.fromString("green"))
 | 
				
			||||||
 | 
					        track1_brush = QBrush(QColor.fromString("orange"))
 | 
				
			||||||
 | 
					        self._positions = positions
 | 
				
			||||||
 | 
					        self._tracks = tracks
 | 
				
			||||||
 | 
					        self._frames = frames
 | 
				
			||||||
 | 
					        t1_positions = self._positions[self._tracks == 1]
 | 
				
			||||||
 | 
					        t1_frames = self._frames[self._tracks == 1]
 | 
				
			||||||
 | 
					        t1_distances = self.neighborDistances(t1_positions, t1_frames, 1, False)
 | 
				
			||||||
 | 
					        t2_positions = self._positions[self._tracks == 2]
 | 
				
			||||||
 | 
					        t2_frames = self._frames[self._tracks == 2]
 | 
				
			||||||
 | 
					        t2_distances = self.neighborDistances(t2_positions, t2_frames, 1, False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        n, e = self.estimate_histogram(t1_distances[1:], bin_count=100, log=False)
 | 
				
			||||||
 | 
					        self._plot = self._plot_widget.addPlot()
 | 
				
			||||||
 | 
					        bgi1 = pg.BarGraphItem(x0=e[:-1], x1=e[1:], height=n, pen='w', brush=track1_brush)
 | 
				
			||||||
 | 
					        self._plot.addItem(bgi1)
 | 
				
			||||||
 | 
					        n, e = self.estimate_histogram(t2_distances[1:], bin_count=100, log=False)
 | 
				
			||||||
 | 
					        bgi2 = pg.BarGraphItem(x0=e[:-1], x1=e[1:], height=n, pen='w', brush=track2_brush)
 | 
				
			||||||
 | 
					        self._plot.addItem(bgi2)
 | 
				
			||||||
 | 
					        self._plot.scene().sigMouseClicked.connect(mouseClicked)
 | 
				
			||||||
 | 
					        self._plot.setLogMode(x=False, y=True)
 | 
				
			||||||
 | 
					        # plot.setXRange(np.min(t1_distances), np.max(t1_distances))
 | 
				
			||||||
 | 
					        self._plot.setLabel('left', "prob. density")
 | 
				
			||||||
 | 
					        self._plot.setLabel('bottom', "distance", units="px/frame")
 | 
				
			||||||
 | 
					        # plot.addItem(self._threshold)
 | 
				
			||||||
 | 
					        vLine = pg.InfiniteLine(pos=10, angle=90, movable=False)
 | 
				
			||||||
 | 
					        self._plot.addItem(vLine, ignoreBounds=True)
 | 
				
			||||||
 | 
					        vb = self._plot.vb
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class ClassifierWidget(QTabWidget):
 | 
					class ClassifierWidget(QTabWidget):
 | 
				
			||||||
    apply_sizeclassifier = Signal(np.ndarray)
 | 
					    apply_sizeclassifier = Signal(np.ndarray)
 | 
				
			||||||
@ -90,7 +197,9 @@ class ClassifierWidget(QTabWidget):
 | 
				
			|||||||
    def __init__(self, parent=None):
 | 
					    def __init__(self, parent=None):
 | 
				
			||||||
        super().__init__(parent)
 | 
					        super().__init__(parent)
 | 
				
			||||||
        self._size_classifier = SizeClassifier()
 | 
					        self._size_classifier = SizeClassifier()
 | 
				
			||||||
        self.addTab(self._size_classifier, "Size classifier")
 | 
					        self._neigborhood_validator = NeighborhoodValidator()
 | 
				
			||||||
 | 
					        self.addTab(self._size_classifier, SizeClassifier.name)
 | 
				
			||||||
 | 
					        self.addTab(self._neigborhood_validator, NeighborhoodValidator.name)
 | 
				
			||||||
        self._size_classifier.apply.connect(self._on_applySizeClassifier)
 | 
					        self._size_classifier.apply.connect(self._on_applySizeClassifier)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _on_applySizeClassifier(self):
 | 
					    def _on_applySizeClassifier(self):
 | 
				
			||||||
@ -101,48 +210,51 @@ class ClassifierWidget(QTabWidget):
 | 
				
			|||||||
    def size_classifier(self):
 | 
					    def size_classifier(self):
 | 
				
			||||||
        return self._size_classifier
 | 
					        return self._size_classifier
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def neighborhood_validator(self):
 | 
				
			||||||
 | 
					        return self._neigborhood_validator
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_sizeClassifier(coords):
 | 
					def as_dict(df):
 | 
				
			||||||
    app = QApplication([])
 | 
					    d = {c: df[c].values for c in df.columns}
 | 
				
			||||||
    window = QWidget()
 | 
					    d["index"] = df.index.values
 | 
				
			||||||
    window.setMinimumSize(200, 200)
 | 
					    return d
 | 
				
			||||||
    layout = QVBoxLayout()
 | 
					 | 
				
			||||||
    win = SizeClassifier()
 | 
					 | 
				
			||||||
    win.setCoordinates(coords)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    layout.addWidget(win)
 | 
					 | 
				
			||||||
    window.setLayout(layout)
 | 
					 | 
				
			||||||
    window.show()
 | 
					 | 
				
			||||||
    app.exec()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def test_neighborhoodClassifier(coords):
 | 
					 | 
				
			||||||
    app = QApplication([])
 | 
					 | 
				
			||||||
    window = QWidget()
 | 
					 | 
				
			||||||
    window.setMinimumSize(200, 200)
 | 
					 | 
				
			||||||
    layout = QVBoxLayout()
 | 
					 | 
				
			||||||
    win = SizeClassifier()
 | 
					 | 
				
			||||||
    win.setCoordinates(coords)
 | 
					 | 
				
			||||||
    layout.addWidget(win)
 | 
					 | 
				
			||||||
    window.setLayout(layout)
 | 
					 | 
				
			||||||
    window.show()
 | 
					 | 
				
			||||||
    app.exec()
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def main():
 | 
					def main():
 | 
				
			||||||
 | 
					    test_size = False
 | 
				
			||||||
    import pickle
 | 
					    import pickle
 | 
				
			||||||
    from fixtracks.info import PACKAGE_ROOT
 | 
					    from fixtracks.info import PACKAGE_ROOT
 | 
				
			||||||
    from PySide6.QtWidgets import QApplication
 | 
					    
 | 
				
			||||||
    datafile = PACKAGE_ROOT / "data/merged_small.pkl"
 | 
					    datafile = PACKAGE_ROOT / "data/merged_small.pkl"
 | 
				
			||||||
    print(datafile)
 | 
					
 | 
				
			||||||
    with open(datafile, "rb") as f:
 | 
					    with open(datafile, "rb") as f:
 | 
				
			||||||
        df = pickle.load(f)
 | 
					        df = pickle.load(f)
 | 
				
			||||||
 | 
					    data = TrackingData()
 | 
				
			||||||
 | 
					    data.setData(as_dict(df))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    coords = np.stack(df.keypoints.values,).astype(np.float32)
 | 
					    positions = data.centerOfGravity()
 | 
				
			||||||
    frames = df.frame.values
 | 
					    tracks = data["track"]
 | 
				
			||||||
    test_sizeClassifier(coords)
 | 
					    frames = data["frame"]
 | 
				
			||||||
 | 
					    coords = data.coordinates()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    app = QApplication([])
 | 
				
			||||||
 | 
					    window = QWidget()
 | 
				
			||||||
 | 
					    window.setMinimumSize(200, 200)
 | 
				
			||||||
 | 
					    # if test_size:
 | 
				
			||||||
 | 
					    #     win = SizeClassifier()
 | 
				
			||||||
 | 
					    #     win.setCoordinates(coords)
 | 
				
			||||||
 | 
					    # else:
 | 
				
			||||||
 | 
					    w = ClassifierWidget()
 | 
				
			||||||
 | 
					    w.neighborhood_validator.setData(positions, tracks, frames)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    layout = QVBoxLayout()
 | 
				
			||||||
 | 
					    layout.addWidget(w)
 | 
				
			||||||
 | 
					    window.setLayout(layout)
 | 
				
			||||||
 | 
					    window.show()
 | 
				
			||||||
 | 
					    app.exec()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
 | 
					    from PySide6.QtWidgets import QApplication
 | 
				
			||||||
    main()
 | 
					    main()
 | 
				
			||||||
 | 
				
			|||||||
@ -378,7 +378,11 @@ class FixTracks(QWidget):
 | 
				
			|||||||
            rel_width = self._windowspinner.value() / maxframes
 | 
					            rel_width = self._windowspinner.value() / maxframes
 | 
				
			||||||
            self._timeline.setWindowWidth(rel_width)
 | 
					            self._timeline.setWindowWidth(rel_width)
 | 
				
			||||||
            coordinates = self._data.coordinates()
 | 
					            coordinates = self._data.coordinates()
 | 
				
			||||||
 | 
					            positions = self._data.centerOfGravity()
 | 
				
			||||||
 | 
					            tracks = self._data["track"]
 | 
				
			||||||
 | 
					            frames = self._data["frame"]
 | 
				
			||||||
            self._classifier.size_classifier.setCoordinates(coordinates)
 | 
					            self._classifier.size_classifier.setCoordinates(coordinates)
 | 
				
			||||||
 | 
					            self._classifier.neighborhood_validator.setData(positions, tracks, frames)
 | 
				
			||||||
            self.update()
 | 
					            self.update()
 | 
				
			||||||
            self._saveBtn.setEnabled(True)
 | 
					            self._saveBtn.setEnabled(True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
		Reference in New Issue
	
	Block a user