Compare commits
4 Commits
6c46d834eb
...
7a2084e159
Author | SHA1 | Date | |
---|---|---|---|
7a2084e159 | |||
881194ac66 | |||
ef6ff0d2b4 | |||
2737fed192 |
@ -19,6 +19,7 @@ class TrackingData(QObject):
|
|||||||
def setData(self, datadict):
|
def setData(self, datadict):
|
||||||
assert isinstance(datadict, dict)
|
assert isinstance(datadict, dict)
|
||||||
self._data = datadict
|
self._data = datadict
|
||||||
|
self._data["userlabeled"] = np.zeros_like(self["frame"], dtype=bool)
|
||||||
self._columns = [k for k in self._data.keys()]
|
self._columns = [k for k in self._data.keys()]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -81,6 +82,7 @@ class TrackingData(QObject):
|
|||||||
The new track id for the user-selected detections
|
The new track id for the user-selected detections
|
||||||
"""
|
"""
|
||||||
self._data["track"][self._user_selections] = track_id
|
self._data["track"][self._user_selections] = track_id
|
||||||
|
self._data["userlabeled"][self._user_selections] = True
|
||||||
|
|
||||||
def assignTracks(self, tracks):
|
def assignTracks(self, tracks):
|
||||||
"""assignTracks _summary_
|
"""assignTracks _summary_
|
||||||
|
@ -2,30 +2,58 @@ import logging
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from PySide6.QtWidgets import QWidget, QVBoxLayout, QTabWidget, QPushButton, QGraphicsView
|
from PySide6.QtWidgets import QWidget, QVBoxLayout, QTabWidget, QPushButton, QGraphicsView
|
||||||
from PySide6.QtWidgets import QSpinBox, QProgressBar, QGridLayout, QLabel, QCheckBox
|
from PySide6.QtWidgets import QSpinBox, QProgressBar, QGridLayout, QLabel, QCheckBox, QProgressDialog
|
||||||
from PySide6.QtCore import Signal, Slot, QRunnable, QObject, QThreadPool
|
from PySide6.QtCore import Qt, Signal, Slot, QRunnable, QObject, QThreadPool
|
||||||
from PySide6.QtGui import QBrush, QColor
|
from PySide6.QtGui import QBrush, QColor
|
||||||
|
|
||||||
import pyqtgraph as pg # needs to be imported after pyside to not import pyqt
|
import pyqtgraph as pg # needs to be imported after pyside to not import pyqt
|
||||||
|
|
||||||
from fixtracks.utils.trackingdata import TrackingData
|
from fixtracks.utils.trackingdata import TrackingData
|
||||||
|
|
||||||
from IPython import embed
|
from IPython import embed
|
||||||
|
|
||||||
class WorkerSignals(QObject):
|
class WorkerSignals(QObject):
|
||||||
error = Signal(str)
|
error = Signal(str)
|
||||||
running = Signal(bool)
|
running = Signal(bool)
|
||||||
progress = Signal(int, int, int)
|
progress = Signal(int, int, int)
|
||||||
stopped = Signal(int)
|
stopped = Signal(int)
|
||||||
|
|
||||||
|
class ConsitencyDataLoader(QRunnable):
|
||||||
|
def __init__(self, data):
|
||||||
|
super().__init__()
|
||||||
|
self.signals = WorkerSignals()
|
||||||
|
self.data = data
|
||||||
|
self.bendedness = self.positions = None
|
||||||
|
self.lengths = None
|
||||||
|
self.orientations = None
|
||||||
|
self.userlabeled = None
|
||||||
|
self.scores = None
|
||||||
|
self.frames = None
|
||||||
|
self.tracks = None
|
||||||
|
|
||||||
|
@Slot()
|
||||||
|
def run(self):
|
||||||
|
self.positions = self.data.centerOfGravity()
|
||||||
|
self.orientations = self.data.orientation()
|
||||||
|
self.lengths = self.data.animalLength()
|
||||||
|
self.bendedness = self.data.bendedness()
|
||||||
|
self.userlabeled = self.data["userlabeled"]
|
||||||
|
self.scores = self.data["confidence"] # ignore for now, let's see how far this carries.
|
||||||
|
self.frames = self.data["frame"]
|
||||||
|
self.tracks = self.data["track"]
|
||||||
|
self.signals.stopped.emit(0)
|
||||||
|
|
||||||
class ConsistencyWorker(QRunnable):
|
class ConsistencyWorker(QRunnable):
|
||||||
signals = WorkerSignals()
|
|
||||||
|
|
||||||
def __init__(self, positions, orientations, lengths, bendedness, frames, tracks,
|
def __init__(self, positions, orientations, lengths, bendedness, frames, tracks,
|
||||||
startframe=0, stoponerror=False) -> None:
|
userlabeled, startframe=0, stoponerror=False) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.signals = WorkerSignals()
|
||||||
self.positions = positions
|
self.positions = positions
|
||||||
self.orientations = orientations
|
self.orientations = orientations
|
||||||
self.lengths = lengths
|
self.lengths = lengths
|
||||||
self._bendedness = bendedness
|
self.bendedness = bendedness
|
||||||
|
self.userlabeled = userlabeled
|
||||||
self.frames = frames
|
self.frames = frames
|
||||||
self.tracks = tracks
|
self.tracks = tracks
|
||||||
self._startframe = startframe
|
self._startframe = startframe
|
||||||
@ -38,11 +66,41 @@ class ConsistencyWorker(QRunnable):
|
|||||||
|
|
||||||
@Slot()
|
@Slot()
|
||||||
def run(self):
|
def run(self):
|
||||||
|
def needs_checking(original, new):
|
||||||
|
res = False
|
||||||
|
for n, o in zip(new, original):
|
||||||
|
res = (o == 1 or o == 2) and n != o
|
||||||
|
if not res:
|
||||||
|
res = len(new) > 1 and (np.all(new == 1) or np.all(new == 2))
|
||||||
|
return res
|
||||||
|
|
||||||
|
def assign_by_distance(f, p):
|
||||||
|
t1_step = f - last_frame[0]
|
||||||
|
t2_step = f - last_frame[1]
|
||||||
|
if t1_step == 0 or t2_step == 0:
|
||||||
|
print(f"framecount is zero! current frame {f}, last frame {last_frame[0]} and {last_frame[1]}")
|
||||||
|
|
||||||
|
distance_to_trackone = np.linalg.norm(p - last_pos[0])/t1_step
|
||||||
|
distance_to_tracktwo = np.linalg.norm(p - last_pos[1])/t2_step
|
||||||
|
most_likely_track = np.argmin([distance_to_trackone, distance_to_tracktwo]) + 1
|
||||||
|
distances = np.zeros(2)
|
||||||
|
distances[0] = distance_to_trackone
|
||||||
|
distances[1] = distance_to_tracktwo
|
||||||
|
return most_likely_track, distances
|
||||||
|
|
||||||
|
def assign_by_orientation(f, o):
|
||||||
|
t1_step = f - last_frame[0]
|
||||||
|
t2_step = f - last_frame[1]
|
||||||
|
orientationchange = np.unwrap((last_angle - o)/np.array([t1_step, t2_step]))
|
||||||
|
most_likely_track = np.argmin(orientationchange) + 1
|
||||||
|
return most_likely_track, orientationchange
|
||||||
|
|
||||||
last_pos = [self.positions[(self.tracks == 1) & (self.frames <= self._startframe)][-1],
|
last_pos = [self.positions[(self.tracks == 1) & (self.frames <= self._startframe)][-1],
|
||||||
self.positions[(self.tracks == 2) & (self.frames <= self._startframe)][-1]]
|
self.positions[(self.tracks == 2) & (self.frames <= self._startframe)][-1]]
|
||||||
last_frame = [self.frames[(self.tracks == 1) & (self.frames <= self._startframe)][-1],
|
last_frame = [self.frames[(self.tracks == 1) & (self.frames <= self._startframe)][-1],
|
||||||
self.frames[(self.tracks == 2) & (self.frames <= self._startframe)][-1]]
|
self.frames[(self.tracks == 2) & (self.frames <= self._startframe)][-1]]
|
||||||
# last_angle = [self.orientations[self.tracks == 1][0], self.orientations[self.tracks == 2][0]]
|
last_angle = [self.orientations[(self.tracks == 1) & (self.frames <= self._startframe)][-1],
|
||||||
|
self.orientations[(self.tracks == 2) & (self.frames <= self._startframe)][-1]]
|
||||||
errors = 0
|
errors = 0
|
||||||
processed = 1
|
processed = 1
|
||||||
progress = 0
|
progress = 0
|
||||||
@ -52,42 +110,38 @@ class ConsistencyWorker(QRunnable):
|
|||||||
startframe = np.max(last_frame)
|
startframe = np.max(last_frame)
|
||||||
steps = int((maxframes - startframe) // 200)
|
steps = int((maxframes - startframe) // 200)
|
||||||
|
|
||||||
for f in range(startframe + 1, maxframes, 1):
|
for f in np.unique(self.frames[self.frames > startframe]):
|
||||||
if self._stoprequest:
|
if self._stoprequest:
|
||||||
break
|
break
|
||||||
indices = np.where(self.frames == f)[0]
|
indices = np.where(self.frames == f)[0]
|
||||||
pp = self.positions[indices]
|
pp = self.positions[indices]
|
||||||
originaltracks = self.tracks[indices]
|
originaltracks = self.tracks[indices]
|
||||||
assignments = np.zeros_like(originaltracks)
|
dist_assignments = np.zeros_like(originaltracks)
|
||||||
|
angle_assignments = np.zeros_like(originaltracks)
|
||||||
|
# userlabeld = np.zeros_like(originaltracks)
|
||||||
distances = np.zeros((len(originaltracks), 2))
|
distances = np.zeros((len(originaltracks), 2))
|
||||||
|
orientations = np.zeros((len(originaltracks), 2))
|
||||||
|
|
||||||
for i, (idx, p) in enumerate(zip(indices, pp)):
|
for i, (idx, p) in enumerate(zip(indices, pp)):
|
||||||
if f < last_frame[0]:
|
if self.userlabeled[idx]:
|
||||||
self.tracks[idx] = 2
|
print("user")
|
||||||
last_frame[1] = f
|
processed += 1
|
||||||
last_pos[1] = p
|
last_pos[originaltracks[i]-1] = pp[i]
|
||||||
# last_angle[1] = self.orientations[idx]
|
last_frame[originaltracks[i]-1] = f
|
||||||
|
last_angle[originaltracks[i]-1] = self.orientations[idx]
|
||||||
continue
|
continue
|
||||||
if f < last_frame[1]:
|
dist_assignments[i], distances[i, :] = assign_by_distance(f, p)
|
||||||
last_frame[0] = f
|
angle_assignments[i], orientations[i,:] = assign_by_orientation(f, self.orientations[idx])
|
||||||
last_pos[0] = p
|
|
||||||
# last_angle[0] = self.orientations[idx]
|
|
||||||
self.tracks[idx] = 1
|
|
||||||
continue
|
|
||||||
# else, we have already seen track one and track two entries
|
|
||||||
if f - last_frame[0] == 0 or f - last_frame[1] == 0:
|
|
||||||
print(f"framecount is zero! current frame {f}, last frame {last_frame[0]} and {last_frame[1]}")
|
|
||||||
distance_to_trackone = np.linalg.norm(p - last_pos[0])/(f - last_frame[0])
|
|
||||||
distance_to_tracktwo = np.linalg.norm(p - last_pos[1])/(f - last_frame[1])
|
|
||||||
most_likely_track = np.argmin([distance_to_trackone, distance_to_tracktwo]) + 1
|
|
||||||
distances[i, 0] = distance_to_trackone
|
|
||||||
distances[i, 1] = distance_to_tracktwo
|
|
||||||
assignments[i] = most_likely_track
|
|
||||||
# check (re) assignment update and proceed
|
# check (re) assignment update and proceed
|
||||||
if len(assignments) > 1 and (np.all(assignments == 1) or np.all(assignments == 2)):
|
print("dist", distances)
|
||||||
logging.warning("frame %i: Issues assigning based on distances %s", f, str(distances))
|
print("angle", orientations)
|
||||||
|
if needs_checking(originaltracks, dist_assignments):
|
||||||
|
logging.info("frame %i: Issues assigning based on distances %s", f, str(distances))
|
||||||
assignment_error = True
|
assignment_error = True
|
||||||
errors += 1
|
errors += 1
|
||||||
if self._stoponerror:
|
if self._stoponerror:
|
||||||
|
embed()
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
processed += 1
|
processed += 1
|
||||||
@ -95,9 +149,10 @@ class ConsistencyWorker(QRunnable):
|
|||||||
if assignment_error:
|
if assignment_error:
|
||||||
self.tracks[idx] = -1
|
self.tracks[idx] = -1
|
||||||
else:
|
else:
|
||||||
self.tracks[idx] = assignments[i]
|
self.tracks[idx] = dist_assignments[i]
|
||||||
last_pos[assignments[i]-1] = pp[i]
|
last_pos[dist_assignments[i]-1] = pp[i]
|
||||||
last_frame[assignments[i]-1] = f
|
last_frame[dist_assignments[i]-1] = f
|
||||||
|
last_angle[dist_assignments[i]-1] = self.orientations[idx]
|
||||||
assignment_error = False
|
assignment_error = False
|
||||||
if steps > 0 and f % steps == 0:
|
if steps > 0 and f % steps == 0:
|
||||||
progress += 1
|
progress += 1
|
||||||
@ -305,9 +360,12 @@ class ConsistencyClassifier(QWidget):
|
|||||||
self._all_lengths = None
|
self._all_lengths = None
|
||||||
self._all_bendedness = None
|
self._all_bendedness = None
|
||||||
self._all_scores = None
|
self._all_scores = None
|
||||||
|
self._userlabeled = None
|
||||||
|
self._maxframes = 0
|
||||||
self._frames = None
|
self._frames = None
|
||||||
self._tracks = None
|
self._tracks = None
|
||||||
self._worker = None
|
self._worker = None
|
||||||
|
self._dataworker = None
|
||||||
self._processed_frames = 0
|
self._processed_frames = 0
|
||||||
|
|
||||||
self._errorlabel = QLabel()
|
self._errorlabel = QLabel()
|
||||||
@ -327,7 +385,7 @@ class ConsistencyClassifier(QWidget):
|
|||||||
self._proceedbtn = QPushButton("proceed")
|
self._proceedbtn = QPushButton("proceed")
|
||||||
self._proceedbtn.clicked.connect(self.proceed)
|
self._proceedbtn.clicked.connect(self.proceed)
|
||||||
self._proceedbtn.setEnabled(False)
|
self._proceedbtn.setEnabled(False)
|
||||||
|
|
||||||
self._refreshbtn = QPushButton("refresh")
|
self._refreshbtn = QPushButton("refresh")
|
||||||
self._refreshbtn.clicked.connect(self.refresh)
|
self._refreshbtn.clicked.connect(self.refresh)
|
||||||
self._refreshbtn.setEnabled(True)
|
self._refreshbtn.setEnabled(True)
|
||||||
@ -341,7 +399,7 @@ class ConsistencyClassifier(QWidget):
|
|||||||
self._progressbar.setMaximum(100)
|
self._progressbar.setMaximum(100)
|
||||||
|
|
||||||
self._stoponerror = QCheckBox("Stop processing whenever an error is encountered")
|
self._stoponerror = QCheckBox("Stop processing whenever an error is encountered")
|
||||||
self._stoponerror.setToolTip("Stop process whenever ")
|
self._stoponerror.setToolTip("Stop process upon errors")
|
||||||
self._stoponerror.setCheckable(True)
|
self._stoponerror.setCheckable(True)
|
||||||
self._stoponerror.setChecked(True)
|
self._stoponerror.setChecked(True)
|
||||||
self.threadpool = QThreadPool()
|
self.threadpool = QThreadPool()
|
||||||
@ -373,24 +431,37 @@ class ConsistencyClassifier(QWidget):
|
|||||||
data : Trackingdata
|
data : Trackingdata
|
||||||
The tracking data.
|
The tracking data.
|
||||||
"""
|
"""
|
||||||
|
self.setEnabled(False)
|
||||||
|
self._progressbar.setRange(0,0)
|
||||||
self._data = data
|
self._data = data
|
||||||
self._all_pos = data.centerOfGravity()
|
self._dataworker = ConsitencyDataLoader(self._data)
|
||||||
self._all_orientations = data.orientation()
|
self._dataworker.signals.stopped.connect(self.data_processed)
|
||||||
self._all_lengths = data.animalLength()
|
self.threadpool.start(self._dataworker)
|
||||||
self._all_bendedness = data.bendedness()
|
|
||||||
self._all_scores = data["confidence"] # ignore for now, let's see how far this carries.
|
@Slot()
|
||||||
self._frames = data["frame"]
|
def data_processed(self):
|
||||||
self._tracks = data["track"]
|
if self._dataworker is not None:
|
||||||
self._maxframes = np.max(self._frames)
|
self._progressbar.setRange(0,100)
|
||||||
min_frame = max([self._frames[self._tracks == 1][0], self._frames[self._tracks == 2][0]]) + 1
|
self._progressbar.setValue(0)
|
||||||
self._maxframeslabel.setText(str(self._maxframes))
|
self._all_pos = self._dataworker.positions
|
||||||
self._startframe_spinner.setMinimum(min_frame)
|
self._all_orientations = self._dataworker.orientations
|
||||||
self._startframe_spinner.setMaximum(self._frames[-1])
|
self._all_lengths = self._dataworker.lengths
|
||||||
self._startframe_spinner.setValue(self._frames[0] + 1)
|
self._all_bendedness = self._dataworker.bendedness
|
||||||
self._startbtn.setEnabled(True)
|
self._userlabeled = self._dataworker.userlabeled
|
||||||
self._assignedlabel.setText("0")
|
self._all_scores = self._dataworker.scores
|
||||||
self._errorlabel.setText("0")
|
self._frames = self._dataworker.frames
|
||||||
self._worker = None
|
self._tracks = self._dataworker.tracks
|
||||||
|
self._maxframes = np.max(self._frames)
|
||||||
|
min_frame = max([self._frames[self._tracks == 1][0], self._frames[self._tracks == 2][0]]) + 1
|
||||||
|
self._maxframeslabel.setText(str(self._maxframes))
|
||||||
|
self._startframe_spinner.setMinimum(min_frame)
|
||||||
|
self._startframe_spinner.setMaximum(self._frames[-1])
|
||||||
|
self._startframe_spinner.setValue(self._frames[0] + 1)
|
||||||
|
self._startbtn.setEnabled(True)
|
||||||
|
self._assignedlabel.setText("0")
|
||||||
|
self._errorlabel.setText("0")
|
||||||
|
self._dataworker = None
|
||||||
|
self.setEnabled(True)
|
||||||
|
|
||||||
@Slot(float)
|
@Slot(float)
|
||||||
def on_progress(self, value):
|
def on_progress(self, value):
|
||||||
@ -410,7 +481,7 @@ class ConsistencyClassifier(QWidget):
|
|||||||
self._refreshbtn.setEnabled(False)
|
self._refreshbtn.setEnabled(False)
|
||||||
self._stopbtn.setEnabled(True)
|
self._stopbtn.setEnabled(True)
|
||||||
self._worker = ConsistencyWorker(self._all_pos, self._all_orientations, self._all_lengths,
|
self._worker = ConsistencyWorker(self._all_pos, self._all_orientations, self._all_lengths,
|
||||||
self._all_bendedness, self._frames, self._tracks,
|
self._all_bendedness, self._frames, self._tracks, self._userlabeled,
|
||||||
self._startframe_spinner.value(), self._stoponerror.isChecked())
|
self._startframe_spinner.value(), self._stoponerror.isChecked())
|
||||||
self._worker.signals.stopped.connect(self.worker_stopped)
|
self._worker.signals.stopped.connect(self.worker_stopped)
|
||||||
self._worker.signals.progress.connect(self.worker_progress)
|
self._worker.signals.progress.connect(self.worker_progress)
|
||||||
@ -470,6 +541,7 @@ class ClassifierWidget(QTabWidget):
|
|||||||
def consistency_tracker(self):
|
def consistency_tracker(self):
|
||||||
return self._consistency_tracker
|
return self._consistency_tracker
|
||||||
|
|
||||||
|
@Slot()
|
||||||
def update(self):
|
def update(self):
|
||||||
self.consistency_tracker.setData(self._data)
|
self.consistency_tracker.setData(self._data)
|
||||||
|
|
||||||
@ -485,10 +557,9 @@ def as_dict(df):
|
|||||||
def main():
|
def main():
|
||||||
test_size = False
|
test_size = False
|
||||||
import pickle
|
import pickle
|
||||||
from IPython import embed
|
|
||||||
from fixtracks.info import PACKAGE_ROOT
|
from fixtracks.info import PACKAGE_ROOT
|
||||||
|
|
||||||
datafile = PACKAGE_ROOT / "data/merged_small.pkl"
|
datafile = PACKAGE_ROOT / "data/merged.pkl"
|
||||||
|
|
||||||
with open(datafile, "rb") as f:
|
with open(datafile, "rb") as f:
|
||||||
df = pickle.load(f)
|
df = pickle.load(f)
|
||||||
|
Loading…
Reference in New Issue
Block a user