[tracks, classifier] include neighborhood classifier, not yet doing
anything
This commit is contained in:
parent
96e4b0b2c5
commit
5a128cf28e
@ -1,15 +1,17 @@
|
|||||||
import logging
|
import logging
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pyqtgraph as pg
|
||||||
|
|
||||||
from PySide6.QtWidgets import QWidget, QVBoxLayout, QTabWidget,QPushButton
|
from PySide6.QtWidgets import QWidget, QVBoxLayout, QTabWidget,QPushButton, QGraphicsView
|
||||||
from PySide6.QtCore import Signal
|
from PySide6.QtCore import Signal
|
||||||
from PySide6.QtGui import QBrush, QColor
|
from PySide6.QtGui import QBrush, QColor
|
||||||
|
|
||||||
import pyqtgraph as pg
|
from fixtracks.utils.trackingdata import TrackingData
|
||||||
|
|
||||||
|
|
||||||
class SizeClassifier(QWidget):
|
class SizeClassifier(QWidget):
|
||||||
apply = Signal()
|
apply = Signal()
|
||||||
|
name = "SizeClassifier"
|
||||||
|
|
||||||
def __init__(self, parent=None):
|
def __init__(self, parent=None):
|
||||||
super().__init__(parent)
|
super().__init__(parent)
|
||||||
@ -29,17 +31,16 @@ class SizeClassifier(QWidget):
|
|||||||
|
|
||||||
def setupGraph(self):
|
def setupGraph(self):
|
||||||
track1_brush = QBrush(QColor.fromString("orange"))
|
track1_brush = QBrush(QColor.fromString("orange"))
|
||||||
track1_brush.color().setAlphaF(0.5)
|
|
||||||
track2_brush = QBrush(QColor.fromString("green"))
|
track2_brush = QBrush(QColor.fromString("green"))
|
||||||
|
|
||||||
pg.setConfigOptions(antialias=True)
|
pg.setConfigOptions(antialias=True)
|
||||||
plot_widget = pg.GraphicsLayoutWidget(show=False)
|
plot_widget = pg.GraphicsLayoutWidget(show=False)
|
||||||
|
|
||||||
self._t1_selection = pg.LinearRegionItem([100, 200])
|
self._t1_selection = pg.LinearRegionItem([100, 200])
|
||||||
self._t1_selection.setZValue(-10) # what is that?
|
self._t1_selection.setZValue(-10)
|
||||||
self._t1_selection.setBrush(track1_brush)
|
self._t1_selection.setBrush(track1_brush)
|
||||||
self._t2_selection = pg.LinearRegionItem([300,400])
|
self._t2_selection = pg.LinearRegionItem([300,400])
|
||||||
self._t2_selection.setZValue(-10) # what is that?
|
self._t2_selection.setZValue(-10)
|
||||||
self._t2_selection.setBrush(track2_brush)
|
self._t2_selection.setBrush(track2_brush)
|
||||||
return plot_widget
|
return plot_widget
|
||||||
|
|
||||||
@ -47,8 +48,8 @@ class SizeClassifier(QWidget):
|
|||||||
if bodyaxis is None:
|
if bodyaxis is None:
|
||||||
bodyaxis = [0, 1, 2, 5]
|
bodyaxis = [0, 1, 2, 5]
|
||||||
bodycoords = coords[:, bodyaxis, :]
|
bodycoords = coords[:, bodyaxis, :]
|
||||||
dists = np.sum(np.sqrt(np.sum(np.diff(bodycoords, axis=1)**2, axis=2)), axis=1)
|
lengths = np.sum(np.sqrt(np.sum(np.diff(bodycoords, axis=1)**2, axis=2)), axis=1)
|
||||||
return dists
|
return lengths
|
||||||
|
|
||||||
def estimate_histogram(self, dists, min_threshold=1., max_threshold=99.):
|
def estimate_histogram(self, dists, min_threshold=1., max_threshold=99.):
|
||||||
min_length = np.percentile(dists, min_threshold)
|
min_length = np.percentile(dists, min_threshold)
|
||||||
@ -83,6 +84,112 @@ class SizeClassifier(QWidget):
|
|||||||
tracks[(self._sizes >= t2lower) & (self._sizes < t2upper)] = 2
|
tracks[(self._sizes >= t2lower) & (self._sizes < t2upper)] = 2
|
||||||
return tracks
|
return tracks
|
||||||
|
|
||||||
|
class NeighborhoodValidator(QWidget):
|
||||||
|
apply = Signal()
|
||||||
|
name = "Neighborhood Validator"
|
||||||
|
|
||||||
|
def __init__(self, parent = None):
|
||||||
|
super().__init__(parent)
|
||||||
|
self._threshold = None
|
||||||
|
self._positions = None
|
||||||
|
self._distances = None
|
||||||
|
self._tracks = None
|
||||||
|
self._plot = None
|
||||||
|
|
||||||
|
self._plot_widget = self.setupGraph()
|
||||||
|
self._apply_btn = QPushButton("apply")
|
||||||
|
self._apply_btn.clicked.connect(lambda: self.apply.emit())
|
||||||
|
|
||||||
|
layout = QVBoxLayout()
|
||||||
|
print(isinstance(self._plot_widget, QGraphicsView))
|
||||||
|
layout.addWidget(self._plot_widget)
|
||||||
|
layout.addWidget(self._apply_btn)
|
||||||
|
self.setLayout(layout)
|
||||||
|
|
||||||
|
def setupGraph(self):
|
||||||
|
pg.setConfigOptions(antialias=True)
|
||||||
|
plot_widget = pg.GraphicsLayoutWidget(show=False)
|
||||||
|
self._threshold = pg.LineSegmentROI([[10, 64], [120,64]], pen='r')
|
||||||
|
self._threshold.setZValue(-10) # what is that?
|
||||||
|
return plot_widget
|
||||||
|
|
||||||
|
def estimate_histogram(self, dists, min_threshold=1., max_threshold=99., bin_count=100, log=False):
|
||||||
|
min_dist = np.percentile(dists, min_threshold)
|
||||||
|
max_dist = np.percentile(dists, max_threshold)
|
||||||
|
if log:
|
||||||
|
bins = np.logspace(min_dist, max_dist, bin_count, base=10)
|
||||||
|
bins = np.linspace(min_dist, max_dist, bin_count)
|
||||||
|
hist, edges = np.histogram(dists, bins=bins, density=True)
|
||||||
|
return hist, edges
|
||||||
|
|
||||||
|
def neighborDistances(self, x, frames, n=5, symmetric=True):
|
||||||
|
logging.debug("classifier:NeighborhoodValidator neighborDistance")
|
||||||
|
pad_shape = list(x.shape)
|
||||||
|
pad_shape[0] = n
|
||||||
|
pad = np.atleast_2d(np.zeros(pad_shape))
|
||||||
|
if symmetric:
|
||||||
|
padded_x = np.vstack((pad, x, pad))
|
||||||
|
dists = np.zeros((x.shape[0]-1, 2*n))
|
||||||
|
else:
|
||||||
|
padded_x = np.vstack((pad, x))
|
||||||
|
dists = np.zeros((x.shape[0]-1, n))
|
||||||
|
|
||||||
|
count = 0
|
||||||
|
r = range(-n, n+1) if symmetric else range(-n, 0)
|
||||||
|
for i in r:
|
||||||
|
if i == 0:
|
||||||
|
continue
|
||||||
|
shifted_x = np.roll(padded_x, i, axis=0)
|
||||||
|
dist = np.sqrt(np.sum((padded_x - shifted_x)**2, axis=1))
|
||||||
|
dists[:, count] = dist[n+1:]/np.diff(frames)
|
||||||
|
count += 1
|
||||||
|
return dists
|
||||||
|
|
||||||
|
def setData(self, positions, tracks, frames):
|
||||||
|
"""Set the data, the classifier/should be working on.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
positions : np.ndarray
|
||||||
|
The position estimates, e.g. the center of gravity for each detection
|
||||||
|
tracks : np.ndarray
|
||||||
|
The current track assignment.
|
||||||
|
frames : np.ndarray
|
||||||
|
respective frame.
|
||||||
|
"""
|
||||||
|
def mouseClicked(self, event):
|
||||||
|
print("mouse clicked at", event.pos())
|
||||||
|
|
||||||
|
track2_brush = QBrush(QColor.fromString("green"))
|
||||||
|
track1_brush = QBrush(QColor.fromString("orange"))
|
||||||
|
self._positions = positions
|
||||||
|
self._tracks = tracks
|
||||||
|
self._frames = frames
|
||||||
|
t1_positions = self._positions[self._tracks == 1]
|
||||||
|
t1_frames = self._frames[self._tracks == 1]
|
||||||
|
t1_distances = self.neighborDistances(t1_positions, t1_frames, 1, False)
|
||||||
|
t2_positions = self._positions[self._tracks == 2]
|
||||||
|
t2_frames = self._frames[self._tracks == 2]
|
||||||
|
t2_distances = self.neighborDistances(t2_positions, t2_frames, 1, False)
|
||||||
|
|
||||||
|
n, e = self.estimate_histogram(t1_distances[1:], bin_count=100, log=False)
|
||||||
|
self._plot = self._plot_widget.addPlot()
|
||||||
|
bgi1 = pg.BarGraphItem(x0=e[:-1], x1=e[1:], height=n, pen='w', brush=track1_brush)
|
||||||
|
self._plot.addItem(bgi1)
|
||||||
|
n, e = self.estimate_histogram(t2_distances[1:], bin_count=100, log=False)
|
||||||
|
bgi2 = pg.BarGraphItem(x0=e[:-1], x1=e[1:], height=n, pen='w', brush=track2_brush)
|
||||||
|
self._plot.addItem(bgi2)
|
||||||
|
self._plot.scene().sigMouseClicked.connect(mouseClicked)
|
||||||
|
self._plot.setLogMode(x=False, y=True)
|
||||||
|
# plot.setXRange(np.min(t1_distances), np.max(t1_distances))
|
||||||
|
self._plot.setLabel('left', "prob. density")
|
||||||
|
self._plot.setLabel('bottom', "distance", units="px/frame")
|
||||||
|
# plot.addItem(self._threshold)
|
||||||
|
vLine = pg.InfiniteLine(pos=10, angle=90, movable=False)
|
||||||
|
self._plot.addItem(vLine, ignoreBounds=True)
|
||||||
|
vb = self._plot.vb
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ClassifierWidget(QTabWidget):
|
class ClassifierWidget(QTabWidget):
|
||||||
apply_sizeclassifier = Signal(np.ndarray)
|
apply_sizeclassifier = Signal(np.ndarray)
|
||||||
@ -90,7 +197,9 @@ class ClassifierWidget(QTabWidget):
|
|||||||
def __init__(self, parent=None):
|
def __init__(self, parent=None):
|
||||||
super().__init__(parent)
|
super().__init__(parent)
|
||||||
self._size_classifier = SizeClassifier()
|
self._size_classifier = SizeClassifier()
|
||||||
self.addTab(self._size_classifier, "Size classifier")
|
self._neigborhood_validator = NeighborhoodValidator()
|
||||||
|
self.addTab(self._size_classifier, SizeClassifier.name)
|
||||||
|
self.addTab(self._neigborhood_validator, NeighborhoodValidator.name)
|
||||||
self._size_classifier.apply.connect(self._on_applySizeClassifier)
|
self._size_classifier.apply.connect(self._on_applySizeClassifier)
|
||||||
|
|
||||||
def _on_applySizeClassifier(self):
|
def _on_applySizeClassifier(self):
|
||||||
@ -101,48 +210,51 @@ class ClassifierWidget(QTabWidget):
|
|||||||
def size_classifier(self):
|
def size_classifier(self):
|
||||||
return self._size_classifier
|
return self._size_classifier
|
||||||
|
|
||||||
|
@property
|
||||||
|
def neighborhood_validator(self):
|
||||||
|
return self._neigborhood_validator
|
||||||
|
|
||||||
|
|
||||||
def test_sizeClassifier(coords):
|
def as_dict(df):
|
||||||
app = QApplication([])
|
d = {c: df[c].values for c in df.columns}
|
||||||
window = QWidget()
|
d["index"] = df.index.values
|
||||||
window.setMinimumSize(200, 200)
|
return d
|
||||||
layout = QVBoxLayout()
|
|
||||||
win = SizeClassifier()
|
|
||||||
win.setCoordinates(coords)
|
|
||||||
|
|
||||||
layout.addWidget(win)
|
|
||||||
window.setLayout(layout)
|
|
||||||
window.show()
|
|
||||||
app.exec()
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
test_size = False
|
||||||
|
import pickle
|
||||||
|
from fixtracks.info import PACKAGE_ROOT
|
||||||
|
|
||||||
|
datafile = PACKAGE_ROOT / "data/merged_small.pkl"
|
||||||
|
|
||||||
|
with open(datafile, "rb") as f:
|
||||||
|
df = pickle.load(f)
|
||||||
|
data = TrackingData()
|
||||||
|
data.setData(as_dict(df))
|
||||||
|
|
||||||
|
positions = data.centerOfGravity()
|
||||||
|
tracks = data["track"]
|
||||||
|
frames = data["frame"]
|
||||||
|
coords = data.coordinates()
|
||||||
|
|
||||||
def test_neighborhoodClassifier(coords):
|
|
||||||
app = QApplication([])
|
app = QApplication([])
|
||||||
window = QWidget()
|
window = QWidget()
|
||||||
window.setMinimumSize(200, 200)
|
window.setMinimumSize(200, 200)
|
||||||
|
# if test_size:
|
||||||
|
# win = SizeClassifier()
|
||||||
|
# win.setCoordinates(coords)
|
||||||
|
# else:
|
||||||
|
w = ClassifierWidget()
|
||||||
|
w.neighborhood_validator.setData(positions, tracks, frames)
|
||||||
|
|
||||||
layout = QVBoxLayout()
|
layout = QVBoxLayout()
|
||||||
win = SizeClassifier()
|
layout.addWidget(w)
|
||||||
win.setCoordinates(coords)
|
|
||||||
layout.addWidget(win)
|
|
||||||
window.setLayout(layout)
|
window.setLayout(layout)
|
||||||
window.show()
|
window.show()
|
||||||
app.exec()
|
app.exec()
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
import pickle
|
|
||||||
from fixtracks.info import PACKAGE_ROOT
|
|
||||||
from PySide6.QtWidgets import QApplication
|
|
||||||
datafile = PACKAGE_ROOT / "data/merged_small.pkl"
|
|
||||||
print(datafile)
|
|
||||||
with open(datafile, "rb") as f:
|
|
||||||
df = pickle.load(f)
|
|
||||||
|
|
||||||
coords = np.stack(df.keypoints.values,).astype(np.float32)
|
|
||||||
frames = df.frame.values
|
|
||||||
test_sizeClassifier(coords)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
from PySide6.QtWidgets import QApplication
|
||||||
main()
|
main()
|
||||||
|
@ -378,7 +378,11 @@ class FixTracks(QWidget):
|
|||||||
rel_width = self._windowspinner.value() / maxframes
|
rel_width = self._windowspinner.value() / maxframes
|
||||||
self._timeline.setWindowWidth(rel_width)
|
self._timeline.setWindowWidth(rel_width)
|
||||||
coordinates = self._data.coordinates()
|
coordinates = self._data.coordinates()
|
||||||
|
positions = self._data.centerOfGravity()
|
||||||
|
tracks = self._data["track"]
|
||||||
|
frames = self._data["frame"]
|
||||||
self._classifier.size_classifier.setCoordinates(coordinates)
|
self._classifier.size_classifier.setCoordinates(coordinates)
|
||||||
|
self._classifier.neighborhood_validator.setData(positions, tracks, frames)
|
||||||
self.update()
|
self.update()
|
||||||
self._saveBtn.setEnabled(True)
|
self._saveBtn.setEnabled(True)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user