Compare commits

...

7 Commits

6 changed files with 114 additions and 55 deletions

View File

@@ -23,6 +23,7 @@ class DetectionSceneSignals(QObject):
class DetectionTimelineSignals(QObject):
windowMoved = Signal()
manualMove = Signal()
moveRequest = Signal(float)
class DetectionSignals(QObject):
hover = Signal((int, QPointF))

View File

@@ -84,7 +84,7 @@ class TrackingData(QObject):
logging.debug("TrackingData.setSelection: %i number of ids", len(ids))
self._selection_indices = self._find(ids)
self._selected_ids = ids
print(self._selected_ids, self._selection_indices)
# print(self._selected_ids, self._selection_indices)
def setTrack(self, track_id:int, setUserLabeled:bool=True)-> None:
"""Assign a new track_id to the user-selected detections
@@ -97,12 +97,12 @@ class TrackingData(QObject):
Should the "userlabeled" state of the detections be set to True? Otherwise they will be left untouched.
"""
logging.info("TrackingData: set track id %i for selection, set user-labeled status %s", track_id, str(setUserLabeled))
print(self._selected_ids, self._selection_indices)
print("before: ", self["track"][self._selection_indices], self["frame"][self._selection_indices])
# print(self._selected_ids, self._selection_indices)
# print("before: ", self["track"][self._selection_indices], self["frame"][self._selection_indices])
self["track"][self._selection_indices] = track_id
if setUserLabeled:
self.setUserLabeledStatus(True, True)
print("after: ", self["track"][self._selection_indices], self["frame"][self._selection_indices])
# print("after: ", self["track"][self._selection_indices], self["frame"][self._selection_indices])
def setUserLabeledStatus(self, new_status: bool, selection=True):
"""Sets the status of the "userlabeled" column to a given value (True|False). This can done for ALL data in one go, or only for the UserSelection.

View File

@@ -2,7 +2,7 @@ import logging
import numpy as np
from PySide6.QtWidgets import QWidget, QVBoxLayout, QTabWidget, QPushButton, QGraphicsView, QTextEdit
from PySide6.QtWidgets import QSpinBox, QProgressBar, QGridLayout, QLabel, QCheckBox, QProgressDialog
from PySide6.QtWidgets import QSpinBox, QProgressBar, QGridLayout, QLabel, QCheckBox, QDoubleSpinBox
from PySide6.QtCore import Qt, Signal, Slot, QRunnable, QObject, QThreadPool
from PySide6.QtGui import QBrush, QColor
@@ -13,16 +13,17 @@ from fixtracks.utils.trackingdata import TrackingData
from IPython import embed
class Detection():
def __init__(self, id, frame, track, position, orientation, length, userlabeled):
def __init__(self, id, frame, track, position, orientation, length, userlabeled, confidence):
self.id = id
self.frame = frame
self.track = track
self.position = position
self.score = 0.0
self.confidence = confidence
self.angle = orientation
self.length = length
self.userlabeled = userlabeled
class WorkerSignals(QObject):
message = Signal(str)
running = Signal(bool)
@@ -30,7 +31,8 @@ class WorkerSignals(QObject):
currentframe = Signal(int)
stopped = Signal(int)
class ConsitencyDataLoader(QRunnable):
class ConsistencyDataLoader(QRunnable):
def __init__(self, data):
super().__init__()
self.signals = WorkerSignals()
@@ -40,7 +42,7 @@ class ConsitencyDataLoader(QRunnable):
self.lengths = None
self.orientations = None
self.userlabeled = None
self.scores = None
self.confidence = None
self.frames = None
self.tracks = None
@@ -54,15 +56,16 @@ class ConsitencyDataLoader(QRunnable):
self.lengths = self.data.animalLength()
# self.bendedness = self.data.bendedness()
self.userlabeled = self.data["userlabeled"]
self.scores = self.data["confidence"] # ignore for now, let's see how far this carries.
self.confidence = self.data["confidence"] # ignore for now, let's see how far this carries.
self.frames = self.data["frame"]
self.tracks = self.data["track"]
self.signals.stopped.emit(0)
class ConsistencyWorker(QRunnable):
def __init__(self, positions, orientations, lengths, bendedness, frames, tracks,
userlabeled, startframe=0, stoponerror=False) -> None:
userlabeled, confidence, startframe=0, stoponerror=False, min_confidence=0.0) -> None:
super().__init__()
self.signals = WorkerSignals()
self.positions = positions
@@ -70,6 +73,8 @@ class ConsistencyWorker(QRunnable):
self.lengths = lengths
self.bendedness = bendedness
self.userlabeled = userlabeled
self.confidence = confidence
self._min_confidence = min_confidence
self.frames = frames
self.tracks = tracks
self._startframe = startframe
@@ -88,9 +93,12 @@ class ConsistencyWorker(QRunnable):
if np.any(self.positions[i] < 0.1):
logging.debug("Encountered probably invalid position %s", str(self.positions[i]))
continue
if self._min_confidence > 0.0 and self.confidence[i] < self._min_confidence:
self.tracks[i] = -1
continue
d = Detection(i, frame, self.tracks[i], self.positions[i],
self.orientations[i], self.lengths[i],
self.userlabeled[i])
self.userlabeled[i], self.confidence[i])
detections.append(d)
return detections
@@ -127,6 +135,10 @@ class ConsistencyWorker(QRunnable):
return most_likely_track, length_differences
def check_multiple_detections(detections):
if self._min_confidence > 0.0:
for i, d in enumerate(detections):
if d.confidence < self._min_confidence:
del detections[i]
distances = np.zeros((len(detections), len(detections)))
for i, d1 in enumerate(detections):
for j, d2 in enumerate(detections):
@@ -139,9 +151,11 @@ class ConsistencyWorker(QRunnable):
t1index = np.where((self.frames < startframe) & (self.userlabeled) & (self.tracks == 1))[0][-1]
t2index = np.where((self.frames < startframe) & (self.userlabeled) & (self.tracks == 2))[0][-1]
d1 = Detection(t1index, self.frames[t1index], self.tracks[t1index], self.positions[t1index],
self.orientations[t1index], self.lengths[t1index], self.userlabeled[t1index])
self.orientations[t1index], self.lengths[t1index], self.userlabeled[t1index],
self.confidence[t1index])
d2 = Detection(t1index, self.frames[t2index], self.tracks[t2index], self.positions[t2index],
self.orientations[t2index], self.lengths[t2index], self.userlabeled[t2index])
self.orientations[t2index], self.lengths[t2index], self.userlabeled[t2index],
self.confidence[t1index])
last_detections[1] = d1
last_detections[2] = d2
@@ -223,6 +237,7 @@ class ConsistencyWorker(QRunnable):
if assignments[0] == assignments[1]:
d.track = -1
error = True
errors += 1
message = f"Frame {f}: Classification error: both detections in the same frame are assigned to the same track!"
break
elif assignments[0] != assignments[1]:
@@ -241,6 +256,7 @@ class ConsistencyWorker(QRunnable):
self.tracks[detections[0].id] = -1
message = f"Frame: {f}: Decision based on distance not safe. Track set to unassigned."
error = True
errors += 1
if not error:
for k in temp:
@@ -250,7 +266,7 @@ class ConsistencyWorker(QRunnable):
for idx in indices:
self.tracks[idx] = -1
errors += 1
if self._stoponerror:
if error and self._stoponerror:
self.signals.message.emit(message)
break
processed += 1
@@ -258,6 +274,7 @@ class ConsistencyWorker(QRunnable):
if steps > 0 and f % steps == 0:
progress += 1
self.signals.progress.emit(progress, processed, errors)
self.signals.message.emit(f"Tracking stopped at frame {f}.")
self.signals.stopped.emit(f)
@@ -337,6 +354,7 @@ class SizeClassifier(QWidget):
tracks[(self._sizes >= t2lower) & (self._sizes < t2upper)] = 2
return tracks
class NeighborhoodValidator(QWidget):
apply = Signal()
name = "Neighborhood Validator"
@@ -461,6 +479,7 @@ class ConsistencyClassifier(QWidget):
self._all_lengths = None
self._all_bendedness = None
self._all_scores = None
self._confidence = None
self._userlabeled = None
self._maxframes = 0
self._frames = None
@@ -506,30 +525,38 @@ class ConsistencyClassifier(QWidget):
self._stoponerror.setChecked(True)
self.threadpool = QThreadPool()
self._ignore_confidence = QCheckBox("Ignore detections widh confidence less than")
self._confidence_spinner = QDoubleSpinBox()
self._confidence_spinner.setRange(0.0, 1.0)
self._confidence_spinner.setSingleStep(0.01)
self._confidence_spinner.setDecimals(2)
self._confidence_spinner.setValue(0.5)
self._messagebox = QTextEdit()
self._messagebox.setFocusPolicy(Qt.NoFocus)
self._messagebox.setReadOnly(True)
lyt = QGridLayout()
lyt.addWidget(QLabel("Start frame:"), 0, 0 )
lyt.addWidget(self._startframe_spinner, 0, 1, 1, 2)
lyt.addWidget(QLabel("of"), 1, 1, 1, 1)
lyt.addWidget(self._maxframeslabel, 1, 2, 1, 1)
lyt.addWidget(self._stoponerror, 2, 0, 1, 3)
lyt.addWidget(QLabel("Current frame"), 3,0)
lyt.addWidget(self._framelabel, 3,1)
lyt.addWidget(QLabel("assigned"), 4, 0)
lyt.addWidget(self._assignedlabel, 4, 1)
lyt.addWidget(QLabel("errors/issues"), 5, 0)
lyt.addWidget(self._errorlabel, 5, 1)
lyt.addWidget(self._messagebox, 6, 0, 2, 3)
lyt.addWidget(self._startframe_spinner, 0, 1, 1, 1)
lyt.addWidget(QLabel("of"), 0, 2, 1, 1)
lyt.addWidget(self._maxframeslabel, 0, 3, 1, 1)
lyt.addWidget(self._stoponerror, 1, 0, 1, 3)
lyt.addWidget(self._ignore_confidence, 3, 0, 1, 3)
lyt.addWidget(self._confidence_spinner, 3, 3, 1, 1)
lyt.addWidget(QLabel("Current frame"), 4, 0)
lyt.addWidget(self._framelabel, 4, 1)
lyt.addWidget(QLabel("(Re-)Assigned"), 5, 0)
lyt.addWidget(self._assignedlabel, 5, 1)
lyt.addWidget(QLabel("Errors/issues"), 5, 2)
lyt.addWidget(self._errorlabel, 5, 3, 1, 1)
lyt.addWidget(self._messagebox, 6, 0, 2, 4)
lyt.addWidget(self._startbtn, 8, 0)
lyt.addWidget(self._stopbtn, 8, 1)
lyt.addWidget(self._proceedbtn, 8, 2)
lyt.addWidget(self._apply_btn, 9, 0, 1, 2)
lyt.addWidget(self._refreshbtn, 9, 2, 1, 1)
lyt.addWidget(self._progressbar, 10, 0, 1, 3)
lyt.addWidget(self._startbtn, 8, 0, 1, 2)
lyt.addWidget(self._stopbtn, 8, 2)
# lyt.addWidget(self._proceedbtn, 8, 2)
lyt.addWidget(self._refreshbtn, 8, 3, 1, 1)
lyt.addWidget(self._apply_btn, 9, 0, 1, 4)
lyt.addWidget(self._progressbar, 10, 0, 1, 4)
self.setLayout(lyt)
def setData(self, data:TrackingData):
@@ -554,7 +581,7 @@ class ConsistencyClassifier(QWidget):
self._all_lengths = self._dataworker.lengths
self._all_bendedness = self._dataworker.bendedness
self._userlabeled = self._dataworker.userlabeled
self._all_scores = self._dataworker.scores
self._confidence = self._dataworker.confidence
self._frames = self._dataworker.frames
self._tracks = self._dataworker.tracks
self._dataworker = None
@@ -581,6 +608,7 @@ class ConsistencyClassifier(QWidget):
self._startframe_spinner.setMaximum(max_startframe)
self._startframe_spinner.setValue(min_startframe)
self._startframe_spinner.setSingleStep(20)
self._startframe_spinner.setToolTip(f"Maximum possible start frame: {max_startframe}")
self._startbtn.setEnabled(True)
self._assignedlabel.setText("0")
self._errorlabel.setText("0")
@@ -597,12 +625,14 @@ class ConsistencyClassifier(QWidget):
self._messagebox.append("Stopping tracking.")
def start(self):
confidence_level = self._confidence_spinner.value() if self._ignore_confidence.isChecked() else 0.0
self._startbtn.setEnabled(False)
self._refreshbtn.setEnabled(False)
self._stopbtn.setEnabled(True)
self._worker = ConsistencyWorker(self._all_pos, self._all_orientations, self._all_lengths,
self._all_bendedness, self._frames, self._tracks, self._userlabeled,
self._startframe_spinner.value(), self._stoponerror.isChecked())
self._confidence, self._startframe_spinner.value(), self._stoponerror.isChecked(),
min_confidence=confidence_level)
self._worker.signals.stopped.connect(self.worker_stopped)
self._worker.signals.progress.connect(self.worker_progress)
self._worker.signals.message.connect(self.worker_error)
@@ -621,7 +651,7 @@ class ConsistencyClassifier(QWidget):
def refresh(self):
self.setEnabled(False)
self._dataworker = ConsitencyDataLoader(self._data)
self._dataworker = ConsistencyDataLoader(self._data)
self._dataworker.signals.stopped.connect(self.data_processed)
self._messagebox.clear()
self._messagebox.append("Refreshing...")

View File

@@ -65,6 +65,7 @@ class Window(QGraphicsRectItem):
def mousePressEvent(self, event):
self.setCursor(Qt.ClosedHandCursor)
# print(event.pos())
super().mousePressEvent(event)
def mouseReleaseEvent(self, event):
@@ -121,6 +122,7 @@ class DetectionTimeline(QWidget):
self._scene = QGraphicsScene(QRectF(0, 0, self._total_width, 85.))
self._scene.setBackgroundBrush(self._bg_brush)
self._scene.addItem(self._window)
self._scene.mousePressEvent = self.on_sceneMousePress
self._view = QGraphicsView()
# self._view.setRenderHints(QPainter.RenderHint.Antialiasing | QPainter.RenderHint.SmoothPixmapTransform)
@@ -151,12 +153,22 @@ class DetectionTimeline(QWidget):
self._position_label.setFont(f)
layout = QVBoxLayout()
layout.setSpacing(0)
layout.setContentsMargins(5, 2, 5, 2)
layout.addWidget(self._view)
layout.addWidget(self._position_label, Qt.AlignmentFlag.AlignRight)
self.setLayout(layout)
# self.setMaximumHeight(100)
# self.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Fixed)
def on_sceneMousePress(self, event):
scene_pos = event.scenePos()
relpos = scene_pos.x() / self._total_width
relpos = 0 if relpos < 0.0 else relpos
relpos = 2000/self._total_width if scene_pos.x() > self._total_width else relpos
self.signals.moveRequest.emit(relpos)
logging.debug("Timeline: Scene clicked at position: %.2f, %.2f --> rel x-pos %.3f", scene_pos.x(), scene_pos.y(), relpos)
def clear(self):
for i in self._scene.items():
if isinstance(i, (QGraphicsLineItem, QGraphicsEllipseItem)):
@@ -310,8 +322,7 @@ def main():
datafile = PACKAGE_ROOT / "data/merged_small.pkl"
with open(datafile, "rb") as f:
df = pickle.load(f)
data = TrackingData()
data.setData(as_dict(df))
data = TrackingData(as_dict(df))
data.setSelection(np.arange(0,100, 1))
data.setUserLabeledStatus(True)
start_x = 0.1
@@ -330,12 +341,14 @@ def main():
backBtn.clicked.connect(lambda: back(0.2))
btnLyt = QHBoxLayout()
btnLyt.setSpacing(1)
btnLyt.addWidget(backBtn)
btnLyt.addWidget(zeroBtn)
btnLyt.addWidget(fwdBtn)
view.setWindowPos(start_x)
layout = QVBoxLayout()
layout.setSpacing(1)
layout.addWidget(view)
layout.addLayout(btnLyt)
window.setLayout(layout)

View File

@@ -142,9 +142,10 @@ class DetectionView(QWidget):
scores = self._data.selectedData("confidence")
image_rect = self._pixmapitem.boundingRect() if self._pixmapitem is not None else QRectF(0,0,0,0)
num_detections = len(frames)
for i, (id, f, t, l, s) in enumerate(zip(ids, frames, tracks, userlabeled, scores)):
c = Tracks.fromValue(t).toColor()
c.setAlpha(int(i * 255 / num_detections))
if keypoint >= 0:
x = coordinates[i, keypoint, 0]
y = coordinates[i, keypoint, 1]
@@ -157,10 +158,8 @@ class DetectionView(QWidget):
item.setData(DetectionData.ID.value, id)
item.setData(DetectionData.COORDINATES.value, coordinates[i, :, :])
item.setData(DetectionData.FRAME.value, f)
# item.setData(DetectionData.USERLABELED.value, l)
item.setData(DetectionData.USERLABELED.value, l)
item.setData(DetectionData.SCORE.value, s)
print(s)
print(item.data(DetectionData.SCORE.value))
item = self._scene.addItem(item)
def fit_image_to_view(self):

View File

@@ -45,6 +45,7 @@ class FixTracks(QWidget):
self._timeline = DetectionTimeline()
self._timeline.signals.windowMoved.connect(self.on_windowChanged)
self._timeline.signals.moveRequest.connect(self.on_moveRequest)
self._windowspinner = QSpinBox()
self._windowspinner.setRange(10, 10000)
@@ -55,8 +56,9 @@ class FixTracks(QWidget):
self._keypointcombo = QComboBox()
self._keypointcombo.currentIndexChanged.connect(self.on_keypointSelected)
self._gotoframe = QSpinBox()
self._gotoframe.setSingleStep(1)
self._goto_spinner = QSpinBox()
self._goto_spinner.setSingleStep(1)
self._gotobtn = QPushButton("go!")
self._gotobtn.setToolTip("Jump to a given frame")
self._gotobtn.clicked.connect(self.on_goto)
@@ -65,14 +67,15 @@ class FixTracks(QWidget):
combo_layout.addWidget(QLabel("Window width:"))
combo_layout.addWidget(self._windowspinner)
combo_layout.addWidget(QLabel("frames"))
combo_layout.addItem(QSpacerItem(10, 10, QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Fixed))
combo_layout.addItem(QSpacerItem(10, 10, QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Fixed))
combo_layout.addWidget(QLabel("Keypoint:"))
combo_layout.addWidget(self._keypointcombo)
combo_layout.addItem(QSpacerItem(10, 10, QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Fixed))
combo_layout.addItem(QSpacerItem(10, 10, QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Fixed))
combo_layout.addWidget(QLabel("Jump to frame:"))
combo_layout.addWidget(self._gotoframe)
combo_layout.addWidget(self._goto_spinner)
combo_layout.addWidget(self._gotobtn)
combo_layout.addItem(QSpacerItem(100, 10, QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Fixed))
combo_layout.setSpacing(1)
timelinebox = QVBoxLayout()
timelinebox.setSpacing(2)
@@ -112,6 +115,7 @@ class FixTracks(QWidget):
data_selection_box.addWidget(QLabel("Select data file"))
data_selection_box.addWidget(self._data_combo)
data_selection_box.addItem(QSpacerItem(100, 10, QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Fixed))
data_selection_box.setSpacing(0)
btnBox = QHBoxLayout()
btnBox.setAlignment(Qt.AlignmentFlag.AlignLeft)
@@ -129,8 +133,12 @@ class FixTracks(QWidget):
cntrlBox.addWidget(self._controls_widget, alignment=Qt.AlignmentFlag.AlignCenter)
cntrlBox.addWidget(self._skeleton)
cntrlBox.addItem(QSpacerItem(50, 100, QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Expanding))
cntrlBox.setSpacing(0)
cntrlBox.setContentsMargins(0,0,0,0)
vbox = QVBoxLayout()
vbox.setSpacing(0)
vbox.setContentsMargins(0,0,0,0)
vbox.addLayout(timelinebox)
vbox.addLayout(cntrlBox)
vbox.addLayout(btnBox)
@@ -142,9 +150,12 @@ class FixTracks(QWidget):
splitter.addWidget(container)
splitter.setStretchFactor(0, 3)
splitter.setStretchFactor(1, 1)
layout = QVBoxLayout()
layout.addLayout(data_selection_box)
layout.addWidget(splitter)
layout.setSpacing(0)
layout.setContentsMargins(5,2,2,5)
self.setLayout(layout)
def on_autoClassify(self, tracks):
@@ -225,7 +236,7 @@ class FixTracks(QWidget):
self._currentWindowPos = 0
self._currentWindowWidth = self._windowspinner.value()
self._maxframes = np.max(self._data["frame"])
self._gotoframe.setMaximum(self._maxframes)
self._goto_spinner.setMaximum(self._maxframes)
self.populateKeypointCombo(self._data.numKeypoints())
self._timeline.setData(self._data)
# self._timeline.setWindow(self._currentWindowPos / self._maxframes,
@@ -333,6 +344,11 @@ class FixTracks(QWidget):
self.update()
self._manualmove = False
def on_moveRequest(self, pos):
new_pos = int(np.round(pos * self._maxframes))
self._currentWindowPos = new_pos
self.update()
def on_windowSizeChanged(self, value):
"""Reacts on the user window-width selection. Selection is done in the unit of frames.
@@ -343,14 +359,15 @@ class FixTracks(QWidget):
"""
self._currentWindowWidth = value
logging.debug("Tracks:OnWindowSizeChanged %i franes", value)
if self._maxframes == 0:
self._timeline.setWindowWidth(self._currentWindowWidth / 2000)
else:
self._timeline.setWindowWidth(self._currentWindowWidth / self._maxframes)
self._controls_widget.setSelectedTracks(None)
# if self._maxframes == 0:
# self._timeline.setWindowWidth(self._currentWindowWidth / 2000)
# else:
# self._timeline.setWindowWidth(self._currentWindowWidth / self._maxframes)
# self._controls_widget.setSelectedTracks(None)
self.update()
def on_goto(self):
target = self._gotoframe.value()
target = self._goto_spinner.value()
if target > self._maxframes - self._currentWindowWidth:
target = self._maxframes - self._currentWindowWidth
logging.info("Jump to frame %i", target)
@@ -364,7 +381,7 @@ class FixTracks(QWidget):
tracks = np.zeros(len(detections), dtype=int)
ids = np.zeros_like(tracks)
frames = np.zeros_like(tracks)
scores = np.zeros_like(tracks)
scores = np.zeros(tracks.shape, dtype=float)
coordinates = None
if len(detections) > 0:
c = detections[0].data(DetectionData.COORDINATES.value)
@@ -376,7 +393,6 @@ class FixTracks(QWidget):
frames[i] = d.data(DetectionData.FRAME.value)
coordinates[i, :, :] = d.data(DetectionData.COORDINATES.value)
scores[i] = d.data(DetectionData.SCORE.value)
print(scores[i])
self._data.setSelection(ids)
self._controls_widget.setSelectedTracks(tracks)
self._skeleton.clear()