Compare commits
8 Commits
d6b91c25d2
...
1c2f84b236
Author | SHA1 | Date | |
---|---|---|---|
1c2f84b236 | |||
ff3e0841a6 | |||
9e2c6f343a | |||
c0a7631acd | |||
5758cf61c6 | |||
faf095a2a1 | |||
0c5e5629b7 | |||
4ef6143d14 |
@ -7,6 +7,7 @@ class DetectionData(Enum):
|
|||||||
FRAME = 1
|
FRAME = 1
|
||||||
COORDINATES = 2
|
COORDINATES = 2
|
||||||
TRACK_ID = 3
|
TRACK_ID = 3
|
||||||
|
USERLABELED = 4
|
||||||
|
|
||||||
class Tracks(Enum):
|
class Tracks(Enum):
|
||||||
TRACKONE = 1
|
TRACKONE = 1
|
||||||
|
@ -14,8 +14,8 @@ class TrackingData(QObject):
|
|||||||
if "userlabeled" not in self._data.keys():
|
if "userlabeled" not in self._data.keys():
|
||||||
self._data["userlabeled"] = np.zeros_like(self["frame"], dtype=bool)
|
self._data["userlabeled"] = np.zeros_like(self["frame"], dtype=bool)
|
||||||
self._columns = [k for k in self._data.keys()]
|
self._columns = [k for k in self._data.keys()]
|
||||||
self._indices = self["index"]
|
self._indices = np.arange(len(self["index"]), dtype=int)
|
||||||
self._selection = np.asarray([])
|
self._selection_indices = np.asarray([])
|
||||||
self._selected_ids = None
|
self._selected_ids = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -43,8 +43,8 @@ class TrackingData(QObject):
|
|||||||
ids = np.sort(ids)
|
ids = np.sort(ids)
|
||||||
indexes = np.ones_like(ids, dtype=int) * -1
|
indexes = np.ones_like(ids, dtype=int) * -1
|
||||||
j = 0
|
j = 0
|
||||||
for idx, i in enumerate(self._indices):
|
for idx in self._indices:
|
||||||
if i == ids[j]:
|
if self["index"][idx] == ids[j]:
|
||||||
indexes[j] = idx
|
indexes[j] = idx
|
||||||
j += 1
|
j += 1
|
||||||
if j == len(indexes):
|
if j == len(indexes):
|
||||||
@ -54,19 +54,23 @@ class TrackingData(QObject):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def selectionIndices(self):
|
def selectionIndices(self):
|
||||||
return self._selection
|
return self._selection_indices
|
||||||
|
|
||||||
|
@property
|
||||||
|
def selectionIDs(self):
|
||||||
|
return self._selected_ids
|
||||||
|
|
||||||
def setSelectionRange(self, col, start, stop):
|
def setSelectionRange(self, col, start, stop):
|
||||||
logging.info("Trackingdata: set selection range based on column %s to %.2f - %.2f", col, start, stop)
|
logging.info("Trackingdata: set selection range based on column %s to %.2f - %.2f", col, start, stop)
|
||||||
col_indices = np.where((self[col] >= start) & (self[col] < stop))[0]
|
col_indices = np.where((self[col] >= start) & (self[col] < stop))[0]
|
||||||
self._selection = self._indices[col_indices]
|
self._selection_indices = self._indices[col_indices]
|
||||||
if len(col_indices) < 1:
|
if len(col_indices) < 1:
|
||||||
logging.warning("TrackingData: Selection range is empty!")
|
logging.warning("TrackingData: Selection range is empty!")
|
||||||
|
|
||||||
def selectedData(self, col:str):
|
def selectedData(self, col:str):
|
||||||
if col not in self.columns:
|
if col not in self.columns:
|
||||||
logging.error("TrackingData:selectedData: Invalid column name! %s", col)
|
logging.error("TrackingData:selectedData: Invalid column name! %s", col)
|
||||||
return self[col][self._selection]
|
return self[col][self._selection_indices]
|
||||||
|
|
||||||
def setSelection(self, ids):
|
def setSelection(self, ids):
|
||||||
"""
|
"""
|
||||||
@ -78,8 +82,9 @@ class TrackingData(QObject):
|
|||||||
An array-like object containing the IDs to be set as user selections.
|
An array-like object containing the IDs to be set as user selections.
|
||||||
"""
|
"""
|
||||||
logging.debug("TrackingData.setSelection: %i number of ids", len(ids))
|
logging.debug("TrackingData.setSelection: %i number of ids", len(ids))
|
||||||
self._selection = self._find(ids)
|
self._selection_indices = self._find(ids)
|
||||||
self._selected_ids = ids
|
self._selected_ids = ids
|
||||||
|
print(self._selected_ids, self._selection_indices)
|
||||||
|
|
||||||
def setTrack(self, track_id:int, setUserLabeled:bool=True)-> None:
|
def setTrack(self, track_id:int, setUserLabeled:bool=True)-> None:
|
||||||
"""Assign a new track_id to the user-selected detections
|
"""Assign a new track_id to the user-selected detections
|
||||||
@ -92,9 +97,12 @@ class TrackingData(QObject):
|
|||||||
Should the "userlabeled" state of the detections be set to True? Otherwise they will be left untouched.
|
Should the "userlabeled" state of the detections be set to True? Otherwise they will be left untouched.
|
||||||
"""
|
"""
|
||||||
logging.info("TrackingData: set track id %i for selection, set user-labeled status %s", track_id, str(setUserLabeled))
|
logging.info("TrackingData: set track id %i for selection, set user-labeled status %s", track_id, str(setUserLabeled))
|
||||||
self["track"][self._selection] = track_id
|
print(self._selected_ids, self._selection_indices)
|
||||||
|
print("before: ", self["track"][self._selection_indices], self["frame"][self._selection_indices])
|
||||||
|
self["track"][self._selection_indices] = track_id
|
||||||
if setUserLabeled:
|
if setUserLabeled:
|
||||||
self.setUserLabeledStatus(True, True)
|
self.setUserLabeledStatus(True, True)
|
||||||
|
print("after: ", self["track"][self._selection_indices], self["frame"][self._selection_indices])
|
||||||
|
|
||||||
def setUserLabeledStatus(self, new_status: bool, selection=True):
|
def setUserLabeledStatus(self, new_status: bool, selection=True):
|
||||||
"""Sets the status of the "userlabeled" column to a given value (True|False). This can done for ALL data in one go, or only for the UserSelection.
|
"""Sets the status of the "userlabeled" column to a given value (True|False). This can done for ALL data in one go, or only for the UserSelection.
|
||||||
@ -109,7 +117,7 @@ class TrackingData(QObject):
|
|||||||
logging.debug("TrackingData: (Re-)setting assignment status of %s to %s",
|
logging.debug("TrackingData: (Re-)setting assignment status of %s to %s",
|
||||||
"user selected data" if selection else " ALL", str(new_status))
|
"user selected data" if selection else " ALL", str(new_status))
|
||||||
if selection:
|
if selection:
|
||||||
self["userlabeled"][self._selection] = new_status
|
self["userlabeled"][self._selection_indices] = new_status
|
||||||
else:
|
else:
|
||||||
self["userlabeled"][:] = new_status
|
self["userlabeled"][:] = new_status
|
||||||
|
|
||||||
@ -123,12 +131,14 @@ class TrackingData(QObject):
|
|||||||
|
|
||||||
def deleteDetections(self, ids=None):
|
def deleteDetections(self, ids=None):
|
||||||
if ids is not None:
|
if ids is not None:
|
||||||
|
logging.debug("TrackingData.deleteDetections of %i detections", len(ids))
|
||||||
del_indices = self._find(ids)
|
del_indices = self._find(ids)
|
||||||
else:
|
else:
|
||||||
del_indices = self._indices
|
logging.debug("TrackingData.deleteDetections of all selected detections (%i)", len(self._selected_ids))
|
||||||
|
del_indices = self._selected_ids
|
||||||
for c in self._columns:
|
for c in self._columns:
|
||||||
self._data[c] = np.delete(self._data[c], del_indices, axis=0)
|
self._data[c] = np.delete(self._data[c], del_indices, axis=0)
|
||||||
self._indices = self["index"]
|
self._indices = self._indices[:-len(del_indices)]
|
||||||
self._selected_ids = np.setdiff1d(self._selected_ids, del_indices)
|
self._selected_ids = np.setdiff1d(self._selected_ids, del_indices)
|
||||||
|
|
||||||
def assignTracks(self, tracks:np.ndarray):
|
def assignTracks(self, tracks:np.ndarray):
|
||||||
@ -171,10 +181,10 @@ class TrackingData(QObject):
|
|||||||
and M is number of keypoints
|
and M is number of keypoints
|
||||||
"""
|
"""
|
||||||
if selection:
|
if selection:
|
||||||
if len(self._selection) < 1:
|
if len(self._selection_indices) < 1:
|
||||||
logging.info("TrackingData.coordinates returns empty array, not detections in range!")
|
logging.info("TrackingData.coordinates returns empty array, not detections in range!")
|
||||||
return np.ndarray([])
|
return np.ndarray([])
|
||||||
return np.stack(self["keypoints"][self._selection]).astype(np.float32)
|
return np.stack(self["keypoints"][self._selection_indices]).astype(np.float32)
|
||||||
return np.stack(self["keypoints"]).astype(np.float32)
|
return np.stack(self["keypoints"]).astype(np.float32)
|
||||||
|
|
||||||
def keypointScores(self, selection=False):
|
def keypointScores(self, selection=False):
|
||||||
@ -188,10 +198,10 @@ class TrackingData(QObject):
|
|||||||
with N the number of detections and M the number of keypoints.
|
with N the number of detections and M the number of keypoints.
|
||||||
"""
|
"""
|
||||||
if selection:
|
if selection:
|
||||||
if len(self._selection) < 1:
|
if len(self._selection_indices) < 1:
|
||||||
logging.info("TrackingData.scores returns empty array, not detections in range!")
|
logging.info("TrackingData.scores returns empty array, not detections in range!")
|
||||||
return None
|
return None
|
||||||
return np.stack(self["keypoint_score"][self._selection]).astype(np.float32)
|
return np.stack(self["keypoint_score"][self._selection_indices]).astype(np.float32)
|
||||||
return np.stack(self["keypoint_score"]).astype(np.float32)
|
return np.stack(self["keypoint_score"]).astype(np.float32)
|
||||||
|
|
||||||
def centerOfGravity(self, selection=False, threshold=0.8, nodes=[0,1,2]):
|
def centerOfGravity(self, selection=False, threshold=0.8, nodes=[0,1,2]):
|
||||||
|
@ -247,7 +247,7 @@ class ConsistencyWorker(QRunnable):
|
|||||||
if steps > 0 and f % steps == 0:
|
if steps > 0 and f % steps == 0:
|
||||||
progress += 1
|
progress += 1
|
||||||
self.signals.progress.emit(progress, processed, errors)
|
self.signals.progress.emit(progress, processed, errors)
|
||||||
self.signals.message.emit("Tracking stopped at frame %i.", f)
|
self.signals.message.emit(f"Tracking stopped at frame {f}.")
|
||||||
self.signals.stopped.emit(f)
|
self.signals.stopped.emit(f)
|
||||||
|
|
||||||
|
|
||||||
@ -548,11 +548,17 @@ class ConsistencyClassifier(QWidget):
|
|||||||
self._tracks = self._dataworker.tracks
|
self._tracks = self._dataworker.tracks
|
||||||
self._dataworker = None
|
self._dataworker = None
|
||||||
if np.sum(self._userlabeled) < 1:
|
if np.sum(self._userlabeled) < 1:
|
||||||
logging.error("ConsistencyTracker: I need at least 1 user-labeled frame to start with!")
|
msg = "ConsistencyTracker: I need at least 1 user-labeled frame to start with!"
|
||||||
|
logging.error(msg)
|
||||||
|
self._messagebox.append(msg)
|
||||||
self.setEnabled(False)
|
self.setEnabled(False)
|
||||||
else:
|
else:
|
||||||
t1_userlabeled = self._frames[self._userlabeled & (self._tracks == 1)]
|
t1_userlabeled = self._frames[self._userlabeled & (self._tracks == 1)]
|
||||||
t2_userlabeled = self._frames[self._userlabeled & (self._tracks == 2)]
|
t2_userlabeled = self._frames[self._userlabeled & (self._tracks == 2)]
|
||||||
|
if any([len(t1_userlabeled) == 0, len(t2_userlabeled)== 0]):
|
||||||
|
self._messagebox.append("Error preparing data! Make sure that the first user-labeled frames contain both tracks!")
|
||||||
|
self.setEnabled(False)
|
||||||
|
return
|
||||||
max_startframe = np.min([t1_userlabeled[-1], t2_userlabeled[-1]])
|
max_startframe = np.min([t1_userlabeled[-1], t2_userlabeled[-1]])
|
||||||
min_startframe = np.max([t1_userlabeled[0], t2_userlabeled[0]])
|
min_startframe = np.max([t1_userlabeled[0], t2_userlabeled[0]])
|
||||||
self._maxframes = np.max(self._frames)
|
self._maxframes = np.max(self._frames)
|
||||||
@ -574,10 +580,7 @@ class ConsistencyClassifier(QWidget):
|
|||||||
def stop(self):
|
def stop(self):
|
||||||
if self._worker is not None:
|
if self._worker is not None:
|
||||||
self._worker.stop()
|
self._worker.stop()
|
||||||
self._startbtn.setEnabled(True)
|
self._messagebox.append("Stopping tracking.")
|
||||||
self._proceedbtn.setEnabled(True)
|
|
||||||
self._stopbtn.setEnabled(False)
|
|
||||||
self._refreshbtn.setEnabled(True)
|
|
||||||
|
|
||||||
def start(self):
|
def start(self):
|
||||||
self._startbtn.setEnabled(False)
|
self._startbtn.setEnabled(False)
|
||||||
@ -590,6 +593,7 @@ class ConsistencyClassifier(QWidget):
|
|||||||
self._worker.signals.progress.connect(self.worker_progress)
|
self._worker.signals.progress.connect(self.worker_progress)
|
||||||
self._worker.signals.message.connect(self.worker_error)
|
self._worker.signals.message.connect(self.worker_error)
|
||||||
self._worker.signals.currentframe.connect(self.worker_frame)
|
self._worker.signals.currentframe.connect(self.worker_frame)
|
||||||
|
self._messagebox.append("Tracking in progress ...")
|
||||||
self.threadpool.start(self._worker)
|
self.threadpool.start(self._worker)
|
||||||
|
|
||||||
def worker_frame(self, frame):
|
def worker_frame(self, frame):
|
||||||
@ -602,8 +606,11 @@ class ConsistencyClassifier(QWidget):
|
|||||||
self.start()
|
self.start()
|
||||||
|
|
||||||
def refresh(self):
|
def refresh(self):
|
||||||
|
self.setEnabled(False)
|
||||||
self._dataworker = ConsitencyDataLoader(self._data)
|
self._dataworker = ConsitencyDataLoader(self._data)
|
||||||
self._dataworker.signals.stopped.connect(self.data_processed)
|
self._dataworker.signals.stopped.connect(self.data_processed)
|
||||||
|
self._messagebox.clear()
|
||||||
|
self._messagebox.append("Refreshing...")
|
||||||
self.threadpool.start(self._dataworker)
|
self.threadpool.start(self._dataworker)
|
||||||
|
|
||||||
def worker_progress(self, progress, processed, errors):
|
def worker_progress(self, progress, processed, errors):
|
||||||
@ -612,13 +619,15 @@ class ConsistencyClassifier(QWidget):
|
|||||||
self._assignedlabel.setText(str(processed))
|
self._assignedlabel.setText(str(processed))
|
||||||
|
|
||||||
def worker_stopped(self, frame):
|
def worker_stopped(self, frame):
|
||||||
self._apply_btn.setEnabled(True)
|
|
||||||
self._startbtn.setEnabled(True)
|
self._startbtn.setEnabled(True)
|
||||||
|
self._proceedbtn.setEnabled(True)
|
||||||
self._stopbtn.setEnabled(False)
|
self._stopbtn.setEnabled(False)
|
||||||
|
self._apply_btn.setEnabled(True)
|
||||||
|
self._refreshbtn.setEnabled(True)
|
||||||
self._startframe_spinner.setValue(frame-1)
|
self._startframe_spinner.setValue(frame-1)
|
||||||
self._proceedbtn.setEnabled(bool(frame < self._maxframes-1))
|
self._proceedbtn.setEnabled(bool(frame < self._maxframes-1))
|
||||||
self._refreshbtn.setEnabled(True)
|
|
||||||
self._processed_frames = frame
|
self._processed_frames = frame
|
||||||
|
self._messagebox.append("... done.")
|
||||||
|
|
||||||
def assignedTracks(self):
|
def assignedTracks(self):
|
||||||
return self._tracks
|
return self._tracks
|
||||||
@ -683,7 +692,7 @@ def main():
|
|||||||
import pickle
|
import pickle
|
||||||
from fixtracks.info import PACKAGE_ROOT
|
from fixtracks.info import PACKAGE_ROOT
|
||||||
|
|
||||||
datafile = PACKAGE_ROOT / "data/merged_small_starter.pkl"
|
datafile = PACKAGE_ROOT / "data/merged_small_beginning.pkl"
|
||||||
|
|
||||||
with open(datafile, "rb") as f:
|
with open(datafile, "rb") as f:
|
||||||
df = pickle.load(f)
|
df = pickle.load(f)
|
||||||
|
@ -128,20 +128,20 @@ class DetectionView(QWidget):
|
|||||||
del it
|
del it
|
||||||
|
|
||||||
def updateDetections(self, keypoint=-1):
|
def updateDetections(self, keypoint=-1):
|
||||||
|
logging.info("DetectionView.updateDetections!")
|
||||||
self.clearDetections()
|
self.clearDetections()
|
||||||
if self._data is None:
|
if self._data is None:
|
||||||
return
|
return
|
||||||
frames = self._data.selectedData("frame")
|
frames = self._data.selectedData("frame")
|
||||||
tracks = self._data.selectedData("track")
|
tracks = self._data.selectedData("track")
|
||||||
|
ids = self._data.selectedData("index")
|
||||||
coordinates = self._data.coordinates(selection=True)
|
coordinates = self._data.coordinates(selection=True)
|
||||||
centercoordinates = self._data.centerOfGravity(selection=True)
|
centercoordinates = self._data.centerOfGravity(selection=True)
|
||||||
userlabeled = self._data.selectedData("userlabeled")
|
userlabeled = self._data.selectedData("userlabeled")
|
||||||
|
|
||||||
indices = self._data.selectionIndices
|
|
||||||
image_rect = self._pixmapitem.boundingRect() if self._pixmapitem is not None else QRectF(0,0,0,0)
|
image_rect = self._pixmapitem.boundingRect() if self._pixmapitem is not None else QRectF(0,0,0,0)
|
||||||
|
|
||||||
for i, idx in enumerate(indices):
|
for i, (id, f, t, l) in enumerate(zip(ids, frames, tracks, userlabeled)):
|
||||||
t = tracks[i]
|
|
||||||
c = Tracks.fromValue(t).toColor()
|
c = Tracks.fromValue(t).toColor()
|
||||||
if keypoint >= 0:
|
if keypoint >= 0:
|
||||||
x = coordinates[i, keypoint, 0]
|
x = coordinates[i, keypoint, 0]
|
||||||
@ -151,10 +151,11 @@ class DetectionView(QWidget):
|
|||||||
y = centercoordinates[i, 1]
|
y = centercoordinates[i, 1]
|
||||||
|
|
||||||
item = Detection(image_rect.left() + x, image_rect.top() + y, 20, 20, brush=QBrush(c))
|
item = Detection(image_rect.left() + x, image_rect.top() + y, 20, 20, brush=QBrush(c))
|
||||||
item.setData(DetectionData.TRACK_ID.value, tracks[i])
|
item.setData(DetectionData.TRACK_ID.value, t)
|
||||||
item.setData(DetectionData.ID.value, idx)
|
item.setData(DetectionData.ID.value, id)
|
||||||
item.setData(DetectionData.COORDINATES.value, coordinates[i, :, :])
|
item.setData(DetectionData.COORDINATES.value, coordinates[i, :, :])
|
||||||
item.setData(DetectionData.FRAME.value, frames[i])
|
item.setData(DetectionData.FRAME.value, f)
|
||||||
|
item.setData(DetectionData.USERLABELED.value, l)
|
||||||
item = self._scene.addItem(item)
|
item = self._scene.addItem(item)
|
||||||
|
|
||||||
def fit_image_to_view(self):
|
def fit_image_to_view(self):
|
||||||
|
@ -4,7 +4,7 @@ import numpy as np
|
|||||||
from PySide6.QtCore import Qt, Signal, QSize
|
from PySide6.QtCore import Qt, Signal, QSize
|
||||||
from PySide6.QtGui import QFont
|
from PySide6.QtGui import QFont
|
||||||
from PySide6.QtWidgets import QWidget, QLabel, QPushButton, QSizePolicy
|
from PySide6.QtWidgets import QWidget, QLabel, QPushButton, QSizePolicy
|
||||||
from PySide6.QtWidgets import QGridLayout, QVBoxLayout
|
from PySide6.QtWidgets import QGridLayout, QVBoxLayout, QApplication
|
||||||
|
|
||||||
from fixtracks.utils.styles import pushBtnStyle
|
from fixtracks.utils.styles import pushBtnStyle
|
||||||
|
|
||||||
@ -15,6 +15,7 @@ class SelectionControls(QWidget):
|
|||||||
assignTwo = Signal()
|
assignTwo = Signal()
|
||||||
assignOther = Signal()
|
assignOther = Signal()
|
||||||
accept = Signal()
|
accept = Signal()
|
||||||
|
accept_until = Signal()
|
||||||
unaccept = Signal()
|
unaccept = Signal()
|
||||||
delete = Signal()
|
delete = Signal()
|
||||||
revertall = Signal()
|
revertall = Signal()
|
||||||
@ -102,7 +103,7 @@ class SelectionControls(QWidget):
|
|||||||
acceptBtn.setFont(font)
|
acceptBtn.setFont(font)
|
||||||
acceptBtn.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding)
|
acceptBtn.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding)
|
||||||
acceptBtn.setStyleSheet(pushBtnStyle("darkgray"))
|
acceptBtn.setStyleSheet(pushBtnStyle("darkgray"))
|
||||||
acceptBtn.setToolTip(f"Accept assignments of current selection as TRUE")
|
acceptBtn.setToolTip(f"Accept assignments of current selection as TRUE, Hold shift while clicking to accept all until here.")
|
||||||
acceptBtn.clicked.connect(self.on_Accept)
|
acceptBtn.clicked.connect(self.on_Accept)
|
||||||
|
|
||||||
unacceptBtn = QPushButton("un-accept")
|
unacceptBtn = QPushButton("un-accept")
|
||||||
@ -117,8 +118,9 @@ class SelectionControls(QWidget):
|
|||||||
deleteBtn.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding)
|
deleteBtn.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding)
|
||||||
deleteBtn.setStyleSheet(pushBtnStyle("red"))
|
deleteBtn.setStyleSheet(pushBtnStyle("red"))
|
||||||
deleteBtn.setToolTip(f"DANGERZONE! Delete current selection of detections!")
|
deleteBtn.setToolTip(f"DANGERZONE! Delete current selection of detections!")
|
||||||
deleteBtn.setEnabled(False)
|
deleteBtn.setShortcut("Ctrl+D")
|
||||||
deleteBtn.clicked.connect(self.on_Delete)
|
deleteBtn.clicked.connect(self.on_Delete)
|
||||||
|
deleteBtn.setEnabled(False)
|
||||||
|
|
||||||
revertBtn = QPushButton("revert assignments")
|
revertBtn = QPushButton("revert assignments")
|
||||||
revertBtn.setFont(font)
|
revertBtn.setFont(font)
|
||||||
@ -210,7 +212,12 @@ class SelectionControls(QWidget):
|
|||||||
|
|
||||||
def on_Accept(self):
|
def on_Accept(self):
|
||||||
logging.debug("SelectionControl: accept AssignmentBtn")
|
logging.debug("SelectionControl: accept AssignmentBtn")
|
||||||
self.accept.emit()
|
modifiers = QApplication.keyboardModifiers()
|
||||||
|
if modifiers == Qt.KeyboardModifier.ShiftModifier:
|
||||||
|
logging.debug("Shift key was pressed during accept")
|
||||||
|
self.accept_until.emit()
|
||||||
|
else:
|
||||||
|
self.accept.emit()
|
||||||
|
|
||||||
def on_Unaccept(self):
|
def on_Unaccept(self):
|
||||||
logging.debug("SelectionControl: revoke user assignmentBtn")
|
logging.debug("SelectionControl: revoke user assignmentBtn")
|
||||||
|
@ -1,11 +1,11 @@
|
|||||||
import logging
|
import logging
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
|
||||||
|
|
||||||
from PySide6.QtCore import Qt, QThreadPool, Signal
|
from PySide6.QtCore import Qt, QThreadPool, Signal
|
||||||
from PySide6.QtGui import QImage, QBrush, QColor
|
from PySide6.QtGui import QImage, QBrush, QColor
|
||||||
from PySide6.QtWidgets import QWidget, QVBoxLayout, QHBoxLayout, QLabel, QPushButton, QSizePolicy, QComboBox
|
from PySide6.QtWidgets import QWidget, QVBoxLayout, QHBoxLayout, QLabel, QPushButton, QSizePolicy, QComboBox
|
||||||
from PySide6.QtWidgets import QSpinBox, QSpacerItem, QProgressBar, QSplitter, QGridLayout, QFileDialog, QGridLayout
|
from PySide6.QtWidgets import QSpinBox, QSpacerItem, QProgressBar, QSplitter, QFileDialog, QMessageBox
|
||||||
|
|
||||||
|
|
||||||
from fixtracks.utils.reader import PickleLoader
|
from fixtracks.utils.reader import PickleLoader
|
||||||
from fixtracks.utils.writer import PickleWriter
|
from fixtracks.utils.writer import PickleWriter
|
||||||
@ -37,12 +37,6 @@ class FixTracks(QWidget):
|
|||||||
self._detectionView = DetectionView()
|
self._detectionView = DetectionView()
|
||||||
self._detectionView.signals.itemsSelected.connect(self.on_detectionsSelected)
|
self._detectionView.signals.itemsSelected.connect(self.on_detectionsSelected)
|
||||||
self._skeleton = SkeletonWidget()
|
self._skeleton = SkeletonWidget()
|
||||||
# self._skeleton.setMaximumSize(QSize(400, 400))
|
|
||||||
top_splitter = QSplitter(Qt.Orientation.Horizontal)
|
|
||||||
top_splitter.addWidget(self._detectionView)
|
|
||||||
top_splitter.addWidget(self._skeleton)
|
|
||||||
top_splitter.setStretchFactor(0, 2)
|
|
||||||
top_splitter.setStretchFactor(1, 1)
|
|
||||||
|
|
||||||
self._progress_bar = QProgressBar(self)
|
self._progress_bar = QProgressBar(self)
|
||||||
self._progress_bar.setMaximumHeight(20)
|
self._progress_bar.setMaximumHeight(20)
|
||||||
@ -61,15 +55,18 @@ class FixTracks(QWidget):
|
|||||||
self._keypointcombo = QComboBox()
|
self._keypointcombo = QComboBox()
|
||||||
self._keypointcombo.currentIndexChanged.connect(self.on_keypointSelected)
|
self._keypointcombo.currentIndexChanged.connect(self.on_keypointSelected)
|
||||||
|
|
||||||
combo_layout = QGridLayout()
|
combo_layout = QHBoxLayout()
|
||||||
combo_layout.addWidget(QLabel("Window:"), 0, 0)
|
combo_layout.addWidget(QLabel("Window width:"))
|
||||||
combo_layout.addWidget(self._windowspinner, 0, 1)
|
combo_layout.addWidget(self._windowspinner)
|
||||||
combo_layout.addWidget(QLabel("Keypoint:"), 1, 0)
|
combo_layout.addWidget(QLabel("frames"))
|
||||||
combo_layout.addWidget(self._keypointcombo, 1, 1)
|
combo_layout.addWidget(QLabel("Keypoint:"))
|
||||||
|
combo_layout.addWidget(self._keypointcombo)
|
||||||
|
combo_layout.addItem(QSpacerItem(100, 10, QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Fixed))
|
||||||
|
|
||||||
timelinebox = QHBoxLayout()
|
timelinebox = QVBoxLayout()
|
||||||
timelinebox.addWidget(self._timeline)
|
timelinebox.setSpacing(2)
|
||||||
timelinebox.addLayout(combo_layout)
|
timelinebox.addLayout(combo_layout)
|
||||||
|
timelinebox.addWidget(self._timeline)
|
||||||
|
|
||||||
self._controls_widget = SelectionControls()
|
self._controls_widget = SelectionControls()
|
||||||
self._controls_widget.assignOne.connect(self.on_assignOne)
|
self._controls_widget.assignOne.connect(self.on_assignOne)
|
||||||
@ -78,6 +75,7 @@ class FixTracks(QWidget):
|
|||||||
self._controls_widget.fwd.connect(self.on_forward)
|
self._controls_widget.fwd.connect(self.on_forward)
|
||||||
self._controls_widget.back.connect(self.on_backward)
|
self._controls_widget.back.connect(self.on_backward)
|
||||||
self._controls_widget.accept.connect(self.on_setUserFlag)
|
self._controls_widget.accept.connect(self.on_setUserFlag)
|
||||||
|
self._controls_widget.accept_until.connect(self.on_setUserFlagsUntil)
|
||||||
self._controls_widget.unaccept.connect(self.on_unsetUserFlag)
|
self._controls_widget.unaccept.connect(self.on_unsetUserFlag)
|
||||||
self._controls_widget.delete.connect(self.on_deleteDetection)
|
self._controls_widget.delete.connect(self.on_deleteDetection)
|
||||||
self._controls_widget.revertall.connect(self.on_revertUserFlags)
|
self._controls_widget.revertall.connect(self.on_revertUserFlags)
|
||||||
@ -118,7 +116,8 @@ class FixTracks(QWidget):
|
|||||||
cntrlBox = QHBoxLayout()
|
cntrlBox = QHBoxLayout()
|
||||||
cntrlBox.addWidget(self._classifier)
|
cntrlBox.addWidget(self._classifier)
|
||||||
cntrlBox.addWidget(self._controls_widget, alignment=Qt.AlignmentFlag.AlignCenter)
|
cntrlBox.addWidget(self._controls_widget, alignment=Qt.AlignmentFlag.AlignCenter)
|
||||||
cntrlBox.addItem(QSpacerItem(300, 100, QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Expanding))
|
cntrlBox.addWidget(self._skeleton)
|
||||||
|
cntrlBox.addItem(QSpacerItem(50, 100, QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Expanding))
|
||||||
|
|
||||||
vbox = QVBoxLayout()
|
vbox = QVBoxLayout()
|
||||||
vbox.addLayout(timelinebox)
|
vbox.addLayout(timelinebox)
|
||||||
@ -128,7 +127,7 @@ class FixTracks(QWidget):
|
|||||||
container.setLayout(vbox)
|
container.setLayout(vbox)
|
||||||
|
|
||||||
splitter = QSplitter(Qt.Orientation.Vertical)
|
splitter = QSplitter(Qt.Orientation.Vertical)
|
||||||
splitter.addWidget(top_splitter)
|
splitter.addWidget(self._detectionView)
|
||||||
splitter.addWidget(container)
|
splitter.addWidget(container)
|
||||||
splitter.setStretchFactor(0, 3)
|
splitter.setStretchFactor(0, 3)
|
||||||
splitter.setStretchFactor(1, 1)
|
splitter.setStretchFactor(1, 1)
|
||||||
@ -270,6 +269,12 @@ class FixTracks(QWidget):
|
|||||||
self._timeline.update()
|
self._timeline.update()
|
||||||
self.update()
|
self.update()
|
||||||
|
|
||||||
|
def on_setUserFlagsUntil(self):
|
||||||
|
self._data.setSelectionRange("frame", 0, self._currentWindowPos + self._currentWindowWidth)
|
||||||
|
self._data.setUserLabeledStatus(True)
|
||||||
|
self._timeline.update()
|
||||||
|
self.update()
|
||||||
|
|
||||||
def on_unsetUserFlag(self):
|
def on_unsetUserFlag(self):
|
||||||
logging.debug("Tracks:unsetUserFlag")
|
logging.debug("Tracks:unsetUserFlag")
|
||||||
self._data.setUserLabeledStatus(False)
|
self._data.setUserLabeledStatus(False)
|
||||||
@ -278,14 +283,30 @@ class FixTracks(QWidget):
|
|||||||
|
|
||||||
def on_revertUserFlags(self):
|
def on_revertUserFlags(self):
|
||||||
logging.debug("Tracks:revert ALL UserFlags and track assignments")
|
logging.debug("Tracks:revert ALL UserFlags and track assignments")
|
||||||
self._data.revertUserLabeledStatus()
|
msg_box = QMessageBox()
|
||||||
self._data.revertTrackAssignments()
|
msg_box.setIcon(QMessageBox.Icon.Warning)
|
||||||
|
msg_box.setText(f"Are you sure you want to revert ALL track assignments?")
|
||||||
|
msg_box.setStandardButtons(QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No)
|
||||||
|
msg_box.setDefaultButton(QMessageBox.StandardButton.No)
|
||||||
|
ret = msg_box.exec()
|
||||||
|
|
||||||
|
if ret == QMessageBox.StandardButton.Yes:
|
||||||
|
self._data.revertUserLabeledStatus()
|
||||||
|
self._data.revertTrackAssignments()
|
||||||
self._timeline.update()
|
self._timeline.update()
|
||||||
self.update()
|
self.update()
|
||||||
|
|
||||||
def on_deleteDetection(self):
|
def on_deleteDetection(self):
|
||||||
logging.warning("Tracks:delete detections is currently not supported!")
|
logging.info("Tracks:deleting detections!")
|
||||||
# self._data.deleteDetections()
|
msg_box = QMessageBox()
|
||||||
|
msg_box.setIcon(QMessageBox.Icon.Warning)
|
||||||
|
msg_box.setText(f"Are you sure you want to delete the selected ({len(self._data.selectionIndices)})detections?")
|
||||||
|
msg_box.setStandardButtons(QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No)
|
||||||
|
msg_box.setDefaultButton(QMessageBox.StandardButton.No)
|
||||||
|
ret = msg_box.exec()
|
||||||
|
|
||||||
|
if ret == QMessageBox.StandardButton.Yes:
|
||||||
|
self._data.deleteDetections()
|
||||||
self._timeline.update()
|
self._timeline.update()
|
||||||
self.update()
|
self.update()
|
||||||
|
|
||||||
@ -306,7 +327,10 @@ class FixTracks(QWidget):
|
|||||||
"""
|
"""
|
||||||
self._currentWindowWidth = value
|
self._currentWindowWidth = value
|
||||||
logging.debug("Tracks:OnWindowSizeChanged %i franes", value)
|
logging.debug("Tracks:OnWindowSizeChanged %i franes", value)
|
||||||
self._timeline.setWindowWidth(self._currentWindowWidth / self._maxframes)
|
if self._maxframes == 0:
|
||||||
|
self._timeline.setWindowWidth(self._currentWindowWidth / 2000)
|
||||||
|
else:
|
||||||
|
self._timeline.setWindowWidth(self._currentWindowWidth / self._maxframes)
|
||||||
self._controls_widget.setSelectedTracks(None)
|
self._controls_widget.setSelectedTracks(None)
|
||||||
|
|
||||||
def on_detectionsSelected(self, detections):
|
def on_detectionsSelected(self, detections):
|
||||||
|
Loading…
Reference in New Issue
Block a user