Compare commits

...

12 Commits

9 changed files with 175 additions and 59 deletions

View File

@@ -8,6 +8,7 @@ class DetectionData(Enum):
COORDINATES = 2 COORDINATES = 2
TRACK_ID = 3 TRACK_ID = 3
USERLABELED = 4 USERLABELED = 4
SCORE = 5
class Tracks(Enum): class Tracks(Enum):
TRACKONE = 1 TRACKONE = 1

View File

@@ -23,6 +23,7 @@ class DetectionSceneSignals(QObject):
class DetectionTimelineSignals(QObject): class DetectionTimelineSignals(QObject):
windowMoved = Signal() windowMoved = Signal()
manualMove = Signal() manualMove = Signal()
moveRequest = Signal(float)
class DetectionSignals(QObject): class DetectionSignals(QObject):
hover = Signal((int, QPointF)) hover = Signal((int, QPointF))

View File

@@ -84,7 +84,7 @@ class TrackingData(QObject):
logging.debug("TrackingData.setSelection: %i number of ids", len(ids)) logging.debug("TrackingData.setSelection: %i number of ids", len(ids))
self._selection_indices = self._find(ids) self._selection_indices = self._find(ids)
self._selected_ids = ids self._selected_ids = ids
print(self._selected_ids, self._selection_indices) # print(self._selected_ids, self._selection_indices)
def setTrack(self, track_id:int, setUserLabeled:bool=True)-> None: def setTrack(self, track_id:int, setUserLabeled:bool=True)-> None:
"""Assign a new track_id to the user-selected detections """Assign a new track_id to the user-selected detections
@@ -97,12 +97,12 @@ class TrackingData(QObject):
Should the "userlabeled" state of the detections be set to True? Otherwise they will be left untouched. Should the "userlabeled" state of the detections be set to True? Otherwise they will be left untouched.
""" """
logging.info("TrackingData: set track id %i for selection, set user-labeled status %s", track_id, str(setUserLabeled)) logging.info("TrackingData: set track id %i for selection, set user-labeled status %s", track_id, str(setUserLabeled))
print(self._selected_ids, self._selection_indices) # print(self._selected_ids, self._selection_indices)
print("before: ", self["track"][self._selection_indices], self["frame"][self._selection_indices]) # print("before: ", self["track"][self._selection_indices], self["frame"][self._selection_indices])
self["track"][self._selection_indices] = track_id self["track"][self._selection_indices] = track_id
if setUserLabeled: if setUserLabeled:
self.setUserLabeledStatus(True, True) self.setUserLabeledStatus(True, True)
print("after: ", self["track"][self._selection_indices], self["frame"][self._selection_indices]) # print("after: ", self["track"][self._selection_indices], self["frame"][self._selection_indices])
def setUserLabeledStatus(self, new_status: bool, selection=True): def setUserLabeledStatus(self, new_status: bool, selection=True):
"""Sets the status of the "userlabeled" column to a given value (True|False). This can done for ALL data in one go, or only for the UserSelection. """Sets the status of the "userlabeled" column to a given value (True|False). This can done for ALL data in one go, or only for the UserSelection.

View File

@@ -2,7 +2,7 @@ import logging
import numpy as np import numpy as np
from PySide6.QtWidgets import QWidget, QVBoxLayout, QTabWidget, QPushButton, QGraphicsView, QTextEdit from PySide6.QtWidgets import QWidget, QVBoxLayout, QTabWidget, QPushButton, QGraphicsView, QTextEdit
from PySide6.QtWidgets import QSpinBox, QProgressBar, QGridLayout, QLabel, QCheckBox, QProgressDialog from PySide6.QtWidgets import QSpinBox, QProgressBar, QGridLayout, QLabel, QCheckBox, QDoubleSpinBox
from PySide6.QtCore import Qt, Signal, Slot, QRunnable, QObject, QThreadPool from PySide6.QtCore import Qt, Signal, Slot, QRunnable, QObject, QThreadPool
from PySide6.QtGui import QBrush, QColor from PySide6.QtGui import QBrush, QColor
@@ -13,16 +13,17 @@ from fixtracks.utils.trackingdata import TrackingData
from IPython import embed from IPython import embed
class Detection(): class Detection():
def __init__(self, id, frame, track, position, orientation, length, userlabeled): def __init__(self, id, frame, track, position, orientation, length, userlabeled, confidence):
self.id = id self.id = id
self.frame = frame self.frame = frame
self.track = track self.track = track
self.position = position self.position = position
self.score = 0.0 self.confidence = confidence
self.angle = orientation self.angle = orientation
self.length = length self.length = length
self.userlabeled = userlabeled self.userlabeled = userlabeled
class WorkerSignals(QObject): class WorkerSignals(QObject):
message = Signal(str) message = Signal(str)
running = Signal(bool) running = Signal(bool)
@@ -30,7 +31,8 @@ class WorkerSignals(QObject):
currentframe = Signal(int) currentframe = Signal(int)
stopped = Signal(int) stopped = Signal(int)
class ConsitencyDataLoader(QRunnable):
class ConsistencyDataLoader(QRunnable):
def __init__(self, data): def __init__(self, data):
super().__init__() super().__init__()
self.signals = WorkerSignals() self.signals = WorkerSignals()
@@ -40,7 +42,7 @@ class ConsitencyDataLoader(QRunnable):
self.lengths = None self.lengths = None
self.orientations = None self.orientations = None
self.userlabeled = None self.userlabeled = None
self.scores = None self.confidence = None
self.frames = None self.frames = None
self.tracks = None self.tracks = None
@@ -54,15 +56,16 @@ class ConsitencyDataLoader(QRunnable):
self.lengths = self.data.animalLength() self.lengths = self.data.animalLength()
# self.bendedness = self.data.bendedness() # self.bendedness = self.data.bendedness()
self.userlabeled = self.data["userlabeled"] self.userlabeled = self.data["userlabeled"]
self.scores = self.data["confidence"] # ignore for now, let's see how far this carries. self.confidence = self.data["confidence"] # ignore for now, let's see how far this carries.
self.frames = self.data["frame"] self.frames = self.data["frame"]
self.tracks = self.data["track"] self.tracks = self.data["track"]
self.signals.stopped.emit(0) self.signals.stopped.emit(0)
class ConsistencyWorker(QRunnable): class ConsistencyWorker(QRunnable):
def __init__(self, positions, orientations, lengths, bendedness, frames, tracks, def __init__(self, positions, orientations, lengths, bendedness, frames, tracks,
userlabeled, startframe=0, stoponerror=False) -> None: userlabeled, confidence, startframe=0, stoponerror=False, min_confidence=0.0) -> None:
super().__init__() super().__init__()
self.signals = WorkerSignals() self.signals = WorkerSignals()
self.positions = positions self.positions = positions
@@ -70,6 +73,8 @@ class ConsistencyWorker(QRunnable):
self.lengths = lengths self.lengths = lengths
self.bendedness = bendedness self.bendedness = bendedness
self.userlabeled = userlabeled self.userlabeled = userlabeled
self.confidence = confidence
self._min_confidence = min_confidence
self.frames = frames self.frames = frames
self.tracks = tracks self.tracks = tracks
self._startframe = startframe self._startframe = startframe
@@ -88,9 +93,12 @@ class ConsistencyWorker(QRunnable):
if np.any(self.positions[i] < 0.1): if np.any(self.positions[i] < 0.1):
logging.debug("Encountered probably invalid position %s", str(self.positions[i])) logging.debug("Encountered probably invalid position %s", str(self.positions[i]))
continue continue
if self._min_confidence > 0.0 and self.confidence[i] < self._min_confidence:
self.tracks[i] = -1
continue
d = Detection(i, frame, self.tracks[i], self.positions[i], d = Detection(i, frame, self.tracks[i], self.positions[i],
self.orientations[i], self.lengths[i], self.orientations[i], self.lengths[i],
self.userlabeled[i]) self.userlabeled[i], self.confidence[i])
detections.append(d) detections.append(d)
return detections return detections
@@ -127,6 +135,10 @@ class ConsistencyWorker(QRunnable):
return most_likely_track, length_differences return most_likely_track, length_differences
def check_multiple_detections(detections): def check_multiple_detections(detections):
if self._min_confidence > 0.0:
for i, d in enumerate(detections):
if d.confidence < self._min_confidence:
del detections[i]
distances = np.zeros((len(detections), len(detections))) distances = np.zeros((len(detections), len(detections)))
for i, d1 in enumerate(detections): for i, d1 in enumerate(detections):
for j, d2 in enumerate(detections): for j, d2 in enumerate(detections):
@@ -135,6 +147,18 @@ class ConsistencyWorker(QRunnable):
del detections[lowest_dist] del detections[lowest_dist]
return detections return detections
def find_last_userlabeled(startframe):
t1index = np.where((self.frames < startframe) & (self.userlabeled) & (self.tracks == 1))[0][-1]
t2index = np.where((self.frames < startframe) & (self.userlabeled) & (self.tracks == 2))[0][-1]
d1 = Detection(t1index, self.frames[t1index], self.tracks[t1index], self.positions[t1index],
self.orientations[t1index], self.lengths[t1index], self.userlabeled[t1index],
self.confidence[t1index])
d2 = Detection(t1index, self.frames[t2index], self.tracks[t2index], self.positions[t2index],
self.orientations[t2index], self.lengths[t2index], self.userlabeled[t2index],
self.confidence[t1index])
last_detections[1] = d1
last_detections[2] = d2
unique_frames = np.unique(self.frames) unique_frames = np.unique(self.frames)
steps = int((len(unique_frames) - self._startframe) // 100) steps = int((len(unique_frames) - self._startframe) // 100)
errors = 0 errors = 0
@@ -142,6 +166,7 @@ class ConsistencyWorker(QRunnable):
progress = 0 progress = 0
self._stoprequest = False self._stoprequest = False
last_detections = {1: None, 2: None, -1: None} last_detections = {1: None, 2: None, -1: None}
find_last_userlabeled(self._startframe)
for f in unique_frames[unique_frames >= self._startframe]: for f in unique_frames[unique_frames >= self._startframe]:
if self._stoprequest: if self._stoprequest:
@@ -188,7 +213,7 @@ class ConsistencyWorker(QRunnable):
continue continue
if error and self._stoponerror: if error and self._stoponerror:
self.signals.message.emit("Tracking stopped at frame %i.", f) self.signals.message.emit(f"Tracking stopped at frame {f}.")
break break
elif error: elif error:
continue continue
@@ -212,6 +237,7 @@ class ConsistencyWorker(QRunnable):
if assignments[0] == assignments[1]: if assignments[0] == assignments[1]:
d.track = -1 d.track = -1
error = True error = True
errors += 1
message = f"Frame {f}: Classification error: both detections in the same frame are assigned to the same track!" message = f"Frame {f}: Classification error: both detections in the same frame are assigned to the same track!"
break break
elif assignments[0] != assignments[1]: elif assignments[0] != assignments[1]:
@@ -230,6 +256,7 @@ class ConsistencyWorker(QRunnable):
self.tracks[detections[0].id] = -1 self.tracks[detections[0].id] = -1
message = f"Frame: {f}: Decision based on distance not safe. Track set to unassigned." message = f"Frame: {f}: Decision based on distance not safe. Track set to unassigned."
error = True error = True
errors += 1
if not error: if not error:
for k in temp: for k in temp:
@@ -239,7 +266,7 @@ class ConsistencyWorker(QRunnable):
for idx in indices: for idx in indices:
self.tracks[idx] = -1 self.tracks[idx] = -1
errors += 1 errors += 1
if self._stoponerror: if error and self._stoponerror:
self.signals.message.emit(message) self.signals.message.emit(message)
break break
processed += 1 processed += 1
@@ -247,6 +274,7 @@ class ConsistencyWorker(QRunnable):
if steps > 0 and f % steps == 0: if steps > 0 and f % steps == 0:
progress += 1 progress += 1
self.signals.progress.emit(progress, processed, errors) self.signals.progress.emit(progress, processed, errors)
self.signals.message.emit(f"Tracking stopped at frame {f}.") self.signals.message.emit(f"Tracking stopped at frame {f}.")
self.signals.stopped.emit(f) self.signals.stopped.emit(f)
@@ -326,6 +354,7 @@ class SizeClassifier(QWidget):
tracks[(self._sizes >= t2lower) & (self._sizes < t2upper)] = 2 tracks[(self._sizes >= t2lower) & (self._sizes < t2upper)] = 2
return tracks return tracks
class NeighborhoodValidator(QWidget): class NeighborhoodValidator(QWidget):
apply = Signal() apply = Signal()
name = "Neighborhood Validator" name = "Neighborhood Validator"
@@ -450,6 +479,7 @@ class ConsistencyClassifier(QWidget):
self._all_lengths = None self._all_lengths = None
self._all_bendedness = None self._all_bendedness = None
self._all_scores = None self._all_scores = None
self._confidence = None
self._userlabeled = None self._userlabeled = None
self._maxframes = 0 self._maxframes = 0
self._frames = None self._frames = None
@@ -495,30 +525,38 @@ class ConsistencyClassifier(QWidget):
self._stoponerror.setChecked(True) self._stoponerror.setChecked(True)
self.threadpool = QThreadPool() self.threadpool = QThreadPool()
self._ignore_confidence = QCheckBox("Ignore detections widh confidence less than")
self._confidence_spinner = QDoubleSpinBox()
self._confidence_spinner.setRange(0.0, 1.0)
self._confidence_spinner.setSingleStep(0.01)
self._confidence_spinner.setDecimals(2)
self._confidence_spinner.setValue(0.5)
self._messagebox = QTextEdit() self._messagebox = QTextEdit()
self._messagebox.setFocusPolicy(Qt.NoFocus) self._messagebox.setFocusPolicy(Qt.NoFocus)
self._messagebox.setReadOnly(True) self._messagebox.setReadOnly(True)
lyt = QGridLayout() lyt = QGridLayout()
lyt.addWidget(QLabel("Start frame:"), 0, 0 ) lyt.addWidget(QLabel("Start frame:"), 0, 0 )
lyt.addWidget(self._startframe_spinner, 0, 1, 1, 2) lyt.addWidget(self._startframe_spinner, 0, 1, 1, 1)
lyt.addWidget(QLabel("of"), 1, 1, 1, 1) lyt.addWidget(QLabel("of"), 0, 2, 1, 1)
lyt.addWidget(self._maxframeslabel, 1, 2, 1, 1) lyt.addWidget(self._maxframeslabel, 0, 3, 1, 1)
lyt.addWidget(self._stoponerror, 2, 0, 1, 3) lyt.addWidget(self._stoponerror, 1, 0, 1, 3)
lyt.addWidget(QLabel("Current frame"), 3,0) lyt.addWidget(self._ignore_confidence, 3, 0, 1, 3)
lyt.addWidget(self._framelabel, 3,1) lyt.addWidget(self._confidence_spinner, 3, 3, 1, 1)
lyt.addWidget(QLabel("assigned"), 4, 0) lyt.addWidget(QLabel("Current frame"), 4, 0)
lyt.addWidget(self._assignedlabel, 4, 1) lyt.addWidget(self._framelabel, 4, 1)
lyt.addWidget(QLabel("errors/issues"), 5, 0) lyt.addWidget(QLabel("(Re-)Assigned"), 5, 0)
lyt.addWidget(self._errorlabel, 5, 1) lyt.addWidget(self._assignedlabel, 5, 1)
lyt.addWidget(self._messagebox, 6, 0, 2, 3) lyt.addWidget(QLabel("Errors/issues"), 5, 2)
lyt.addWidget(self._errorlabel, 5, 3, 1, 1)
lyt.addWidget(self._messagebox, 6, 0, 2, 4)
lyt.addWidget(self._startbtn, 8, 0) lyt.addWidget(self._startbtn, 8, 0, 1, 2)
lyt.addWidget(self._stopbtn, 8, 1) lyt.addWidget(self._stopbtn, 8, 2)
lyt.addWidget(self._proceedbtn, 8, 2) # lyt.addWidget(self._proceedbtn, 8, 2)
lyt.addWidget(self._apply_btn, 9, 0, 1, 2) lyt.addWidget(self._refreshbtn, 8, 3, 1, 1)
lyt.addWidget(self._refreshbtn, 9, 2, 1, 1) lyt.addWidget(self._apply_btn, 9, 0, 1, 4)
lyt.addWidget(self._progressbar, 10, 0, 1, 3) lyt.addWidget(self._progressbar, 10, 0, 1, 4)
self.setLayout(lyt) self.setLayout(lyt)
def setData(self, data:TrackingData): def setData(self, data:TrackingData):
@@ -543,7 +581,7 @@ class ConsistencyClassifier(QWidget):
self._all_lengths = self._dataworker.lengths self._all_lengths = self._dataworker.lengths
self._all_bendedness = self._dataworker.bendedness self._all_bendedness = self._dataworker.bendedness
self._userlabeled = self._dataworker.userlabeled self._userlabeled = self._dataworker.userlabeled
self._all_scores = self._dataworker.scores self._confidence = self._dataworker.confidence
self._frames = self._dataworker.frames self._frames = self._dataworker.frames
self._tracks = self._dataworker.tracks self._tracks = self._dataworker.tracks
self._dataworker = None self._dataworker = None
@@ -559,14 +597,18 @@ class ConsistencyClassifier(QWidget):
self._messagebox.append("Error preparing data! Make sure that the first user-labeled frames contain both tracks!") self._messagebox.append("Error preparing data! Make sure that the first user-labeled frames contain both tracks!")
self.setEnabled(False) self.setEnabled(False)
return return
max_startframe = np.min([t1_userlabeled[-1], t2_userlabeled[-1]]) max_startframe = np.min([t1_userlabeled[-1], t2_userlabeled[-1]]) -1
min_startframe = np.max([t1_userlabeled[0], t2_userlabeled[0]]) first_guess = np.max([t1_userlabeled[0], t2_userlabeled[0]])
while first_guess not in t1_userlabeled or first_guess not in t2_userlabeled:
first_guess += 1
min_startframe = first_guess + 1
self._maxframes = np.max(self._frames) self._maxframes = np.max(self._frames)
self._maxframeslabel.setText(str(self._maxframes)) self._maxframeslabel.setText(str(self._maxframes))
self._startframe_spinner.setMinimum(min_startframe) self._startframe_spinner.setMinimum(min_startframe)
self._startframe_spinner.setMaximum(max_startframe) self._startframe_spinner.setMaximum(max_startframe)
self._startframe_spinner.setValue(min_startframe) self._startframe_spinner.setValue(min_startframe)
self._startframe_spinner.setSingleStep(20) self._startframe_spinner.setSingleStep(20)
self._startframe_spinner.setToolTip(f"Maximum possible start frame: {max_startframe}")
self._startbtn.setEnabled(True) self._startbtn.setEnabled(True)
self._assignedlabel.setText("0") self._assignedlabel.setText("0")
self._errorlabel.setText("0") self._errorlabel.setText("0")
@@ -583,12 +625,14 @@ class ConsistencyClassifier(QWidget):
self._messagebox.append("Stopping tracking.") self._messagebox.append("Stopping tracking.")
def start(self): def start(self):
confidence_level = self._confidence_spinner.value() if self._ignore_confidence.isChecked() else 0.0
self._startbtn.setEnabled(False) self._startbtn.setEnabled(False)
self._refreshbtn.setEnabled(False) self._refreshbtn.setEnabled(False)
self._stopbtn.setEnabled(True) self._stopbtn.setEnabled(True)
self._worker = ConsistencyWorker(self._all_pos, self._all_orientations, self._all_lengths, self._worker = ConsistencyWorker(self._all_pos, self._all_orientations, self._all_lengths,
self._all_bendedness, self._frames, self._tracks, self._userlabeled, self._all_bendedness, self._frames, self._tracks, self._userlabeled,
self._startframe_spinner.value(), self._stoponerror.isChecked()) self._confidence, self._startframe_spinner.value(), self._stoponerror.isChecked(),
min_confidence=confidence_level)
self._worker.signals.stopped.connect(self.worker_stopped) self._worker.signals.stopped.connect(self.worker_stopped)
self._worker.signals.progress.connect(self.worker_progress) self._worker.signals.progress.connect(self.worker_progress)
self._worker.signals.message.connect(self.worker_error) self._worker.signals.message.connect(self.worker_error)
@@ -607,7 +651,7 @@ class ConsistencyClassifier(QWidget):
def refresh(self): def refresh(self):
self.setEnabled(False) self.setEnabled(False)
self._dataworker = ConsitencyDataLoader(self._data) self._dataworker = ConsistencyDataLoader(self._data)
self._dataworker.signals.stopped.connect(self.data_processed) self._dataworker.signals.stopped.connect(self.data_processed)
self._messagebox.clear() self._messagebox.clear()
self._messagebox.append("Refreshing...") self._messagebox.append("Refreshing...")

View File

@@ -21,10 +21,10 @@ class Window(QGraphicsRectItem):
self.setBrush(brush) self.setBrush(brush)
self.setZValue(1.0) self.setZValue(1.0)
self.setAcceptHoverEvents(True) # Enable hover events if needed self.setAcceptHoverEvents(True) # Enable hover events if needed
self.setFlags( # self.setFlags(
QGraphicsItem.ItemIsMovable | # Enables item dragging # QGraphicsItem.ItemIsMovable | # Enables item dragging
QGraphicsItem.ItemIsSelectable # Enables item selection # QGraphicsItem.ItemIsSelectable # Enables item selection
) # )
self._y = y self._y = y
def setWindowX(self, newx): def setWindowX(self, newx):
@@ -65,6 +65,7 @@ class Window(QGraphicsRectItem):
def mousePressEvent(self, event): def mousePressEvent(self, event):
self.setCursor(Qt.ClosedHandCursor) self.setCursor(Qt.ClosedHandCursor)
# print(event.pos())
super().mousePressEvent(event) super().mousePressEvent(event)
def mouseReleaseEvent(self, event): def mouseReleaseEvent(self, event):
@@ -121,6 +122,7 @@ class DetectionTimeline(QWidget):
self._scene = QGraphicsScene(QRectF(0, 0, self._total_width, 85.)) self._scene = QGraphicsScene(QRectF(0, 0, self._total_width, 85.))
self._scene.setBackgroundBrush(self._bg_brush) self._scene.setBackgroundBrush(self._bg_brush)
self._scene.addItem(self._window) self._scene.addItem(self._window)
self._scene.mousePressEvent = self.on_sceneMousePress
self._view = QGraphicsView() self._view = QGraphicsView()
# self._view.setRenderHints(QPainter.RenderHint.Antialiasing | QPainter.RenderHint.SmoothPixmapTransform) # self._view.setRenderHints(QPainter.RenderHint.Antialiasing | QPainter.RenderHint.SmoothPixmapTransform)
@@ -151,12 +153,22 @@ class DetectionTimeline(QWidget):
self._position_label.setFont(f) self._position_label.setFont(f)
layout = QVBoxLayout() layout = QVBoxLayout()
layout.setSpacing(0)
layout.setContentsMargins(5, 2, 5, 2)
layout.addWidget(self._view) layout.addWidget(self._view)
layout.addWidget(self._position_label, Qt.AlignmentFlag.AlignRight) layout.addWidget(self._position_label, Qt.AlignmentFlag.AlignRight)
self.setLayout(layout) self.setLayout(layout)
# self.setMaximumHeight(100) # self.setMaximumHeight(100)
# self.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Fixed) # self.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Fixed)
def on_sceneMousePress(self, event):
scene_pos = event.scenePos()
relpos = scene_pos.x() / self._total_width
relpos = 0 if relpos < 0.0 else relpos
relpos = 2000/self._total_width if scene_pos.x() > self._total_width else relpos
self.signals.moveRequest.emit(relpos)
logging.debug("Timeline: Scene clicked at position: %.2f, %.2f --> rel x-pos %.3f", scene_pos.x(), scene_pos.y(), relpos)
def clear(self): def clear(self):
for i in self._scene.items(): for i in self._scene.items():
if isinstance(i, (QGraphicsLineItem, QGraphicsEllipseItem)): if isinstance(i, (QGraphicsLineItem, QGraphicsEllipseItem)):
@@ -166,6 +178,7 @@ class DetectionTimeline(QWidget):
logging.debug("Timeline: setData!") logging.debug("Timeline: setData!")
self._data = data self._data = data
self.update() self.update()
self.resizeEvent(None)
def update(self): def update(self):
self.clear() self.clear()
@@ -309,8 +322,7 @@ def main():
datafile = PACKAGE_ROOT / "data/merged_small.pkl" datafile = PACKAGE_ROOT / "data/merged_small.pkl"
with open(datafile, "rb") as f: with open(datafile, "rb") as f:
df = pickle.load(f) df = pickle.load(f)
data = TrackingData() data = TrackingData(as_dict(df))
data.setData(as_dict(df))
data.setSelection(np.arange(0,100, 1)) data.setSelection(np.arange(0,100, 1))
data.setUserLabeledStatus(True) data.setUserLabeledStatus(True)
start_x = 0.1 start_x = 0.1
@@ -329,12 +341,14 @@ def main():
backBtn.clicked.connect(lambda: back(0.2)) backBtn.clicked.connect(lambda: back(0.2))
btnLyt = QHBoxLayout() btnLyt = QHBoxLayout()
btnLyt.setSpacing(1)
btnLyt.addWidget(backBtn) btnLyt.addWidget(backBtn)
btnLyt.addWidget(zeroBtn) btnLyt.addWidget(zeroBtn)
btnLyt.addWidget(fwdBtn) btnLyt.addWidget(fwdBtn)
view.setWindowPos(start_x) view.setWindowPos(start_x)
layout = QVBoxLayout() layout = QVBoxLayout()
layout.setSpacing(1)
layout.addWidget(view) layout.addWidget(view)
layout.addLayout(btnLyt) layout.addLayout(btnLyt)
window.setLayout(layout) window.setLayout(layout)

View File

@@ -10,6 +10,7 @@ from fixtracks.utils.signals import DetectionSignals, DetectionViewSignals, Dete
from fixtracks.utils.enums import DetectionData, Tracks from fixtracks.utils.enums import DetectionData, Tracks
from fixtracks.utils.trackingdata import TrackingData from fixtracks.utils.trackingdata import TrackingData
class Detection(QGraphicsEllipseItem): class Detection(QGraphicsEllipseItem):
signals = DetectionSignals() signals = DetectionSignals()
@@ -138,11 +139,13 @@ class DetectionView(QWidget):
coordinates = self._data.coordinates(selection=True) coordinates = self._data.coordinates(selection=True)
centercoordinates = self._data.centerOfGravity(selection=True) centercoordinates = self._data.centerOfGravity(selection=True)
userlabeled = self._data.selectedData("userlabeled") userlabeled = self._data.selectedData("userlabeled")
scores = self._data.selectedData("confidence")
image_rect = self._pixmapitem.boundingRect() if self._pixmapitem is not None else QRectF(0,0,0,0) image_rect = self._pixmapitem.boundingRect() if self._pixmapitem is not None else QRectF(0,0,0,0)
num_detections = len(frames)
for i, (id, f, t, l) in enumerate(zip(ids, frames, tracks, userlabeled)): for i, (id, f, t, l, s) in enumerate(zip(ids, frames, tracks, userlabeled, scores)):
c = Tracks.fromValue(t).toColor() c = Tracks.fromValue(t).toColor()
c.setAlpha(int(i * 255 / num_detections))
if keypoint >= 0: if keypoint >= 0:
x = coordinates[i, keypoint, 0] x = coordinates[i, keypoint, 0]
y = coordinates[i, keypoint, 1] y = coordinates[i, keypoint, 1]
@@ -156,6 +159,7 @@ class DetectionView(QWidget):
item.setData(DetectionData.COORDINATES.value, coordinates[i, :, :]) item.setData(DetectionData.COORDINATES.value, coordinates[i, :, :])
item.setData(DetectionData.FRAME.value, f) item.setData(DetectionData.FRAME.value, f)
item.setData(DetectionData.USERLABELED.value, l) item.setData(DetectionData.USERLABELED.value, l)
item.setData(DetectionData.SCORE.value, s)
item = self._scene.addItem(item) item = self._scene.addItem(item)
def fit_image_to_view(self): def fit_image_to_view(self):
@@ -214,7 +218,7 @@ def main():
view.setImage(img) view.setImage(img)
view.addDetections(bg_coords, bg_tracks, bg_ids, background_brush) view.addDetections(bg_coords, bg_tracks, bg_ids, background_brush)
view.addDetections(focus_coords, focus_tracks, focus_ids, focus_brush) view.addDetections(focus_coords, focus_tracks, focus_ids, focus_brush)
view.addDetections(scnd_coords, scnd_tracks, scnd_ids, second_brush) view.addDetections(scnd_coords, scnd_tracks, scnd_ids, second_brush)
window.setLayout(layout) window.setLayout(layout)
window.show() window.show()
app.exec() app.exec()

View File

@@ -52,7 +52,6 @@ class SelectionControls(QWidget):
quarterstepBackBtn.setStyleSheet(pushBtnStyle("darkgray")) quarterstepBackBtn.setStyleSheet(pushBtnStyle("darkgray"))
quarterstepBackBtn.clicked.connect(lambda: self.on_Back(quarterstep)) quarterstepBackBtn.clicked.connect(lambda: self.on_Back(quarterstep))
fwdBtn = QPushButton(">>|") fwdBtn = QPushButton(">>|")
fwdBtn.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding) fwdBtn.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding)
fwdBtn.setShortcut(Qt.Key.Key_Right) fwdBtn.setShortcut(Qt.Key.Key_Right)
@@ -173,7 +172,7 @@ class SelectionControls(QWidget):
grid.setColumnStretch(0, 1) grid.setColumnStretch(0, 1)
grid.setColumnStretch(7, 1) grid.setColumnStretch(7, 1)
self.setLayout(grid) self.setLayout(grid)
self.setMaximumSize(QSize(500, 500)) # self.setMaximumSize(QSize(500, 500))
def setWindow(self, start:int=0, end:int=0): def setWindow(self, start:int=0, end:int=0):
self.startframe.setText(f"{start:.0f}") self.startframe.setText(f"{start:.0f}")

View File

@@ -94,7 +94,8 @@ class SkeletonWidget(QWidget):
i = s.data(DetectionData.ID.value) i = s.data(DetectionData.ID.value)
t = s.data(DetectionData.TRACK_ID.value) t = s.data(DetectionData.TRACK_ID.value)
f = s.data(DetectionData.FRAME.value) f = s.data(DetectionData.FRAME.value)
self._info_label.setText(f"Id {i}, track {t} on frame {f}, length {l:.1f} px") sc = s.data(DetectionData.SCORE.value)
self._info_label.setText(f"Id {i}, track {t} on frame {f}, length {l:.1f} px, confidence {sc:.2f}")
else: else:
self._info_label.setText("") self._info_label.setText("")
@@ -129,7 +130,7 @@ class SkeletonWidget(QWidget):
self._scene.setSceneRect(self._minx, self._miny, self._maxx - self._minx, self._maxy - self._miny) self._scene.setSceneRect(self._minx, self._miny, self._maxx - self._minx, self._maxy - self._miny)
self._view.fitInView(self._scene.sceneRect(), Qt.AspectRatioMode.KeepAspectRatio) self._view.fitInView(self._scene.sceneRect(), Qt.AspectRatioMode.KeepAspectRatio)
def addSkeleton(self, coords, detection_id, frame, track, brush, update=True): def addSkeleton(self, coords, detection_id, frame, track, score, brush, update=True):
def check_extent(x, y, w, h): def check_extent(x, y, w, h):
if x == 0 and y == 0: if x == 0 and y == 0:
return return
@@ -157,12 +158,14 @@ class SkeletonWidget(QWidget):
item.setData(DetectionData.ID.value, detection_id) item.setData(DetectionData.ID.value, detection_id)
item.setData(DetectionData.TRACK_ID.value, track) item.setData(DetectionData.TRACK_ID.value, track)
item.setData(DetectionData.FRAME.value, frame) item.setData(DetectionData.FRAME.value, frame)
item.setData(DetectionData.SCORE.value, score)
self._skeletons.append(item) self._skeletons.append(item)
if update: if update:
self.update() self.update()
def addSkeletons(self, coordinates:np.ndarray, detection_ids:np.ndarray, def addSkeletons(self, coordinates:np.ndarray, detection_ids:np.ndarray,
frames:np.ndarray, tracks:np.ndarray, brush:QBrush): frames:np.ndarray, tracks:np.ndarray, scores:np.ndarray,
brush:QBrush):
num_detections = 0 if coordinates is None else coordinates.shape[0] num_detections = 0 if coordinates is None else coordinates.shape[0]
logging.debug("SkeletonWidget: add %i Skeletons", num_detections) logging.debug("SkeletonWidget: add %i Skeletons", num_detections)
if num_detections < 1: if num_detections < 1:
@@ -172,9 +175,10 @@ class SkeletonWidget(QWidget):
detection_ids = detection_ids[sorting] detection_ids = detection_ids[sorting]
frames = frames[sorting] frames = frames[sorting]
tracks = tracks[sorting] tracks = tracks[sorting]
scores = scores[sorting]
for i in range(num_detections): for i in range(num_detections):
self.addSkeleton(coordinates[i,:,:], detection_ids[i], frames[i], self.addSkeleton(coordinates[i,:,:], detection_ids[i], frames[i],
tracks[i], brush=brush, update=False) tracks[i], scores[i], brush=brush, update=False)
self.update() self.update()
# def addSkeleton(self, coords, detection_id, brush): # def addSkeleton(self, coords, detection_id, brush):

View File

@@ -45,6 +45,7 @@ class FixTracks(QWidget):
self._timeline = DetectionTimeline() self._timeline = DetectionTimeline()
self._timeline.signals.windowMoved.connect(self.on_windowChanged) self._timeline.signals.windowMoved.connect(self.on_windowChanged)
self._timeline.signals.moveRequest.connect(self.on_moveRequest)
self._windowspinner = QSpinBox() self._windowspinner = QSpinBox()
self._windowspinner.setRange(10, 10000) self._windowspinner.setRange(10, 10000)
@@ -55,13 +56,26 @@ class FixTracks(QWidget):
self._keypointcombo = QComboBox() self._keypointcombo = QComboBox()
self._keypointcombo.currentIndexChanged.connect(self.on_keypointSelected) self._keypointcombo.currentIndexChanged.connect(self.on_keypointSelected)
self._goto_spinner = QSpinBox()
self._goto_spinner.setSingleStep(1)
self._gotobtn = QPushButton("go!")
self._gotobtn.setToolTip("Jump to a given frame")
self._gotobtn.clicked.connect(self.on_goto)
combo_layout = QHBoxLayout() combo_layout = QHBoxLayout()
combo_layout.addWidget(QLabel("Window width:")) combo_layout.addWidget(QLabel("Window width:"))
combo_layout.addWidget(self._windowspinner) combo_layout.addWidget(self._windowspinner)
combo_layout.addWidget(QLabel("frames")) combo_layout.addWidget(QLabel("frames"))
combo_layout.addItem(QSpacerItem(10, 10, QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Fixed))
combo_layout.addWidget(QLabel("Keypoint:")) combo_layout.addWidget(QLabel("Keypoint:"))
combo_layout.addWidget(self._keypointcombo) combo_layout.addWidget(self._keypointcombo)
combo_layout.addItem(QSpacerItem(10, 10, QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Fixed))
combo_layout.addWidget(QLabel("Jump to frame:"))
combo_layout.addWidget(self._goto_spinner)
combo_layout.addWidget(self._gotobtn)
combo_layout.addItem(QSpacerItem(100, 10, QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Fixed)) combo_layout.addItem(QSpacerItem(100, 10, QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Fixed))
combo_layout.setSpacing(1)
timelinebox = QVBoxLayout() timelinebox = QVBoxLayout()
timelinebox.setSpacing(2) timelinebox.setSpacing(2)
@@ -101,6 +115,7 @@ class FixTracks(QWidget):
data_selection_box.addWidget(QLabel("Select data file")) data_selection_box.addWidget(QLabel("Select data file"))
data_selection_box.addWidget(self._data_combo) data_selection_box.addWidget(self._data_combo)
data_selection_box.addItem(QSpacerItem(100, 10, QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Fixed)) data_selection_box.addItem(QSpacerItem(100, 10, QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Fixed))
data_selection_box.setSpacing(0)
btnBox = QHBoxLayout() btnBox = QHBoxLayout()
btnBox.setAlignment(Qt.AlignmentFlag.AlignLeft) btnBox.setAlignment(Qt.AlignmentFlag.AlignLeft)
@@ -118,8 +133,12 @@ class FixTracks(QWidget):
cntrlBox.addWidget(self._controls_widget, alignment=Qt.AlignmentFlag.AlignCenter) cntrlBox.addWidget(self._controls_widget, alignment=Qt.AlignmentFlag.AlignCenter)
cntrlBox.addWidget(self._skeleton) cntrlBox.addWidget(self._skeleton)
cntrlBox.addItem(QSpacerItem(50, 100, QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Expanding)) cntrlBox.addItem(QSpacerItem(50, 100, QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Expanding))
cntrlBox.setSpacing(0)
cntrlBox.setContentsMargins(0,0,0,0)
vbox = QVBoxLayout() vbox = QVBoxLayout()
vbox.setSpacing(0)
vbox.setContentsMargins(0,0,0,0)
vbox.addLayout(timelinebox) vbox.addLayout(timelinebox)
vbox.addLayout(cntrlBox) vbox.addLayout(cntrlBox)
vbox.addLayout(btnBox) vbox.addLayout(btnBox)
@@ -131,9 +150,12 @@ class FixTracks(QWidget):
splitter.addWidget(container) splitter.addWidget(container)
splitter.setStretchFactor(0, 3) splitter.setStretchFactor(0, 3)
splitter.setStretchFactor(1, 1) splitter.setStretchFactor(1, 1)
layout = QVBoxLayout() layout = QVBoxLayout()
layout.addLayout(data_selection_box) layout.addLayout(data_selection_box)
layout.addWidget(splitter) layout.addWidget(splitter)
layout.setSpacing(0)
layout.setContentsMargins(5,2,2,5)
self.setLayout(layout) self.setLayout(layout)
def on_autoClassify(self, tracks): def on_autoClassify(self, tracks):
@@ -159,15 +181,19 @@ class FixTracks(QWidget):
self._detectionView.setImage(img) self._detectionView.setImage(img)
def update(self): def update(self):
kp = self._keypointcombo.currentText().lower()
if len(kp) == 0:
return
kpi = -1 if "center" in kp else int(kp)
start_frame = self._currentWindowPos start_frame = self._currentWindowPos
stop_frame = start_frame + self._currentWindowWidth stop_frame = start_frame + self._currentWindowWidth
self._timeline.setWindow(start_frame / self._maxframes, self._timeline.setWindow(start_frame / self._maxframes,
self._currentWindowWidth/self._maxframes) self._currentWindowWidth/self._maxframes)
logging.debug("Tracks:update: Updating View for detection range %i, %i frames", start_frame, stop_frame) logging.debug("Tracks:update: Updating View for detection range %i, %i frames", start_frame, stop_frame)
self._data.setSelectionRange("frame", start_frame, stop_frame) self._data.setSelectionRange("frame", start_frame, stop_frame)
self._controls_widget.setWindow(start_frame, stop_frame) self._controls_widget.setWindow(start_frame, stop_frame)
kp = self._keypointcombo.currentText().lower()
kpi = -1 if "center" in kp else int(kp)
self._detectionView.updateDetections(kpi) self._detectionView.updateDetections(kpi)
@property @property
@@ -210,6 +236,7 @@ class FixTracks(QWidget):
self._currentWindowPos = 0 self._currentWindowPos = 0
self._currentWindowWidth = self._windowspinner.value() self._currentWindowWidth = self._windowspinner.value()
self._maxframes = np.max(self._data["frame"]) self._maxframes = np.max(self._data["frame"])
self._goto_spinner.setMaximum(self._maxframes)
self.populateKeypointCombo(self._data.numKeypoints()) self.populateKeypointCombo(self._data.numKeypoints())
self._timeline.setData(self._data) self._timeline.setData(self._data)
# self._timeline.setWindow(self._currentWindowPos / self._maxframes, # self._timeline.setWindow(self._currentWindowPos / self._maxframes,
@@ -317,6 +344,11 @@ class FixTracks(QWidget):
self.update() self.update()
self._manualmove = False self._manualmove = False
def on_moveRequest(self, pos):
new_pos = int(np.round(pos * self._maxframes))
self._currentWindowPos = new_pos
self.update()
def on_windowSizeChanged(self, value): def on_windowSizeChanged(self, value):
"""Reacts on the user window-width selection. Selection is done in the unit of frames. """Reacts on the user window-width selection. Selection is done in the unit of frames.
@@ -327,17 +359,29 @@ class FixTracks(QWidget):
""" """
self._currentWindowWidth = value self._currentWindowWidth = value
logging.debug("Tracks:OnWindowSizeChanged %i franes", value) logging.debug("Tracks:OnWindowSizeChanged %i franes", value)
if self._maxframes == 0: # if self._maxframes == 0:
self._timeline.setWindowWidth(self._currentWindowWidth / 2000) # self._timeline.setWindowWidth(self._currentWindowWidth / 2000)
else: # else:
self._timeline.setWindowWidth(self._currentWindowWidth / self._maxframes) # self._timeline.setWindowWidth(self._currentWindowWidth / self._maxframes)
self._controls_widget.setSelectedTracks(None) # self._controls_widget.setSelectedTracks(None)
self.update()
def on_goto(self):
target = self._goto_spinner.value()
if target > self._maxframes - self._currentWindowWidth:
target = self._maxframes - self._currentWindowWidth
logging.info("Jump to frame %i", target)
self._currentWindowPos = target
self._timeline.setWindow(self._currentWindowPos / self._maxframes,
self._currentWindowWidth / self._maxframes)
self.update()
def on_detectionsSelected(self, detections): def on_detectionsSelected(self, detections):
logging.debug("Tracks: %i Detections selected", len(detections)) logging.debug("Tracks: %i Detections selected", len(detections))
tracks = np.zeros(len(detections), dtype=int) tracks = np.zeros(len(detections), dtype=int)
ids = np.zeros_like(tracks) ids = np.zeros_like(tracks)
frames = np.zeros_like(tracks) frames = np.zeros_like(tracks)
scores = np.zeros(tracks.shape, dtype=float)
coordinates = None coordinates = None
if len(detections) > 0: if len(detections) > 0:
c = detections[0].data(DetectionData.COORDINATES.value) c = detections[0].data(DetectionData.COORDINATES.value)
@@ -348,15 +392,20 @@ class FixTracks(QWidget):
ids[i] = d.data(DetectionData.ID.value) ids[i] = d.data(DetectionData.ID.value)
frames[i] = d.data(DetectionData.FRAME.value) frames[i] = d.data(DetectionData.FRAME.value)
coordinates[i, :, :] = d.data(DetectionData.COORDINATES.value) coordinates[i, :, :] = d.data(DetectionData.COORDINATES.value)
scores[i] = d.data(DetectionData.SCORE.value)
self._data.setSelection(ids) self._data.setSelection(ids)
self._controls_widget.setSelectedTracks(tracks) self._controls_widget.setSelectedTracks(tracks)
self._skeleton.clear() self._skeleton.clear()
self._skeleton.addSkeletons(coordinates, ids, frames, tracks, QBrush(QColor(10, 255, 65, 255))) self._skeleton.addSkeletons(coordinates, ids, frames, tracks, scores, QBrush(QColor(10, 255, 65, 255)))
def moveWindow(self, stepsize): def moveWindow(self, stepsize):
logging.info("Tracks.moveWindow: move window with stepsize %.2f", stepsize) logging.info("Tracks.moveWindow: move window with stepsize %.2f", stepsize)
self._manualmove = True self._manualmove = True
new_start_frame = self._currentWindowPos + np.round(stepsize * self._currentWindowWidth) new_start_frame = self._currentWindowPos + np.round(stepsize * self._currentWindowWidth)
if new_start_frame < 0:
new_start_frame = 0
elif new_start_frame + self._currentWindowWidth > self._maxframes:
new_start_frame = self._maxframes - self._currentWindowWidth
self._currentWindowPos = new_start_frame self._currentWindowPos = new_start_frame
self._controls_widget.setSelectedTracks(None) self._controls_widget.setSelectedTracks(None)
self.update() self.update()