[classifier] auto distance classifier

This commit is contained in:
Jan Grewe 2025-02-17 18:20:25 +01:00
parent f62c7c43e0
commit a8fd5375f2
3 changed files with 209 additions and 46 deletions

View File

@ -20,9 +20,6 @@ class ImageReader(QRunnable):
@Slot()
def run(self):
'''
Your code goes in this function
'''
logging.debug("ImageReader: trying to open file %s", self._filename)
cap = cv.VideoCapture(self._filename)
framecount = int(cap.get(cv.CAP_PROP_FRAME_COUNT))

View File

@ -1,13 +1,101 @@
import logging
import numpy as np
from PySide6.QtWidgets import QWidget, QVBoxLayout, QTabWidget,QPushButton, QGraphicsView
from PySide6.QtCore import Signal
from PySide6.QtWidgets import QWidget, QVBoxLayout, QTabWidget, QPushButton, QGraphicsView, QSpinBox, QProgressBar, QGridLayout, QLabel
from PySide6.QtCore import Signal, Slot, QRunnable, QObject, QThreadPool
from PySide6.QtGui import QBrush, QColor
import pyqtgraph as pg # needs to be imported after pyside to not import pyqt
from fixtracks.utils.trackingdata import TrackingData
from IPython import embed
class WorkerSignals(QObject):
error = Signal(str)
running = Signal(bool)
progress = Signal(int, int, int)
finished = Signal(bool)
class ConsistencyWorker(QRunnable):
signals = WorkerSignals()
def __init__(self, positions, orientations, lengths, bendedness, frames, tracks, startframe=0) -> None:
super().__init__()
self.positions = positions
self.orientations = orientations
self.lengths = lengths
self._bendedness = bendedness
self.frames = frames
self.tracks = tracks
self._startframe = startframe
self._stoprequest = False
@Slot()
def cancel(self):
self._stoprequest = True
@Slot()
def run(self):
last_pos = [self.positions[self.tracks == 1][0], self.positions[self.tracks == 2][0]]
last_frame = [self.frames[self.tracks == 1][0], self.frames[self.tracks == 2][0]]
last_angle = [self.orientations[self.tracks == 1][0], self.orientations[self.tracks == 2][0]]
errors = 0
processed = 0
self._stoprequest = False
maxframes = np.max(self.frames)
steps = int((maxframes - self._startframe) // 100)
progress = 0
assignment_error = False
for f in range(self._startframe, np.max(self.frames), 1):
processed += 1
if self._stoprequest:
break
indices = np.where(self.frames == f)[0]
pp = self.positions[indices]
originaltracks = self.tracks[indices]
assignments = np.zeros_like(originaltracks)
distances = np.zeros((len(originaltracks), 2))
for i, (idx, p) in enumerate(zip(indices, pp)):
if f < last_frame[0]:
self.tracks[idx] = 2
last_frame[1] = f
last_pos[1] = p
last_angle[1] = self.orientations[idx]
continue
if f < last_frame[1]:
last_frame[0] = f
last_pos[0] = p
last_angle[0] = self.orientations[idx]
self.tracks[idx] = 1
continue
# else, we have already seen track one and track two entries
distance_to_trackone = np.linalg.norm(p - last_pos[0])/(f - last_frame[0])
distance_to_tracktwo = np.linalg.norm(p - last_pos[1])/(f - last_frame[1])
most_likely_track = np.argmin([distance_to_trackone, distance_to_tracktwo]) + 1
distances[i, 0] = distance_to_trackone
distances[i, 1] = distance_to_tracktwo
assignments[i] = most_likely_track
# check (re) assignment update and proceed
if len(assignments) > 1 and (np.all(assignments == 1) or np.all(assignments == 2)):
logging.warning("frame %i: Issues assigning based on distances %s", f, str(distances))
assignment_error = True
errors += 1
else:
processed += 1
for i, idx in enumerate(indices):
if assignment_error:
self.tracks[idx] = -1
else:
self.tracks[idx] = assignments[i]
last_pos[assignments[i]-1] = pp[i]
last_frame[assignments[i]-1] = f
assignment_error = False
if f % steps == 0:
progress += 1
self.signals.progress.emit(progress, processed, errors)
self.signals.finished.emit(True)
class SizeClassifier(QWidget):
apply = Signal()
@ -202,63 +290,144 @@ class ConsistencyClassifier(QWidget):
def __init__(self, parent=None):
super().__init__(parent)
self._data = None
self._all_cogs = None
self._all_orientations = None
self._all_lengths = None
self._all_bendedness = None
self._all_scores = None
self._frames = None
self._tracks = None
self._worker = None
self._errorlabel = QLabel()
self._errorlabel.setStyleSheet("QLabel { color : red; }")
self._assignedlabel = QLabel()
self._startframe_spinner = QSpinBox()
self._startbtn = QPushButton("run")
self._startbtn.clicked.connect(self.run)
self._startbtn.setEnabled(False)
def setData(self, keypoints, tracks, frames):
self._cancelbtn = QPushButton("cancel")
self._cancelbtn.clicked.connect(self.cancel)
self._cancelbtn.setEnabled(False)
self._apply_btn = QPushButton("apply")
self._progressbar = QProgressBar()
self._progressbar.setMinimum(0)
self._progressbar.setMaximum(100)
self._apply_btn.clicked.connect(lambda: self.apply.emit())
self._apply_btn.setEnabled(False)
self.threadpool = QThreadPool()
lyt = QGridLayout()
lyt.addWidget(QLabel("Start frame:"), 0, 0 )
lyt.addWidget(self._startframe_spinner, 0, 1 )
lyt.addWidget(QLabel("assigned"), 1, 0)
lyt.addWidget(self._assignedlabel, 1, 1)
lyt.addWidget(QLabel("errors/issues"), 2, 0)
lyt.addWidget(self._errorlabel, 2, 1)
lyt.addWidget(self._startbtn, 3, 0)
lyt.addWidget(self._cancelbtn, 3, 1)
lyt.addWidget(self._progressbar, 4, 0, 1, 2)
lyt.addWidget(self._apply_btn, 5, 0, 1, 2)
self.setLayout(lyt)
def setData(self, data:TrackingData):
"""Set the data, the classifier/should be working on.
Parameters
----------
positions : np.ndarray
The position estimates, e.g. the center of gravity for each detection
tracks : np.ndarray
The current track assignment.
frames : np.ndarray
respective frame.
data : Trackingdata
The tracking data.
"""
def mouseClicked(event):
pos = event.pos()
if self._plot.sceneBoundingRect().contains(pos):
mousePoint = vb.mapSceneToView(pos)
print("mouse clicked at", mousePoint)
vLine.setPos(mousePoint.x())
track2_brush = QBrush(QColor.fromString("green"))
track1_brush = QBrush(QColor.fromString("orange"))
self._positions = positions
self._tracks = tracks
self._frames = frames
t1_positions = self._positions[self._tracks == 1]
t1_frames = self._frames[self._tracks == 1]
t1_distances = self.neighborDistances(t1_positions, t1_frames, 1, False)
t2_positions = self._positions[self._tracks == 2]
t2_frames = self._frames[self._tracks == 2]
t2_distances = self.neighborDistances(t2_positions, t2_frames, 1, False)
self._all_cogs = data.centerOfGravity()
self._all_orientations = data.orientation()
self._all_lengths = data.animalLength()
self._all_bendedness = data.bendedness()
self._all_scores = data["confidence"] # ignore for now, let's see how far this carries.
self._frames = data["frame"]
self._tracks = data["track"]
min_frame = max([self._frames[self._tracks == 1][0], self._frames[self._tracks == 2][0]]) + 1
self._startframe_spinner.setMinimum(min_frame)
self._startframe_spinner.setMaximum(self._frames[-1])
self._startframe_spinner.setValue(self._frames[0] + 1)
self._startbtn.setEnabled(True)
self._worker = None
@Slot(float)
def on_progress(self, value):
if self._progressbar is not None:
self._progressDialog.setValue(int(value * 100))
def cancel(self):
if self._worker is not None:
self._worker.cancel()
self._startbtn.setEnabled(True)
self._cancelbtn.setEnabled(False)
def run(self):
self._startbtn.setEnabled(False)
self._cancelbtn.setEnabled(True)
self._worker = ConsistencyWorker(self._all_cogs, self._all_orientations, self._all_lengths,
self._all_bendedness, self._frames, self._tracks, self._startframe_spinner.value())
self._worker.signals.finished.connect(self.worker_done)
self._worker.signals.progress.connect(self.worker_progress)
self.threadpool.start(self._worker)
def worker_progress(self, progress, processed, errors):
self._progressbar.setValue(progress)
self._errorlabel.setText(str(errors))
self._assignedlabel.setText(str(processed))
def worker_done(self):
self._apply_btn.setEnabled(True)
self._startbtn.setEnabled(True)
self._cancelbtn.setEnabled(False)
def assignedTracks(self):
return self._tracks
class ClassifierWidget(QTabWidget):
apply_sizeclassifier = Signal(np.ndarray)
apply_classifier = Signal(np.ndarray)
def __init__(self, parent=None):
super().__init__(parent)
self._data = None
self._size_classifier = SizeClassifier()
self._neigborhood_validator = NeighborhoodValidator()
# self._neigborhood_validator = NeighborhoodValidator()
self._consistency_tracker = ConsistencyClassifier()
self.addTab(self._size_classifier, SizeClassifier.name)
self.addTab(self._neigborhood_validator, NeighborhoodValidator.name)
self.addTab(self._consistency_tracker, ConsistencyClassifier.name)
self.tabBarClicked.connect(self.update)
self._size_classifier.apply.connect(self._on_applySizeClassifier)
self._consistency_tracker.apply.connect(self._on_applyConsistencyTracker)
def _on_applySizeClassifier(self):
tracks = self.size_classifier.assignedTracks()
self.apply_sizeclassifier.emit(tracks)
self.apply_classifier.emit(tracks)
def _on_applyConsistencyTracker(self):
tracks = self._consistency_tracker.assignedTracks()
self.apply_classifier.emit(tracks)
@property
def size_classifier(self):
return self._size_classifier
@property
def neighborhood_validator(self):
return self._neigborhood_validator
def consistency_tracker(self):
return self._consistency_tracker
def update(self):
self.consistency_tracker.setData(self._data)
def setData(self, data:TrackingData):
self._data = data
def as_dict(df):
d = {c: df[c].values for c in df.columns}
@ -269,8 +438,9 @@ def as_dict(df):
def main():
test_size = False
import pickle
from IPython import embed
from fixtracks.info import PACKAGE_ROOT
datafile = PACKAGE_ROOT / "data/merged_small_tracked.pkl"
with open(datafile, "rb") as f:
@ -278,11 +448,6 @@ def main():
data = TrackingData()
data.setData(as_dict(df))
positions = data.centerOfGravity()
tracks = data["track"]
frames = data["frame"]
coords = data.coordinates()
app = QApplication([])
window = QWidget()
window.setMinimumSize(200, 200)
@ -291,7 +456,7 @@ def main():
# win.setCoordinates(coords)
# else:
w = ClassifierWidget()
w.neighborhood_validator.setData(positions, tracks, frames)
w.setData(data)
layout = QVBoxLayout()
layout.addWidget(w)

View File

@ -254,7 +254,7 @@ class FixTracks(QWidget):
btnBox.addWidget(self._saveBtn)
self._classifier = ClassifierWidget()
self._classifier.apply_sizeclassifier.connect(self.on_classifyBySize)
self._classifier.apply_classifier.connect(self.on_autoClassify)
self._classifier.setMaximumWidth(500)
cntrlBox = QHBoxLayout()
cntrlBox.addWidget(self._classifier)
@ -278,7 +278,7 @@ class FixTracks(QWidget):
layout.addWidget(splitter)
self.setLayout(layout)
def on_classifyBySize(self, tracks):
def on_autoClassify(self, tracks):
self._data.setSelectionRange("index", 0, self._data.numDetections)
self._data.assignTracks(tracks)
self._timeline.setDetectionData(self._data.data)
@ -333,6 +333,7 @@ class FixTracks(QWidget):
update_detectionView(unassigned, "unassigned")
update_detectionView(assigned_left, "assigned_left")
update_detectionView(assigned_right, "assigned_right")
self._classifier.setData(self._data)
@property
def fileList(self):
@ -369,6 +370,7 @@ class FixTracks(QWidget):
self._progress_bar.setValue(0)
if state and self._reader is not None:
self._data.setData(self._reader.asdict)
self._saveBtn.setEnabled(True)
self._currentWindowPos = 0
self._currentWindowWidth = self._windowspinner.value()
self._maxframes = self._data.max("frame")
@ -381,9 +383,8 @@ class FixTracks(QWidget):
tracks = self._data["track"]
frames = self._data["frame"]
self._classifier.size_classifier.setCoordinates(coordinates)
self._classifier.neighborhood_validator.setData(positions, tracks, frames)
self._classifier.consistency_tracker.setData(self._data)
self.update()
self._saveBtn.setEnabled(True)
logging.info("Finished loading data: %i frames, %i detections", self._maxframes, len(positions))
def on_keypointSelected(self):