Compare commits

...

21 Commits

Author SHA1 Message Date
00b6b54db9 [tracks] react on move request and move window 2025-03-01 17:11:37 +01:00
e3b26c3da4 [Timeline] send move request signal when user clicks onto the scene 2025-03-01 17:11:01 +01:00
30741be200 [tracks] changing window size immediately updates the view 2025-03-01 16:51:39 +01:00
50d982c93b [detectionview] fix bug, when no detection is in the current selection 2025-03-01 16:50:27 +01:00
6c54a86cde fix bug in tracks 2025-02-28 17:20:21 +01:00
03ebb6485a [some improvements] 2025-02-28 08:46:19 +01:00
116e0ce5de [wip] add score to items, and ignore them 2025-02-28 08:12:04 +01:00
d1b5776e69 [controls] remove size limit 2025-02-27 19:59:50 +01:00
4a76655766 [detections] add score to items 2025-02-27 19:59:46 +01:00
ae24463be2 [classifier] make sure, we always start with user labeled detections 2025-02-27 17:41:07 +01:00
15264dbe48 [tracks] allow jumping to a given frame 2025-02-27 16:14:48 +01:00
9d38421e02 [timeline] disable mouse dragging for now, trigger resize on setData 2025-02-26 11:23:20 +01:00
1c2f84b236 [enums] add userlabeled flag to detectionItems 2025-02-26 11:16:25 +01:00
ff3e0841a6 [classifier] better messaging 2025-02-26 11:16:04 +01:00
9e2c6f343a [trackingdata] fixes of selection handling, ...
something is still off with the deletion...
2025-02-26 11:15:49 +01:00
c0a7631acd [detectionview] simplify indexing 2025-02-26 11:15:15 +01:00
5758cf61c6 [selection] add shortcut, disable deletion for now 2025-02-26 11:14:49 +01:00
faf095a2a1 [tracks] add warnings around dangerzone actions 2025-02-26 11:14:18 +01:00
0c5e5629b7 [tracks] add way to flag many detections in one go 2025-02-26 09:24:20 +01:00
4ef6143d14 [tracks] layout tweaks 2025-02-26 08:32:02 +01:00
d6b91c25d2 [classifier] kind of handling mulitple detections in one frame 2025-02-26 08:19:59 +01:00
9 changed files with 329 additions and 144 deletions

View File

@@ -7,6 +7,8 @@ class DetectionData(Enum):
FRAME = 1 FRAME = 1
COORDINATES = 2 COORDINATES = 2
TRACK_ID = 3 TRACK_ID = 3
USERLABELED = 4
SCORE = 5
class Tracks(Enum): class Tracks(Enum):
TRACKONE = 1 TRACKONE = 1

View File

@@ -23,6 +23,7 @@ class DetectionSceneSignals(QObject):
class DetectionTimelineSignals(QObject): class DetectionTimelineSignals(QObject):
windowMoved = Signal() windowMoved = Signal()
manualMove = Signal() manualMove = Signal()
moveRequest = Signal(float)
class DetectionSignals(QObject): class DetectionSignals(QObject):
hover = Signal((int, QPointF)) hover = Signal((int, QPointF))

View File

@@ -14,8 +14,8 @@ class TrackingData(QObject):
if "userlabeled" not in self._data.keys(): if "userlabeled" not in self._data.keys():
self._data["userlabeled"] = np.zeros_like(self["frame"], dtype=bool) self._data["userlabeled"] = np.zeros_like(self["frame"], dtype=bool)
self._columns = [k for k in self._data.keys()] self._columns = [k for k in self._data.keys()]
self._indices = self["index"] self._indices = np.arange(len(self["index"]), dtype=int)
self._selection = np.asarray([]) self._selection_indices = np.asarray([])
self._selected_ids = None self._selected_ids = None
@property @property
@@ -43,8 +43,8 @@ class TrackingData(QObject):
ids = np.sort(ids) ids = np.sort(ids)
indexes = np.ones_like(ids, dtype=int) * -1 indexes = np.ones_like(ids, dtype=int) * -1
j = 0 j = 0
for idx, i in enumerate(self._indices): for idx in self._indices:
if i == ids[j]: if self["index"][idx] == ids[j]:
indexes[j] = idx indexes[j] = idx
j += 1 j += 1
if j == len(indexes): if j == len(indexes):
@@ -54,19 +54,23 @@ class TrackingData(QObject):
@property @property
def selectionIndices(self): def selectionIndices(self):
return self._selection return self._selection_indices
@property
def selectionIDs(self):
return self._selected_ids
def setSelectionRange(self, col, start, stop): def setSelectionRange(self, col, start, stop):
logging.info("Trackingdata: set selection range based on column %s to %.2f - %.2f", col, start, stop) logging.info("Trackingdata: set selection range based on column %s to %.2f - %.2f", col, start, stop)
col_indices = np.where((self[col] >= start) & (self[col] < stop))[0] col_indices = np.where((self[col] >= start) & (self[col] < stop))[0]
self._selection = self._indices[col_indices] self._selection_indices = self._indices[col_indices]
if len(col_indices) < 1: if len(col_indices) < 1:
logging.warning("TrackingData: Selection range is empty!") logging.warning("TrackingData: Selection range is empty!")
def selectedData(self, col:str): def selectedData(self, col:str):
if col not in self.columns: if col not in self.columns:
logging.error("TrackingData:selectedData: Invalid column name! %s", col) logging.error("TrackingData:selectedData: Invalid column name! %s", col)
return self[col][self._selection] return self[col][self._selection_indices]
def setSelection(self, ids): def setSelection(self, ids):
""" """
@@ -78,8 +82,9 @@ class TrackingData(QObject):
An array-like object containing the IDs to be set as user selections. An array-like object containing the IDs to be set as user selections.
""" """
logging.debug("TrackingData.setSelection: %i number of ids", len(ids)) logging.debug("TrackingData.setSelection: %i number of ids", len(ids))
self._selection = self._find(ids) self._selection_indices = self._find(ids)
self._selected_ids = ids self._selected_ids = ids
# print(self._selected_ids, self._selection_indices)
def setTrack(self, track_id:int, setUserLabeled:bool=True)-> None: def setTrack(self, track_id:int, setUserLabeled:bool=True)-> None:
"""Assign a new track_id to the user-selected detections """Assign a new track_id to the user-selected detections
@@ -92,9 +97,12 @@ class TrackingData(QObject):
Should the "userlabeled" state of the detections be set to True? Otherwise they will be left untouched. Should the "userlabeled" state of the detections be set to True? Otherwise they will be left untouched.
""" """
logging.info("TrackingData: set track id %i for selection, set user-labeled status %s", track_id, str(setUserLabeled)) logging.info("TrackingData: set track id %i for selection, set user-labeled status %s", track_id, str(setUserLabeled))
self["track"][self._selection] = track_id # print(self._selected_ids, self._selection_indices)
# print("before: ", self["track"][self._selection_indices], self["frame"][self._selection_indices])
self["track"][self._selection_indices] = track_id
if setUserLabeled: if setUserLabeled:
self.setUserLabeledStatus(True, True) self.setUserLabeledStatus(True, True)
# print("after: ", self["track"][self._selection_indices], self["frame"][self._selection_indices])
def setUserLabeledStatus(self, new_status: bool, selection=True): def setUserLabeledStatus(self, new_status: bool, selection=True):
"""Sets the status of the "userlabeled" column to a given value (True|False). This can done for ALL data in one go, or only for the UserSelection. """Sets the status of the "userlabeled" column to a given value (True|False). This can done for ALL data in one go, or only for the UserSelection.
@@ -109,7 +117,7 @@ class TrackingData(QObject):
logging.debug("TrackingData: (Re-)setting assignment status of %s to %s", logging.debug("TrackingData: (Re-)setting assignment status of %s to %s",
"user selected data" if selection else " ALL", str(new_status)) "user selected data" if selection else " ALL", str(new_status))
if selection: if selection:
self["userlabeled"][self._selection] = new_status self["userlabeled"][self._selection_indices] = new_status
else: else:
self["userlabeled"][:] = new_status self["userlabeled"][:] = new_status
@@ -123,12 +131,14 @@ class TrackingData(QObject):
def deleteDetections(self, ids=None): def deleteDetections(self, ids=None):
if ids is not None: if ids is not None:
logging.debug("TrackingData.deleteDetections of %i detections", len(ids))
del_indices = self._find(ids) del_indices = self._find(ids)
else: else:
del_indices = self._indices logging.debug("TrackingData.deleteDetections of all selected detections (%i)", len(self._selected_ids))
del_indices = self._selected_ids
for c in self._columns: for c in self._columns:
self._data[c] = np.delete(self._data[c], del_indices, axis=0) self._data[c] = np.delete(self._data[c], del_indices, axis=0)
self._indices = self["index"] self._indices = self._indices[:-len(del_indices)]
self._selected_ids = np.setdiff1d(self._selected_ids, del_indices) self._selected_ids = np.setdiff1d(self._selected_ids, del_indices)
def assignTracks(self, tracks:np.ndarray): def assignTracks(self, tracks:np.ndarray):
@@ -171,10 +181,10 @@ class TrackingData(QObject):
and M is number of keypoints and M is number of keypoints
""" """
if selection: if selection:
if len(self._selection) < 1: if len(self._selection_indices) < 1:
logging.info("TrackingData.coordinates returns empty array, not detections in range!") logging.info("TrackingData.coordinates returns empty array, not detections in range!")
return np.ndarray([]) return np.ndarray([])
return np.stack(self["keypoints"][self._selection]).astype(np.float32) return np.stack(self["keypoints"][self._selection_indices]).astype(np.float32)
return np.stack(self["keypoints"]).astype(np.float32) return np.stack(self["keypoints"]).astype(np.float32)
def keypointScores(self, selection=False): def keypointScores(self, selection=False):
@@ -188,10 +198,10 @@ class TrackingData(QObject):
with N the number of detections and M the number of keypoints. with N the number of detections and M the number of keypoints.
""" """
if selection: if selection:
if len(self._selection) < 1: if len(self._selection_indices) < 1:
logging.info("TrackingData.scores returns empty array, not detections in range!") logging.info("TrackingData.scores returns empty array, not detections in range!")
return None return None
return np.stack(self["keypoint_score"][self._selection]).astype(np.float32) return np.stack(self["keypoint_score"][self._selection_indices]).astype(np.float32)
return np.stack(self["keypoint_score"]).astype(np.float32) return np.stack(self["keypoint_score"]).astype(np.float32)
def centerOfGravity(self, selection=False, threshold=0.8, nodes=[0,1,2]): def centerOfGravity(self, selection=False, threshold=0.8, nodes=[0,1,2]):

View File

@@ -1,8 +1,8 @@
import logging import logging
import numpy as np import numpy as np
from PySide6.QtWidgets import QWidget, QVBoxLayout, QTabWidget, QPushButton, QGraphicsView from PySide6.QtWidgets import QWidget, QVBoxLayout, QTabWidget, QPushButton, QGraphicsView, QTextEdit
from PySide6.QtWidgets import QSpinBox, QProgressBar, QGridLayout, QLabel, QCheckBox, QProgressDialog from PySide6.QtWidgets import QSpinBox, QProgressBar, QGridLayout, QLabel, QCheckBox, QDoubleSpinBox
from PySide6.QtCore import Qt, Signal, Slot, QRunnable, QObject, QThreadPool from PySide6.QtCore import Qt, Signal, Slot, QRunnable, QObject, QThreadPool
from PySide6.QtGui import QBrush, QColor from PySide6.QtGui import QBrush, QColor
@@ -13,24 +13,26 @@ from fixtracks.utils.trackingdata import TrackingData
from IPython import embed from IPython import embed
class Detection(): class Detection():
def __init__(self, id, frame, track, position, orientation, length, userlabeled): def __init__(self, id, frame, track, position, orientation, length, userlabeled, confidence):
self.id = id self.id = id
self.frame = frame self.frame = frame
self.track = track self.track = track
self.position = position self.position = position
self.score = 0.0 self.confidence = confidence
self.angle = orientation self.angle = orientation
self.length = length self.length = length
self.userlabeled = userlabeled self.userlabeled = userlabeled
class WorkerSignals(QObject): class WorkerSignals(QObject):
error = Signal(str) message = Signal(str)
running = Signal(bool) running = Signal(bool)
progress = Signal(int, int, int) progress = Signal(int, int, int)
currentframe = Signal(int) currentframe = Signal(int)
stopped = Signal(int) stopped = Signal(int)
class ConsitencyDataLoader(QRunnable):
class ConsistencyDataLoader(QRunnable):
def __init__(self, data): def __init__(self, data):
super().__init__() super().__init__()
self.signals = WorkerSignals() self.signals = WorkerSignals()
@@ -40,7 +42,7 @@ class ConsitencyDataLoader(QRunnable):
self.lengths = None self.lengths = None
self.orientations = None self.orientations = None
self.userlabeled = None self.userlabeled = None
self.scores = None self.confidence = None
self.frames = None self.frames = None
self.tracks = None self.tracks = None
@@ -52,17 +54,18 @@ class ConsitencyDataLoader(QRunnable):
self.positions = self.data.centerOfGravity() self.positions = self.data.centerOfGravity()
self.orientations = self.data.orientation() self.orientations = self.data.orientation()
self.lengths = self.data.animalLength() self.lengths = self.data.animalLength()
self.bendedness = self.data.bendedness() # self.bendedness = self.data.bendedness()
self.userlabeled = self.data["userlabeled"] self.userlabeled = self.data["userlabeled"]
self.scores = self.data["confidence"] # ignore for now, let's see how far this carries. self.confidence = self.data["confidence"] # ignore for now, let's see how far this carries.
self.frames = self.data["frame"] self.frames = self.data["frame"]
self.tracks = self.data["track"] self.tracks = self.data["track"]
self.signals.stopped.emit(0) self.signals.stopped.emit(0)
class ConsistencyWorker(QRunnable): class ConsistencyWorker(QRunnable):
def __init__(self, positions, orientations, lengths, bendedness, frames, tracks, def __init__(self, positions, orientations, lengths, bendedness, frames, tracks,
userlabeled, startframe=0, stoponerror=False) -> None: userlabeled, confidence, startframe=0, stoponerror=False, min_confidence=0.0) -> None:
super().__init__() super().__init__()
self.signals = WorkerSignals() self.signals = WorkerSignals()
self.positions = positions self.positions = positions
@@ -70,6 +73,8 @@ class ConsistencyWorker(QRunnable):
self.lengths = lengths self.lengths = lengths
self.bendedness = bendedness self.bendedness = bendedness
self.userlabeled = userlabeled self.userlabeled = userlabeled
self.confidence = confidence
self._min_confidence = min_confidence
self.frames = frames self.frames = frames
self.tracks = tracks self.tracks = tracks
self._startframe = startframe self._startframe = startframe
@@ -88,24 +93,15 @@ class ConsistencyWorker(QRunnable):
if np.any(self.positions[i] < 0.1): if np.any(self.positions[i] < 0.1):
logging.debug("Encountered probably invalid position %s", str(self.positions[i])) logging.debug("Encountered probably invalid position %s", str(self.positions[i]))
continue continue
if self._min_confidence > 0.0 and self.confidence[i] < self._min_confidence:
self.tracks[i] = -1
continue
d = Detection(i, frame, self.tracks[i], self.positions[i], d = Detection(i, frame, self.tracks[i], self.positions[i],
self.orientations[i], self.lengths[i], self.orientations[i], self.lengths[i],
self.userlabeled[i]) self.userlabeled[i], self.confidence[i])
detections.append(d) detections.append(d)
return detections return detections
def needs_checking(original, new):
res = False
for n, o in zip(new, original):
res = (o == 1 or o == 2) and n != o
if res:
print("inverted assignment, needs cross-checking?")
if not res:
res = len(new) > 1 and (np.all(new == 1) or np.all(new == 2))
if res:
print("all detections would be assigned to one track!")
return res
def assign_by_distance(d): def assign_by_distance(d):
t1_step = d.frame - last_detections[1].frame t1_step = d.frame - last_detections[1].frame
t2_step = d.frame - last_detections[2].frame t2_step = d.frame - last_detections[2].frame
@@ -138,6 +134,31 @@ class ConsistencyWorker(QRunnable):
most_likely_track = np.argmin(length_differences) + 1 most_likely_track = np.argmin(length_differences) + 1
return most_likely_track, length_differences return most_likely_track, length_differences
def check_multiple_detections(detections):
if self._min_confidence > 0.0:
for i, d in enumerate(detections):
if d.confidence < self._min_confidence:
del detections[i]
distances = np.zeros((len(detections), len(detections)))
for i, d1 in enumerate(detections):
for j, d2 in enumerate(detections):
distances[i, j] = np.abs(np.linalg.norm(d2.position - d1.position))
lowest_dist = np.argmin(np.sum(distances, axis=1))
del detections[lowest_dist]
return detections
def find_last_userlabeled(startframe):
t1index = np.where((self.frames < startframe) & (self.userlabeled) & (self.tracks == 1))[0][-1]
t2index = np.where((self.frames < startframe) & (self.userlabeled) & (self.tracks == 2))[0][-1]
d1 = Detection(t1index, self.frames[t1index], self.tracks[t1index], self.positions[t1index],
self.orientations[t1index], self.lengths[t1index], self.userlabeled[t1index],
self.confidence[t1index])
d2 = Detection(t1index, self.frames[t2index], self.tracks[t2index], self.positions[t2index],
self.orientations[t2index], self.lengths[t2index], self.userlabeled[t2index],
self.confidence[t1index])
last_detections[1] = d1
last_detections[2] = d2
unique_frames = np.unique(self.frames) unique_frames = np.unique(self.frames)
steps = int((len(unique_frames) - self._startframe) // 100) steps = int((len(unique_frames) - self._startframe) // 100)
errors = 0 errors = 0
@@ -145,24 +166,34 @@ class ConsistencyWorker(QRunnable):
progress = 0 progress = 0
self._stoprequest = False self._stoprequest = False
last_detections = {1: None, 2: None, -1: None} last_detections = {1: None, 2: None, -1: None}
find_last_userlabeled(self._startframe)
for f in unique_frames[unique_frames >= self._startframe]: for f in unique_frames[unique_frames >= self._startframe]:
if self._stoprequest: if self._stoprequest:
break break
error = False error = False
message = ""
self.signals.currentframe.emit(f) self.signals.currentframe.emit(f)
indices = np.where(self.frames == f)[0] indices = np.where(self.frames == f)[0]
detections = get_detections(f, indices) detections = get_detections(f, indices)
done = [False, False] done = [False, False]
if len(detections) == 0: if len(detections) == 0:
continue continue
if len(detections) > 2:
message = f"Frame {f}: More than 2 detections ({len(detections)}) in the same frame!"
logging.info("ConsistencyTracker: %s", message)
self.signals.message.emit(message)
while len(detections) > 2:
detections = check_multiple_detections(detections)
if len(detections) > 1 and np.any([detections[0].userlabeled, detections[1].userlabeled]): if len(detections) > 1 and np.any([detections[0].userlabeled, detections[1].userlabeled]):
# more than one detection # more than one detection
if detections[0].userlabeled and detections[1].userlabeled: if detections[0].userlabeled and detections[1].userlabeled:
if detections[0].track == detections[1].track: if detections[0].track == detections[1].track:
error = True error = True
logging.info("Classification error both detections in the same frame are assigned to the same track!") message = f"Frame {f}: Classification error both detections in the same frame are assigned to the same track!"
logging.info("ConsistencyTracker: %s", message)
self.signals.message.emit(message)
elif detections[0].userlabeled and not detections[1].userlabeled: elif detections[0].userlabeled and not detections[1].userlabeled:
detections[1].track = 1 if detections[0].track == 2 else 2 detections[1].track = 1 if detections[0].track == 2 else 2
elif not detections[0].userlabeled and detections[1].userlabeled: elif not detections[0].userlabeled and detections[1].userlabeled:
@@ -178,50 +209,54 @@ class ConsistencyWorker(QRunnable):
elif len(detections) == 1 and detections[0].userlabeled: # ony one detection and labeled elif len(detections) == 1 and detections[0].userlabeled: # ony one detection and labeled
last_detections[detections[0].track] = detections[0] last_detections[detections[0].track] = detections[0]
done[0] = True done[0] = True
if np.sum(done) == len(detections): if np.sum(done) == len(detections):
continue continue
# if f == 2088:
# embed()
# return
if error and self._stoponerror: if error and self._stoponerror:
self.signals.error.emit("Classification error both detections in the same frame are assigned to the same track!") self.signals.message.emit(f"Tracking stopped at frame {f}.")
break break
elif error:
continue
dist_assignments = np.zeros(2, dtype=int) dist_assignments = np.zeros(2, dtype=int)
orientation_assignments = np.zeros_like(dist_assignments) orientation_assignments = np.zeros_like(dist_assignments)
length_assignments = np.zeros_like(dist_assignments) length_assignments = np.zeros_like(dist_assignments)
distances = np.zeros((2, 2)) distances = np.zeros((2, 2))
orientations = np.zeros_like(distances) orientations = np.zeros_like(distances)
lengths = np.zeros_like(distances) lengths = np.zeros_like(distances)
assignments = np.zeros((2, 2)) assignments = np.zeros(2)
for i, d in enumerate(detections): for i, d in enumerate(detections):
dist_assignments[i], distances[i, :] = assign_by_distance(d) dist_assignments[i], distances[i, :] = assign_by_distance(d)
orientation_assignments[i], orientations[i,:] = assign_by_orientation(d) orientation_assignments[i], orientations[i,:] = assign_by_orientation(d)
length_assignments[i], lengths[i, :] = assign_by_length(d) length_assignments[i], lengths[i, :] = assign_by_length(d)
assignments[i, :] = dist_assignments # (dist_assignments * 10 + orientation_assignments + length_assignments) / 3 assignments = dist_assignments # (dist_assignments * 10 + orientation_assignments + length_assignments) / 3
diffs = np.diff(assignments, axis=1)
error = False error = False
temp = {} temp = {}
message = "" message = ""
for i, d in enumerate(detections): if len(detections) > 1:
temp = {} if assignments[0] == assignments[1]:
if diffs[i] == 0: # both are equally likely
d.track = -1 d.track = -1
error = True error = True
message = "Classification error both detections in the same frame are assigned to the same track!" errors += 1
message = f"Frame {f}: Classification error: both detections in the same frame are assigned to the same track!"
break break
if diffs[i] < 0: elif assignments[0] != assignments[1]:
d.track = 1 detections[0].track = assignments[0]
detections[1].track = assignments[1]
temp[detections[0].track] = detections[0]
temp[detections[1].track] = detections[1]
self.tracks[detections[0].id] = detections[0].track
self.tracks[detections[1].id] = detections[1].track
else: else:
d.track = 2 if np.abs(np.diff(distances[0,:])) > 50: # maybe include the time difference into this?
self.tracks[d.id] = d.track detections[0].track = assignments[0]
if d.track not in temp: temp[detections[0].track] = detections[0]
temp[d.track] = d self.tracks[detections[0].id] = detections[0].track
else: else:
self.tracks[detections[0].id] = -1
message = f"Frame: {f}: Decision based on distance not safe. Track set to unassigned."
error = True error = True
message = "Double assignment to the same track!" errors += 1
break
if not error: if not error:
for k in temp: for k in temp:
@@ -231,8 +266,8 @@ class ConsistencyWorker(QRunnable):
for idx in indices: for idx in indices:
self.tracks[idx] = -1 self.tracks[idx] = -1
errors += 1 errors += 1
if self._stoponerror: if error and self._stoponerror:
self.signals.error.emit(message) self.signals.message.emit(message)
break break
processed += 1 processed += 1
@@ -240,6 +275,7 @@ class ConsistencyWorker(QRunnable):
progress += 1 progress += 1
self.signals.progress.emit(progress, processed, errors) self.signals.progress.emit(progress, processed, errors)
self.signals.message.emit(f"Tracking stopped at frame {f}.")
self.signals.stopped.emit(f) self.signals.stopped.emit(f)
@@ -318,6 +354,7 @@ class SizeClassifier(QWidget):
tracks[(self._sizes >= t2lower) & (self._sizes < t2upper)] = 2 tracks[(self._sizes >= t2lower) & (self._sizes < t2upper)] = 2
return tracks return tracks
class NeighborhoodValidator(QWidget): class NeighborhoodValidator(QWidget):
apply = Signal() apply = Signal()
name = "Neighborhood Validator" name = "Neighborhood Validator"
@@ -442,6 +479,7 @@ class ConsistencyClassifier(QWidget):
self._all_lengths = None self._all_lengths = None
self._all_bendedness = None self._all_bendedness = None
self._all_scores = None self._all_scores = None
self._confidence = None
self._userlabeled = None self._userlabeled = None
self._maxframes = 0 self._maxframes = 0
self._frames = None self._frames = None
@@ -487,25 +525,38 @@ class ConsistencyClassifier(QWidget):
self._stoponerror.setChecked(True) self._stoponerror.setChecked(True)
self.threadpool = QThreadPool() self.threadpool = QThreadPool()
self._ignore_confidence = QCheckBox("Ignore detections widh confidence less than")
self._confidence_spinner = QDoubleSpinBox()
self._confidence_spinner.setRange(0.0, 1.0)
self._confidence_spinner.setSingleStep(0.01)
self._confidence_spinner.setDecimals(2)
self._confidence_spinner.setValue(0.5)
self._messagebox = QTextEdit()
self._messagebox.setFocusPolicy(Qt.NoFocus)
self._messagebox.setReadOnly(True)
lyt = QGridLayout() lyt = QGridLayout()
lyt.addWidget(QLabel("Start frame:"), 0, 0 ) lyt.addWidget(QLabel("Start frame:"), 0, 0 )
lyt.addWidget(self._startframe_spinner, 0, 1, 1, 2) lyt.addWidget(self._startframe_spinner, 0, 1, 1, 1)
lyt.addWidget(QLabel("of"), 1, 1, 1, 1) lyt.addWidget(QLabel("of"), 0, 2, 1, 1)
lyt.addWidget(self._maxframeslabel, 1, 2, 1, 1) lyt.addWidget(self._maxframeslabel, 0, 3, 1, 1)
lyt.addWidget(self._stoponerror, 2, 0, 1, 3) lyt.addWidget(self._stoponerror, 1, 0, 1, 3)
lyt.addWidget(QLabel("Current frame"), 3,0) lyt.addWidget(self._ignore_confidence, 3, 0, 1, 3)
lyt.addWidget(self._framelabel, 3,1) lyt.addWidget(self._confidence_spinner, 3, 3, 1, 1)
lyt.addWidget(QLabel("assigned"), 4, 0) lyt.addWidget(QLabel("Current frame"), 4, 0)
lyt.addWidget(self._assignedlabel, 4, 1) lyt.addWidget(self._framelabel, 4, 1)
lyt.addWidget(QLabel("errors/issues"), 5, 0) lyt.addWidget(QLabel("(Re-)Assigned"), 5, 0)
lyt.addWidget(self._errorlabel, 5, 1) lyt.addWidget(self._assignedlabel, 5, 1)
lyt.addWidget(QLabel("Errors/issues"), 5, 2)
lyt.addWidget(self._errorlabel, 5, 3, 1, 1)
lyt.addWidget(self._messagebox, 6, 0, 2, 4)
lyt.addWidget(self._startbtn, 6, 0) lyt.addWidget(self._startbtn, 8, 0, 1, 2)
lyt.addWidget(self._stopbtn, 6, 1) lyt.addWidget(self._stopbtn, 8, 2)
lyt.addWidget(self._proceedbtn, 6, 2) # lyt.addWidget(self._proceedbtn, 8, 2)
lyt.addWidget(self._apply_btn, 7, 0, 1, 2) lyt.addWidget(self._refreshbtn, 8, 3, 1, 1)
lyt.addWidget(self._refreshbtn, 7, 2, 1, 1) lyt.addWidget(self._apply_btn, 9, 0, 1, 4)
lyt.addWidget(self._progressbar, 8, 0, 1, 3) lyt.addWidget(self._progressbar, 10, 0, 1, 4)
self.setLayout(lyt) self.setLayout(lyt)
def setData(self, data:TrackingData): def setData(self, data:TrackingData):
@@ -530,24 +581,34 @@ class ConsistencyClassifier(QWidget):
self._all_lengths = self._dataworker.lengths self._all_lengths = self._dataworker.lengths
self._all_bendedness = self._dataworker.bendedness self._all_bendedness = self._dataworker.bendedness
self._userlabeled = self._dataworker.userlabeled self._userlabeled = self._dataworker.userlabeled
self._all_scores = self._dataworker.scores self._confidence = self._dataworker.confidence
self._frames = self._dataworker.frames self._frames = self._dataworker.frames
self._tracks = self._dataworker.tracks self._tracks = self._dataworker.tracks
self._dataworker = None self._dataworker = None
if np.sum(self._userlabeled) < 1: if np.sum(self._userlabeled) < 1:
logging.error("ConsistencyTracker: I need at least 1 user-labeled frame to start with!") msg = "ConsistencyTracker: I need at least 1 user-labeled frame to start with!"
logging.error(msg)
self._messagebox.append(msg)
self.setEnabled(False) self.setEnabled(False)
else: else:
t1_userlabeled = self._frames[self._userlabeled & (self._tracks == 1)] t1_userlabeled = self._frames[self._userlabeled & (self._tracks == 1)]
t2_userlabeled = self._frames[self._userlabeled & (self._tracks == 2)] t2_userlabeled = self._frames[self._userlabeled & (self._tracks == 2)]
max_startframe = np.min([t1_userlabeled[-1], t2_userlabeled[-1]]) if any([len(t1_userlabeled) == 0, len(t2_userlabeled)== 0]):
min_startframe = np.max([t1_userlabeled[0], t2_userlabeled[0]]) self._messagebox.append("Error preparing data! Make sure that the first user-labeled frames contain both tracks!")
self.setEnabled(False)
return
max_startframe = np.min([t1_userlabeled[-1], t2_userlabeled[-1]]) -1
first_guess = np.max([t1_userlabeled[0], t2_userlabeled[0]])
while first_guess not in t1_userlabeled or first_guess not in t2_userlabeled:
first_guess += 1
min_startframe = first_guess + 1
self._maxframes = np.max(self._frames) self._maxframes = np.max(self._frames)
self._maxframeslabel.setText(str(self._maxframes)) self._maxframeslabel.setText(str(self._maxframes))
self._startframe_spinner.setMinimum(min_startframe) self._startframe_spinner.setMinimum(min_startframe)
self._startframe_spinner.setMaximum(max_startframe) self._startframe_spinner.setMaximum(max_startframe)
self._startframe_spinner.setValue(min_startframe) self._startframe_spinner.setValue(min_startframe)
self._startframe_spinner.setSingleStep(20) self._startframe_spinner.setSingleStep(20)
self._startframe_spinner.setToolTip(f"Maximum possible start frame: {max_startframe}")
self._startbtn.setEnabled(True) self._startbtn.setEnabled(True)
self._assignedlabel.setText("0") self._assignedlabel.setText("0")
self._errorlabel.setText("0") self._errorlabel.setText("0")
@@ -561,32 +622,39 @@ class ConsistencyClassifier(QWidget):
def stop(self): def stop(self):
if self._worker is not None: if self._worker is not None:
self._worker.stop() self._worker.stop()
self._startbtn.setEnabled(True) self._messagebox.append("Stopping tracking.")
self._proceedbtn.setEnabled(True)
self._stopbtn.setEnabled(False)
self._refreshbtn.setEnabled(True)
def start(self): def start(self):
confidence_level = self._confidence_spinner.value() if self._ignore_confidence.isChecked() else 0.0
self._startbtn.setEnabled(False) self._startbtn.setEnabled(False)
self._refreshbtn.setEnabled(False) self._refreshbtn.setEnabled(False)
self._stopbtn.setEnabled(True) self._stopbtn.setEnabled(True)
self._worker = ConsistencyWorker(self._all_pos, self._all_orientations, self._all_lengths, self._worker = ConsistencyWorker(self._all_pos, self._all_orientations, self._all_lengths,
self._all_bendedness, self._frames, self._tracks, self._userlabeled, self._all_bendedness, self._frames, self._tracks, self._userlabeled,
self._startframe_spinner.value(), self._stoponerror.isChecked()) self._confidence, self._startframe_spinner.value(), self._stoponerror.isChecked(),
min_confidence=confidence_level)
self._worker.signals.stopped.connect(self.worker_stopped) self._worker.signals.stopped.connect(self.worker_stopped)
self._worker.signals.progress.connect(self.worker_progress) self._worker.signals.progress.connect(self.worker_progress)
self._worker.signals.message.connect(self.worker_error)
self._worker.signals.currentframe.connect(self.worker_frame) self._worker.signals.currentframe.connect(self.worker_frame)
self._messagebox.append("Tracking in progress ...")
self.threadpool.start(self._worker) self.threadpool.start(self._worker)
def worker_frame(self, frame): def worker_frame(self, frame):
self._framelabel.setText(str(frame)) self._framelabel.setText(str(frame))
def worker_error(self, msg):
self._messagebox.append(msg)
def proceed(self): def proceed(self):
self.start() self.start()
def refresh(self): def refresh(self):
self._dataworker = ConsitencyDataLoader(self._data) self.setEnabled(False)
self._dataworker = ConsistencyDataLoader(self._data)
self._dataworker.signals.stopped.connect(self.data_processed) self._dataworker.signals.stopped.connect(self.data_processed)
self._messagebox.clear()
self._messagebox.append("Refreshing...")
self.threadpool.start(self._dataworker) self.threadpool.start(self._dataworker)
def worker_progress(self, progress, processed, errors): def worker_progress(self, progress, processed, errors):
@@ -595,13 +663,15 @@ class ConsistencyClassifier(QWidget):
self._assignedlabel.setText(str(processed)) self._assignedlabel.setText(str(processed))
def worker_stopped(self, frame): def worker_stopped(self, frame):
self._apply_btn.setEnabled(True)
self._startbtn.setEnabled(True) self._startbtn.setEnabled(True)
self._proceedbtn.setEnabled(True)
self._stopbtn.setEnabled(False) self._stopbtn.setEnabled(False)
self._apply_btn.setEnabled(True)
self._refreshbtn.setEnabled(True)
self._startframe_spinner.setValue(frame-1) self._startframe_spinner.setValue(frame-1)
self._proceedbtn.setEnabled(bool(frame < self._maxframes-1)) self._proceedbtn.setEnabled(bool(frame < self._maxframes-1))
self._refreshbtn.setEnabled(True)
self._processed_frames = frame self._processed_frames = frame
self._messagebox.append("... done.")
def assignedTracks(self): def assignedTracks(self):
return self._tracks return self._tracks

View File

@@ -21,10 +21,10 @@ class Window(QGraphicsRectItem):
self.setBrush(brush) self.setBrush(brush)
self.setZValue(1.0) self.setZValue(1.0)
self.setAcceptHoverEvents(True) # Enable hover events if needed self.setAcceptHoverEvents(True) # Enable hover events if needed
self.setFlags( # self.setFlags(
QGraphicsItem.ItemIsMovable | # Enables item dragging # QGraphicsItem.ItemIsMovable | # Enables item dragging
QGraphicsItem.ItemIsSelectable # Enables item selection # QGraphicsItem.ItemIsSelectable # Enables item selection
) # )
self._y = y self._y = y
def setWindowX(self, newx): def setWindowX(self, newx):
@@ -65,6 +65,7 @@ class Window(QGraphicsRectItem):
def mousePressEvent(self, event): def mousePressEvent(self, event):
self.setCursor(Qt.ClosedHandCursor) self.setCursor(Qt.ClosedHandCursor)
# print(event.pos())
super().mousePressEvent(event) super().mousePressEvent(event)
def mouseReleaseEvent(self, event): def mouseReleaseEvent(self, event):
@@ -121,6 +122,7 @@ class DetectionTimeline(QWidget):
self._scene = QGraphicsScene(QRectF(0, 0, self._total_width, 85.)) self._scene = QGraphicsScene(QRectF(0, 0, self._total_width, 85.))
self._scene.setBackgroundBrush(self._bg_brush) self._scene.setBackgroundBrush(self._bg_brush)
self._scene.addItem(self._window) self._scene.addItem(self._window)
self._scene.mousePressEvent = self.on_sceneMousePress
self._view = QGraphicsView() self._view = QGraphicsView()
# self._view.setRenderHints(QPainter.RenderHint.Antialiasing | QPainter.RenderHint.SmoothPixmapTransform) # self._view.setRenderHints(QPainter.RenderHint.Antialiasing | QPainter.RenderHint.SmoothPixmapTransform)
@@ -151,12 +153,22 @@ class DetectionTimeline(QWidget):
self._position_label.setFont(f) self._position_label.setFont(f)
layout = QVBoxLayout() layout = QVBoxLayout()
layout.setSpacing(0)
layout.setContentsMargins(5, 2, 5, 2)
layout.addWidget(self._view) layout.addWidget(self._view)
layout.addWidget(self._position_label, Qt.AlignmentFlag.AlignRight) layout.addWidget(self._position_label, Qt.AlignmentFlag.AlignRight)
self.setLayout(layout) self.setLayout(layout)
# self.setMaximumHeight(100) # self.setMaximumHeight(100)
# self.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Fixed) # self.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Fixed)
def on_sceneMousePress(self, event):
scene_pos = event.scenePos()
relpos = scene_pos.x() / self._total_width
relpos = 0 if relpos < 0.0 else relpos
relpos = 2000/self._total_width if scene_pos.x() > self._total_width else relpos
self.signals.moveRequest.emit(relpos)
logging.debug("Timeline: Scene clicked at position: %.2f, %.2f --> rel x-pos %.3f", scene_pos.x(), scene_pos.y(), relpos)
def clear(self): def clear(self):
for i in self._scene.items(): for i in self._scene.items():
if isinstance(i, (QGraphicsLineItem, QGraphicsEllipseItem)): if isinstance(i, (QGraphicsLineItem, QGraphicsEllipseItem)):
@@ -166,6 +178,7 @@ class DetectionTimeline(QWidget):
logging.debug("Timeline: setData!") logging.debug("Timeline: setData!")
self._data = data self._data = data
self.update() self.update()
self.resizeEvent(None)
def update(self): def update(self):
self.clear() self.clear()
@@ -309,8 +322,7 @@ def main():
datafile = PACKAGE_ROOT / "data/merged_small.pkl" datafile = PACKAGE_ROOT / "data/merged_small.pkl"
with open(datafile, "rb") as f: with open(datafile, "rb") as f:
df = pickle.load(f) df = pickle.load(f)
data = TrackingData() data = TrackingData(as_dict(df))
data.setData(as_dict(df))
data.setSelection(np.arange(0,100, 1)) data.setSelection(np.arange(0,100, 1))
data.setUserLabeledStatus(True) data.setUserLabeledStatus(True)
start_x = 0.1 start_x = 0.1
@@ -329,12 +341,14 @@ def main():
backBtn.clicked.connect(lambda: back(0.2)) backBtn.clicked.connect(lambda: back(0.2))
btnLyt = QHBoxLayout() btnLyt = QHBoxLayout()
btnLyt.setSpacing(1)
btnLyt.addWidget(backBtn) btnLyt.addWidget(backBtn)
btnLyt.addWidget(zeroBtn) btnLyt.addWidget(zeroBtn)
btnLyt.addWidget(fwdBtn) btnLyt.addWidget(fwdBtn)
view.setWindowPos(start_x) view.setWindowPos(start_x)
layout = QVBoxLayout() layout = QVBoxLayout()
layout.setSpacing(1)
layout.addWidget(view) layout.addWidget(view)
layout.addLayout(btnLyt) layout.addLayout(btnLyt)
window.setLayout(layout) window.setLayout(layout)

View File

@@ -10,6 +10,7 @@ from fixtracks.utils.signals import DetectionSignals, DetectionViewSignals, Dete
from fixtracks.utils.enums import DetectionData, Tracks from fixtracks.utils.enums import DetectionData, Tracks
from fixtracks.utils.trackingdata import TrackingData from fixtracks.utils.trackingdata import TrackingData
class Detection(QGraphicsEllipseItem): class Detection(QGraphicsEllipseItem):
signals = DetectionSignals() signals = DetectionSignals()
@@ -128,21 +129,23 @@ class DetectionView(QWidget):
del it del it
def updateDetections(self, keypoint=-1): def updateDetections(self, keypoint=-1):
logging.info("DetectionView.updateDetections!")
self.clearDetections() self.clearDetections()
if self._data is None: if self._data is None:
return return
frames = self._data.selectedData("frame") frames = self._data.selectedData("frame")
tracks = self._data.selectedData("track") tracks = self._data.selectedData("track")
ids = self._data.selectedData("index")
coordinates = self._data.coordinates(selection=True) coordinates = self._data.coordinates(selection=True)
centercoordinates = self._data.centerOfGravity(selection=True) centercoordinates = self._data.centerOfGravity(selection=True)
userlabeled = self._data.selectedData("userlabeled") userlabeled = self._data.selectedData("userlabeled")
scores = self._data.selectedData("confidence")
indices = self._data.selectionIndices
image_rect = self._pixmapitem.boundingRect() if self._pixmapitem is not None else QRectF(0,0,0,0) image_rect = self._pixmapitem.boundingRect() if self._pixmapitem is not None else QRectF(0,0,0,0)
num_detections = len(frames)
for i, idx in enumerate(indices): for i, (id, f, t, l, s) in enumerate(zip(ids, frames, tracks, userlabeled, scores)):
t = tracks[i]
c = Tracks.fromValue(t).toColor() c = Tracks.fromValue(t).toColor()
c.setAlpha(int(i * 255 / num_detections))
if keypoint >= 0: if keypoint >= 0:
x = coordinates[i, keypoint, 0] x = coordinates[i, keypoint, 0]
y = coordinates[i, keypoint, 1] y = coordinates[i, keypoint, 1]
@@ -151,10 +154,12 @@ class DetectionView(QWidget):
y = centercoordinates[i, 1] y = centercoordinates[i, 1]
item = Detection(image_rect.left() + x, image_rect.top() + y, 20, 20, brush=QBrush(c)) item = Detection(image_rect.left() + x, image_rect.top() + y, 20, 20, brush=QBrush(c))
item.setData(DetectionData.TRACK_ID.value, tracks[i]) item.setData(DetectionData.TRACK_ID.value, t)
item.setData(DetectionData.ID.value, idx) item.setData(DetectionData.ID.value, id)
item.setData(DetectionData.COORDINATES.value, coordinates[i, :, :]) item.setData(DetectionData.COORDINATES.value, coordinates[i, :, :])
item.setData(DetectionData.FRAME.value, frames[i]) item.setData(DetectionData.FRAME.value, f)
item.setData(DetectionData.USERLABELED.value, l)
item.setData(DetectionData.SCORE.value, s)
item = self._scene.addItem(item) item = self._scene.addItem(item)
def fit_image_to_view(self): def fit_image_to_view(self):

View File

@@ -4,7 +4,7 @@ import numpy as np
from PySide6.QtCore import Qt, Signal, QSize from PySide6.QtCore import Qt, Signal, QSize
from PySide6.QtGui import QFont from PySide6.QtGui import QFont
from PySide6.QtWidgets import QWidget, QLabel, QPushButton, QSizePolicy from PySide6.QtWidgets import QWidget, QLabel, QPushButton, QSizePolicy
from PySide6.QtWidgets import QGridLayout, QVBoxLayout from PySide6.QtWidgets import QGridLayout, QVBoxLayout, QApplication
from fixtracks.utils.styles import pushBtnStyle from fixtracks.utils.styles import pushBtnStyle
@@ -15,6 +15,7 @@ class SelectionControls(QWidget):
assignTwo = Signal() assignTwo = Signal()
assignOther = Signal() assignOther = Signal()
accept = Signal() accept = Signal()
accept_until = Signal()
unaccept = Signal() unaccept = Signal()
delete = Signal() delete = Signal()
revertall = Signal() revertall = Signal()
@@ -51,7 +52,6 @@ class SelectionControls(QWidget):
quarterstepBackBtn.setStyleSheet(pushBtnStyle("darkgray")) quarterstepBackBtn.setStyleSheet(pushBtnStyle("darkgray"))
quarterstepBackBtn.clicked.connect(lambda: self.on_Back(quarterstep)) quarterstepBackBtn.clicked.connect(lambda: self.on_Back(quarterstep))
fwdBtn = QPushButton(">>|") fwdBtn = QPushButton(">>|")
fwdBtn.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding) fwdBtn.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding)
fwdBtn.setShortcut(Qt.Key.Key_Right) fwdBtn.setShortcut(Qt.Key.Key_Right)
@@ -102,7 +102,7 @@ class SelectionControls(QWidget):
acceptBtn.setFont(font) acceptBtn.setFont(font)
acceptBtn.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding) acceptBtn.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding)
acceptBtn.setStyleSheet(pushBtnStyle("darkgray")) acceptBtn.setStyleSheet(pushBtnStyle("darkgray"))
acceptBtn.setToolTip(f"Accept assignments of current selection as TRUE") acceptBtn.setToolTip(f"Accept assignments of current selection as TRUE, Hold shift while clicking to accept all until here.")
acceptBtn.clicked.connect(self.on_Accept) acceptBtn.clicked.connect(self.on_Accept)
unacceptBtn = QPushButton("un-accept") unacceptBtn = QPushButton("un-accept")
@@ -117,8 +117,9 @@ class SelectionControls(QWidget):
deleteBtn.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding) deleteBtn.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding)
deleteBtn.setStyleSheet(pushBtnStyle("red")) deleteBtn.setStyleSheet(pushBtnStyle("red"))
deleteBtn.setToolTip(f"DANGERZONE! Delete current selection of detections!") deleteBtn.setToolTip(f"DANGERZONE! Delete current selection of detections!")
deleteBtn.setEnabled(False) deleteBtn.setShortcut("Ctrl+D")
deleteBtn.clicked.connect(self.on_Delete) deleteBtn.clicked.connect(self.on_Delete)
deleteBtn.setEnabled(False)
revertBtn = QPushButton("revert assignments") revertBtn = QPushButton("revert assignments")
revertBtn.setFont(font) revertBtn.setFont(font)
@@ -171,7 +172,7 @@ class SelectionControls(QWidget):
grid.setColumnStretch(0, 1) grid.setColumnStretch(0, 1)
grid.setColumnStretch(7, 1) grid.setColumnStretch(7, 1)
self.setLayout(grid) self.setLayout(grid)
self.setMaximumSize(QSize(500, 500)) # self.setMaximumSize(QSize(500, 500))
def setWindow(self, start:int=0, end:int=0): def setWindow(self, start:int=0, end:int=0):
self.startframe.setText(f"{start:.0f}") self.startframe.setText(f"{start:.0f}")
@@ -210,6 +211,11 @@ class SelectionControls(QWidget):
def on_Accept(self): def on_Accept(self):
logging.debug("SelectionControl: accept AssignmentBtn") logging.debug("SelectionControl: accept AssignmentBtn")
modifiers = QApplication.keyboardModifiers()
if modifiers == Qt.KeyboardModifier.ShiftModifier:
logging.debug("Shift key was pressed during accept")
self.accept_until.emit()
else:
self.accept.emit() self.accept.emit()
def on_Unaccept(self): def on_Unaccept(self):

View File

@@ -94,7 +94,8 @@ class SkeletonWidget(QWidget):
i = s.data(DetectionData.ID.value) i = s.data(DetectionData.ID.value)
t = s.data(DetectionData.TRACK_ID.value) t = s.data(DetectionData.TRACK_ID.value)
f = s.data(DetectionData.FRAME.value) f = s.data(DetectionData.FRAME.value)
self._info_label.setText(f"Id {i}, track {t} on frame {f}, length {l:.1f} px") sc = s.data(DetectionData.SCORE.value)
self._info_label.setText(f"Id {i}, track {t} on frame {f}, length {l:.1f} px, confidence {sc:.2f}")
else: else:
self._info_label.setText("") self._info_label.setText("")
@@ -129,7 +130,7 @@ class SkeletonWidget(QWidget):
self._scene.setSceneRect(self._minx, self._miny, self._maxx - self._minx, self._maxy - self._miny) self._scene.setSceneRect(self._minx, self._miny, self._maxx - self._minx, self._maxy - self._miny)
self._view.fitInView(self._scene.sceneRect(), Qt.AspectRatioMode.KeepAspectRatio) self._view.fitInView(self._scene.sceneRect(), Qt.AspectRatioMode.KeepAspectRatio)
def addSkeleton(self, coords, detection_id, frame, track, brush, update=True): def addSkeleton(self, coords, detection_id, frame, track, score, brush, update=True):
def check_extent(x, y, w, h): def check_extent(x, y, w, h):
if x == 0 and y == 0: if x == 0 and y == 0:
return return
@@ -157,12 +158,14 @@ class SkeletonWidget(QWidget):
item.setData(DetectionData.ID.value, detection_id) item.setData(DetectionData.ID.value, detection_id)
item.setData(DetectionData.TRACK_ID.value, track) item.setData(DetectionData.TRACK_ID.value, track)
item.setData(DetectionData.FRAME.value, frame) item.setData(DetectionData.FRAME.value, frame)
item.setData(DetectionData.SCORE.value, score)
self._skeletons.append(item) self._skeletons.append(item)
if update: if update:
self.update() self.update()
def addSkeletons(self, coordinates:np.ndarray, detection_ids:np.ndarray, def addSkeletons(self, coordinates:np.ndarray, detection_ids:np.ndarray,
frames:np.ndarray, tracks:np.ndarray, brush:QBrush): frames:np.ndarray, tracks:np.ndarray, scores:np.ndarray,
brush:QBrush):
num_detections = 0 if coordinates is None else coordinates.shape[0] num_detections = 0 if coordinates is None else coordinates.shape[0]
logging.debug("SkeletonWidget: add %i Skeletons", num_detections) logging.debug("SkeletonWidget: add %i Skeletons", num_detections)
if num_detections < 1: if num_detections < 1:
@@ -172,9 +175,10 @@ class SkeletonWidget(QWidget):
detection_ids = detection_ids[sorting] detection_ids = detection_ids[sorting]
frames = frames[sorting] frames = frames[sorting]
tracks = tracks[sorting] tracks = tracks[sorting]
scores = scores[sorting]
for i in range(num_detections): for i in range(num_detections):
self.addSkeleton(coordinates[i,:,:], detection_ids[i], frames[i], self.addSkeleton(coordinates[i,:,:], detection_ids[i], frames[i],
tracks[i], brush=brush, update=False) tracks[i], scores[i], brush=brush, update=False)
self.update() self.update()
# def addSkeleton(self, coords, detection_id, brush): # def addSkeleton(self, coords, detection_id, brush):

View File

@@ -1,11 +1,11 @@
import logging import logging
import numpy as np import numpy as np
import pandas as pd
from PySide6.QtCore import Qt, QThreadPool, Signal from PySide6.QtCore import Qt, QThreadPool, Signal
from PySide6.QtGui import QImage, QBrush, QColor from PySide6.QtGui import QImage, QBrush, QColor
from PySide6.QtWidgets import QWidget, QVBoxLayout, QHBoxLayout, QLabel, QPushButton, QSizePolicy, QComboBox from PySide6.QtWidgets import QWidget, QVBoxLayout, QHBoxLayout, QLabel, QPushButton, QSizePolicy, QComboBox
from PySide6.QtWidgets import QSpinBox, QSpacerItem, QProgressBar, QSplitter, QGridLayout, QFileDialog, QGridLayout from PySide6.QtWidgets import QSpinBox, QSpacerItem, QProgressBar, QSplitter, QFileDialog, QMessageBox
from fixtracks.utils.reader import PickleLoader from fixtracks.utils.reader import PickleLoader
from fixtracks.utils.writer import PickleWriter from fixtracks.utils.writer import PickleWriter
@@ -37,12 +37,6 @@ class FixTracks(QWidget):
self._detectionView = DetectionView() self._detectionView = DetectionView()
self._detectionView.signals.itemsSelected.connect(self.on_detectionsSelected) self._detectionView.signals.itemsSelected.connect(self.on_detectionsSelected)
self._skeleton = SkeletonWidget() self._skeleton = SkeletonWidget()
# self._skeleton.setMaximumSize(QSize(400, 400))
top_splitter = QSplitter(Qt.Orientation.Horizontal)
top_splitter.addWidget(self._detectionView)
top_splitter.addWidget(self._skeleton)
top_splitter.setStretchFactor(0, 2)
top_splitter.setStretchFactor(1, 1)
self._progress_bar = QProgressBar(self) self._progress_bar = QProgressBar(self)
self._progress_bar.setMaximumHeight(20) self._progress_bar.setMaximumHeight(20)
@@ -51,6 +45,7 @@ class FixTracks(QWidget):
self._timeline = DetectionTimeline() self._timeline = DetectionTimeline()
self._timeline.signals.windowMoved.connect(self.on_windowChanged) self._timeline.signals.windowMoved.connect(self.on_windowChanged)
self._timeline.signals.moveRequest.connect(self.on_moveRequest)
self._windowspinner = QSpinBox() self._windowspinner = QSpinBox()
self._windowspinner.setRange(10, 10000) self._windowspinner.setRange(10, 10000)
@@ -61,15 +56,31 @@ class FixTracks(QWidget):
self._keypointcombo = QComboBox() self._keypointcombo = QComboBox()
self._keypointcombo.currentIndexChanged.connect(self.on_keypointSelected) self._keypointcombo.currentIndexChanged.connect(self.on_keypointSelected)
combo_layout = QGridLayout() self._goto_spinner = QSpinBox()
combo_layout.addWidget(QLabel("Window:"), 0, 0) self._goto_spinner.setSingleStep(1)
combo_layout.addWidget(self._windowspinner, 0, 1)
combo_layout.addWidget(QLabel("Keypoint:"), 1, 0)
combo_layout.addWidget(self._keypointcombo, 1, 1)
timelinebox = QHBoxLayout() self._gotobtn = QPushButton("go!")
timelinebox.addWidget(self._timeline) self._gotobtn.setToolTip("Jump to a given frame")
self._gotobtn.clicked.connect(self.on_goto)
combo_layout = QHBoxLayout()
combo_layout.addWidget(QLabel("Window width:"))
combo_layout.addWidget(self._windowspinner)
combo_layout.addWidget(QLabel("frames"))
combo_layout.addItem(QSpacerItem(10, 10, QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Fixed))
combo_layout.addWidget(QLabel("Keypoint:"))
combo_layout.addWidget(self._keypointcombo)
combo_layout.addItem(QSpacerItem(10, 10, QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Fixed))
combo_layout.addWidget(QLabel("Jump to frame:"))
combo_layout.addWidget(self._goto_spinner)
combo_layout.addWidget(self._gotobtn)
combo_layout.addItem(QSpacerItem(100, 10, QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Fixed))
combo_layout.setSpacing(1)
timelinebox = QVBoxLayout()
timelinebox.setSpacing(2)
timelinebox.addLayout(combo_layout) timelinebox.addLayout(combo_layout)
timelinebox.addWidget(self._timeline)
self._controls_widget = SelectionControls() self._controls_widget = SelectionControls()
self._controls_widget.assignOne.connect(self.on_assignOne) self._controls_widget.assignOne.connect(self.on_assignOne)
@@ -78,6 +89,7 @@ class FixTracks(QWidget):
self._controls_widget.fwd.connect(self.on_forward) self._controls_widget.fwd.connect(self.on_forward)
self._controls_widget.back.connect(self.on_backward) self._controls_widget.back.connect(self.on_backward)
self._controls_widget.accept.connect(self.on_setUserFlag) self._controls_widget.accept.connect(self.on_setUserFlag)
self._controls_widget.accept_until.connect(self.on_setUserFlagsUntil)
self._controls_widget.unaccept.connect(self.on_unsetUserFlag) self._controls_widget.unaccept.connect(self.on_unsetUserFlag)
self._controls_widget.delete.connect(self.on_deleteDetection) self._controls_widget.delete.connect(self.on_deleteDetection)
self._controls_widget.revertall.connect(self.on_revertUserFlags) self._controls_widget.revertall.connect(self.on_revertUserFlags)
@@ -103,6 +115,7 @@ class FixTracks(QWidget):
data_selection_box.addWidget(QLabel("Select data file")) data_selection_box.addWidget(QLabel("Select data file"))
data_selection_box.addWidget(self._data_combo) data_selection_box.addWidget(self._data_combo)
data_selection_box.addItem(QSpacerItem(100, 10, QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Fixed)) data_selection_box.addItem(QSpacerItem(100, 10, QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Fixed))
data_selection_box.setSpacing(0)
btnBox = QHBoxLayout() btnBox = QHBoxLayout()
btnBox.setAlignment(Qt.AlignmentFlag.AlignLeft) btnBox.setAlignment(Qt.AlignmentFlag.AlignLeft)
@@ -118,9 +131,14 @@ class FixTracks(QWidget):
cntrlBox = QHBoxLayout() cntrlBox = QHBoxLayout()
cntrlBox.addWidget(self._classifier) cntrlBox.addWidget(self._classifier)
cntrlBox.addWidget(self._controls_widget, alignment=Qt.AlignmentFlag.AlignCenter) cntrlBox.addWidget(self._controls_widget, alignment=Qt.AlignmentFlag.AlignCenter)
cntrlBox.addItem(QSpacerItem(300, 100, QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Expanding)) cntrlBox.addWidget(self._skeleton)
cntrlBox.addItem(QSpacerItem(50, 100, QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Expanding))
cntrlBox.setSpacing(0)
cntrlBox.setContentsMargins(0,0,0,0)
vbox = QVBoxLayout() vbox = QVBoxLayout()
vbox.setSpacing(0)
vbox.setContentsMargins(0,0,0,0)
vbox.addLayout(timelinebox) vbox.addLayout(timelinebox)
vbox.addLayout(cntrlBox) vbox.addLayout(cntrlBox)
vbox.addLayout(btnBox) vbox.addLayout(btnBox)
@@ -128,13 +146,16 @@ class FixTracks(QWidget):
container.setLayout(vbox) container.setLayout(vbox)
splitter = QSplitter(Qt.Orientation.Vertical) splitter = QSplitter(Qt.Orientation.Vertical)
splitter.addWidget(top_splitter) splitter.addWidget(self._detectionView)
splitter.addWidget(container) splitter.addWidget(container)
splitter.setStretchFactor(0, 3) splitter.setStretchFactor(0, 3)
splitter.setStretchFactor(1, 1) splitter.setStretchFactor(1, 1)
layout = QVBoxLayout() layout = QVBoxLayout()
layout.addLayout(data_selection_box) layout.addLayout(data_selection_box)
layout.addWidget(splitter) layout.addWidget(splitter)
layout.setSpacing(0)
layout.setContentsMargins(5,2,2,5)
self.setLayout(layout) self.setLayout(layout)
def on_autoClassify(self, tracks): def on_autoClassify(self, tracks):
@@ -160,15 +181,19 @@ class FixTracks(QWidget):
self._detectionView.setImage(img) self._detectionView.setImage(img)
def update(self): def update(self):
kp = self._keypointcombo.currentText().lower()
if len(kp) == 0:
return
kpi = -1 if "center" in kp else int(kp)
start_frame = self._currentWindowPos start_frame = self._currentWindowPos
stop_frame = start_frame + self._currentWindowWidth stop_frame = start_frame + self._currentWindowWidth
self._timeline.setWindow(start_frame / self._maxframes, self._timeline.setWindow(start_frame / self._maxframes,
self._currentWindowWidth/self._maxframes) self._currentWindowWidth/self._maxframes)
logging.debug("Tracks:update: Updating View for detection range %i, %i frames", start_frame, stop_frame) logging.debug("Tracks:update: Updating View for detection range %i, %i frames", start_frame, stop_frame)
self._data.setSelectionRange("frame", start_frame, stop_frame) self._data.setSelectionRange("frame", start_frame, stop_frame)
self._controls_widget.setWindow(start_frame, stop_frame) self._controls_widget.setWindow(start_frame, stop_frame)
kp = self._keypointcombo.currentText().lower()
kpi = -1 if "center" in kp else int(kp)
self._detectionView.updateDetections(kpi) self._detectionView.updateDetections(kpi)
@property @property
@@ -211,6 +236,7 @@ class FixTracks(QWidget):
self._currentWindowPos = 0 self._currentWindowPos = 0
self._currentWindowWidth = self._windowspinner.value() self._currentWindowWidth = self._windowspinner.value()
self._maxframes = np.max(self._data["frame"]) self._maxframes = np.max(self._data["frame"])
self._goto_spinner.setMaximum(self._maxframes)
self.populateKeypointCombo(self._data.numKeypoints()) self.populateKeypointCombo(self._data.numKeypoints())
self._timeline.setData(self._data) self._timeline.setData(self._data)
# self._timeline.setWindow(self._currentWindowPos / self._maxframes, # self._timeline.setWindow(self._currentWindowPos / self._maxframes,
@@ -270,6 +296,12 @@ class FixTracks(QWidget):
self._timeline.update() self._timeline.update()
self.update() self.update()
def on_setUserFlagsUntil(self):
self._data.setSelectionRange("frame", 0, self._currentWindowPos + self._currentWindowWidth)
self._data.setUserLabeledStatus(True)
self._timeline.update()
self.update()
def on_unsetUserFlag(self): def on_unsetUserFlag(self):
logging.debug("Tracks:unsetUserFlag") logging.debug("Tracks:unsetUserFlag")
self._data.setUserLabeledStatus(False) self._data.setUserLabeledStatus(False)
@@ -278,14 +310,30 @@ class FixTracks(QWidget):
def on_revertUserFlags(self): def on_revertUserFlags(self):
logging.debug("Tracks:revert ALL UserFlags and track assignments") logging.debug("Tracks:revert ALL UserFlags and track assignments")
msg_box = QMessageBox()
msg_box.setIcon(QMessageBox.Icon.Warning)
msg_box.setText(f"Are you sure you want to revert ALL track assignments?")
msg_box.setStandardButtons(QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No)
msg_box.setDefaultButton(QMessageBox.StandardButton.No)
ret = msg_box.exec()
if ret == QMessageBox.StandardButton.Yes:
self._data.revertUserLabeledStatus() self._data.revertUserLabeledStatus()
self._data.revertTrackAssignments() self._data.revertTrackAssignments()
self._timeline.update() self._timeline.update()
self.update() self.update()
def on_deleteDetection(self): def on_deleteDetection(self):
logging.warning("Tracks:delete detections is currently not supported!") logging.info("Tracks:deleting detections!")
# self._data.deleteDetections() msg_box = QMessageBox()
msg_box.setIcon(QMessageBox.Icon.Warning)
msg_box.setText(f"Are you sure you want to delete the selected ({len(self._data.selectionIndices)})detections?")
msg_box.setStandardButtons(QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No)
msg_box.setDefaultButton(QMessageBox.StandardButton.No)
ret = msg_box.exec()
if ret == QMessageBox.StandardButton.Yes:
self._data.deleteDetections()
self._timeline.update() self._timeline.update()
self.update() self.update()
@@ -296,6 +344,11 @@ class FixTracks(QWidget):
self.update() self.update()
self._manualmove = False self._manualmove = False
def on_moveRequest(self, pos):
new_pos = int(np.round(pos * self._maxframes))
self._currentWindowPos = new_pos
self.update()
def on_windowSizeChanged(self, value): def on_windowSizeChanged(self, value):
"""Reacts on the user window-width selection. Selection is done in the unit of frames. """Reacts on the user window-width selection. Selection is done in the unit of frames.
@@ -306,14 +359,29 @@ class FixTracks(QWidget):
""" """
self._currentWindowWidth = value self._currentWindowWidth = value
logging.debug("Tracks:OnWindowSizeChanged %i franes", value) logging.debug("Tracks:OnWindowSizeChanged %i franes", value)
self._timeline.setWindowWidth(self._currentWindowWidth / self._maxframes) # if self._maxframes == 0:
self._controls_widget.setSelectedTracks(None) # self._timeline.setWindowWidth(self._currentWindowWidth / 2000)
# else:
# self._timeline.setWindowWidth(self._currentWindowWidth / self._maxframes)
# self._controls_widget.setSelectedTracks(None)
self.update()
def on_goto(self):
target = self._goto_spinner.value()
if target > self._maxframes - self._currentWindowWidth:
target = self._maxframes - self._currentWindowWidth
logging.info("Jump to frame %i", target)
self._currentWindowPos = target
self._timeline.setWindow(self._currentWindowPos / self._maxframes,
self._currentWindowWidth / self._maxframes)
self.update()
def on_detectionsSelected(self, detections): def on_detectionsSelected(self, detections):
logging.debug("Tracks: %i Detections selected", len(detections)) logging.debug("Tracks: %i Detections selected", len(detections))
tracks = np.zeros(len(detections), dtype=int) tracks = np.zeros(len(detections), dtype=int)
ids = np.zeros_like(tracks) ids = np.zeros_like(tracks)
frames = np.zeros_like(tracks) frames = np.zeros_like(tracks)
scores = np.zeros(tracks.shape, dtype=float)
coordinates = None coordinates = None
if len(detections) > 0: if len(detections) > 0:
c = detections[0].data(DetectionData.COORDINATES.value) c = detections[0].data(DetectionData.COORDINATES.value)
@@ -324,15 +392,20 @@ class FixTracks(QWidget):
ids[i] = d.data(DetectionData.ID.value) ids[i] = d.data(DetectionData.ID.value)
frames[i] = d.data(DetectionData.FRAME.value) frames[i] = d.data(DetectionData.FRAME.value)
coordinates[i, :, :] = d.data(DetectionData.COORDINATES.value) coordinates[i, :, :] = d.data(DetectionData.COORDINATES.value)
scores[i] = d.data(DetectionData.SCORE.value)
self._data.setSelection(ids) self._data.setSelection(ids)
self._controls_widget.setSelectedTracks(tracks) self._controls_widget.setSelectedTracks(tracks)
self._skeleton.clear() self._skeleton.clear()
self._skeleton.addSkeletons(coordinates, ids, frames, tracks, QBrush(QColor(10, 255, 65, 255))) self._skeleton.addSkeletons(coordinates, ids, frames, tracks, scores, QBrush(QColor(10, 255, 65, 255)))
def moveWindow(self, stepsize): def moveWindow(self, stepsize):
logging.info("Tracks.moveWindow: move window with stepsize %.2f", stepsize) logging.info("Tracks.moveWindow: move window with stepsize %.2f", stepsize)
self._manualmove = True self._manualmove = True
new_start_frame = self._currentWindowPos + np.round(stepsize * self._currentWindowWidth) new_start_frame = self._currentWindowPos + np.round(stepsize * self._currentWindowWidth)
if new_start_frame < 0:
new_start_frame = 0
elif new_start_frame + self._currentWindowWidth > self._maxframes:
new_start_frame = self._maxframes - self._currentWindowWidth
self._currentWindowPos = new_start_frame self._currentWindowPos = new_start_frame
self._controls_widget.setSelectedTracks(None) self._controls_widget.setSelectedTracks(None)
self.update() self.update()