Compare commits
14 Commits
2c62ee28a9
...
64e75ba4b0
Author | SHA1 | Date | |
---|---|---|---|
64e75ba4b0 | |||
0f1b1d6252 | |||
2ff1af7c36 | |||
e33528392c | |||
af5dbc7dfc | |||
f09c78adb5 | |||
2e918866e1 | |||
367cbb021f | |||
dc4833e825 | |||
c231b52876 | |||
4762921ccd | |||
6f4ac1136b | |||
98900ff480 | |||
3206950f5e |
80
FixTracks.py
Normal file
80
FixTracks.py
Normal file
@ -0,0 +1,80 @@
|
||||
"""
|
||||
pyside6-rcc resources.qrc -o resources.py
|
||||
|
||||
"""
|
||||
import sys
|
||||
import logging
|
||||
import argparse
|
||||
import platform
|
||||
|
||||
from PySide6.QtWidgets import QApplication
|
||||
from PySide6.QtCore import QSettings
|
||||
from PySide6.QtGui import QIcon, QPalette
|
||||
|
||||
from fixtracks import info, mainwindow
|
||||
|
||||
logging.basicConfig(level=logging.INFO, force=True)
|
||||
|
||||
|
||||
def is_dark_mode(app: QApplication) -> bool:
|
||||
palette = app.palette()
|
||||
# Check the brightness of the window text and base colors
|
||||
text_color = palette.color(QPalette.ColorRole.WindowText)
|
||||
base_color = palette.color(QPalette.ColorRole.Base)
|
||||
|
||||
# Calculate brightness (0 for dark, 255 for bright)
|
||||
def brightness(color):
|
||||
return (color.red() * 299 + color.green() * 587 + color.blue() * 114) // 1000
|
||||
|
||||
return brightness(base_color) < brightness(text_color)
|
||||
|
||||
def set_logging(loglevel):
|
||||
logging.basicConfig(level=loglevel, force=True)
|
||||
|
||||
def main(args):
|
||||
set_logging(logging.DEBUG)
|
||||
if platform.system() == "Windows":
|
||||
# from PySide6.QtWinExtras import QtWin
|
||||
myappid = f"{info.organization_name}.{info.application_version}"
|
||||
# QtWin.setCurrentProcessExplicitAppUserModelID(myappid)
|
||||
settings = QSettings()
|
||||
width = int(settings.value("app/width", 1024))
|
||||
height = int(settings.value("app/height", 768))
|
||||
x = int(settings.value("app/pos_x", 100))
|
||||
y = int(settings.value("app/pos_y", 100))
|
||||
|
||||
app = QApplication(sys.argv)
|
||||
app.setApplicationName(info.application_name)
|
||||
app.setApplicationVersion(str(info.application_version))
|
||||
app.setOrganizationDomain(info.organization_name)
|
||||
|
||||
# if platform.system() == 'Linux':
|
||||
# icn = QIcon(":/icons/app_icon")
|
||||
# app.setWindowIcon(icn)
|
||||
# Create a Qt widget, which will be our window.
|
||||
window = mainwindow.MainWindow(is_dark_mode(app))
|
||||
window.setGeometry(100, 100, 1024, 768)
|
||||
window.setWindowTitle("FixTracks")
|
||||
window.setMinimumWidth(1024)
|
||||
window.setMinimumHeight(768)
|
||||
window.resize(width, height)
|
||||
window.move(x, y)
|
||||
window.show()
|
||||
|
||||
# Start the event loop.
|
||||
app.exec()
|
||||
pos = window.pos()
|
||||
settings.setValue("app/width", window.width())
|
||||
settings.setValue("app/height", window.height())
|
||||
settings.setValue("app/pos_x", pos.x())
|
||||
settings.setValue("app/pos_y", pos.y())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
levels = {"critical": logging.CRITICAL, "error": logging.ERROR, "warning":logging.WARNING, "info":logging.INFO, "debug":logging.DEBUG}
|
||||
parser = argparse.ArgumentParser(description="FixTracks. Tools for fixing animal tracking")
|
||||
parser.add_argument("-ll", "--loglevel", type=str, default="INFO", help=f"The log level that should be used. Valid levels are {[str(k) for k in levels.keys()]}")
|
||||
args = parser.parse_args()
|
||||
args.loglevel = levels[args.loglevel if args.loglevel.lower() in levels else "info"]
|
||||
|
||||
main(args)
|
@ -144,8 +144,8 @@ class MainWindow(QMainWindow):
|
||||
about.show()
|
||||
|
||||
def on_help(self, s):
|
||||
help = HelpDialog(self)
|
||||
help.show()
|
||||
help_dlg = HelpDialog(self)
|
||||
help_dlg.show()
|
||||
|
||||
# @Slot(None)
|
||||
def exit_request(self):
|
@ -1,7 +1,29 @@
|
||||
from enum import Enum
|
||||
|
||||
from PySide6.QtGui import QColor
|
||||
|
||||
class DetectionData(Enum):
|
||||
ID = 0
|
||||
FRAME = 1
|
||||
COORDINATES = 2
|
||||
TRACK_ID = 3
|
||||
TRACK_ID = 3
|
||||
|
||||
class Tracks(Enum):
|
||||
TRACKONE = 1
|
||||
TRACKTWO = 2
|
||||
UNASSIGNED = -1
|
||||
|
||||
def toColor(self):
|
||||
track_colors = {
|
||||
Tracks.TRACKONE: QColor.fromString("orange"),
|
||||
Tracks.TRACKTWO: QColor.fromString("green"),
|
||||
Tracks.UNASSIGNED: QColor.fromString("red")
|
||||
}
|
||||
return track_colors.get(self, QColor(128, 128, 128)) # Default to black if not found
|
||||
|
||||
@staticmethod
|
||||
def fromValue(value):
|
||||
for track in Tracks:
|
||||
if track.value == value:
|
||||
return track
|
||||
return Tracks.UNASSIGNED
|
||||
|
@ -5,6 +5,7 @@ import pandas as pd
|
||||
|
||||
from PySide6.QtCore import QObject
|
||||
|
||||
|
||||
class TrackingData(QObject):
|
||||
def __init__(self, parent=None):
|
||||
super().__init__(parent)
|
||||
@ -58,10 +59,15 @@ class TrackingData(QObject):
|
||||
self._start = start
|
||||
self._stop = stop
|
||||
self._selection_column = col
|
||||
self._indices = np.where((self._data[col] >= self._start) & (self._data[col] < self._stop))[0]
|
||||
col_indices = np.where((self._data[col] >= self._start) & (self._data[col] < self._stop))[0]
|
||||
self._indices = self["index"][col_indices]
|
||||
if len(col_indices) < 1:
|
||||
logging.warning("TrackingData: Selection range is empty!")
|
||||
|
||||
def selectedData(self, col):
|
||||
return self._data[col][self._indices]
|
||||
def selectedData(self, col:str):
|
||||
if col not in self.columns:
|
||||
logging.error("TrackingData:selectedData: Invalid column name! %s", col)
|
||||
return self[col][self._indices]
|
||||
|
||||
def setUserSelection(self, ids):
|
||||
"""
|
||||
@ -84,14 +90,20 @@ class TrackingData(QObject):
|
||||
userFlag : bool
|
||||
Should the "userlabeled" state of the detections be set to True or False?
|
||||
"""
|
||||
self._data["track"][self._user_selections] = track_id
|
||||
self["track"][self._user_selections] = track_id
|
||||
self.setAssignmentStatus(userFlag)
|
||||
|
||||
def setAssignmentStatus(self, isTrue: bool):
|
||||
self._data["userlabeled"][self._user_selections] = isTrue
|
||||
logging.debug("TrackingData:Re-setting assignment status of user selected data to %s", str(isTrue))
|
||||
self["userlabeled"][self._user_selections] = isTrue
|
||||
|
||||
def revertAssignmentStatus(self):
|
||||
self._data["userlabeled"][:] = False
|
||||
logging.debug("TrackingData:Un-setting assignment status of all data!")
|
||||
self["userlabeled"][:] = False
|
||||
|
||||
def revertTrackAssignments(self):
|
||||
logging.debug("TrackingData: Reverting all track assignments!")
|
||||
self["track"][:] = -1
|
||||
|
||||
def deleteDetections(self):
|
||||
# from IPython import embed
|
||||
@ -102,21 +114,21 @@ class TrackingData(QObject):
|
||||
# embed()
|
||||
pass
|
||||
|
||||
def assignTracks(self, tracks):
|
||||
"""assignTracks _summary_
|
||||
def assignTracks(self, tracks:np.ndarray):
|
||||
"""assigns the given tracks to the user-selected detections. If the sizes of
|
||||
provided tracks and the user selection do not match and error is logged and the tracks are not set.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tracks : _type_
|
||||
_description_
|
||||
tracks : np.ndarray
|
||||
The track information.
|
||||
|
||||
Returns
|
||||
-------
|
||||
_type_
|
||||
_description_
|
||||
None
|
||||
"""
|
||||
if len(tracks) != self.numDetections:
|
||||
logging.error("DataController: Size of passed tracks does not match data!")
|
||||
logging.error("Trackingdata: Size of passed tracks does not match data!")
|
||||
return
|
||||
self._data["track"] = tracks
|
||||
|
||||
@ -142,11 +154,13 @@ class TrackingData(QObject):
|
||||
and M is number of keypoints
|
||||
"""
|
||||
if selection:
|
||||
return np.stack(self._data["keypoints"][self._start:self._stop, :, :]).astype(np.float32)
|
||||
else:
|
||||
return np.stack(self._data["keypoints"]).astype(np.float32)
|
||||
if len(self._indices) < 1:
|
||||
logging.info("TrackingData.coordinates returns empty array, not detections in range!")
|
||||
return np.ndarray([])
|
||||
return np.stack(self._data["keypoints"][self._indices]).astype(np.float32)
|
||||
return np.stack(self._data["keypoints"]).astype(np.float32)
|
||||
|
||||
def keypointScores(self):
|
||||
def keypointScores(self, selection=False):
|
||||
"""
|
||||
Returns the keypoint scores as a NumPy array of type float32.
|
||||
|
||||
@ -155,10 +169,15 @@ class TrackingData(QObject):
|
||||
numpy.ndarray
|
||||
A NumPy array of type float32 containing the keypoint scores of the shape (N, M)
|
||||
with N the number of detections and M the number of keypoints.
|
||||
"""
|
||||
"""
|
||||
if selection:
|
||||
if len(self._indices) < 1:
|
||||
logging.info("TrackingData.scores returns empty array, not detections in range!")
|
||||
return np.ndarray([])
|
||||
return np.stack(self._data["keypoint_score"][self._indices]).astype(np.float32)
|
||||
return np.stack(self._data["keypoint_score"]).astype(np.float32)
|
||||
|
||||
def centerOfGravity(self, threshold=0.8):
|
||||
def centerOfGravity(self, selection=False, threshold=0.8, nodes=[0,1,2]):
|
||||
"""
|
||||
Calculate the center of gravity of keypoints weighted by their scores. Ignores keypoints that have a score
|
||||
less than threshold.
|
||||
@ -166,16 +185,19 @@ class TrackingData(QObject):
|
||||
Parameters:
|
||||
-----------
|
||||
threshold: float
|
||||
keypoints with a score less than threshold are ignored
|
||||
nodes with a score less than threshold are ignored
|
||||
nodes: list
|
||||
nodes/keypoints to consider for estimation. Defaults to [0,1,2]
|
||||
|
||||
Returns:
|
||||
--------
|
||||
np.ndarray:
|
||||
A NumPy array of shape (N, 2) containing the center of gravity for each detection.
|
||||
"""
|
||||
scores = self.keypointScores()
|
||||
scores = self.keypointScores(selection)
|
||||
scores[scores < threshold] = 0.0
|
||||
weighted_coords = self.coordinates() * scores[:, :, np.newaxis]
|
||||
scores[:, np.setdiff1d(np.arange(scores.shape[1]), nodes)] = 0.0
|
||||
weighted_coords = self.coordinates(selection=selection) * scores[:, :, np.newaxis]
|
||||
sum_scores = np.sum(scores, axis=1, keepdims=True)
|
||||
center_of_gravity = np.sum(weighted_coords, axis=1) / sum_scores
|
||||
return center_of_gravity
|
||||
@ -279,19 +301,19 @@ def main():
|
||||
frames = data["frame"]
|
||||
tracks = data["track"]
|
||||
bendedness = data.bendedness()
|
||||
positions = data.coordinates()[[160388, 160389]]
|
||||
# positions = data.coordinates()[[160388, 160389]]
|
||||
|
||||
embed()
|
||||
tracks = data["track"]
|
||||
cogs = all_cogs[tracks==1]
|
||||
all_dists = neighborDistances(cogs, 2, False)
|
||||
plt.hist(all_dists[1:, 0], bins=1000)
|
||||
print(np.percentile(all_dists[1:, 0], 99))
|
||||
print(np.percentile(all_dists[1:, 0], 1))
|
||||
plt.gca().set_xscale("log")
|
||||
plt.gca().set_yscale("log")
|
||||
# plt.hist(all_dists[1:, 0], bins=1000)
|
||||
# print(np.percentile(all_dists[1:, 0], 99))
|
||||
# print(np.percentile(all_dists[1:, 0], 1))
|
||||
# plt.gca().set_xscale("log")
|
||||
# plt.gca().set_yscale("log")
|
||||
# plt.hist(all_dists[1:, 1], bins=100)
|
||||
plt.show()
|
||||
# plt.show()
|
||||
# def compute_neighbor_distances(cogs, window=10):
|
||||
# distances = []
|
||||
# for i in range(len(cogs)):
|
||||
@ -303,9 +325,6 @@ def main():
|
||||
# return distances
|
||||
# print("estimating neighorhood distances")
|
||||
# neighbor_distances = compute_neighbor_distances(cogs)
|
||||
embed()
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -34,6 +34,9 @@ class ConsitencyDataLoader(QRunnable):
|
||||
|
||||
@Slot()
|
||||
def run(self):
|
||||
if self.data is None:
|
||||
logging.error("ConsistencyTracker.DataLoader failed. No Data!")
|
||||
return
|
||||
self.positions = self.data.centerOfGravity()
|
||||
self.orientations = self.data.orientation()
|
||||
self.lengths = self.data.animalLength()
|
||||
@ -464,9 +467,6 @@ class ConsistencyClassifier(QWidget):
|
||||
self.setEnabled(False)
|
||||
self._progressbar.setRange(0,0)
|
||||
self._data = data
|
||||
self._dataworker = ConsitencyDataLoader(self._data)
|
||||
self._dataworker.signals.stopped.connect(self.data_processed)
|
||||
self.threadpool.start(self._dataworker)
|
||||
|
||||
@Slot()
|
||||
def data_processed(self):
|
||||
@ -482,6 +482,7 @@ class ConsistencyClassifier(QWidget):
|
||||
self._frames = self._dataworker.frames
|
||||
self._tracks = self._dataworker.tracks
|
||||
self._maxframes = np.max(self._frames)
|
||||
# FIXME the following line causes an error when there are no detections in the range
|
||||
min_frame = max([self._frames[self._tracks == 1][0], self._frames[self._tracks == 2][0]]) + 1
|
||||
self._maxframeslabel.setText(str(self._maxframes))
|
||||
self._startframe_spinner.setMinimum(min_frame)
|
||||
@ -525,7 +526,9 @@ class ConsistencyClassifier(QWidget):
|
||||
self.start()
|
||||
|
||||
def refresh(self):
|
||||
self.setData(self._data)
|
||||
self._dataworker = ConsitencyDataLoader(self._data)
|
||||
self._dataworker.signals.stopped.connect(self.data_processed)
|
||||
self.threadpool.start(self._dataworker)
|
||||
|
||||
def worker_progress(self, progress, processed, errors):
|
||||
self._progressbar.setValue(progress)
|
||||
@ -556,7 +559,8 @@ class ClassifierWidget(QTabWidget):
|
||||
self._consistency_tracker = ConsistencyClassifier()
|
||||
self.addTab(self._size_classifier, SizeClassifier.name)
|
||||
self.addTab(self._consistency_tracker, ConsistencyClassifier.name)
|
||||
self.tabBarClicked.connect(self.update)
|
||||
# self.tabBarClicked.connect(self.update)
|
||||
self.currentChanged.connect(self.tabChanged)
|
||||
self._size_classifier.apply.connect(self._on_applySizeClassifier)
|
||||
self._consistency_tracker.apply.connect(self._on_applyConsistencyTracker)
|
||||
|
||||
@ -576,12 +580,21 @@ class ClassifierWidget(QTabWidget):
|
||||
def consistency_tracker(self):
|
||||
return self._consistency_tracker
|
||||
|
||||
@Slot()
|
||||
def tabChanged(self):
|
||||
if isinstance(self.currentWidget(), ConsistencyClassifier):
|
||||
self.consistency_tracker.refresh()
|
||||
|
||||
@Slot()
|
||||
def update(self):
|
||||
self.consistency_tracker.setData(self._data)
|
||||
if isinstance(self.currentWidget(), ConsistencyClassifier):
|
||||
self.consistency_tracker.refresh()
|
||||
|
||||
def setData(self, data:TrackingData):
|
||||
self._data = data
|
||||
self.consistency_tracker.setData(data)
|
||||
coordinates = self._data.coordinates()
|
||||
self._size_classifier.setCoordinates(coordinates)
|
||||
|
||||
def as_dict(df):
|
||||
d = {c: df[c].values for c in df.columns}
|
||||
|
@ -1,13 +1,14 @@
|
||||
import logging
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from PySide6.QtCore import Qt
|
||||
from PySide6.QtWidgets import QWidget, QVBoxLayout, QSizePolicy, QLabel
|
||||
from PySide6.QtWidgets import QGraphicsView, QGraphicsScene, QGraphicsItem, QGraphicsRectItem, QGraphicsLineItem
|
||||
from PySide6.QtWidgets import QGraphicsView, QGraphicsScene, QGraphicsItem, QGraphicsRectItem, QGraphicsLineItem, QGraphicsEllipseItem
|
||||
from PySide6.QtCore import Qt, QRectF, QRectF
|
||||
from PySide6.QtGui import QBrush, QColor, QPen, QFont
|
||||
|
||||
from fixtracks.utils.signals import DetectionTimelineSignals
|
||||
from fixtracks.utils.trackingdata import TrackingData
|
||||
|
||||
|
||||
class Window(QGraphicsRectItem):
|
||||
@ -40,24 +41,23 @@ class Window(QGraphicsRectItem):
|
||||
self.signals.windowMoved.emit()
|
||||
|
||||
def setWindow(self, newx:float, newwidth:float):
|
||||
def setWindow(self, newx: float, newwidth: float):
|
||||
"""
|
||||
Update the window to the specified range.
|
||||
Parameters
|
||||
----------
|
||||
newx : float
|
||||
The new x-coordinate of the window.
|
||||
newwidth : float
|
||||
The new width of the window.
|
||||
Returns
|
||||
-------
|
||||
None
|
||||
"""
|
||||
|
||||
"""
|
||||
Update the window to the specified range.
|
||||
Parameters
|
||||
----------
|
||||
newx : float
|
||||
The new x-coordinate of the window.
|
||||
newwidth : float
|
||||
The new width of the window.
|
||||
Returns
|
||||
-------
|
||||
None
|
||||
"""
|
||||
logging.debug("timeline.window: update window to range %.5f to %.5f", newx, newwidth)
|
||||
self._width = newwidth
|
||||
r = self.rect()
|
||||
self.setRect(newx, r.y(), self._width, r.height())
|
||||
self.update()
|
||||
self.signals.windowMoved.emit()
|
||||
|
||||
def mouseMoveEvent(self, event):
|
||||
@ -86,40 +86,44 @@ class Window(QGraphicsRectItem):
|
||||
class DetectionTimeline(QWidget):
|
||||
signals = DetectionTimelineSignals()
|
||||
|
||||
def __init__(self, detectiondata=None, trackone_id=1, tracktwo_id=2, parent=None):
|
||||
def __init__(self, trackone_id=1, tracktwo_id=2, parent=None):
|
||||
super().__init__(parent)
|
||||
self._trackone = trackone_id
|
||||
self._tracktwo = tracktwo_id
|
||||
self._data = detectiondata
|
||||
self._data = None
|
||||
self._rangeStart = 0.0
|
||||
self._rangeStop = 0.005
|
||||
self._total_width = 2000
|
||||
self._stepCount = 300
|
||||
self._stepCount = 1000
|
||||
self._bg_brush = QBrush(QColor(20, 20, 20, 255))
|
||||
transparent_brush = QBrush(QColor(200, 200, 200, 64))
|
||||
self._white_pen = QPen(QColor.fromString("white"))
|
||||
self._white_pen.setWidth(0.1)
|
||||
self._t1_pen = QPen(QColor.fromString("orange"))
|
||||
self._t1_pen.setWidth(2)
|
||||
self._t1_pen.setWidth(1)
|
||||
self._t2_pen = QPen(QColor(0, 255, 0, 255))
|
||||
self._t2_pen.setWidth(2)
|
||||
self._t2_pen.setWidth(1)
|
||||
self._other_pen = QPen(QColor.fromString("red"))
|
||||
self._other_pen.setWidth(2)
|
||||
axis_pen = QPen(QColor.fromString("white"))
|
||||
axis_pen.setWidth(2)
|
||||
self._other_pen.setWidth(1)
|
||||
window_pen = QPen(QColor.fromString("white"))
|
||||
window_pen.setWidth(2)
|
||||
self._user_brush = QBrush(QColor.fromString("white"))
|
||||
user_pen = QPen(QColor.fromString("white"))
|
||||
user_pen.setWidth(2)
|
||||
|
||||
font = QFont()
|
||||
font.setPointSize(15)
|
||||
font.setBold(False)
|
||||
font.setBold(True)
|
||||
|
||||
self._window = Window(0, 0, 100, 60, axis_pen, transparent_brush)
|
||||
self._window = Window(0, 0, 100, 60, window_pen, transparent_brush)
|
||||
self._window.signals.windowMoved.connect(self.on_windowMoved)
|
||||
|
||||
self._scene = QGraphicsScene(QRectF(0, 0, self._total_width, 65.))
|
||||
self._scene = QGraphicsScene(QRectF(0, 0, self._total_width, 85.))
|
||||
self._scene.setBackgroundBrush(self._bg_brush)
|
||||
self._scene.addItem(self._window)
|
||||
|
||||
self._view = QGraphicsView()
|
||||
# self._view.setRenderHints(QPainter.RenderHint.Antialiasing | QPainter.RenderHint.SmoothPixmapTransform);
|
||||
# self._view.setRenderHints(QPainter.RenderHint.Antialiasing | QPainter.RenderHint.SmoothPixmapTransform)
|
||||
self._view.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding)
|
||||
self._view.setScene(self._scene)
|
||||
self._view.fitInView(self._scene.sceneRect(), aspectRadioMode=Qt.AspectRatioMode.KeepAspectRatio)
|
||||
@ -136,6 +140,10 @@ class DetectionTimeline(QWidget):
|
||||
other_label.setFont(font)
|
||||
other_label.setDefaultTextColor(self._other_pen.color())
|
||||
other_label.setPos(200, 50)
|
||||
user_label = self._scene.addText("user-labeled", font)
|
||||
user_label.setFont(font)
|
||||
user_label.setDefaultTextColor(user_pen.color())
|
||||
user_label.setPos(350, 50)
|
||||
|
||||
self._position_label = QLabel("")
|
||||
f = self._position_label.font()
|
||||
@ -146,51 +154,48 @@ class DetectionTimeline(QWidget):
|
||||
layout.addWidget(self._view)
|
||||
layout.addWidget(self._position_label, Qt.AlignmentFlag.AlignRight)
|
||||
self.setLayout(layout)
|
||||
|
||||
if self._data is not None:
|
||||
self.draw_coverage()
|
||||
# self.setMaximumHeight(100)
|
||||
# self.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Fixed)
|
||||
|
||||
def setDetectionData(self, data):
|
||||
self._data = data
|
||||
def clear(self):
|
||||
for i in self._scene.items():
|
||||
if isinstance(i, QGraphicsLineItem):
|
||||
if isinstance(i, (QGraphicsLineItem, QGraphicsEllipseItem)):
|
||||
self._scene.removeItem(i)
|
||||
|
||||
def setData(self, data:TrackingData):
|
||||
logging.debug("Timeline: setData!")
|
||||
self._data = data
|
||||
self.update()
|
||||
|
||||
def update(self):
|
||||
self.clear()
|
||||
self.draw_coverage()
|
||||
|
||||
def draw_coverage(self):
|
||||
# FIXME this must be disentangled. timeline should not have to deal with two different ways of data storage
|
||||
if isinstance(self._data, pd.DataFrame):
|
||||
maxframe = np.max(self._data.frame.values)
|
||||
|
||||
bins = np.linspace(0, maxframe, self._stepCount)
|
||||
pos = np.linspace(0, self._scene.width(), self._stepCount)
|
||||
track1_frames = self._data.frame.values[self._data.track == self._trackone]
|
||||
track2_frames = self._data.frame.values[self._data.track == self._tracktwo]
|
||||
other_frames = self._data.frame.values[(self._data.track != self._trackone) &
|
||||
(self._data.track != self._tracktwo)]
|
||||
elif isinstance(self._data, dict):
|
||||
logging.debug("Timeline: drawCoverage!")
|
||||
if isinstance(self._data, TrackingData):
|
||||
maxframe = np.max(self._data["frame"])
|
||||
bins = np.linspace(0, maxframe, self._stepCount)
|
||||
pos = np.linspace(0, self._scene.width(), self._stepCount)
|
||||
pos = np.linspace(0, self._scene.width(), self._stepCount) # of the vertical dashes is this correct?
|
||||
track1_frames = self._data["frame"][self._data["track"] == self._trackone]
|
||||
track2_frames = self._data["frame"][self._data["track"] == self._tracktwo]
|
||||
other_frames = self._data["frame"][(self._data["track"] != self._trackone) &
|
||||
(self._data["track"] != self._tracktwo)]
|
||||
userlabeled = self._data["frame"][self._data["userlabeled"]]
|
||||
else:
|
||||
print("Data is not trackingdata")
|
||||
return
|
||||
t1_coverage, _ = np.histogram(track1_frames, bins=bins)
|
||||
t1_coverage = t1_coverage > 0
|
||||
t2_coverage, _ = np.histogram(track2_frames, bins=bins)
|
||||
t2_coverage = t2_coverage > 0
|
||||
other_coverage, _ = np.histogram(other_frames, bins=bins)
|
||||
other_coverage = other_coverage > 0
|
||||
labeled_coverage, _ = np.histogram(userlabeled, bins=bins)
|
||||
|
||||
for i in range(len(t1_coverage)-1):
|
||||
for i in range(len(bins)-1):
|
||||
if t1_coverage[i]: self._scene.addLine(pos[i], 0, pos[i], 15., pen=self._t1_pen)
|
||||
if t2_coverage[i]: self._scene.addLine(pos[i], 17, pos[i], 32., pen=self._t2_pen)
|
||||
if other_coverage[i]: self._scene.addLine(pos[i], 34, pos[i], 49., pen=self._other_pen)
|
||||
if other_coverage[i]: self._scene.addLine(pos[i], 34, pos[i], 49., pen=self._other_pen)
|
||||
if labeled_coverage[i]: self._scene.addEllipse(pos[i]-2, 52, 4, 4, brush=self._user_brush)
|
||||
|
||||
def updatePositionLabel(self):
|
||||
start = np.round(self._rangeStart * 100, 4)
|
||||
@ -207,6 +212,7 @@ class DetectionTimeline(QWidget):
|
||||
|
||||
def fit_scene_to_view(self):
|
||||
"""Scale the image to fit the QGraphicsView."""
|
||||
logging.debug("Timeline: fit scene to view")
|
||||
self._view.fitInView(self._scene.sceneRect(), Qt.KeepAspectRatio)
|
||||
|
||||
def resizeEvent(self, event):
|
||||
@ -289,6 +295,11 @@ def main():
|
||||
view.setWindowPos(0.0)
|
||||
print(view.windowBounds())
|
||||
|
||||
def as_dict(df):
|
||||
d = {c: df[c].values for c in df.columns}
|
||||
d["index"] = df.index.values
|
||||
return d
|
||||
|
||||
import pickle
|
||||
import numpy as np
|
||||
from PySide6.QtWidgets import QApplication, QPushButton, QHBoxLayout
|
||||
@ -298,13 +309,18 @@ def main():
|
||||
datafile = PACKAGE_ROOT / "data/merged_small.pkl"
|
||||
with open(datafile, "rb") as f:
|
||||
df = pickle.load(f)
|
||||
|
||||
data = TrackingData()
|
||||
data.setData(as_dict(df))
|
||||
data.setUserSelection(np.arange(0,100, 1))
|
||||
data.setAssignmentStatus(True)
|
||||
start_x = 0.1
|
||||
app = QApplication([])
|
||||
window = QWidget()
|
||||
window.setMinimumSize(200, 75)
|
||||
|
||||
view = DetectionTimeline(df)
|
||||
view = DetectionTimeline()
|
||||
view.setData(data)
|
||||
|
||||
fwdBtn = QPushButton(">>")
|
||||
fwdBtn.clicked.connect(lambda: fwd(0.5))
|
||||
zeroBtn = QPushButton("0->|")
|
||||
|
@ -3,12 +3,12 @@ import numpy as np
|
||||
|
||||
from PySide6.QtWidgets import QWidget, QVBoxLayout, QSizePolicy, QGraphicsView, QGraphicsScene, QGraphicsEllipseItem, QGraphicsRectItem
|
||||
from PySide6.QtCore import Qt, QPointF, QRectF, QPointF
|
||||
from PySide6.QtGui import QPixmap, QBrush, QColor, QImage
|
||||
from PySide6.QtGui import QPixmap, QBrush, QColor, QImage, QPen
|
||||
|
||||
from fixtracks.info import PACKAGE_ROOT
|
||||
from fixtracks.utils.signals import DetectionSignals, DetectionViewSignals, DetectionSceneSignals
|
||||
from ..utils.enums import DetectionData
|
||||
|
||||
from fixtracks.utils.enums import DetectionData, Tracks
|
||||
from fixtracks.utils.trackingdata import TrackingData
|
||||
|
||||
class Detection(QGraphicsEllipseItem):
|
||||
signals = DetectionSignals()
|
||||
@ -79,6 +79,7 @@ class DetectionView(QWidget):
|
||||
def __init__(self, parent=None):
|
||||
super().__init__(parent)
|
||||
self._img = None
|
||||
self._data = None
|
||||
self._pixmapitem = None
|
||||
self._scene = DetectionScene()
|
||||
# self.setRenderHint(QGraphicsView.RenderFlag.Ren Antialiasing)
|
||||
@ -90,7 +91,6 @@ class DetectionView(QWidget):
|
||||
self._minZoom = 0.1
|
||||
self._maxZoom = 10
|
||||
self._currentZoom = 1.0
|
||||
|
||||
lyt = QVBoxLayout()
|
||||
lyt.addWidget(self._view)
|
||||
self.setLayout(lyt)
|
||||
@ -116,6 +116,9 @@ class DetectionView(QWidget):
|
||||
self._view.setScene(self._scene)
|
||||
self._view.fitInView(self._scene.sceneRect(), aspectRadioMode=Qt.AspectRatioMode.KeepAspectRatio)
|
||||
|
||||
def setData(self, data:TrackingData):
|
||||
self._data = data
|
||||
|
||||
def clearDetections(self):
|
||||
items = self._scene.items()
|
||||
if items is not None:
|
||||
@ -124,23 +127,34 @@ class DetectionView(QWidget):
|
||||
self._scene.removeItem(it)
|
||||
del it
|
||||
|
||||
def addDetections(self, coordinates:np.array, track_ids:np.array, detection_ids:np.array, frames: np.array,
|
||||
keypoint:int, brush:QBrush):
|
||||
def updateDetections(self, keypoint=-1):
|
||||
self.clearDetections()
|
||||
if self._data is None:
|
||||
return
|
||||
frames = self._data.selectedData("frame")
|
||||
tracks = self._data.selectedData("track")
|
||||
coordinates = self._data.coordinates(selection=True)
|
||||
centercoordinates = self._data.centerOfGravity(selection=True)
|
||||
userlabeled = self._data.selectedData("userlabeled")
|
||||
indices = self._data.selectedData("index")
|
||||
image_rect = self._pixmapitem.boundingRect() if self._pixmapitem is not None else QRectF(0,0,0,0)
|
||||
num_detections = coordinates.shape[0]
|
||||
for i in range(num_detections):
|
||||
x = coordinates[i, keypoint, 0]
|
||||
y = coordinates[i, keypoint, 1]
|
||||
c = brush.color()
|
||||
c.setAlpha(int(i * 255 / num_detections))
|
||||
brush.setColor(c)
|
||||
item = Detection(image_rect.left() + x, image_rect.top() + y, 20, 20, brush=brush)
|
||||
item.setData(DetectionData.TRACK_ID.value, track_ids[i])
|
||||
item.setData(DetectionData.ID.value, detection_ids[i])
|
||||
|
||||
for i, idx in enumerate(indices):
|
||||
t = tracks[i]
|
||||
c = Tracks.fromValue(t).toColor()
|
||||
if keypoint >= 0:
|
||||
x = coordinates[i, keypoint, 0]
|
||||
y = coordinates[i, keypoint, 1]
|
||||
else:
|
||||
x = centercoordinates[i, 0]
|
||||
y = centercoordinates[i, 1]
|
||||
|
||||
item = Detection(image_rect.left() + x, image_rect.top() + y, 20, 20, brush=QBrush(c))
|
||||
item.setData(DetectionData.TRACK_ID.value, tracks[i])
|
||||
item.setData(DetectionData.ID.value, idx)
|
||||
item.setData(DetectionData.COORDINATES.value, coordinates[i, :, :])
|
||||
item.setData(DetectionData.FRAME.value, frames[i])
|
||||
item = self._scene.addItem(item)
|
||||
# logging.debug("DetectionView: Number of items in scene: %i", len(self._scene.items()))
|
||||
|
||||
def fit_image_to_view(self):
|
||||
"""Scale the image to fit the QGraphicsView."""
|
||||
@ -159,7 +173,7 @@ class DetectionView(QWidget):
|
||||
def main():
|
||||
def items_selected(items):
|
||||
print("items selected")
|
||||
|
||||
# FIXME The following code will no longer work...
|
||||
import pickle
|
||||
import numpy as np
|
||||
from IPython import embed
|
||||
|
@ -117,6 +117,7 @@ class SelectionControls(QWidget):
|
||||
deleteBtn.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding)
|
||||
deleteBtn.setStyleSheet(pushBtnStyle("red"))
|
||||
deleteBtn.setToolTip(f"DANGERZONE! Delete current selection of detections!")
|
||||
deleteBtn.setEnabled(False)
|
||||
deleteBtn.clicked.connect(self.on_Delete)
|
||||
|
||||
revertBtn = QPushButton("revert assignments")
|
||||
|
@ -2,8 +2,8 @@ import logging
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from PySide6.QtCore import Qt, QThreadPool, Signal, QSize, QObject
|
||||
from PySide6.QtGui import QImage, QBrush, QColor, QFont
|
||||
from PySide6.QtCore import Qt, QThreadPool, Signal
|
||||
from PySide6.QtGui import QImage, QBrush, QColor
|
||||
from PySide6.QtWidgets import QWidget, QVBoxLayout, QHBoxLayout, QLabel, QPushButton, QSizePolicy, QComboBox
|
||||
from PySide6.QtWidgets import QSpinBox, QSpacerItem, QProgressBar, QSplitter, QGridLayout, QFileDialog, QGridLayout
|
||||
|
||||
@ -28,15 +28,11 @@ class FixTracks(QWidget):
|
||||
self._threadpool = QThreadPool()
|
||||
self._reader = None
|
||||
self._image = None
|
||||
self._clear_detections = True
|
||||
self._currentWindowPos = 0 # in frames
|
||||
self._currentWindowWidth = 0 # in frames
|
||||
self._maxframes = 0
|
||||
self._data = TrackingData()
|
||||
self._brushes = {"assigned_left": QBrush(QColor.fromString("orange")),
|
||||
"assigned_right": QBrush(QColor.fromString("green")),
|
||||
"unassigned": QBrush(QColor.fromString("red"))
|
||||
}
|
||||
|
||||
self._detectionView = DetectionView()
|
||||
self._detectionView.signals.itemsSelected.connect(self.on_detectionsSelected)
|
||||
self._skeleton = SkeletonWidget()
|
||||
@ -60,7 +56,7 @@ class FixTracks(QWidget):
|
||||
self._windowspinner.setSingleStep(50)
|
||||
self._windowspinner.setValue(500)
|
||||
self._windowspinner.valueChanged.connect(self.on_windowSizeChanged)
|
||||
# self._timeline.setWindowWidth(0.01)
|
||||
|
||||
self._keypointcombo = QComboBox()
|
||||
self._keypointcombo.currentIndexChanged.connect(self.on_keypointSelected)
|
||||
|
||||
@ -143,7 +139,7 @@ class FixTracks(QWidget):
|
||||
def on_autoClassify(self, tracks):
|
||||
self._data.setSelectionRange("index", 0, self._data.numDetections)
|
||||
self._data.assignTracks(tracks)
|
||||
self._timeline.setDetectionData(self._data.data)
|
||||
self._timeline.update()
|
||||
self.update()
|
||||
|
||||
def on_dataSelection(self):
|
||||
@ -163,40 +159,15 @@ class FixTracks(QWidget):
|
||||
self._detectionView.setImage(img)
|
||||
|
||||
def update(self):
|
||||
def update_detectionView(df, name):
|
||||
if len(df) == 0:
|
||||
return
|
||||
keypoint = self._keypointcombo.currentIndex()
|
||||
coords = np.stack(df["keypoints"].values).astype(np.float32)[:, :,:]
|
||||
tracks = df["track"].values.astype(int)
|
||||
ids = df.index.values.astype(int)
|
||||
frames = df["frame"].values.astype(int)
|
||||
self._detectionView.addDetections(coords, tracks, ids, frames, keypoint, self._brushes[name])
|
||||
|
||||
start_frame = self._currentWindowPos
|
||||
stop_frame = start_frame + self._currentWindowWidth
|
||||
self._controls_widget.setWindow(start_frame, stop_frame)
|
||||
logging.debug("Tracks:update: Updating View for detection range %i, %i frames", start_frame, stop_frame)
|
||||
self._data.setSelectionRange("frame", start_frame, stop_frame)
|
||||
frames = self._data.selectedData("frame")
|
||||
tracks = self._data.selectedData("track")
|
||||
keypoints = self._data.selectedData("keypoints")
|
||||
index = self._data.selectedData("index")
|
||||
|
||||
df = pd.DataFrame({"frame": frames,
|
||||
"track": tracks,
|
||||
"keypoints": keypoints},
|
||||
index=index)
|
||||
assigned_left = df[(df.track == self.trackone_id)]
|
||||
assigned_right = df[(df.track == self.tracktwo_id)]
|
||||
unassigned = df[(df.track != self.trackone_id) & (df.track != self.tracktwo_id)]
|
||||
|
||||
if self._clear_detections:
|
||||
self._detectionView.clearDetections()
|
||||
update_detectionView(unassigned, "unassigned")
|
||||
update_detectionView(assigned_left, "assigned_left")
|
||||
update_detectionView(assigned_right, "assigned_right")
|
||||
self._classifier.setData(self._data)
|
||||
|
||||
self._controls_widget.setWindow(start_frame, stop_frame)
|
||||
kp = self._keypointcombo.currentText().lower()
|
||||
kpi = -1 if "center" in kp else int(kp)
|
||||
self._detectionView.updateDetections(kpi)
|
||||
|
||||
@property
|
||||
def fileList(self):
|
||||
@ -223,6 +194,7 @@ class FixTracks(QWidget):
|
||||
|
||||
def populateKeypointCombo(self, num_keypoints):
|
||||
self._keypointcombo.clear()
|
||||
self._keypointcombo.addItem("Center")
|
||||
for i in range(num_keypoints):
|
||||
self._keypointcombo.addItem(str(i))
|
||||
self._keypointcombo.setCurrentIndex(0)
|
||||
@ -238,17 +210,13 @@ class FixTracks(QWidget):
|
||||
self._currentWindowWidth = self._windowspinner.value()
|
||||
self._maxframes = self._data.max("frame")
|
||||
self.populateKeypointCombo(self._data.numKeypoints())
|
||||
self._timeline.setDetectionData(self._data.data)
|
||||
self._timeline.setData(self._data)
|
||||
self._timeline.setWindow(self._currentWindowPos / self._maxframes,
|
||||
self._currentWindowWidth / self._maxframes)
|
||||
coordinates = self._data.coordinates()
|
||||
positions = self._data.centerOfGravity()
|
||||
tracks = self._data["track"]
|
||||
frames = self._data["frame"]
|
||||
self._classifier.size_classifier.setCoordinates(coordinates)
|
||||
self._classifier.consistency_tracker.setData(self._data)
|
||||
self._detectionView.setData(self._data)
|
||||
self._classifier.setData(self._data)
|
||||
self.update()
|
||||
logging.info("Finished loading data: %i frames, %i detections", self._maxframes, len(positions))
|
||||
logging.info("Finished loading data: %i frames", self._maxframes)
|
||||
|
||||
def on_keypointSelected(self):
|
||||
self.update()
|
||||
@ -280,38 +248,43 @@ class FixTracks(QWidget):
|
||||
def on_assignOne(self):
|
||||
logging.debug("Assigning user selection to track One")
|
||||
self._data.assignUserSelection(self.trackone_id)
|
||||
self._timeline.setDetectionData(self._data.data)
|
||||
self._timeline.update()
|
||||
self.update()
|
||||
|
||||
def on_assignTwo(self):
|
||||
logging.debug("Assigning user selection to track Two")
|
||||
self._data.assignUserSelection(self.tracktwo_id)
|
||||
self._timeline.setDetectionData(self._data.data)
|
||||
self._timeline.update()
|
||||
self.update()
|
||||
|
||||
def on_assignOther(self):
|
||||
logging.debug("Assigning user selection to track Other")
|
||||
self._data.assignUserSelection(self.trackother_id, False)
|
||||
self._timeline.setDetectionData(self._data.data)
|
||||
self._timeline.update()
|
||||
self.update()
|
||||
|
||||
def on_setUserFlag(self):
|
||||
self._data.setAssignmentStatus(True)
|
||||
self._timeline.update()
|
||||
self.update()
|
||||
|
||||
def on_unsetUserFlag(self):
|
||||
logging.debug("Tracks:unsetUserFlag")
|
||||
self._data.setAssignmentStatus(False)
|
||||
self._timeline.update()
|
||||
self.update()
|
||||
|
||||
def on_revertUserFlags(self):
|
||||
logging.debug("Tracks:revert ALL UserFlags")
|
||||
logging.debug("Tracks:revert ALL UserFlags and track assignments")
|
||||
self._data.revertAssignmentStatus()
|
||||
self._data.revertTrackAssignments()
|
||||
self._timeline.update()
|
||||
self.update()
|
||||
|
||||
def on_deleteDetection(self):
|
||||
logging.debug("Tracks:delete detections")
|
||||
logging.warning("Tracks:delete detections is currently not supported!")
|
||||
# self._data.deleteDetections()
|
||||
self._timeline.update()
|
||||
self.update()
|
||||
|
||||
def on_windowChanged(self):
|
||||
@ -354,7 +327,6 @@ class FixTracks(QWidget):
|
||||
self.update()
|
||||
|
||||
def moveWindow(self, stepsize):
|
||||
self._clear_detections = True
|
||||
step = np.round(stepsize * (self._currentWindowWidth))
|
||||
new_start_frame = self._currentWindowPos + step
|
||||
self._timeline.setWindowPos(new_start_frame / self._maxframes)
|
||||
|
65
main.py
65
main.py
@ -1,65 +0,0 @@
|
||||
"""
|
||||
pyside6-rcc resources.qrc -o resources.py
|
||||
|
||||
"""
|
||||
import sys
|
||||
import platform
|
||||
import logging
|
||||
from PySide6.QtWidgets import QApplication
|
||||
from PySide6.QtCore import QSettings
|
||||
from PySide6.QtGui import QIcon, QPalette
|
||||
|
||||
from fixtracks import fixtracks, info
|
||||
|
||||
logging.basicConfig(level=logging.INFO, force=True)
|
||||
|
||||
|
||||
def is_dark_mode(app: QApplication) -> bool:
|
||||
palette = app.palette()
|
||||
# Check the brightness of the window text and base colors
|
||||
text_color = palette.color(QPalette.ColorRole.WindowText)
|
||||
base_color = palette.color(QPalette.ColorRole.Base)
|
||||
|
||||
# Calculate brightness (0 for dark, 255 for bright)
|
||||
def brightness(color):
|
||||
return (color.red() * 299 + color.green() * 587 + color.blue() * 114) // 1000
|
||||
|
||||
return brightness(base_color) < brightness(text_color)
|
||||
|
||||
|
||||
# import resources # needs to be imported somewhere in the project to be picked up by qt
|
||||
|
||||
if platform.system() == "Windows":
|
||||
# from PySide6.QtWinExtras import QtWin
|
||||
myappid = f"{info.organization_name}.{info.application_version}"
|
||||
# QtWin.setCurrentProcessExplicitAppUserModelID(myappid)
|
||||
settings = QSettings()
|
||||
width = int(settings.value("app/width", 1024))
|
||||
height = int(settings.value("app/height", 768))
|
||||
x = int(settings.value("app/pos_x", 100))
|
||||
y = int(settings.value("app/pos_y", 100))
|
||||
app = QApplication(sys.argv)
|
||||
app.setApplicationName(info.application_name)
|
||||
app.setApplicationVersion(str(info.application_version))
|
||||
app.setOrganizationDomain(info.organization_name)
|
||||
|
||||
# if platform.system() == 'Linux':
|
||||
# icn = QIcon(":/icons/app_icon")
|
||||
# app.setWindowIcon(icn)
|
||||
# Create a Qt widget, which will be our window.
|
||||
window = fixtracks.MainWindow(is_dark_mode(app))
|
||||
window.setGeometry(100, 100, 1024, 768)
|
||||
window.setWindowTitle("FixTracks")
|
||||
window.setMinimumWidth(1024)
|
||||
window.setMinimumHeight(768)
|
||||
window.resize(width, height)
|
||||
window.move(x, y)
|
||||
window.show()
|
||||
|
||||
# Start the event loop.
|
||||
app.exec()
|
||||
pos = window.pos()
|
||||
settings.setValue("app/width", window.width())
|
||||
settings.setValue("app/height", window.height())
|
||||
settings.setValue("app/pos_x", pos.x())
|
||||
settings.setValue("app/pos_y", pos.y())
|
@ -2,7 +2,7 @@
|
||||
name = "fixtracks"
|
||||
version = "0.1.0"
|
||||
description = "A project to fix track metadata"
|
||||
authors = ["Your Name <your.email@example.com>"]
|
||||
authors = ["Your Name <jan.grewe@uni-tuebingen.de>"]
|
||||
license = "MIT"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
|
Loading…
Reference in New Issue
Block a user