[wip] add score to items, and ignore them

This commit is contained in:
Jan Grewe 2025-02-28 08:12:04 +01:00
parent d1b5776e69
commit 116e0ce5de
4 changed files with 71 additions and 38 deletions

View File

@ -2,7 +2,7 @@ import logging
import numpy as np import numpy as np
from PySide6.QtWidgets import QWidget, QVBoxLayout, QTabWidget, QPushButton, QGraphicsView, QTextEdit from PySide6.QtWidgets import QWidget, QVBoxLayout, QTabWidget, QPushButton, QGraphicsView, QTextEdit
from PySide6.QtWidgets import QSpinBox, QProgressBar, QGridLayout, QLabel, QCheckBox, QProgressDialog from PySide6.QtWidgets import QSpinBox, QProgressBar, QGridLayout, QLabel, QCheckBox, QDoubleSpinBox
from PySide6.QtCore import Qt, Signal, Slot, QRunnable, QObject, QThreadPool from PySide6.QtCore import Qt, Signal, Slot, QRunnable, QObject, QThreadPool
from PySide6.QtGui import QBrush, QColor from PySide6.QtGui import QBrush, QColor
@ -13,16 +13,17 @@ from fixtracks.utils.trackingdata import TrackingData
from IPython import embed from IPython import embed
class Detection(): class Detection():
def __init__(self, id, frame, track, position, orientation, length, userlabeled): def __init__(self, id, frame, track, position, orientation, length, userlabeled, confidence):
self.id = id self.id = id
self.frame = frame self.frame = frame
self.track = track self.track = track
self.position = position self.position = position
self.score = 0.0 self.confidence = confidence
self.angle = orientation self.angle = orientation
self.length = length self.length = length
self.userlabeled = userlabeled self.userlabeled = userlabeled
class WorkerSignals(QObject): class WorkerSignals(QObject):
message = Signal(str) message = Signal(str)
running = Signal(bool) running = Signal(bool)
@ -30,7 +31,8 @@ class WorkerSignals(QObject):
currentframe = Signal(int) currentframe = Signal(int)
stopped = Signal(int) stopped = Signal(int)
class ConsitencyDataLoader(QRunnable):
class ConsistencyDataLoader(QRunnable):
def __init__(self, data): def __init__(self, data):
super().__init__() super().__init__()
self.signals = WorkerSignals() self.signals = WorkerSignals()
@ -40,7 +42,7 @@ class ConsitencyDataLoader(QRunnable):
self.lengths = None self.lengths = None
self.orientations = None self.orientations = None
self.userlabeled = None self.userlabeled = None
self.scores = None self.confidence = None
self.frames = None self.frames = None
self.tracks = None self.tracks = None
@ -54,15 +56,16 @@ class ConsitencyDataLoader(QRunnable):
self.lengths = self.data.animalLength() self.lengths = self.data.animalLength()
# self.bendedness = self.data.bendedness() # self.bendedness = self.data.bendedness()
self.userlabeled = self.data["userlabeled"] self.userlabeled = self.data["userlabeled"]
self.scores = self.data["confidence"] # ignore for now, let's see how far this carries. self.confidence = self.data["confidence"] # ignore for now, let's see how far this carries.
self.frames = self.data["frame"] self.frames = self.data["frame"]
self.tracks = self.data["track"] self.tracks = self.data["track"]
self.signals.stopped.emit(0) self.signals.stopped.emit(0)
class ConsistencyWorker(QRunnable): class ConsistencyWorker(QRunnable):
def __init__(self, positions, orientations, lengths, bendedness, frames, tracks, def __init__(self, positions, orientations, lengths, bendedness, frames, tracks,
userlabeled, startframe=0, stoponerror=False) -> None: userlabeled, confidence, startframe=0, stoponerror=False, min_confidence=0.0) -> None:
super().__init__() super().__init__()
self.signals = WorkerSignals() self.signals = WorkerSignals()
self.positions = positions self.positions = positions
@ -70,6 +73,8 @@ class ConsistencyWorker(QRunnable):
self.lengths = lengths self.lengths = lengths
self.bendedness = bendedness self.bendedness = bendedness
self.userlabeled = userlabeled self.userlabeled = userlabeled
self.confidence = confidence
self._min_confidence = min_confidence
self.frames = frames self.frames = frames
self.tracks = tracks self.tracks = tracks
self._startframe = startframe self._startframe = startframe
@ -88,9 +93,11 @@ class ConsistencyWorker(QRunnable):
if np.any(self.positions[i] < 0.1): if np.any(self.positions[i] < 0.1):
logging.debug("Encountered probably invalid position %s", str(self.positions[i])) logging.debug("Encountered probably invalid position %s", str(self.positions[i]))
continue continue
if self._min_confidence > 0.0 and self.confidence[i] < self._min_confidence:
continue
d = Detection(i, frame, self.tracks[i], self.positions[i], d = Detection(i, frame, self.tracks[i], self.positions[i],
self.orientations[i], self.lengths[i], self.orientations[i], self.lengths[i],
self.userlabeled[i]) self.userlabeled[i], self.confidence[i])
detections.append(d) detections.append(d)
return detections return detections
@ -127,6 +134,10 @@ class ConsistencyWorker(QRunnable):
return most_likely_track, length_differences return most_likely_track, length_differences
def check_multiple_detections(detections): def check_multiple_detections(detections):
if self._min_confidence > 0.0:
for i, d in detections:
if d.confidence < self._min_confidence:
del detections[i]
distances = np.zeros((len(detections), len(detections))) distances = np.zeros((len(detections), len(detections)))
for i, d1 in enumerate(detections): for i, d1 in enumerate(detections):
for j, d2 in enumerate(detections): for j, d2 in enumerate(detections):
@ -139,9 +150,11 @@ class ConsistencyWorker(QRunnable):
t1index = np.where((self.frames < startframe) & (self.userlabeled) & (self.tracks == 1))[0][-1] t1index = np.where((self.frames < startframe) & (self.userlabeled) & (self.tracks == 1))[0][-1]
t2index = np.where((self.frames < startframe) & (self.userlabeled) & (self.tracks == 2))[0][-1] t2index = np.where((self.frames < startframe) & (self.userlabeled) & (self.tracks == 2))[0][-1]
d1 = Detection(t1index, self.frames[t1index], self.tracks[t1index], self.positions[t1index], d1 = Detection(t1index, self.frames[t1index], self.tracks[t1index], self.positions[t1index],
self.orientations[t1index], self.lengths[t1index], self.userlabeled[t1index]) self.orientations[t1index], self.lengths[t1index], self.userlabeled[t1index],
self.confidence[t1index])
d2 = Detection(t1index, self.frames[t2index], self.tracks[t2index], self.positions[t2index], d2 = Detection(t1index, self.frames[t2index], self.tracks[t2index], self.positions[t2index],
self.orientations[t2index], self.lengths[t2index], self.userlabeled[t2index]) self.orientations[t2index], self.lengths[t2index], self.userlabeled[t2index],
self.confidence[t1index])
last_detections[1] = d1 last_detections[1] = d1
last_detections[2] = d2 last_detections[2] = d2
@ -337,6 +350,7 @@ class SizeClassifier(QWidget):
tracks[(self._sizes >= t2lower) & (self._sizes < t2upper)] = 2 tracks[(self._sizes >= t2lower) & (self._sizes < t2upper)] = 2
return tracks return tracks
class NeighborhoodValidator(QWidget): class NeighborhoodValidator(QWidget):
apply = Signal() apply = Signal()
name = "Neighborhood Validator" name = "Neighborhood Validator"
@ -506,30 +520,38 @@ class ConsistencyClassifier(QWidget):
self._stoponerror.setChecked(True) self._stoponerror.setChecked(True)
self.threadpool = QThreadPool() self.threadpool = QThreadPool()
self._ignore_confidence = QCheckBox("Ignore detections widh confidence less than")
self._confidence_spinner = QDoubleSpinBox()
self._confidence_spinner.setRange(0.0, 1.0)
self._confidence_spinner.setSingleStep(0.01)
self._confidence_spinner.setDecimals(2)
self._confidence_spinner.setValue(0.5)
self._messagebox = QTextEdit() self._messagebox = QTextEdit()
self._messagebox.setFocusPolicy(Qt.NoFocus) self._messagebox.setFocusPolicy(Qt.NoFocus)
self._messagebox.setReadOnly(True) self._messagebox.setReadOnly(True)
lyt = QGridLayout() lyt = QGridLayout()
lyt.addWidget(QLabel("Start frame:"), 0, 0 ) lyt.addWidget(QLabel("Start frame:"), 0, 0 )
lyt.addWidget(self._startframe_spinner, 0, 1, 1, 2) lyt.addWidget(self._startframe_spinner, 0, 1, 1, 1)
lyt.addWidget(QLabel("of"), 1, 1, 1, 1) lyt.addWidget(QLabel("of"), 0, 2, 1, 1)
lyt.addWidget(self._maxframeslabel, 1, 2, 1, 1) lyt.addWidget(self._maxframeslabel, 0, 3, 1, 1)
lyt.addWidget(self._stoponerror, 2, 0, 1, 3) lyt.addWidget(self._stoponerror, 1, 0, 1, 3)
lyt.addWidget(QLabel("Current frame"), 3,0) lyt.addWidget(self._ignore_confidence, 3, 0, 1, 3)
lyt.addWidget(self._framelabel, 3,1) lyt.addWidget(self._confidence_spinner, 3, 3, 1, 1)
lyt.addWidget(QLabel("assigned"), 4, 0) lyt.addWidget(QLabel("Current frame"), 4, 0)
lyt.addWidget(self._assignedlabel, 4, 1) lyt.addWidget(self._framelabel, 4, 1)
lyt.addWidget(QLabel("errors/issues"), 5, 0) lyt.addWidget(QLabel("(Re-)Assigned"), 5, 0)
lyt.addWidget(self._errorlabel, 5, 1) lyt.addWidget(self._assignedlabel, 5, 1)
lyt.addWidget(self._messagebox, 6, 0, 2, 3) lyt.addWidget(QLabel("Errors/issues"), 5, 2)
lyt.addWidget(self._errorlabel, 5, 3, 1, 1)
lyt.addWidget(self._startbtn, 8, 0) lyt.addWidget(self._messagebox, 6, 0, 2, 4)
lyt.addWidget(self._stopbtn, 8, 1)
lyt.addWidget(self._proceedbtn, 8, 2) lyt.addWidget(self._startbtn, 8, 0, 1, 2)
lyt.addWidget(self._apply_btn, 9, 0, 1, 2) lyt.addWidget(self._stopbtn, 8, 2)
lyt.addWidget(self._refreshbtn, 9, 2, 1, 1) # lyt.addWidget(self._proceedbtn, 8, 2)
lyt.addWidget(self._progressbar, 10, 0, 1, 3) lyt.addWidget(self._refreshbtn, 8, 3, 1, 1)
lyt.addWidget(self._apply_btn, 9, 0, 1, 4)
lyt.addWidget(self._progressbar, 10, 0, 1, 4)
self.setLayout(lyt) self.setLayout(lyt)
def setData(self, data:TrackingData): def setData(self, data:TrackingData):
@ -554,7 +576,7 @@ class ConsistencyClassifier(QWidget):
self._all_lengths = self._dataworker.lengths self._all_lengths = self._dataworker.lengths
self._all_bendedness = self._dataworker.bendedness self._all_bendedness = self._dataworker.bendedness
self._userlabeled = self._dataworker.userlabeled self._userlabeled = self._dataworker.userlabeled
self._all_scores = self._dataworker.scores self._confidence = self._dataworker.confidence
self._frames = self._dataworker.frames self._frames = self._dataworker.frames
self._tracks = self._dataworker.tracks self._tracks = self._dataworker.tracks
self._dataworker = None self._dataworker = None
@ -597,12 +619,14 @@ class ConsistencyClassifier(QWidget):
self._messagebox.append("Stopping tracking.") self._messagebox.append("Stopping tracking.")
def start(self): def start(self):
confidence_level = self._confidence_spinner.value() if self._ignore_confidence.isChecked() else 0.0
self._startbtn.setEnabled(False) self._startbtn.setEnabled(False)
self._refreshbtn.setEnabled(False) self._refreshbtn.setEnabled(False)
self._stopbtn.setEnabled(True) self._stopbtn.setEnabled(True)
self._worker = ConsistencyWorker(self._all_pos, self._all_orientations, self._all_lengths, self._worker = ConsistencyWorker(self._all_pos, self._all_orientations, self._all_lengths,
self._all_bendedness, self._frames, self._tracks, self._userlabeled, self._all_bendedness, self._frames, self._tracks, self._userlabeled,
self._startframe_spinner.value(), self._stoponerror.isChecked()) self._confidence, self._startframe_spinner.value(), self._stoponerror.isChecked(),
min_confidence=confidence_level)
self._worker.signals.stopped.connect(self.worker_stopped) self._worker.signals.stopped.connect(self.worker_stopped)
self._worker.signals.progress.connect(self.worker_progress) self._worker.signals.progress.connect(self.worker_progress)
self._worker.signals.message.connect(self.worker_error) self._worker.signals.message.connect(self.worker_error)
@ -621,7 +645,7 @@ class ConsistencyClassifier(QWidget):
def refresh(self): def refresh(self):
self.setEnabled(False) self.setEnabled(False)
self._dataworker = ConsitencyDataLoader(self._data) self._dataworker = ConsistencyDataLoader(self._data)
self._dataworker.signals.stopped.connect(self.data_processed) self._dataworker.signals.stopped.connect(self.data_processed)
self._messagebox.clear() self._messagebox.clear()
self._messagebox.append("Refreshing...") self._messagebox.append("Refreshing...")

View File

@ -151,6 +151,8 @@ class DetectionTimeline(QWidget):
self._position_label.setFont(f) self._position_label.setFont(f)
layout = QVBoxLayout() layout = QVBoxLayout()
layout.setSpacing(0)
layout.setContentsMargins(5, 2, 5, 2)
layout.addWidget(self._view) layout.addWidget(self._view)
layout.addWidget(self._position_label, Qt.AlignmentFlag.AlignRight) layout.addWidget(self._position_label, Qt.AlignmentFlag.AlignRight)
self.setLayout(layout) self.setLayout(layout)
@ -310,8 +312,7 @@ def main():
datafile = PACKAGE_ROOT / "data/merged_small.pkl" datafile = PACKAGE_ROOT / "data/merged_small.pkl"
with open(datafile, "rb") as f: with open(datafile, "rb") as f:
df = pickle.load(f) df = pickle.load(f)
data = TrackingData() data = TrackingData(as_dict(df))
data.setData(as_dict(df))
data.setSelection(np.arange(0,100, 1)) data.setSelection(np.arange(0,100, 1))
data.setUserLabeledStatus(True) data.setUserLabeledStatus(True)
start_x = 0.1 start_x = 0.1
@ -330,12 +331,14 @@ def main():
backBtn.clicked.connect(lambda: back(0.2)) backBtn.clicked.connect(lambda: back(0.2))
btnLyt = QHBoxLayout() btnLyt = QHBoxLayout()
btnLyt.setSpacing(1)
btnLyt.addWidget(backBtn) btnLyt.addWidget(backBtn)
btnLyt.addWidget(zeroBtn) btnLyt.addWidget(zeroBtn)
btnLyt.addWidget(fwdBtn) btnLyt.addWidget(fwdBtn)
view.setWindowPos(start_x) view.setWindowPos(start_x)
layout = QVBoxLayout() layout = QVBoxLayout()
layout.setSpacing(1)
layout.addWidget(view) layout.addWidget(view)
layout.addLayout(btnLyt) layout.addLayout(btnLyt)
window.setLayout(layout) window.setLayout(layout)

View File

@ -157,10 +157,8 @@ class DetectionView(QWidget):
item.setData(DetectionData.ID.value, id) item.setData(DetectionData.ID.value, id)
item.setData(DetectionData.COORDINATES.value, coordinates[i, :, :]) item.setData(DetectionData.COORDINATES.value, coordinates[i, :, :])
item.setData(DetectionData.FRAME.value, f) item.setData(DetectionData.FRAME.value, f)
# item.setData(DetectionData.USERLABELED.value, l) item.setData(DetectionData.USERLABELED.value, l)
item.setData(DetectionData.SCORE.value, s) item.setData(DetectionData.SCORE.value, s)
print(s)
print(item.data(DetectionData.SCORE.value))
item = self._scene.addItem(item) item = self._scene.addItem(item)
def fit_image_to_view(self): def fit_image_to_view(self):

View File

@ -73,6 +73,7 @@ class FixTracks(QWidget):
combo_layout.addWidget(self._gotoframe) combo_layout.addWidget(self._gotoframe)
combo_layout.addWidget(self._gotobtn) combo_layout.addWidget(self._gotobtn)
combo_layout.addItem(QSpacerItem(100, 10, QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Fixed)) combo_layout.addItem(QSpacerItem(100, 10, QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Fixed))
combo_layout.setSpacing(1)
timelinebox = QVBoxLayout() timelinebox = QVBoxLayout()
timelinebox.setSpacing(2) timelinebox.setSpacing(2)
@ -112,6 +113,7 @@ class FixTracks(QWidget):
data_selection_box.addWidget(QLabel("Select data file")) data_selection_box.addWidget(QLabel("Select data file"))
data_selection_box.addWidget(self._data_combo) data_selection_box.addWidget(self._data_combo)
data_selection_box.addItem(QSpacerItem(100, 10, QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Fixed)) data_selection_box.addItem(QSpacerItem(100, 10, QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Fixed))
data_selection_box.setSpacing(0)
btnBox = QHBoxLayout() btnBox = QHBoxLayout()
btnBox.setAlignment(Qt.AlignmentFlag.AlignLeft) btnBox.setAlignment(Qt.AlignmentFlag.AlignLeft)
@ -129,8 +131,12 @@ class FixTracks(QWidget):
cntrlBox.addWidget(self._controls_widget, alignment=Qt.AlignmentFlag.AlignCenter) cntrlBox.addWidget(self._controls_widget, alignment=Qt.AlignmentFlag.AlignCenter)
cntrlBox.addWidget(self._skeleton) cntrlBox.addWidget(self._skeleton)
cntrlBox.addItem(QSpacerItem(50, 100, QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Expanding)) cntrlBox.addItem(QSpacerItem(50, 100, QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Expanding))
cntrlBox.setSpacing(0)
cntrlBox.setContentsMargins(0,0,0,0)
vbox = QVBoxLayout() vbox = QVBoxLayout()
vbox.setSpacing(0)
vbox.setContentsMargins(0,0,0,0)
vbox.addLayout(timelinebox) vbox.addLayout(timelinebox)
vbox.addLayout(cntrlBox) vbox.addLayout(cntrlBox)
vbox.addLayout(btnBox) vbox.addLayout(btnBox)
@ -142,9 +148,12 @@ class FixTracks(QWidget):
splitter.addWidget(container) splitter.addWidget(container)
splitter.setStretchFactor(0, 3) splitter.setStretchFactor(0, 3)
splitter.setStretchFactor(1, 1) splitter.setStretchFactor(1, 1)
layout = QVBoxLayout() layout = QVBoxLayout()
layout.addLayout(data_selection_box) layout.addLayout(data_selection_box)
layout.addWidget(splitter) layout.addWidget(splitter)
layout.setSpacing(0)
layout.setContentsMargins(5,2,2,5)
self.setLayout(layout) self.setLayout(layout)
def on_autoClassify(self, tracks): def on_autoClassify(self, tracks):
@ -364,7 +373,7 @@ class FixTracks(QWidget):
tracks = np.zeros(len(detections), dtype=int) tracks = np.zeros(len(detections), dtype=int)
ids = np.zeros_like(tracks) ids = np.zeros_like(tracks)
frames = np.zeros_like(tracks) frames = np.zeros_like(tracks)
scores = np.zeros_like(tracks) scores = np.zeros(tracks.shape, dtype=float)
coordinates = None coordinates = None
if len(detections) > 0: if len(detections) > 0:
c = detections[0].data(DetectionData.COORDINATES.value) c = detections[0].data(DetectionData.COORDINATES.value)
@ -376,7 +385,6 @@ class FixTracks(QWidget):
frames[i] = d.data(DetectionData.FRAME.value) frames[i] = d.data(DetectionData.FRAME.value)
coordinates[i, :, :] = d.data(DetectionData.COORDINATES.value) coordinates[i, :, :] = d.data(DetectionData.COORDINATES.value)
scores[i] = d.data(DetectionData.SCORE.value) scores[i] = d.data(DetectionData.SCORE.value)
print(scores[i])
self._data.setSelection(ids) self._data.setSelection(ids)
self._controls_widget.setSelectedTracks(tracks) self._controls_widget.setSelectedTracks(tracks)
self._skeleton.clear() self._skeleton.clear()