fixtracks/fixtracks/widgets/classifier.py

589 lines
22 KiB
Python

import logging
import numpy as np
from PySide6.QtWidgets import QWidget, QVBoxLayout, QTabWidget, QPushButton, QGraphicsView
from PySide6.QtWidgets import QSpinBox, QProgressBar, QGridLayout, QLabel, QCheckBox, QProgressDialog
from PySide6.QtCore import Qt, Signal, Slot, QRunnable, QObject, QThreadPool
from PySide6.QtGui import QBrush, QColor
import pyqtgraph as pg # needs to be imported after pyside to not import pyqt
from fixtracks.utils.trackingdata import TrackingData
from IPython import embed
class WorkerSignals(QObject):
error = Signal(str)
running = Signal(bool)
progress = Signal(int, int, int)
stopped = Signal(int)
class ConsitencyDataLoader(QRunnable):
def __init__(self, data):
super().__init__()
self.signals = WorkerSignals()
self.data = data
self.bendedness = self.positions = None
self.lengths = None
self.orientations = None
self.userlabeled = None
self.scores = None
self.frames = None
self.tracks = None
@Slot()
def run(self):
self.positions = self.data.centerOfGravity()
self.orientations = self.data.orientation()
self.lengths = self.data.animalLength()
self.bendedness = self.data.bendedness()
self.userlabeled = self.data["userlabeled"]
self.scores = self.data["confidence"] # ignore for now, let's see how far this carries.
self.frames = self.data["frame"]
self.tracks = self.data["track"]
self.signals.stopped.emit(0)
class ConsistencyWorker(QRunnable):
def __init__(self, positions, orientations, lengths, bendedness, frames, tracks,
userlabeled, startframe=0, stoponerror=False) -> None:
super().__init__()
self.signals = WorkerSignals()
self.positions = positions
self.orientations = orientations
self.lengths = lengths
self.bendedness = bendedness
self.userlabeled = userlabeled
self.frames = frames
self.tracks = tracks
self._startframe = startframe
self._stoprequest = False
self._stoponerror = stoponerror
@Slot()
def stop(self):
self._stoprequest = True
@Slot()
def run(self):
def needs_checking(original, new):
res = False
for n, o in zip(new, original):
res = (o == 1 or o == 2) and n != o
if not res:
res = len(new) > 1 and (np.all(new == 1) or np.all(new == 2))
return res
def assign_by_distance(f, p):
t1_step = f - last_frame[0]
t2_step = f - last_frame[1]
if t1_step == 0 or t2_step == 0:
print(f"framecount is zero! current frame {f}, last frame {last_frame[0]} and {last_frame[1]}")
distance_to_trackone = np.linalg.norm(p - last_pos[0])/t1_step
distance_to_tracktwo = np.linalg.norm(p - last_pos[1])/t2_step
most_likely_track = np.argmin([distance_to_trackone, distance_to_tracktwo]) + 1
distances = np.zeros(2)
distances[0] = distance_to_trackone
distances[1] = distance_to_tracktwo
return most_likely_track, distances
def assign_by_orientation(f, o):
t1_step = f - last_frame[0]
t2_step = f - last_frame[1]
orientationchange = np.unwrap((last_angle - o)/np.array([t1_step, t2_step]))
most_likely_track = np.argmin(orientationchange) + 1
return most_likely_track, orientationchange
last_pos = [self.positions[(self.tracks == 1) & (self.frames <= self._startframe)][-1],
self.positions[(self.tracks == 2) & (self.frames <= self._startframe)][-1]]
last_frame = [self.frames[(self.tracks == 1) & (self.frames <= self._startframe)][-1],
self.frames[(self.tracks == 2) & (self.frames <= self._startframe)][-1]]
last_angle = [self.orientations[(self.tracks == 1) & (self.frames <= self._startframe)][-1],
self.orientations[(self.tracks == 2) & (self.frames <= self._startframe)][-1]]
errors = 0
processed = 1
progress = 0
assignment_error = False
self._stoprequest = False
maxframes = np.max(self.frames)
startframe = np.max(last_frame)
steps = int((maxframes - startframe) // 200)
for f in np.unique(self.frames[self.frames > startframe]):
if self._stoprequest:
break
indices = np.where(self.frames == f)[0]
pp = self.positions[indices]
originaltracks = self.tracks[indices]
dist_assignments = np.zeros_like(originaltracks)
angle_assignments = np.zeros_like(originaltracks)
# userlabeld = np.zeros_like(originaltracks)
distances = np.zeros((len(originaltracks), 2))
orientations = np.zeros((len(originaltracks), 2))
for i, (idx, p) in enumerate(zip(indices, pp)):
if self.userlabeled[idx]:
print("user")
processed += 1
last_pos[originaltracks[i]-1] = pp[i]
last_frame[originaltracks[i]-1] = f
last_angle[originaltracks[i]-1] = self.orientations[idx]
continue
dist_assignments[i], distances[i, :] = assign_by_distance(f, p)
angle_assignments[i], orientations[i,:] = assign_by_orientation(f, self.orientations[idx])
# check (re) assignment update and proceed
print("dist", distances)
print("angle", orientations)
if needs_checking(originaltracks, dist_assignments):
logging.info("frame %i: Issues assigning based on distances %s", f, str(distances))
assignment_error = True
errors += 1
if self._stoponerror:
embed()
break
else:
processed += 1
for i, idx in enumerate(indices):
if assignment_error:
self.tracks[idx] = -1
else:
self.tracks[idx] = dist_assignments[i]
last_pos[dist_assignments[i]-1] = pp[i]
last_frame[dist_assignments[i]-1] = f
last_angle[dist_assignments[i]-1] = self.orientations[idx]
assignment_error = False
if steps > 0 and f % steps == 0:
progress += 1
self.signals.progress.emit(progress, processed, errors)
self.signals.stopped.emit(f)
class SizeClassifier(QWidget):
apply = Signal()
name = "Size classifier"
def __init__(self, parent=None):
super().__init__(parent)
self._t1_selection = None
self._t2_selection = None
self._coordinates = None
self._sizes = None
self._plot_widget = self.setupGraph()
self._apply_btn = QPushButton("apply")
self._apply_btn.clicked.connect(lambda: self.apply.emit())
layout = QVBoxLayout()
layout.addWidget(self._plot_widget)
layout.addWidget(self._apply_btn)
self.setLayout(layout)
def setupGraph(self):
track1_brush = QBrush(QColor.fromString("orange"))
track2_brush = QBrush(QColor.fromString("green"))
pg.setConfigOptions(antialias=True)
plot_widget = pg.GraphicsLayoutWidget(show=False)
self._t1_selection = pg.LinearRegionItem([100, 200])
self._t1_selection.setZValue(-10)
self._t1_selection.setBrush("orange")
self._t2_selection = pg.LinearRegionItem([300,400])
self._t2_selection.setZValue(-10)
self._t2_selection.setBrush("green")
return plot_widget
def estimate_length(self, coords, bodyaxis =None):
if bodyaxis is None:
bodyaxis = [0, 1, 2, 5]
bodycoords = coords[:, bodyaxis, :]
lengths = np.sum(np.sqrt(np.sum(np.diff(bodycoords, axis=1)**2, axis=2)), axis=1)
return lengths
def estimate_histogram(self, dists, min_threshold=1., max_threshold=99.):
min_length = np.percentile(dists, min_threshold)
max_length = np.percentile(dists, max_threshold)
bins = np.linspace(0.5 * min_length, 1.5 * max_length, 100)
hist, edges = np.histogram(dists, bins=bins, density=True)
return hist, edges
def setCoordinates(self, coordinates):
self._coordinates = coordinates
self._sizes = self.estimate_length(coordinates)
n, e = self.estimate_histogram(self._sizes)
plot = self._plot_widget.addPlot()
bgi = pg.BarGraphItem(x0=e[:-1], x1=e[1:], height=n, pen='w', brush=(0,0,255,150))
plot.addItem(bgi)
plot.setLabel('left', "prob. density")
plot.setLabel('bottom', "bodylength", units="px")
plot.addItem(self._t1_selection)
plot.addItem(self._t2_selection)
def selections(self, track1=True):
if track1:
return self._t1_selection.getRegion()
else:
return self._t2_selection.getRegion()
def assignedTracks(self):
tracks = np.ones_like(self._sizes, dtype=int) * -1
t1lower, t1upper = self.selections()
t2lower, t2upper = self.selections(False)
tracks[(self._sizes >= t1lower) & (self._sizes < t1upper)] = 1
tracks[(self._sizes >= t2lower) & (self._sizes < t2upper)] = 2
return tracks
class NeighborhoodValidator(QWidget):
apply = Signal()
name = "Neighborhood Validator"
def __init__(self, parent = None):
super().__init__(parent)
self._threshold = None
self._positions = None
self._distances = None
self._tracks = None
self._frames = None
self._plot = None
self._plot_widget = self.setupGraph()
self._apply_btn = QPushButton("apply")
self._apply_btn.clicked.connect(lambda: self.apply.emit())
layout = QVBoxLayout()
print(isinstance(self._plot_widget, QGraphicsView))
layout.addWidget(self._plot_widget)
layout.addWidget(self._apply_btn)
self.setLayout(layout)
def setupGraph(self):
pg.setConfigOptions(antialias=True)
plot_widget = pg.GraphicsLayoutWidget(show=False)
self._threshold = pg.LineSegmentROI([[10, 64], [120,64]], pen='r')
self._threshold.setZValue(-10) # what is that?
return plot_widget
def estimate_histogram(self, dists, min_threshold=1., max_threshold=99., bin_count=100, log=False):
min_dist = np.percentile(dists, min_threshold)
max_dist = np.percentile(dists, max_threshold)
print(min_dist, max_dist)
if log:
bins = np.logspace(min_dist, max_dist, bin_count, base=10)
else:
bins = np.linspace(min_dist, max_dist, bin_count)
hist, edges = np.histogram(dists, bins=bins, density=True)
return hist, edges
def neighborDistances(self, x, frames, n=5, symmetric=True):
logging.debug("classifier:NeighborhoodValidator neighborDistance")
pad_shape = list(x.shape)
pad_shape[0] = n
pad = np.atleast_2d(np.zeros(pad_shape))
if symmetric:
padded_x = np.vstack((pad, x, pad))
dists = np.zeros((x.shape[0]-1, 2*n))
else:
padded_x = np.vstack((pad, x))
dists = np.zeros((x.shape[0]-1, n))
count = 0
r = range(-n, n+1) if symmetric else range(-n, 0)
for i in r:
if i == 0:
continue
shifted_x = np.roll(padded_x, i, axis=0)
dist = np.sqrt(np.sum((padded_x - shifted_x)**2, axis=1))
dists[:, count] = dist[n+1:]/np.diff(frames)
count += 1
return dists
def setData(self, positions, tracks, frames):
"""Set the data, the classifier/should be working on.
Parameters
----------
positions : np.ndarray
The position estimates, e.g. the center of gravity for each detection
tracks : np.ndarray
The current track assignment.
frames : np.ndarray
respective frame.
"""
def mouseClicked(event):
pos = event.pos()
if self._plot.sceneBoundingRect().contains(pos):
mousePoint = vb.mapSceneToView(pos)
print("mouse clicked at", mousePoint)
vLine.setPos(mousePoint.x())
track2_brush = QBrush(QColor.fromString("green"))
track1_brush = QBrush(QColor.fromString("orange"))
self._positions = positions
self._tracks = tracks
self._frames = frames
t1_positions = self._positions[self._tracks == 1]
t1_frames = self._frames[self._tracks == 1]
t1_distances = self.neighborDistances(t1_positions, t1_frames, 1, False)
t2_positions = self._positions[self._tracks == 2]
t2_frames = self._frames[self._tracks == 2]
t2_distances = self.neighborDistances(t2_positions, t2_frames, 1, False)
self._plot = self._plot_widget.addPlot()
vb = self._plot.vb
n, e = self.estimate_histogram(t1_distances[1:], 1, 95, bin_count=100, log=False)
bgi1 = pg.BarGraphItem(x0=e[:-1], x1=e[1:], height=n, pen='w', brush=track1_brush)
self._plot.addItem(bgi1)
n, e = self.estimate_histogram(t2_distances[1:], 1, 95, bin_count=100, log=False)
bgi2 = pg.BarGraphItem(x0=e[:-1], x1=e[1:], height=n, pen='w', brush=track2_brush)
self._plot.addItem(bgi2)
self._plot.scene().sigMouseClicked.connect(mouseClicked)
self._plot.setLogMode(x=False, y=True)
# plot.setXRange(np.min(t1_distances), np.max(t1_distances))
self._plot.setLabel('left', "prob. density")
self._plot.setLabel('bottom', "distance", units="px/frame")
# plot.addItem(self._threshold)
vLine = pg.InfiniteLine(pos=10, angle=90, movable=False)
self._plot.addItem(vLine, ignoreBounds=True)
class ConsistencyClassifier(QWidget):
apply = Signal()
name = "Consistency tracker"
def __init__(self, parent=None):
super().__init__(parent)
self._data = None
self._all_pos = None
self._all_orientations = None
self._all_lengths = None
self._all_bendedness = None
self._all_scores = None
self._userlabeled = None
self._maxframes = 0
self._frames = None
self._tracks = None
self._worker = None
self._dataworker = None
self._processed_frames = 0
self._errorlabel = QLabel()
self._errorlabel.setStyleSheet("QLabel { color : red; }")
self._assignedlabel = QLabel()
self._maxframeslabel = QLabel()
self._startframe_spinner = QSpinBox()
self._startbtn = QPushButton("start")
self._startbtn.clicked.connect(self.start)
self._startbtn.setEnabled(False)
self._stopbtn = QPushButton("stop")
self._stopbtn.clicked.connect(self.stop)
self._stopbtn.setEnabled(False)
self._proceedbtn = QPushButton("proceed")
self._proceedbtn.clicked.connect(self.proceed)
self._proceedbtn.setEnabled(False)
self._refreshbtn = QPushButton("refresh")
self._refreshbtn.clicked.connect(self.refresh)
self._refreshbtn.setEnabled(True)
self._apply_btn = QPushButton("apply")
self._apply_btn.clicked.connect(lambda: self.apply.emit())
self._apply_btn.setEnabled(False)
self._progressbar = QProgressBar()
self._progressbar.setMinimum(0)
self._progressbar.setMaximum(100)
self._stoponerror = QCheckBox("Stop processing whenever an error is encountered")
self._stoponerror.setToolTip("Stop process upon errors")
self._stoponerror.setCheckable(True)
self._stoponerror.setChecked(True)
self.threadpool = QThreadPool()
lyt = QGridLayout()
lyt.addWidget(QLabel("Start frame:"), 0, 0 )
lyt.addWidget(self._startframe_spinner, 0, 1, 1, 2)
lyt.addWidget(QLabel("of"), 1, 1, 1, 1)
lyt.addWidget(self._maxframeslabel, 1, 2, 1, 1)
lyt.addWidget(self._stoponerror, 2, 0, 1, 3)
lyt.addWidget(QLabel("assigned"), 3, 0)
lyt.addWidget(self._assignedlabel, 3, 1)
lyt.addWidget(QLabel("errors/issues"), 4, 0)
lyt.addWidget(self._errorlabel, 4, 1)
lyt.addWidget(self._startbtn, 5, 0)
lyt.addWidget(self._stopbtn, 5, 1)
lyt.addWidget(self._proceedbtn, 5, 2)
lyt.addWidget(self._apply_btn, 6, 0, 1, 2)
lyt.addWidget(self._refreshbtn, 6, 2, 1, 1)
lyt.addWidget(self._progressbar, 7, 0, 1, 3)
self.setLayout(lyt)
def setData(self, data:TrackingData):
"""Set the data, the classifier/should be working on.
Parameters
----------
data : Trackingdata
The tracking data.
"""
self.setEnabled(False)
self._progressbar.setRange(0,0)
self._data = data
self._dataworker = ConsitencyDataLoader(self._data)
self._dataworker.signals.stopped.connect(self.data_processed)
self.threadpool.start(self._dataworker)
@Slot()
def data_processed(self):
if self._dataworker is not None:
self._progressbar.setRange(0,100)
self._progressbar.setValue(0)
self._all_pos = self._dataworker.positions
self._all_orientations = self._dataworker.orientations
self._all_lengths = self._dataworker.lengths
self._all_bendedness = self._dataworker.bendedness
self._userlabeled = self._dataworker.userlabeled
self._all_scores = self._dataworker.scores
self._frames = self._dataworker.frames
self._tracks = self._dataworker.tracks
self._maxframes = np.max(self._frames)
min_frame = max([self._frames[self._tracks == 1][0], self._frames[self._tracks == 2][0]]) + 1
self._maxframeslabel.setText(str(self._maxframes))
self._startframe_spinner.setMinimum(min_frame)
self._startframe_spinner.setMaximum(self._frames[-1])
self._startframe_spinner.setValue(self._frames[0] + 1)
self._startbtn.setEnabled(True)
self._assignedlabel.setText("0")
self._errorlabel.setText("0")
self._dataworker = None
self.setEnabled(True)
@Slot(float)
def on_progress(self, value):
if self._progressbar is not None:
self._progressDialog.setValue(int(value * 100))
def stop(self):
if self._worker is not None:
self._worker.stop()
self._startbtn.setEnabled(True)
self._proceedbtn.setEnabled(True)
self._stopbtn.setEnabled(False)
self._refreshbtn.setEnabled(True)
def start(self):
self._startbtn.setEnabled(False)
self._refreshbtn.setEnabled(False)
self._stopbtn.setEnabled(True)
self._worker = ConsistencyWorker(self._all_pos, self._all_orientations, self._all_lengths,
self._all_bendedness, self._frames, self._tracks, self._userlabeled,
self._startframe_spinner.value(), self._stoponerror.isChecked())
self._worker.signals.stopped.connect(self.worker_stopped)
self._worker.signals.progress.connect(self.worker_progress)
self.threadpool.start(self._worker)
def proceed(self):
self.start()
def refresh(self):
self.setData(self._data)
def worker_progress(self, progress, processed, errors):
self._progressbar.setValue(progress)
self._errorlabel.setText(str(errors))
self._assignedlabel.setText(str(processed))
def worker_stopped(self, frame):
self._apply_btn.setEnabled(True)
self._startbtn.setEnabled(True)
self._stopbtn.setEnabled(False)
self._startframe_spinner.setValue(frame-1)
self._proceedbtn.setEnabled(bool(frame < self._maxframes-1))
self._refreshbtn.setEnabled(True)
self._processed_frames = frame
def assignedTracks(self):
return self._tracks
class ClassifierWidget(QTabWidget):
apply_classifier = Signal(np.ndarray)
def __init__(self, parent=None):
super().__init__(parent)
self._data = None
self._size_classifier = SizeClassifier()
# self._neigborhood_validator = NeighborhoodValidator()
self._consistency_tracker = ConsistencyClassifier()
self.addTab(self._size_classifier, SizeClassifier.name)
self.addTab(self._consistency_tracker, ConsistencyClassifier.name)
self.tabBarClicked.connect(self.update)
self._size_classifier.apply.connect(self._on_applySizeClassifier)
self._consistency_tracker.apply.connect(self._on_applyConsistencyTracker)
def _on_applySizeClassifier(self):
tracks = self.size_classifier.assignedTracks()
self.apply_classifier.emit(tracks)
def _on_applyConsistencyTracker(self):
tracks = self._consistency_tracker.assignedTracks()
self.apply_classifier.emit(tracks)
@property
def size_classifier(self):
return self._size_classifier
@property
def consistency_tracker(self):
return self._consistency_tracker
@Slot()
def update(self):
self.consistency_tracker.setData(self._data)
def setData(self, data:TrackingData):
self._data = data
def as_dict(df):
d = {c: df[c].values for c in df.columns}
d["index"] = df.index.values
return d
def main():
test_size = False
import pickle
from fixtracks.info import PACKAGE_ROOT
datafile = PACKAGE_ROOT / "data/merged.pkl"
with open(datafile, "rb") as f:
df = pickle.load(f)
data = TrackingData()
data.setData(as_dict(df))
app = QApplication([])
window = QWidget()
window.setMinimumSize(200, 200)
# if test_size:
# win = SizeClassifier()
# win.setCoordinates(coords)
# else:
w = ClassifierWidget()
w.setData(data)
layout = QVBoxLayout()
layout.addWidget(w)
window.setLayout(layout)
window.show()
app.exec()
if __name__ == "__main__":
from PySide6.QtWidgets import QApplication
main()