[wip] add score to items, and ignore them

This commit is contained in:
Jan Grewe 2025-02-28 08:12:04 +01:00
parent d1b5776e69
commit 116e0ce5de
4 changed files with 71 additions and 38 deletions

View File

@ -2,7 +2,7 @@ import logging
import numpy as np
from PySide6.QtWidgets import QWidget, QVBoxLayout, QTabWidget, QPushButton, QGraphicsView, QTextEdit
from PySide6.QtWidgets import QSpinBox, QProgressBar, QGridLayout, QLabel, QCheckBox, QProgressDialog
from PySide6.QtWidgets import QSpinBox, QProgressBar, QGridLayout, QLabel, QCheckBox, QDoubleSpinBox
from PySide6.QtCore import Qt, Signal, Slot, QRunnable, QObject, QThreadPool
from PySide6.QtGui import QBrush, QColor
@ -13,16 +13,17 @@ from fixtracks.utils.trackingdata import TrackingData
from IPython import embed
class Detection():
def __init__(self, id, frame, track, position, orientation, length, userlabeled):
def __init__(self, id, frame, track, position, orientation, length, userlabeled, confidence):
self.id = id
self.frame = frame
self.track = track
self.position = position
self.score = 0.0
self.confidence = confidence
self.angle = orientation
self.length = length
self.userlabeled = userlabeled
class WorkerSignals(QObject):
message = Signal(str)
running = Signal(bool)
@ -30,7 +31,8 @@ class WorkerSignals(QObject):
currentframe = Signal(int)
stopped = Signal(int)
class ConsitencyDataLoader(QRunnable):
class ConsistencyDataLoader(QRunnable):
def __init__(self, data):
super().__init__()
self.signals = WorkerSignals()
@ -40,7 +42,7 @@ class ConsitencyDataLoader(QRunnable):
self.lengths = None
self.orientations = None
self.userlabeled = None
self.scores = None
self.confidence = None
self.frames = None
self.tracks = None
@ -54,15 +56,16 @@ class ConsitencyDataLoader(QRunnable):
self.lengths = self.data.animalLength()
# self.bendedness = self.data.bendedness()
self.userlabeled = self.data["userlabeled"]
self.scores = self.data["confidence"] # ignore for now, let's see how far this carries.
self.confidence = self.data["confidence"] # ignore for now, let's see how far this carries.
self.frames = self.data["frame"]
self.tracks = self.data["track"]
self.signals.stopped.emit(0)
class ConsistencyWorker(QRunnable):
def __init__(self, positions, orientations, lengths, bendedness, frames, tracks,
userlabeled, startframe=0, stoponerror=False) -> None:
userlabeled, confidence, startframe=0, stoponerror=False, min_confidence=0.0) -> None:
super().__init__()
self.signals = WorkerSignals()
self.positions = positions
@ -70,6 +73,8 @@ class ConsistencyWorker(QRunnable):
self.lengths = lengths
self.bendedness = bendedness
self.userlabeled = userlabeled
self.confidence = confidence
self._min_confidence = min_confidence
self.frames = frames
self.tracks = tracks
self._startframe = startframe
@ -88,9 +93,11 @@ class ConsistencyWorker(QRunnable):
if np.any(self.positions[i] < 0.1):
logging.debug("Encountered probably invalid position %s", str(self.positions[i]))
continue
if self._min_confidence > 0.0 and self.confidence[i] < self._min_confidence:
continue
d = Detection(i, frame, self.tracks[i], self.positions[i],
self.orientations[i], self.lengths[i],
self.userlabeled[i])
self.userlabeled[i], self.confidence[i])
detections.append(d)
return detections
@ -127,6 +134,10 @@ class ConsistencyWorker(QRunnable):
return most_likely_track, length_differences
def check_multiple_detections(detections):
if self._min_confidence > 0.0:
for i, d in detections:
if d.confidence < self._min_confidence:
del detections[i]
distances = np.zeros((len(detections), len(detections)))
for i, d1 in enumerate(detections):
for j, d2 in enumerate(detections):
@ -139,9 +150,11 @@ class ConsistencyWorker(QRunnable):
t1index = np.where((self.frames < startframe) & (self.userlabeled) & (self.tracks == 1))[0][-1]
t2index = np.where((self.frames < startframe) & (self.userlabeled) & (self.tracks == 2))[0][-1]
d1 = Detection(t1index, self.frames[t1index], self.tracks[t1index], self.positions[t1index],
self.orientations[t1index], self.lengths[t1index], self.userlabeled[t1index])
self.orientations[t1index], self.lengths[t1index], self.userlabeled[t1index],
self.confidence[t1index])
d2 = Detection(t1index, self.frames[t2index], self.tracks[t2index], self.positions[t2index],
self.orientations[t2index], self.lengths[t2index], self.userlabeled[t2index])
self.orientations[t2index], self.lengths[t2index], self.userlabeled[t2index],
self.confidence[t1index])
last_detections[1] = d1
last_detections[2] = d2
@ -337,6 +350,7 @@ class SizeClassifier(QWidget):
tracks[(self._sizes >= t2lower) & (self._sizes < t2upper)] = 2
return tracks
class NeighborhoodValidator(QWidget):
apply = Signal()
name = "Neighborhood Validator"
@ -506,30 +520,38 @@ class ConsistencyClassifier(QWidget):
self._stoponerror.setChecked(True)
self.threadpool = QThreadPool()
self._ignore_confidence = QCheckBox("Ignore detections widh confidence less than")
self._confidence_spinner = QDoubleSpinBox()
self._confidence_spinner.setRange(0.0, 1.0)
self._confidence_spinner.setSingleStep(0.01)
self._confidence_spinner.setDecimals(2)
self._confidence_spinner.setValue(0.5)
self._messagebox = QTextEdit()
self._messagebox.setFocusPolicy(Qt.NoFocus)
self._messagebox.setReadOnly(True)
lyt = QGridLayout()
lyt.addWidget(QLabel("Start frame:"), 0, 0 )
lyt.addWidget(self._startframe_spinner, 0, 1, 1, 2)
lyt.addWidget(QLabel("of"), 1, 1, 1, 1)
lyt.addWidget(self._maxframeslabel, 1, 2, 1, 1)
lyt.addWidget(self._stoponerror, 2, 0, 1, 3)
lyt.addWidget(QLabel("Current frame"), 3,0)
lyt.addWidget(self._framelabel, 3,1)
lyt.addWidget(QLabel("assigned"), 4, 0)
lyt.addWidget(self._assignedlabel, 4, 1)
lyt.addWidget(QLabel("errors/issues"), 5, 0)
lyt.addWidget(self._errorlabel, 5, 1)
lyt.addWidget(self._messagebox, 6, 0, 2, 3)
lyt.addWidget(self._startbtn, 8, 0)
lyt.addWidget(self._stopbtn, 8, 1)
lyt.addWidget(self._proceedbtn, 8, 2)
lyt.addWidget(self._apply_btn, 9, 0, 1, 2)
lyt.addWidget(self._refreshbtn, 9, 2, 1, 1)
lyt.addWidget(self._progressbar, 10, 0, 1, 3)
lyt.addWidget(self._startframe_spinner, 0, 1, 1, 1)
lyt.addWidget(QLabel("of"), 0, 2, 1, 1)
lyt.addWidget(self._maxframeslabel, 0, 3, 1, 1)
lyt.addWidget(self._stoponerror, 1, 0, 1, 3)
lyt.addWidget(self._ignore_confidence, 3, 0, 1, 3)
lyt.addWidget(self._confidence_spinner, 3, 3, 1, 1)
lyt.addWidget(QLabel("Current frame"), 4, 0)
lyt.addWidget(self._framelabel, 4, 1)
lyt.addWidget(QLabel("(Re-)Assigned"), 5, 0)
lyt.addWidget(self._assignedlabel, 5, 1)
lyt.addWidget(QLabel("Errors/issues"), 5, 2)
lyt.addWidget(self._errorlabel, 5, 3, 1, 1)
lyt.addWidget(self._messagebox, 6, 0, 2, 4)
lyt.addWidget(self._startbtn, 8, 0, 1, 2)
lyt.addWidget(self._stopbtn, 8, 2)
# lyt.addWidget(self._proceedbtn, 8, 2)
lyt.addWidget(self._refreshbtn, 8, 3, 1, 1)
lyt.addWidget(self._apply_btn, 9, 0, 1, 4)
lyt.addWidget(self._progressbar, 10, 0, 1, 4)
self.setLayout(lyt)
def setData(self, data:TrackingData):
@ -554,7 +576,7 @@ class ConsistencyClassifier(QWidget):
self._all_lengths = self._dataworker.lengths
self._all_bendedness = self._dataworker.bendedness
self._userlabeled = self._dataworker.userlabeled
self._all_scores = self._dataworker.scores
self._confidence = self._dataworker.confidence
self._frames = self._dataworker.frames
self._tracks = self._dataworker.tracks
self._dataworker = None
@ -597,12 +619,14 @@ class ConsistencyClassifier(QWidget):
self._messagebox.append("Stopping tracking.")
def start(self):
confidence_level = self._confidence_spinner.value() if self._ignore_confidence.isChecked() else 0.0
self._startbtn.setEnabled(False)
self._refreshbtn.setEnabled(False)
self._stopbtn.setEnabled(True)
self._worker = ConsistencyWorker(self._all_pos, self._all_orientations, self._all_lengths,
self._all_bendedness, self._frames, self._tracks, self._userlabeled,
self._startframe_spinner.value(), self._stoponerror.isChecked())
self._confidence, self._startframe_spinner.value(), self._stoponerror.isChecked(),
min_confidence=confidence_level)
self._worker.signals.stopped.connect(self.worker_stopped)
self._worker.signals.progress.connect(self.worker_progress)
self._worker.signals.message.connect(self.worker_error)
@ -621,7 +645,7 @@ class ConsistencyClassifier(QWidget):
def refresh(self):
self.setEnabled(False)
self._dataworker = ConsitencyDataLoader(self._data)
self._dataworker = ConsistencyDataLoader(self._data)
self._dataworker.signals.stopped.connect(self.data_processed)
self._messagebox.clear()
self._messagebox.append("Refreshing...")

View File

@ -151,6 +151,8 @@ class DetectionTimeline(QWidget):
self._position_label.setFont(f)
layout = QVBoxLayout()
layout.setSpacing(0)
layout.setContentsMargins(5, 2, 5, 2)
layout.addWidget(self._view)
layout.addWidget(self._position_label, Qt.AlignmentFlag.AlignRight)
self.setLayout(layout)
@ -310,8 +312,7 @@ def main():
datafile = PACKAGE_ROOT / "data/merged_small.pkl"
with open(datafile, "rb") as f:
df = pickle.load(f)
data = TrackingData()
data.setData(as_dict(df))
data = TrackingData(as_dict(df))
data.setSelection(np.arange(0,100, 1))
data.setUserLabeledStatus(True)
start_x = 0.1
@ -330,12 +331,14 @@ def main():
backBtn.clicked.connect(lambda: back(0.2))
btnLyt = QHBoxLayout()
btnLyt.setSpacing(1)
btnLyt.addWidget(backBtn)
btnLyt.addWidget(zeroBtn)
btnLyt.addWidget(fwdBtn)
view.setWindowPos(start_x)
layout = QVBoxLayout()
layout.setSpacing(1)
layout.addWidget(view)
layout.addLayout(btnLyt)
window.setLayout(layout)

View File

@ -157,10 +157,8 @@ class DetectionView(QWidget):
item.setData(DetectionData.ID.value, id)
item.setData(DetectionData.COORDINATES.value, coordinates[i, :, :])
item.setData(DetectionData.FRAME.value, f)
# item.setData(DetectionData.USERLABELED.value, l)
item.setData(DetectionData.USERLABELED.value, l)
item.setData(DetectionData.SCORE.value, s)
print(s)
print(item.data(DetectionData.SCORE.value))
item = self._scene.addItem(item)
def fit_image_to_view(self):

View File

@ -73,6 +73,7 @@ class FixTracks(QWidget):
combo_layout.addWidget(self._gotoframe)
combo_layout.addWidget(self._gotobtn)
combo_layout.addItem(QSpacerItem(100, 10, QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Fixed))
combo_layout.setSpacing(1)
timelinebox = QVBoxLayout()
timelinebox.setSpacing(2)
@ -112,6 +113,7 @@ class FixTracks(QWidget):
data_selection_box.addWidget(QLabel("Select data file"))
data_selection_box.addWidget(self._data_combo)
data_selection_box.addItem(QSpacerItem(100, 10, QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Fixed))
data_selection_box.setSpacing(0)
btnBox = QHBoxLayout()
btnBox.setAlignment(Qt.AlignmentFlag.AlignLeft)
@ -129,8 +131,12 @@ class FixTracks(QWidget):
cntrlBox.addWidget(self._controls_widget, alignment=Qt.AlignmentFlag.AlignCenter)
cntrlBox.addWidget(self._skeleton)
cntrlBox.addItem(QSpacerItem(50, 100, QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Expanding))
cntrlBox.setSpacing(0)
cntrlBox.setContentsMargins(0,0,0,0)
vbox = QVBoxLayout()
vbox.setSpacing(0)
vbox.setContentsMargins(0,0,0,0)
vbox.addLayout(timelinebox)
vbox.addLayout(cntrlBox)
vbox.addLayout(btnBox)
@ -142,9 +148,12 @@ class FixTracks(QWidget):
splitter.addWidget(container)
splitter.setStretchFactor(0, 3)
splitter.setStretchFactor(1, 1)
layout = QVBoxLayout()
layout.addLayout(data_selection_box)
layout.addWidget(splitter)
layout.setSpacing(0)
layout.setContentsMargins(5,2,2,5)
self.setLayout(layout)
def on_autoClassify(self, tracks):
@ -364,7 +373,7 @@ class FixTracks(QWidget):
tracks = np.zeros(len(detections), dtype=int)
ids = np.zeros_like(tracks)
frames = np.zeros_like(tracks)
scores = np.zeros_like(tracks)
scores = np.zeros(tracks.shape, dtype=float)
coordinates = None
if len(detections) > 0:
c = detections[0].data(DetectionData.COORDINATES.value)
@ -376,7 +385,6 @@ class FixTracks(QWidget):
frames[i] = d.data(DetectionData.FRAME.value)
coordinates[i, :, :] = d.data(DetectionData.COORDINATES.value)
scores[i] = d.data(DetectionData.SCORE.value)
print(scores[i])
self._data.setSelection(ids)
self._controls_widget.setSelectedTracks(tracks)
self._skeleton.clear()