Compare commits

...

14 Commits

11 changed files with 302 additions and 230 deletions

80
FixTracks.py Normal file
View File

@ -0,0 +1,80 @@
"""
pyside6-rcc resources.qrc -o resources.py
"""
import sys
import logging
import argparse
import platform
from PySide6.QtWidgets import QApplication
from PySide6.QtCore import QSettings
from PySide6.QtGui import QIcon, QPalette
from fixtracks import info, mainwindow
logging.basicConfig(level=logging.INFO, force=True)
def is_dark_mode(app: QApplication) -> bool:
palette = app.palette()
# Check the brightness of the window text and base colors
text_color = palette.color(QPalette.ColorRole.WindowText)
base_color = palette.color(QPalette.ColorRole.Base)
# Calculate brightness (0 for dark, 255 for bright)
def brightness(color):
return (color.red() * 299 + color.green() * 587 + color.blue() * 114) // 1000
return brightness(base_color) < brightness(text_color)
def set_logging(loglevel):
logging.basicConfig(level=loglevel, force=True)
def main(args):
set_logging(logging.DEBUG)
if platform.system() == "Windows":
# from PySide6.QtWinExtras import QtWin
myappid = f"{info.organization_name}.{info.application_version}"
# QtWin.setCurrentProcessExplicitAppUserModelID(myappid)
settings = QSettings()
width = int(settings.value("app/width", 1024))
height = int(settings.value("app/height", 768))
x = int(settings.value("app/pos_x", 100))
y = int(settings.value("app/pos_y", 100))
app = QApplication(sys.argv)
app.setApplicationName(info.application_name)
app.setApplicationVersion(str(info.application_version))
app.setOrganizationDomain(info.organization_name)
# if platform.system() == 'Linux':
# icn = QIcon(":/icons/app_icon")
# app.setWindowIcon(icn)
# Create a Qt widget, which will be our window.
window = mainwindow.MainWindow(is_dark_mode(app))
window.setGeometry(100, 100, 1024, 768)
window.setWindowTitle("FixTracks")
window.setMinimumWidth(1024)
window.setMinimumHeight(768)
window.resize(width, height)
window.move(x, y)
window.show()
# Start the event loop.
app.exec()
pos = window.pos()
settings.setValue("app/width", window.width())
settings.setValue("app/height", window.height())
settings.setValue("app/pos_x", pos.x())
settings.setValue("app/pos_y", pos.y())
if __name__ == "__main__":
levels = {"critical": logging.CRITICAL, "error": logging.ERROR, "warning":logging.WARNING, "info":logging.INFO, "debug":logging.DEBUG}
parser = argparse.ArgumentParser(description="FixTracks. Tools for fixing animal tracking")
parser.add_argument("-ll", "--loglevel", type=str, default="INFO", help=f"The log level that should be used. Valid levels are {[str(k) for k in levels.keys()]}")
args = parser.parse_args()
args.loglevel = levels[args.loglevel if args.loglevel.lower() in levels else "info"]
main(args)

View File

@ -144,8 +144,8 @@ class MainWindow(QMainWindow):
about.show() about.show()
def on_help(self, s): def on_help(self, s):
help = HelpDialog(self) help_dlg = HelpDialog(self)
help.show() help_dlg.show()
# @Slot(None) # @Slot(None)
def exit_request(self): def exit_request(self):

View File

@ -1,7 +1,29 @@
from enum import Enum from enum import Enum
from PySide6.QtGui import QColor
class DetectionData(Enum): class DetectionData(Enum):
ID = 0 ID = 0
FRAME = 1 FRAME = 1
COORDINATES = 2 COORDINATES = 2
TRACK_ID = 3 TRACK_ID = 3
class Tracks(Enum):
TRACKONE = 1
TRACKTWO = 2
UNASSIGNED = -1
def toColor(self):
track_colors = {
Tracks.TRACKONE: QColor.fromString("orange"),
Tracks.TRACKTWO: QColor.fromString("green"),
Tracks.UNASSIGNED: QColor.fromString("red")
}
return track_colors.get(self, QColor(128, 128, 128)) # Default to black if not found
@staticmethod
def fromValue(value):
for track in Tracks:
if track.value == value:
return track
return Tracks.UNASSIGNED

View File

@ -5,6 +5,7 @@ import pandas as pd
from PySide6.QtCore import QObject from PySide6.QtCore import QObject
class TrackingData(QObject): class TrackingData(QObject):
def __init__(self, parent=None): def __init__(self, parent=None):
super().__init__(parent) super().__init__(parent)
@ -58,10 +59,15 @@ class TrackingData(QObject):
self._start = start self._start = start
self._stop = stop self._stop = stop
self._selection_column = col self._selection_column = col
self._indices = np.where((self._data[col] >= self._start) & (self._data[col] < self._stop))[0] col_indices = np.where((self._data[col] >= self._start) & (self._data[col] < self._stop))[0]
self._indices = self["index"][col_indices]
if len(col_indices) < 1:
logging.warning("TrackingData: Selection range is empty!")
def selectedData(self, col): def selectedData(self, col:str):
return self._data[col][self._indices] if col not in self.columns:
logging.error("TrackingData:selectedData: Invalid column name! %s", col)
return self[col][self._indices]
def setUserSelection(self, ids): def setUserSelection(self, ids):
""" """
@ -84,14 +90,20 @@ class TrackingData(QObject):
userFlag : bool userFlag : bool
Should the "userlabeled" state of the detections be set to True or False? Should the "userlabeled" state of the detections be set to True or False?
""" """
self._data["track"][self._user_selections] = track_id self["track"][self._user_selections] = track_id
self.setAssignmentStatus(userFlag) self.setAssignmentStatus(userFlag)
def setAssignmentStatus(self, isTrue: bool): def setAssignmentStatus(self, isTrue: bool):
self._data["userlabeled"][self._user_selections] = isTrue logging.debug("TrackingData:Re-setting assignment status of user selected data to %s", str(isTrue))
self["userlabeled"][self._user_selections] = isTrue
def revertAssignmentStatus(self): def revertAssignmentStatus(self):
self._data["userlabeled"][:] = False logging.debug("TrackingData:Un-setting assignment status of all data!")
self["userlabeled"][:] = False
def revertTrackAssignments(self):
logging.debug("TrackingData: Reverting all track assignments!")
self["track"][:] = -1
def deleteDetections(self): def deleteDetections(self):
# from IPython import embed # from IPython import embed
@ -102,21 +114,21 @@ class TrackingData(QObject):
# embed() # embed()
pass pass
def assignTracks(self, tracks): def assignTracks(self, tracks:np.ndarray):
"""assignTracks _summary_ """assigns the given tracks to the user-selected detections. If the sizes of
provided tracks and the user selection do not match and error is logged and the tracks are not set.
Parameters Parameters
---------- ----------
tracks : _type_ tracks : np.ndarray
_description_ The track information.
Returns Returns
------- -------
_type_ None
_description_
""" """
if len(tracks) != self.numDetections: if len(tracks) != self.numDetections:
logging.error("DataController: Size of passed tracks does not match data!") logging.error("Trackingdata: Size of passed tracks does not match data!")
return return
self._data["track"] = tracks self._data["track"] = tracks
@ -142,11 +154,13 @@ class TrackingData(QObject):
and M is number of keypoints and M is number of keypoints
""" """
if selection: if selection:
return np.stack(self._data["keypoints"][self._start:self._stop, :, :]).astype(np.float32) if len(self._indices) < 1:
else: logging.info("TrackingData.coordinates returns empty array, not detections in range!")
return np.ndarray([])
return np.stack(self._data["keypoints"][self._indices]).astype(np.float32)
return np.stack(self._data["keypoints"]).astype(np.float32) return np.stack(self._data["keypoints"]).astype(np.float32)
def keypointScores(self): def keypointScores(self, selection=False):
""" """
Returns the keypoint scores as a NumPy array of type float32. Returns the keypoint scores as a NumPy array of type float32.
@ -156,9 +170,14 @@ class TrackingData(QObject):
A NumPy array of type float32 containing the keypoint scores of the shape (N, M) A NumPy array of type float32 containing the keypoint scores of the shape (N, M)
with N the number of detections and M the number of keypoints. with N the number of detections and M the number of keypoints.
""" """
if selection:
if len(self._indices) < 1:
logging.info("TrackingData.scores returns empty array, not detections in range!")
return np.ndarray([])
return np.stack(self._data["keypoint_score"][self._indices]).astype(np.float32)
return np.stack(self._data["keypoint_score"]).astype(np.float32) return np.stack(self._data["keypoint_score"]).astype(np.float32)
def centerOfGravity(self, threshold=0.8): def centerOfGravity(self, selection=False, threshold=0.8, nodes=[0,1,2]):
""" """
Calculate the center of gravity of keypoints weighted by their scores. Ignores keypoints that have a score Calculate the center of gravity of keypoints weighted by their scores. Ignores keypoints that have a score
less than threshold. less than threshold.
@ -166,16 +185,19 @@ class TrackingData(QObject):
Parameters: Parameters:
----------- -----------
threshold: float threshold: float
keypoints with a score less than threshold are ignored nodes with a score less than threshold are ignored
nodes: list
nodes/keypoints to consider for estimation. Defaults to [0,1,2]
Returns: Returns:
-------- --------
np.ndarray: np.ndarray:
A NumPy array of shape (N, 2) containing the center of gravity for each detection. A NumPy array of shape (N, 2) containing the center of gravity for each detection.
""" """
scores = self.keypointScores() scores = self.keypointScores(selection)
scores[scores < threshold] = 0.0 scores[scores < threshold] = 0.0
weighted_coords = self.coordinates() * scores[:, :, np.newaxis] scores[:, np.setdiff1d(np.arange(scores.shape[1]), nodes)] = 0.0
weighted_coords = self.coordinates(selection=selection) * scores[:, :, np.newaxis]
sum_scores = np.sum(scores, axis=1, keepdims=True) sum_scores = np.sum(scores, axis=1, keepdims=True)
center_of_gravity = np.sum(weighted_coords, axis=1) / sum_scores center_of_gravity = np.sum(weighted_coords, axis=1) / sum_scores
return center_of_gravity return center_of_gravity
@ -279,19 +301,19 @@ def main():
frames = data["frame"] frames = data["frame"]
tracks = data["track"] tracks = data["track"]
bendedness = data.bendedness() bendedness = data.bendedness()
positions = data.coordinates()[[160388, 160389]] # positions = data.coordinates()[[160388, 160389]]
embed() embed()
tracks = data["track"] tracks = data["track"]
cogs = all_cogs[tracks==1] cogs = all_cogs[tracks==1]
all_dists = neighborDistances(cogs, 2, False) all_dists = neighborDistances(cogs, 2, False)
plt.hist(all_dists[1:, 0], bins=1000) # plt.hist(all_dists[1:, 0], bins=1000)
print(np.percentile(all_dists[1:, 0], 99)) # print(np.percentile(all_dists[1:, 0], 99))
print(np.percentile(all_dists[1:, 0], 1)) # print(np.percentile(all_dists[1:, 0], 1))
plt.gca().set_xscale("log") # plt.gca().set_xscale("log")
plt.gca().set_yscale("log") # plt.gca().set_yscale("log")
# plt.hist(all_dists[1:, 1], bins=100) # plt.hist(all_dists[1:, 1], bins=100)
plt.show() # plt.show()
# def compute_neighbor_distances(cogs, window=10): # def compute_neighbor_distances(cogs, window=10):
# distances = [] # distances = []
# for i in range(len(cogs)): # for i in range(len(cogs)):
@ -303,9 +325,6 @@ def main():
# return distances # return distances
# print("estimating neighorhood distances") # print("estimating neighorhood distances")
# neighbor_distances = compute_neighbor_distances(cogs) # neighbor_distances = compute_neighbor_distances(cogs)
embed()
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -34,6 +34,9 @@ class ConsitencyDataLoader(QRunnable):
@Slot() @Slot()
def run(self): def run(self):
if self.data is None:
logging.error("ConsistencyTracker.DataLoader failed. No Data!")
return
self.positions = self.data.centerOfGravity() self.positions = self.data.centerOfGravity()
self.orientations = self.data.orientation() self.orientations = self.data.orientation()
self.lengths = self.data.animalLength() self.lengths = self.data.animalLength()
@ -464,9 +467,6 @@ class ConsistencyClassifier(QWidget):
self.setEnabled(False) self.setEnabled(False)
self._progressbar.setRange(0,0) self._progressbar.setRange(0,0)
self._data = data self._data = data
self._dataworker = ConsitencyDataLoader(self._data)
self._dataworker.signals.stopped.connect(self.data_processed)
self.threadpool.start(self._dataworker)
@Slot() @Slot()
def data_processed(self): def data_processed(self):
@ -482,6 +482,7 @@ class ConsistencyClassifier(QWidget):
self._frames = self._dataworker.frames self._frames = self._dataworker.frames
self._tracks = self._dataworker.tracks self._tracks = self._dataworker.tracks
self._maxframes = np.max(self._frames) self._maxframes = np.max(self._frames)
# FIXME the following line causes an error when there are no detections in the range
min_frame = max([self._frames[self._tracks == 1][0], self._frames[self._tracks == 2][0]]) + 1 min_frame = max([self._frames[self._tracks == 1][0], self._frames[self._tracks == 2][0]]) + 1
self._maxframeslabel.setText(str(self._maxframes)) self._maxframeslabel.setText(str(self._maxframes))
self._startframe_spinner.setMinimum(min_frame) self._startframe_spinner.setMinimum(min_frame)
@ -525,7 +526,9 @@ class ConsistencyClassifier(QWidget):
self.start() self.start()
def refresh(self): def refresh(self):
self.setData(self._data) self._dataworker = ConsitencyDataLoader(self._data)
self._dataworker.signals.stopped.connect(self.data_processed)
self.threadpool.start(self._dataworker)
def worker_progress(self, progress, processed, errors): def worker_progress(self, progress, processed, errors):
self._progressbar.setValue(progress) self._progressbar.setValue(progress)
@ -556,7 +559,8 @@ class ClassifierWidget(QTabWidget):
self._consistency_tracker = ConsistencyClassifier() self._consistency_tracker = ConsistencyClassifier()
self.addTab(self._size_classifier, SizeClassifier.name) self.addTab(self._size_classifier, SizeClassifier.name)
self.addTab(self._consistency_tracker, ConsistencyClassifier.name) self.addTab(self._consistency_tracker, ConsistencyClassifier.name)
self.tabBarClicked.connect(self.update) # self.tabBarClicked.connect(self.update)
self.currentChanged.connect(self.tabChanged)
self._size_classifier.apply.connect(self._on_applySizeClassifier) self._size_classifier.apply.connect(self._on_applySizeClassifier)
self._consistency_tracker.apply.connect(self._on_applyConsistencyTracker) self._consistency_tracker.apply.connect(self._on_applyConsistencyTracker)
@ -576,12 +580,21 @@ class ClassifierWidget(QTabWidget):
def consistency_tracker(self): def consistency_tracker(self):
return self._consistency_tracker return self._consistency_tracker
@Slot()
def tabChanged(self):
if isinstance(self.currentWidget(), ConsistencyClassifier):
self.consistency_tracker.refresh()
@Slot() @Slot()
def update(self): def update(self):
self.consistency_tracker.setData(self._data) if isinstance(self.currentWidget(), ConsistencyClassifier):
self.consistency_tracker.refresh()
def setData(self, data:TrackingData): def setData(self, data:TrackingData):
self._data = data self._data = data
self.consistency_tracker.setData(data)
coordinates = self._data.coordinates()
self._size_classifier.setCoordinates(coordinates)
def as_dict(df): def as_dict(df):
d = {c: df[c].values for c in df.columns} d = {c: df[c].values for c in df.columns}

View File

@ -1,13 +1,14 @@
import logging import logging
import numpy as np import numpy as np
import pandas as pd
from PySide6.QtCore import Qt from PySide6.QtCore import Qt
from PySide6.QtWidgets import QWidget, QVBoxLayout, QSizePolicy, QLabel from PySide6.QtWidgets import QWidget, QVBoxLayout, QSizePolicy, QLabel
from PySide6.QtWidgets import QGraphicsView, QGraphicsScene, QGraphicsItem, QGraphicsRectItem, QGraphicsLineItem from PySide6.QtWidgets import QGraphicsView, QGraphicsScene, QGraphicsItem, QGraphicsRectItem, QGraphicsLineItem, QGraphicsEllipseItem
from PySide6.QtCore import Qt, QRectF, QRectF from PySide6.QtCore import Qt, QRectF, QRectF
from PySide6.QtGui import QBrush, QColor, QPen, QFont from PySide6.QtGui import QBrush, QColor, QPen, QFont
from fixtracks.utils.signals import DetectionTimelineSignals from fixtracks.utils.signals import DetectionTimelineSignals
from fixtracks.utils.trackingdata import TrackingData
class Window(QGraphicsRectItem): class Window(QGraphicsRectItem):
@ -39,7 +40,6 @@ class Window(QGraphicsRectItem):
self.setRect(r) self.setRect(r)
self.signals.windowMoved.emit() self.signals.windowMoved.emit()
def setWindow(self, newx:float, newwidth:float):
def setWindow(self, newx:float, newwidth:float): def setWindow(self, newx:float, newwidth:float):
""" """
Update the window to the specified range. Update the window to the specified range.
@ -53,11 +53,11 @@ class Window(QGraphicsRectItem):
------- -------
None None
""" """
logging.debug("timeline.window: update window to range %.5f to %.5f", newx, newwidth) logging.debug("timeline.window: update window to range %.5f to %.5f", newx, newwidth)
self._width = newwidth self._width = newwidth
r = self.rect() r = self.rect()
self.setRect(newx, r.y(), self._width, r.height()) self.setRect(newx, r.y(), self._width, r.height())
self.update()
self.signals.windowMoved.emit() self.signals.windowMoved.emit()
def mouseMoveEvent(self, event): def mouseMoveEvent(self, event):
@ -86,40 +86,44 @@ class Window(QGraphicsRectItem):
class DetectionTimeline(QWidget): class DetectionTimeline(QWidget):
signals = DetectionTimelineSignals() signals = DetectionTimelineSignals()
def __init__(self, detectiondata=None, trackone_id=1, tracktwo_id=2, parent=None): def __init__(self, trackone_id=1, tracktwo_id=2, parent=None):
super().__init__(parent) super().__init__(parent)
self._trackone = trackone_id self._trackone = trackone_id
self._tracktwo = tracktwo_id self._tracktwo = tracktwo_id
self._data = detectiondata self._data = None
self._rangeStart = 0.0 self._rangeStart = 0.0
self._rangeStop = 0.005 self._rangeStop = 0.005
self._total_width = 2000 self._total_width = 2000
self._stepCount = 300 self._stepCount = 1000
self._bg_brush = QBrush(QColor(20, 20, 20, 255)) self._bg_brush = QBrush(QColor(20, 20, 20, 255))
transparent_brush = QBrush(QColor(200, 200, 200, 64)) transparent_brush = QBrush(QColor(200, 200, 200, 64))
self._white_pen = QPen(QColor.fromString("white")) self._white_pen = QPen(QColor.fromString("white"))
self._white_pen.setWidth(0.1) self._white_pen.setWidth(0.1)
self._t1_pen = QPen(QColor.fromString("orange")) self._t1_pen = QPen(QColor.fromString("orange"))
self._t1_pen.setWidth(2) self._t1_pen.setWidth(1)
self._t2_pen = QPen(QColor(0, 255, 0, 255)) self._t2_pen = QPen(QColor(0, 255, 0, 255))
self._t2_pen.setWidth(2) self._t2_pen.setWidth(1)
self._other_pen = QPen(QColor.fromString("red")) self._other_pen = QPen(QColor.fromString("red"))
self._other_pen.setWidth(2) self._other_pen.setWidth(1)
axis_pen = QPen(QColor.fromString("white")) window_pen = QPen(QColor.fromString("white"))
axis_pen.setWidth(2) window_pen.setWidth(2)
self._user_brush = QBrush(QColor.fromString("white"))
user_pen = QPen(QColor.fromString("white"))
user_pen.setWidth(2)
font = QFont() font = QFont()
font.setPointSize(15) font.setPointSize(15)
font.setBold(False) font.setBold(True)
self._window = Window(0, 0, 100, 60, axis_pen, transparent_brush) self._window = Window(0, 0, 100, 60, window_pen, transparent_brush)
self._window.signals.windowMoved.connect(self.on_windowMoved) self._window.signals.windowMoved.connect(self.on_windowMoved)
self._scene = QGraphicsScene(QRectF(0, 0, self._total_width, 65.)) self._scene = QGraphicsScene(QRectF(0, 0, self._total_width, 85.))
self._scene.setBackgroundBrush(self._bg_brush) self._scene.setBackgroundBrush(self._bg_brush)
self._scene.addItem(self._window) self._scene.addItem(self._window)
self._view = QGraphicsView() self._view = QGraphicsView()
# self._view.setRenderHints(QPainter.RenderHint.Antialiasing | QPainter.RenderHint.SmoothPixmapTransform); # self._view.setRenderHints(QPainter.RenderHint.Antialiasing | QPainter.RenderHint.SmoothPixmapTransform)
self._view.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding) self._view.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding)
self._view.setScene(self._scene) self._view.setScene(self._scene)
self._view.fitInView(self._scene.sceneRect(), aspectRadioMode=Qt.AspectRatioMode.KeepAspectRatio) self._view.fitInView(self._scene.sceneRect(), aspectRadioMode=Qt.AspectRatioMode.KeepAspectRatio)
@ -136,6 +140,10 @@ class DetectionTimeline(QWidget):
other_label.setFont(font) other_label.setFont(font)
other_label.setDefaultTextColor(self._other_pen.color()) other_label.setDefaultTextColor(self._other_pen.color())
other_label.setPos(200, 50) other_label.setPos(200, 50)
user_label = self._scene.addText("user-labeled", font)
user_label.setFont(font)
user_label.setDefaultTextColor(user_pen.color())
user_label.setPos(350, 50)
self._position_label = QLabel("") self._position_label = QLabel("")
f = self._position_label.font() f = self._position_label.font()
@ -146,51 +154,48 @@ class DetectionTimeline(QWidget):
layout.addWidget(self._view) layout.addWidget(self._view)
layout.addWidget(self._position_label, Qt.AlignmentFlag.AlignRight) layout.addWidget(self._position_label, Qt.AlignmentFlag.AlignRight)
self.setLayout(layout) self.setLayout(layout)
if self._data is not None:
self.draw_coverage()
# self.setMaximumHeight(100) # self.setMaximumHeight(100)
# self.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Fixed) # self.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Fixed)
def setDetectionData(self, data): def clear(self):
self._data = data
for i in self._scene.items(): for i in self._scene.items():
if isinstance(i, QGraphicsLineItem): if isinstance(i, (QGraphicsLineItem, QGraphicsEllipseItem)):
self._scene.removeItem(i) self._scene.removeItem(i)
def setData(self, data:TrackingData):
logging.debug("Timeline: setData!")
self._data = data
self.update()
def update(self):
self.clear()
self.draw_coverage() self.draw_coverage()
def draw_coverage(self): def draw_coverage(self):
# FIXME this must be disentangled. timeline should not have to deal with two different ways of data storage logging.debug("Timeline: drawCoverage!")
if isinstance(self._data, pd.DataFrame): if isinstance(self._data, TrackingData):
maxframe = np.max(self._data.frame.values)
bins = np.linspace(0, maxframe, self._stepCount)
pos = np.linspace(0, self._scene.width(), self._stepCount)
track1_frames = self._data.frame.values[self._data.track == self._trackone]
track2_frames = self._data.frame.values[self._data.track == self._tracktwo]
other_frames = self._data.frame.values[(self._data.track != self._trackone) &
(self._data.track != self._tracktwo)]
elif isinstance(self._data, dict):
maxframe = np.max(self._data["frame"]) maxframe = np.max(self._data["frame"])
bins = np.linspace(0, maxframe, self._stepCount) bins = np.linspace(0, maxframe, self._stepCount)
pos = np.linspace(0, self._scene.width(), self._stepCount) pos = np.linspace(0, self._scene.width(), self._stepCount) # of the vertical dashes is this correct?
track1_frames = self._data["frame"][self._data["track"] == self._trackone] track1_frames = self._data["frame"][self._data["track"] == self._trackone]
track2_frames = self._data["frame"][self._data["track"] == self._tracktwo] track2_frames = self._data["frame"][self._data["track"] == self._tracktwo]
other_frames = self._data["frame"][(self._data["track"] != self._trackone) & other_frames = self._data["frame"][(self._data["track"] != self._trackone) &
(self._data["track"] != self._tracktwo)] (self._data["track"] != self._tracktwo)]
userlabeled = self._data["frame"][self._data["userlabeled"]]
else: else:
print("Data is not trackingdata")
return return
t1_coverage, _ = np.histogram(track1_frames, bins=bins) t1_coverage, _ = np.histogram(track1_frames, bins=bins)
t1_coverage = t1_coverage > 0
t2_coverage, _ = np.histogram(track2_frames, bins=bins) t2_coverage, _ = np.histogram(track2_frames, bins=bins)
t2_coverage = t2_coverage > 0
other_coverage, _ = np.histogram(other_frames, bins=bins) other_coverage, _ = np.histogram(other_frames, bins=bins)
other_coverage = other_coverage > 0 labeled_coverage, _ = np.histogram(userlabeled, bins=bins)
for i in range(len(t1_coverage)-1): for i in range(len(bins)-1):
if t1_coverage[i]: self._scene.addLine(pos[i], 0, pos[i], 15., pen=self._t1_pen) if t1_coverage[i]: self._scene.addLine(pos[i], 0, pos[i], 15., pen=self._t1_pen)
if t2_coverage[i]: self._scene.addLine(pos[i], 17, pos[i], 32., pen=self._t2_pen) if t2_coverage[i]: self._scene.addLine(pos[i], 17, pos[i], 32., pen=self._t2_pen)
if other_coverage[i]: self._scene.addLine(pos[i], 34, pos[i], 49., pen=self._other_pen) if other_coverage[i]: self._scene.addLine(pos[i], 34, pos[i], 49., pen=self._other_pen)
if other_coverage[i]: self._scene.addLine(pos[i], 34, pos[i], 49., pen=self._other_pen)
if labeled_coverage[i]: self._scene.addEllipse(pos[i]-2, 52, 4, 4, brush=self._user_brush)
def updatePositionLabel(self): def updatePositionLabel(self):
start = np.round(self._rangeStart * 100, 4) start = np.round(self._rangeStart * 100, 4)
@ -207,6 +212,7 @@ class DetectionTimeline(QWidget):
def fit_scene_to_view(self): def fit_scene_to_view(self):
"""Scale the image to fit the QGraphicsView.""" """Scale the image to fit the QGraphicsView."""
logging.debug("Timeline: fit scene to view")
self._view.fitInView(self._scene.sceneRect(), Qt.KeepAspectRatio) self._view.fitInView(self._scene.sceneRect(), Qt.KeepAspectRatio)
def resizeEvent(self, event): def resizeEvent(self, event):
@ -289,6 +295,11 @@ def main():
view.setWindowPos(0.0) view.setWindowPos(0.0)
print(view.windowBounds()) print(view.windowBounds())
def as_dict(df):
d = {c: df[c].values for c in df.columns}
d["index"] = df.index.values
return d
import pickle import pickle
import numpy as np import numpy as np
from PySide6.QtWidgets import QApplication, QPushButton, QHBoxLayout from PySide6.QtWidgets import QApplication, QPushButton, QHBoxLayout
@ -298,13 +309,18 @@ def main():
datafile = PACKAGE_ROOT / "data/merged_small.pkl" datafile = PACKAGE_ROOT / "data/merged_small.pkl"
with open(datafile, "rb") as f: with open(datafile, "rb") as f:
df = pickle.load(f) df = pickle.load(f)
data = TrackingData()
data.setData(as_dict(df))
data.setUserSelection(np.arange(0,100, 1))
data.setAssignmentStatus(True)
start_x = 0.1 start_x = 0.1
app = QApplication([]) app = QApplication([])
window = QWidget() window = QWidget()
window.setMinimumSize(200, 75) window.setMinimumSize(200, 75)
view = DetectionTimeline(df) view = DetectionTimeline()
view.setData(data)
fwdBtn = QPushButton(">>") fwdBtn = QPushButton(">>")
fwdBtn.clicked.connect(lambda: fwd(0.5)) fwdBtn.clicked.connect(lambda: fwd(0.5))
zeroBtn = QPushButton("0->|") zeroBtn = QPushButton("0->|")

View File

@ -3,12 +3,12 @@ import numpy as np
from PySide6.QtWidgets import QWidget, QVBoxLayout, QSizePolicy, QGraphicsView, QGraphicsScene, QGraphicsEllipseItem, QGraphicsRectItem from PySide6.QtWidgets import QWidget, QVBoxLayout, QSizePolicy, QGraphicsView, QGraphicsScene, QGraphicsEllipseItem, QGraphicsRectItem
from PySide6.QtCore import Qt, QPointF, QRectF, QPointF from PySide6.QtCore import Qt, QPointF, QRectF, QPointF
from PySide6.QtGui import QPixmap, QBrush, QColor, QImage from PySide6.QtGui import QPixmap, QBrush, QColor, QImage, QPen
from fixtracks.info import PACKAGE_ROOT from fixtracks.info import PACKAGE_ROOT
from fixtracks.utils.signals import DetectionSignals, DetectionViewSignals, DetectionSceneSignals from fixtracks.utils.signals import DetectionSignals, DetectionViewSignals, DetectionSceneSignals
from ..utils.enums import DetectionData from fixtracks.utils.enums import DetectionData, Tracks
from fixtracks.utils.trackingdata import TrackingData
class Detection(QGraphicsEllipseItem): class Detection(QGraphicsEllipseItem):
signals = DetectionSignals() signals = DetectionSignals()
@ -79,6 +79,7 @@ class DetectionView(QWidget):
def __init__(self, parent=None): def __init__(self, parent=None):
super().__init__(parent) super().__init__(parent)
self._img = None self._img = None
self._data = None
self._pixmapitem = None self._pixmapitem = None
self._scene = DetectionScene() self._scene = DetectionScene()
# self.setRenderHint(QGraphicsView.RenderFlag.Ren Antialiasing) # self.setRenderHint(QGraphicsView.RenderFlag.Ren Antialiasing)
@ -90,7 +91,6 @@ class DetectionView(QWidget):
self._minZoom = 0.1 self._minZoom = 0.1
self._maxZoom = 10 self._maxZoom = 10
self._currentZoom = 1.0 self._currentZoom = 1.0
lyt = QVBoxLayout() lyt = QVBoxLayout()
lyt.addWidget(self._view) lyt.addWidget(self._view)
self.setLayout(lyt) self.setLayout(lyt)
@ -116,6 +116,9 @@ class DetectionView(QWidget):
self._view.setScene(self._scene) self._view.setScene(self._scene)
self._view.fitInView(self._scene.sceneRect(), aspectRadioMode=Qt.AspectRatioMode.KeepAspectRatio) self._view.fitInView(self._scene.sceneRect(), aspectRadioMode=Qt.AspectRatioMode.KeepAspectRatio)
def setData(self, data:TrackingData):
self._data = data
def clearDetections(self): def clearDetections(self):
items = self._scene.items() items = self._scene.items()
if items is not None: if items is not None:
@ -124,23 +127,34 @@ class DetectionView(QWidget):
self._scene.removeItem(it) self._scene.removeItem(it)
del it del it
def addDetections(self, coordinates:np.array, track_ids:np.array, detection_ids:np.array, frames: np.array, def updateDetections(self, keypoint=-1):
keypoint:int, brush:QBrush): self.clearDetections()
if self._data is None:
return
frames = self._data.selectedData("frame")
tracks = self._data.selectedData("track")
coordinates = self._data.coordinates(selection=True)
centercoordinates = self._data.centerOfGravity(selection=True)
userlabeled = self._data.selectedData("userlabeled")
indices = self._data.selectedData("index")
image_rect = self._pixmapitem.boundingRect() if self._pixmapitem is not None else QRectF(0,0,0,0) image_rect = self._pixmapitem.boundingRect() if self._pixmapitem is not None else QRectF(0,0,0,0)
num_detections = coordinates.shape[0]
for i in range(num_detections): for i, idx in enumerate(indices):
t = tracks[i]
c = Tracks.fromValue(t).toColor()
if keypoint >= 0:
x = coordinates[i, keypoint, 0] x = coordinates[i, keypoint, 0]
y = coordinates[i, keypoint, 1] y = coordinates[i, keypoint, 1]
c = brush.color() else:
c.setAlpha(int(i * 255 / num_detections)) x = centercoordinates[i, 0]
brush.setColor(c) y = centercoordinates[i, 1]
item = Detection(image_rect.left() + x, image_rect.top() + y, 20, 20, brush=brush)
item.setData(DetectionData.TRACK_ID.value, track_ids[i]) item = Detection(image_rect.left() + x, image_rect.top() + y, 20, 20, brush=QBrush(c))
item.setData(DetectionData.ID.value, detection_ids[i]) item.setData(DetectionData.TRACK_ID.value, tracks[i])
item.setData(DetectionData.ID.value, idx)
item.setData(DetectionData.COORDINATES.value, coordinates[i, :, :]) item.setData(DetectionData.COORDINATES.value, coordinates[i, :, :])
item.setData(DetectionData.FRAME.value, frames[i]) item.setData(DetectionData.FRAME.value, frames[i])
item = self._scene.addItem(item) item = self._scene.addItem(item)
# logging.debug("DetectionView: Number of items in scene: %i", len(self._scene.items()))
def fit_image_to_view(self): def fit_image_to_view(self):
"""Scale the image to fit the QGraphicsView.""" """Scale the image to fit the QGraphicsView."""
@ -159,7 +173,7 @@ class DetectionView(QWidget):
def main(): def main():
def items_selected(items): def items_selected(items):
print("items selected") print("items selected")
# FIXME The following code will no longer work...
import pickle import pickle
import numpy as np import numpy as np
from IPython import embed from IPython import embed

View File

@ -117,6 +117,7 @@ class SelectionControls(QWidget):
deleteBtn.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding) deleteBtn.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding)
deleteBtn.setStyleSheet(pushBtnStyle("red")) deleteBtn.setStyleSheet(pushBtnStyle("red"))
deleteBtn.setToolTip(f"DANGERZONE! Delete current selection of detections!") deleteBtn.setToolTip(f"DANGERZONE! Delete current selection of detections!")
deleteBtn.setEnabled(False)
deleteBtn.clicked.connect(self.on_Delete) deleteBtn.clicked.connect(self.on_Delete)
revertBtn = QPushButton("revert assignments") revertBtn = QPushButton("revert assignments")

View File

@ -2,8 +2,8 @@ import logging
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from PySide6.QtCore import Qt, QThreadPool, Signal, QSize, QObject from PySide6.QtCore import Qt, QThreadPool, Signal
from PySide6.QtGui import QImage, QBrush, QColor, QFont from PySide6.QtGui import QImage, QBrush, QColor
from PySide6.QtWidgets import QWidget, QVBoxLayout, QHBoxLayout, QLabel, QPushButton, QSizePolicy, QComboBox from PySide6.QtWidgets import QWidget, QVBoxLayout, QHBoxLayout, QLabel, QPushButton, QSizePolicy, QComboBox
from PySide6.QtWidgets import QSpinBox, QSpacerItem, QProgressBar, QSplitter, QGridLayout, QFileDialog, QGridLayout from PySide6.QtWidgets import QSpinBox, QSpacerItem, QProgressBar, QSplitter, QGridLayout, QFileDialog, QGridLayout
@ -28,15 +28,11 @@ class FixTracks(QWidget):
self._threadpool = QThreadPool() self._threadpool = QThreadPool()
self._reader = None self._reader = None
self._image = None self._image = None
self._clear_detections = True
self._currentWindowPos = 0 # in frames self._currentWindowPos = 0 # in frames
self._currentWindowWidth = 0 # in frames self._currentWindowWidth = 0 # in frames
self._maxframes = 0 self._maxframes = 0
self._data = TrackingData() self._data = TrackingData()
self._brushes = {"assigned_left": QBrush(QColor.fromString("orange")),
"assigned_right": QBrush(QColor.fromString("green")),
"unassigned": QBrush(QColor.fromString("red"))
}
self._detectionView = DetectionView() self._detectionView = DetectionView()
self._detectionView.signals.itemsSelected.connect(self.on_detectionsSelected) self._detectionView.signals.itemsSelected.connect(self.on_detectionsSelected)
self._skeleton = SkeletonWidget() self._skeleton = SkeletonWidget()
@ -60,7 +56,7 @@ class FixTracks(QWidget):
self._windowspinner.setSingleStep(50) self._windowspinner.setSingleStep(50)
self._windowspinner.setValue(500) self._windowspinner.setValue(500)
self._windowspinner.valueChanged.connect(self.on_windowSizeChanged) self._windowspinner.valueChanged.connect(self.on_windowSizeChanged)
# self._timeline.setWindowWidth(0.01)
self._keypointcombo = QComboBox() self._keypointcombo = QComboBox()
self._keypointcombo.currentIndexChanged.connect(self.on_keypointSelected) self._keypointcombo.currentIndexChanged.connect(self.on_keypointSelected)
@ -143,7 +139,7 @@ class FixTracks(QWidget):
def on_autoClassify(self, tracks): def on_autoClassify(self, tracks):
self._data.setSelectionRange("index", 0, self._data.numDetections) self._data.setSelectionRange("index", 0, self._data.numDetections)
self._data.assignTracks(tracks) self._data.assignTracks(tracks)
self._timeline.setDetectionData(self._data.data) self._timeline.update()
self.update() self.update()
def on_dataSelection(self): def on_dataSelection(self):
@ -163,40 +159,15 @@ class FixTracks(QWidget):
self._detectionView.setImage(img) self._detectionView.setImage(img)
def update(self): def update(self):
def update_detectionView(df, name):
if len(df) == 0:
return
keypoint = self._keypointcombo.currentIndex()
coords = np.stack(df["keypoints"].values).astype(np.float32)[:, :,:]
tracks = df["track"].values.astype(int)
ids = df.index.values.astype(int)
frames = df["frame"].values.astype(int)
self._detectionView.addDetections(coords, tracks, ids, frames, keypoint, self._brushes[name])
start_frame = self._currentWindowPos start_frame = self._currentWindowPos
stop_frame = start_frame + self._currentWindowWidth stop_frame = start_frame + self._currentWindowWidth
self._controls_widget.setWindow(start_frame, stop_frame)
logging.debug("Tracks:update: Updating View for detection range %i, %i frames", start_frame, stop_frame) logging.debug("Tracks:update: Updating View for detection range %i, %i frames", start_frame, stop_frame)
self._data.setSelectionRange("frame", start_frame, stop_frame) self._data.setSelectionRange("frame", start_frame, stop_frame)
frames = self._data.selectedData("frame")
tracks = self._data.selectedData("track") self._controls_widget.setWindow(start_frame, stop_frame)
keypoints = self._data.selectedData("keypoints") kp = self._keypointcombo.currentText().lower()
index = self._data.selectedData("index") kpi = -1 if "center" in kp else int(kp)
self._detectionView.updateDetections(kpi)
df = pd.DataFrame({"frame": frames,
"track": tracks,
"keypoints": keypoints},
index=index)
assigned_left = df[(df.track == self.trackone_id)]
assigned_right = df[(df.track == self.tracktwo_id)]
unassigned = df[(df.track != self.trackone_id) & (df.track != self.tracktwo_id)]
if self._clear_detections:
self._detectionView.clearDetections()
update_detectionView(unassigned, "unassigned")
update_detectionView(assigned_left, "assigned_left")
update_detectionView(assigned_right, "assigned_right")
self._classifier.setData(self._data)
@property @property
def fileList(self): def fileList(self):
@ -223,6 +194,7 @@ class FixTracks(QWidget):
def populateKeypointCombo(self, num_keypoints): def populateKeypointCombo(self, num_keypoints):
self._keypointcombo.clear() self._keypointcombo.clear()
self._keypointcombo.addItem("Center")
for i in range(num_keypoints): for i in range(num_keypoints):
self._keypointcombo.addItem(str(i)) self._keypointcombo.addItem(str(i))
self._keypointcombo.setCurrentIndex(0) self._keypointcombo.setCurrentIndex(0)
@ -238,17 +210,13 @@ class FixTracks(QWidget):
self._currentWindowWidth = self._windowspinner.value() self._currentWindowWidth = self._windowspinner.value()
self._maxframes = self._data.max("frame") self._maxframes = self._data.max("frame")
self.populateKeypointCombo(self._data.numKeypoints()) self.populateKeypointCombo(self._data.numKeypoints())
self._timeline.setDetectionData(self._data.data) self._timeline.setData(self._data)
self._timeline.setWindow(self._currentWindowPos / self._maxframes, self._timeline.setWindow(self._currentWindowPos / self._maxframes,
self._currentWindowWidth / self._maxframes) self._currentWindowWidth / self._maxframes)
coordinates = self._data.coordinates() self._detectionView.setData(self._data)
positions = self._data.centerOfGravity() self._classifier.setData(self._data)
tracks = self._data["track"]
frames = self._data["frame"]
self._classifier.size_classifier.setCoordinates(coordinates)
self._classifier.consistency_tracker.setData(self._data)
self.update() self.update()
logging.info("Finished loading data: %i frames, %i detections", self._maxframes, len(positions)) logging.info("Finished loading data: %i frames", self._maxframes)
def on_keypointSelected(self): def on_keypointSelected(self):
self.update() self.update()
@ -280,38 +248,43 @@ class FixTracks(QWidget):
def on_assignOne(self): def on_assignOne(self):
logging.debug("Assigning user selection to track One") logging.debug("Assigning user selection to track One")
self._data.assignUserSelection(self.trackone_id) self._data.assignUserSelection(self.trackone_id)
self._timeline.setDetectionData(self._data.data) self._timeline.update()
self.update() self.update()
def on_assignTwo(self): def on_assignTwo(self):
logging.debug("Assigning user selection to track Two") logging.debug("Assigning user selection to track Two")
self._data.assignUserSelection(self.tracktwo_id) self._data.assignUserSelection(self.tracktwo_id)
self._timeline.setDetectionData(self._data.data) self._timeline.update()
self.update() self.update()
def on_assignOther(self): def on_assignOther(self):
logging.debug("Assigning user selection to track Other") logging.debug("Assigning user selection to track Other")
self._data.assignUserSelection(self.trackother_id, False) self._data.assignUserSelection(self.trackother_id, False)
self._timeline.setDetectionData(self._data.data) self._timeline.update()
self.update() self.update()
def on_setUserFlag(self): def on_setUserFlag(self):
self._data.setAssignmentStatus(True) self._data.setAssignmentStatus(True)
self._timeline.update()
self.update() self.update()
def on_unsetUserFlag(self): def on_unsetUserFlag(self):
logging.debug("Tracks:unsetUserFlag") logging.debug("Tracks:unsetUserFlag")
self._data.setAssignmentStatus(False) self._data.setAssignmentStatus(False)
self._timeline.update()
self.update() self.update()
def on_revertUserFlags(self): def on_revertUserFlags(self):
logging.debug("Tracks:revert ALL UserFlags") logging.debug("Tracks:revert ALL UserFlags and track assignments")
self._data.revertAssignmentStatus() self._data.revertAssignmentStatus()
self._data.revertTrackAssignments()
self._timeline.update()
self.update() self.update()
def on_deleteDetection(self): def on_deleteDetection(self):
logging.debug("Tracks:delete detections") logging.warning("Tracks:delete detections is currently not supported!")
# self._data.deleteDetections() # self._data.deleteDetections()
self._timeline.update()
self.update() self.update()
def on_windowChanged(self): def on_windowChanged(self):
@ -354,7 +327,6 @@ class FixTracks(QWidget):
self.update() self.update()
def moveWindow(self, stepsize): def moveWindow(self, stepsize):
self._clear_detections = True
step = np.round(stepsize * (self._currentWindowWidth)) step = np.round(stepsize * (self._currentWindowWidth))
new_start_frame = self._currentWindowPos + step new_start_frame = self._currentWindowPos + step
self._timeline.setWindowPos(new_start_frame / self._maxframes) self._timeline.setWindowPos(new_start_frame / self._maxframes)

65
main.py
View File

@ -1,65 +0,0 @@
"""
pyside6-rcc resources.qrc -o resources.py
"""
import sys
import platform
import logging
from PySide6.QtWidgets import QApplication
from PySide6.QtCore import QSettings
from PySide6.QtGui import QIcon, QPalette
from fixtracks import fixtracks, info
logging.basicConfig(level=logging.INFO, force=True)
def is_dark_mode(app: QApplication) -> bool:
palette = app.palette()
# Check the brightness of the window text and base colors
text_color = palette.color(QPalette.ColorRole.WindowText)
base_color = palette.color(QPalette.ColorRole.Base)
# Calculate brightness (0 for dark, 255 for bright)
def brightness(color):
return (color.red() * 299 + color.green() * 587 + color.blue() * 114) // 1000
return brightness(base_color) < brightness(text_color)
# import resources # needs to be imported somewhere in the project to be picked up by qt
if platform.system() == "Windows":
# from PySide6.QtWinExtras import QtWin
myappid = f"{info.organization_name}.{info.application_version}"
# QtWin.setCurrentProcessExplicitAppUserModelID(myappid)
settings = QSettings()
width = int(settings.value("app/width", 1024))
height = int(settings.value("app/height", 768))
x = int(settings.value("app/pos_x", 100))
y = int(settings.value("app/pos_y", 100))
app = QApplication(sys.argv)
app.setApplicationName(info.application_name)
app.setApplicationVersion(str(info.application_version))
app.setOrganizationDomain(info.organization_name)
# if platform.system() == 'Linux':
# icn = QIcon(":/icons/app_icon")
# app.setWindowIcon(icn)
# Create a Qt widget, which will be our window.
window = fixtracks.MainWindow(is_dark_mode(app))
window.setGeometry(100, 100, 1024, 768)
window.setWindowTitle("FixTracks")
window.setMinimumWidth(1024)
window.setMinimumHeight(768)
window.resize(width, height)
window.move(x, y)
window.show()
# Start the event loop.
app.exec()
pos = window.pos()
settings.setValue("app/width", window.width())
settings.setValue("app/height", window.height())
settings.setValue("app/pos_x", pos.x())
settings.setValue("app/pos_y", pos.y())

View File

@ -2,7 +2,7 @@
name = "fixtracks" name = "fixtracks"
version = "0.1.0" version = "0.1.0"
description = "A project to fix track metadata" description = "A project to fix track metadata"
authors = ["Your Name <your.email@example.com>"] authors = ["Your Name <jan.grewe@uni-tuebingen.de>"]
license = "MIT" license = "MIT"
[tool.poetry.dependencies] [tool.poetry.dependencies]