Compare commits

...

4 Commits

2 changed files with 128 additions and 55 deletions

View File

@ -19,6 +19,7 @@ class TrackingData(QObject):
def setData(self, datadict): def setData(self, datadict):
assert isinstance(datadict, dict) assert isinstance(datadict, dict)
self._data = datadict self._data = datadict
self._data["userlabeled"] = np.zeros_like(self["frame"], dtype=bool)
self._columns = [k for k in self._data.keys()] self._columns = [k for k in self._data.keys()]
@property @property
@ -81,6 +82,7 @@ class TrackingData(QObject):
The new track id for the user-selected detections The new track id for the user-selected detections
""" """
self._data["track"][self._user_selections] = track_id self._data["track"][self._user_selections] = track_id
self._data["userlabeled"][self._user_selections] = True
def assignTracks(self, tracks): def assignTracks(self, tracks):
"""assignTracks _summary_ """assignTracks _summary_

View File

@ -2,30 +2,58 @@ import logging
import numpy as np import numpy as np
from PySide6.QtWidgets import QWidget, QVBoxLayout, QTabWidget, QPushButton, QGraphicsView from PySide6.QtWidgets import QWidget, QVBoxLayout, QTabWidget, QPushButton, QGraphicsView
from PySide6.QtWidgets import QSpinBox, QProgressBar, QGridLayout, QLabel, QCheckBox from PySide6.QtWidgets import QSpinBox, QProgressBar, QGridLayout, QLabel, QCheckBox, QProgressDialog
from PySide6.QtCore import Signal, Slot, QRunnable, QObject, QThreadPool from PySide6.QtCore import Qt, Signal, Slot, QRunnable, QObject, QThreadPool
from PySide6.QtGui import QBrush, QColor from PySide6.QtGui import QBrush, QColor
import pyqtgraph as pg # needs to be imported after pyside to not import pyqt import pyqtgraph as pg # needs to be imported after pyside to not import pyqt
from fixtracks.utils.trackingdata import TrackingData from fixtracks.utils.trackingdata import TrackingData
from IPython import embed from IPython import embed
class WorkerSignals(QObject): class WorkerSignals(QObject):
error = Signal(str) error = Signal(str)
running = Signal(bool) running = Signal(bool)
progress = Signal(int, int, int) progress = Signal(int, int, int)
stopped = Signal(int) stopped = Signal(int)
class ConsitencyDataLoader(QRunnable):
def __init__(self, data):
super().__init__()
self.signals = WorkerSignals()
self.data = data
self.bendedness = self.positions = None
self.lengths = None
self.orientations = None
self.userlabeled = None
self.scores = None
self.frames = None
self.tracks = None
@Slot()
def run(self):
self.positions = self.data.centerOfGravity()
self.orientations = self.data.orientation()
self.lengths = self.data.animalLength()
self.bendedness = self.data.bendedness()
self.userlabeled = self.data["userlabeled"]
self.scores = self.data["confidence"] # ignore for now, let's see how far this carries.
self.frames = self.data["frame"]
self.tracks = self.data["track"]
self.signals.stopped.emit(0)
class ConsistencyWorker(QRunnable): class ConsistencyWorker(QRunnable):
signals = WorkerSignals()
def __init__(self, positions, orientations, lengths, bendedness, frames, tracks, def __init__(self, positions, orientations, lengths, bendedness, frames, tracks,
startframe=0, stoponerror=False) -> None: userlabeled, startframe=0, stoponerror=False) -> None:
super().__init__() super().__init__()
self.signals = WorkerSignals()
self.positions = positions self.positions = positions
self.orientations = orientations self.orientations = orientations
self.lengths = lengths self.lengths = lengths
self._bendedness = bendedness self.bendedness = bendedness
self.userlabeled = userlabeled
self.frames = frames self.frames = frames
self.tracks = tracks self.tracks = tracks
self._startframe = startframe self._startframe = startframe
@ -38,11 +66,41 @@ class ConsistencyWorker(QRunnable):
@Slot() @Slot()
def run(self): def run(self):
def needs_checking(original, new):
res = False
for n, o in zip(new, original):
res = (o == 1 or o == 2) and n != o
if not res:
res = len(new) > 1 and (np.all(new == 1) or np.all(new == 2))
return res
def assign_by_distance(f, p):
t1_step = f - last_frame[0]
t2_step = f - last_frame[1]
if t1_step == 0 or t2_step == 0:
print(f"framecount is zero! current frame {f}, last frame {last_frame[0]} and {last_frame[1]}")
distance_to_trackone = np.linalg.norm(p - last_pos[0])/t1_step
distance_to_tracktwo = np.linalg.norm(p - last_pos[1])/t2_step
most_likely_track = np.argmin([distance_to_trackone, distance_to_tracktwo]) + 1
distances = np.zeros(2)
distances[0] = distance_to_trackone
distances[1] = distance_to_tracktwo
return most_likely_track, distances
def assign_by_orientation(f, o):
t1_step = f - last_frame[0]
t2_step = f - last_frame[1]
orientationchange = np.unwrap((last_angle - o)/np.array([t1_step, t2_step]))
most_likely_track = np.argmin(orientationchange) + 1
return most_likely_track, orientationchange
last_pos = [self.positions[(self.tracks == 1) & (self.frames <= self._startframe)][-1], last_pos = [self.positions[(self.tracks == 1) & (self.frames <= self._startframe)][-1],
self.positions[(self.tracks == 2) & (self.frames <= self._startframe)][-1]] self.positions[(self.tracks == 2) & (self.frames <= self._startframe)][-1]]
last_frame = [self.frames[(self.tracks == 1) & (self.frames <= self._startframe)][-1], last_frame = [self.frames[(self.tracks == 1) & (self.frames <= self._startframe)][-1],
self.frames[(self.tracks == 2) & (self.frames <= self._startframe)][-1]] self.frames[(self.tracks == 2) & (self.frames <= self._startframe)][-1]]
# last_angle = [self.orientations[self.tracks == 1][0], self.orientations[self.tracks == 2][0]] last_angle = [self.orientations[(self.tracks == 1) & (self.frames <= self._startframe)][-1],
self.orientations[(self.tracks == 2) & (self.frames <= self._startframe)][-1]]
errors = 0 errors = 0
processed = 1 processed = 1
progress = 0 progress = 0
@ -52,42 +110,38 @@ class ConsistencyWorker(QRunnable):
startframe = np.max(last_frame) startframe = np.max(last_frame)
steps = int((maxframes - startframe) // 200) steps = int((maxframes - startframe) // 200)
for f in range(startframe + 1, maxframes, 1): for f in np.unique(self.frames[self.frames > startframe]):
if self._stoprequest: if self._stoprequest:
break break
indices = np.where(self.frames == f)[0] indices = np.where(self.frames == f)[0]
pp = self.positions[indices] pp = self.positions[indices]
originaltracks = self.tracks[indices] originaltracks = self.tracks[indices]
assignments = np.zeros_like(originaltracks) dist_assignments = np.zeros_like(originaltracks)
angle_assignments = np.zeros_like(originaltracks)
# userlabeld = np.zeros_like(originaltracks)
distances = np.zeros((len(originaltracks), 2)) distances = np.zeros((len(originaltracks), 2))
orientations = np.zeros((len(originaltracks), 2))
for i, (idx, p) in enumerate(zip(indices, pp)): for i, (idx, p) in enumerate(zip(indices, pp)):
if f < last_frame[0]: if self.userlabeled[idx]:
self.tracks[idx] = 2 print("user")
last_frame[1] = f processed += 1
last_pos[1] = p last_pos[originaltracks[i]-1] = pp[i]
# last_angle[1] = self.orientations[idx] last_frame[originaltracks[i]-1] = f
last_angle[originaltracks[i]-1] = self.orientations[idx]
continue continue
if f < last_frame[1]: dist_assignments[i], distances[i, :] = assign_by_distance(f, p)
last_frame[0] = f angle_assignments[i], orientations[i,:] = assign_by_orientation(f, self.orientations[idx])
last_pos[0] = p
# last_angle[0] = self.orientations[idx]
self.tracks[idx] = 1
continue
# else, we have already seen track one and track two entries
if f - last_frame[0] == 0 or f - last_frame[1] == 0:
print(f"framecount is zero! current frame {f}, last frame {last_frame[0]} and {last_frame[1]}")
distance_to_trackone = np.linalg.norm(p - last_pos[0])/(f - last_frame[0])
distance_to_tracktwo = np.linalg.norm(p - last_pos[1])/(f - last_frame[1])
most_likely_track = np.argmin([distance_to_trackone, distance_to_tracktwo]) + 1
distances[i, 0] = distance_to_trackone
distances[i, 1] = distance_to_tracktwo
assignments[i] = most_likely_track
# check (re) assignment update and proceed # check (re) assignment update and proceed
if len(assignments) > 1 and (np.all(assignments == 1) or np.all(assignments == 2)): print("dist", distances)
logging.warning("frame %i: Issues assigning based on distances %s", f, str(distances)) print("angle", orientations)
if needs_checking(originaltracks, dist_assignments):
logging.info("frame %i: Issues assigning based on distances %s", f, str(distances))
assignment_error = True assignment_error = True
errors += 1 errors += 1
if self._stoponerror: if self._stoponerror:
embed()
break break
else: else:
processed += 1 processed += 1
@ -95,9 +149,10 @@ class ConsistencyWorker(QRunnable):
if assignment_error: if assignment_error:
self.tracks[idx] = -1 self.tracks[idx] = -1
else: else:
self.tracks[idx] = assignments[i] self.tracks[idx] = dist_assignments[i]
last_pos[assignments[i]-1] = pp[i] last_pos[dist_assignments[i]-1] = pp[i]
last_frame[assignments[i]-1] = f last_frame[dist_assignments[i]-1] = f
last_angle[dist_assignments[i]-1] = self.orientations[idx]
assignment_error = False assignment_error = False
if steps > 0 and f % steps == 0: if steps > 0 and f % steps == 0:
progress += 1 progress += 1
@ -305,9 +360,12 @@ class ConsistencyClassifier(QWidget):
self._all_lengths = None self._all_lengths = None
self._all_bendedness = None self._all_bendedness = None
self._all_scores = None self._all_scores = None
self._userlabeled = None
self._maxframes = 0
self._frames = None self._frames = None
self._tracks = None self._tracks = None
self._worker = None self._worker = None
self._dataworker = None
self._processed_frames = 0 self._processed_frames = 0
self._errorlabel = QLabel() self._errorlabel = QLabel()
@ -327,7 +385,7 @@ class ConsistencyClassifier(QWidget):
self._proceedbtn = QPushButton("proceed") self._proceedbtn = QPushButton("proceed")
self._proceedbtn.clicked.connect(self.proceed) self._proceedbtn.clicked.connect(self.proceed)
self._proceedbtn.setEnabled(False) self._proceedbtn.setEnabled(False)
self._refreshbtn = QPushButton("refresh") self._refreshbtn = QPushButton("refresh")
self._refreshbtn.clicked.connect(self.refresh) self._refreshbtn.clicked.connect(self.refresh)
self._refreshbtn.setEnabled(True) self._refreshbtn.setEnabled(True)
@ -341,7 +399,7 @@ class ConsistencyClassifier(QWidget):
self._progressbar.setMaximum(100) self._progressbar.setMaximum(100)
self._stoponerror = QCheckBox("Stop processing whenever an error is encountered") self._stoponerror = QCheckBox("Stop processing whenever an error is encountered")
self._stoponerror.setToolTip("Stop process whenever ") self._stoponerror.setToolTip("Stop process upon errors")
self._stoponerror.setCheckable(True) self._stoponerror.setCheckable(True)
self._stoponerror.setChecked(True) self._stoponerror.setChecked(True)
self.threadpool = QThreadPool() self.threadpool = QThreadPool()
@ -373,24 +431,37 @@ class ConsistencyClassifier(QWidget):
data : Trackingdata data : Trackingdata
The tracking data. The tracking data.
""" """
self.setEnabled(False)
self._progressbar.setRange(0,0)
self._data = data self._data = data
self._all_pos = data.centerOfGravity() self._dataworker = ConsitencyDataLoader(self._data)
self._all_orientations = data.orientation() self._dataworker.signals.stopped.connect(self.data_processed)
self._all_lengths = data.animalLength() self.threadpool.start(self._dataworker)
self._all_bendedness = data.bendedness()
self._all_scores = data["confidence"] # ignore for now, let's see how far this carries. @Slot()
self._frames = data["frame"] def data_processed(self):
self._tracks = data["track"] if self._dataworker is not None:
self._maxframes = np.max(self._frames) self._progressbar.setRange(0,100)
min_frame = max([self._frames[self._tracks == 1][0], self._frames[self._tracks == 2][0]]) + 1 self._progressbar.setValue(0)
self._maxframeslabel.setText(str(self._maxframes)) self._all_pos = self._dataworker.positions
self._startframe_spinner.setMinimum(min_frame) self._all_orientations = self._dataworker.orientations
self._startframe_spinner.setMaximum(self._frames[-1]) self._all_lengths = self._dataworker.lengths
self._startframe_spinner.setValue(self._frames[0] + 1) self._all_bendedness = self._dataworker.bendedness
self._startbtn.setEnabled(True) self._userlabeled = self._dataworker.userlabeled
self._assignedlabel.setText("0") self._all_scores = self._dataworker.scores
self._errorlabel.setText("0") self._frames = self._dataworker.frames
self._worker = None self._tracks = self._dataworker.tracks
self._maxframes = np.max(self._frames)
min_frame = max([self._frames[self._tracks == 1][0], self._frames[self._tracks == 2][0]]) + 1
self._maxframeslabel.setText(str(self._maxframes))
self._startframe_spinner.setMinimum(min_frame)
self._startframe_spinner.setMaximum(self._frames[-1])
self._startframe_spinner.setValue(self._frames[0] + 1)
self._startbtn.setEnabled(True)
self._assignedlabel.setText("0")
self._errorlabel.setText("0")
self._dataworker = None
self.setEnabled(True)
@Slot(float) @Slot(float)
def on_progress(self, value): def on_progress(self, value):
@ -410,7 +481,7 @@ class ConsistencyClassifier(QWidget):
self._refreshbtn.setEnabled(False) self._refreshbtn.setEnabled(False)
self._stopbtn.setEnabled(True) self._stopbtn.setEnabled(True)
self._worker = ConsistencyWorker(self._all_pos, self._all_orientations, self._all_lengths, self._worker = ConsistencyWorker(self._all_pos, self._all_orientations, self._all_lengths,
self._all_bendedness, self._frames, self._tracks, self._all_bendedness, self._frames, self._tracks, self._userlabeled,
self._startframe_spinner.value(), self._stoponerror.isChecked()) self._startframe_spinner.value(), self._stoponerror.isChecked())
self._worker.signals.stopped.connect(self.worker_stopped) self._worker.signals.stopped.connect(self.worker_stopped)
self._worker.signals.progress.connect(self.worker_progress) self._worker.signals.progress.connect(self.worker_progress)
@ -470,6 +541,7 @@ class ClassifierWidget(QTabWidget):
def consistency_tracker(self): def consistency_tracker(self):
return self._consistency_tracker return self._consistency_tracker
@Slot()
def update(self): def update(self):
self.consistency_tracker.setData(self._data) self.consistency_tracker.setData(self._data)
@ -485,10 +557,9 @@ def as_dict(df):
def main(): def main():
test_size = False test_size = False
import pickle import pickle
from IPython import embed
from fixtracks.info import PACKAGE_ROOT from fixtracks.info import PACKAGE_ROOT
datafile = PACKAGE_ROOT / "data/merged_small.pkl" datafile = PACKAGE_ROOT / "data/merged.pkl"
with open(datafile, "rb") as f: with open(datafile, "rb") as f:
df = pickle.load(f) df = pickle.load(f)