Compare commits

...

14 Commits

11 changed files with 302 additions and 230 deletions

80
FixTracks.py Normal file
View File

@ -0,0 +1,80 @@
"""
pyside6-rcc resources.qrc -o resources.py
"""
import sys
import logging
import argparse
import platform
from PySide6.QtWidgets import QApplication
from PySide6.QtCore import QSettings
from PySide6.QtGui import QIcon, QPalette
from fixtracks import info, mainwindow
logging.basicConfig(level=logging.INFO, force=True)
def is_dark_mode(app: QApplication) -> bool:
palette = app.palette()
# Check the brightness of the window text and base colors
text_color = palette.color(QPalette.ColorRole.WindowText)
base_color = palette.color(QPalette.ColorRole.Base)
# Calculate brightness (0 for dark, 255 for bright)
def brightness(color):
return (color.red() * 299 + color.green() * 587 + color.blue() * 114) // 1000
return brightness(base_color) < brightness(text_color)
def set_logging(loglevel):
logging.basicConfig(level=loglevel, force=True)
def main(args):
set_logging(logging.DEBUG)
if platform.system() == "Windows":
# from PySide6.QtWinExtras import QtWin
myappid = f"{info.organization_name}.{info.application_version}"
# QtWin.setCurrentProcessExplicitAppUserModelID(myappid)
settings = QSettings()
width = int(settings.value("app/width", 1024))
height = int(settings.value("app/height", 768))
x = int(settings.value("app/pos_x", 100))
y = int(settings.value("app/pos_y", 100))
app = QApplication(sys.argv)
app.setApplicationName(info.application_name)
app.setApplicationVersion(str(info.application_version))
app.setOrganizationDomain(info.organization_name)
# if platform.system() == 'Linux':
# icn = QIcon(":/icons/app_icon")
# app.setWindowIcon(icn)
# Create a Qt widget, which will be our window.
window = mainwindow.MainWindow(is_dark_mode(app))
window.setGeometry(100, 100, 1024, 768)
window.setWindowTitle("FixTracks")
window.setMinimumWidth(1024)
window.setMinimumHeight(768)
window.resize(width, height)
window.move(x, y)
window.show()
# Start the event loop.
app.exec()
pos = window.pos()
settings.setValue("app/width", window.width())
settings.setValue("app/height", window.height())
settings.setValue("app/pos_x", pos.x())
settings.setValue("app/pos_y", pos.y())
if __name__ == "__main__":
levels = {"critical": logging.CRITICAL, "error": logging.ERROR, "warning":logging.WARNING, "info":logging.INFO, "debug":logging.DEBUG}
parser = argparse.ArgumentParser(description="FixTracks. Tools for fixing animal tracking")
parser.add_argument("-ll", "--loglevel", type=str, default="INFO", help=f"The log level that should be used. Valid levels are {[str(k) for k in levels.keys()]}")
args = parser.parse_args()
args.loglevel = levels[args.loglevel if args.loglevel.lower() in levels else "info"]
main(args)

View File

@ -144,8 +144,8 @@ class MainWindow(QMainWindow):
about.show()
def on_help(self, s):
help = HelpDialog(self)
help.show()
help_dlg = HelpDialog(self)
help_dlg.show()
# @Slot(None)
def exit_request(self):

View File

@ -1,7 +1,29 @@
from enum import Enum
from PySide6.QtGui import QColor
class DetectionData(Enum):
ID = 0
FRAME = 1
COORDINATES = 2
TRACK_ID = 3
TRACK_ID = 3
class Tracks(Enum):
TRACKONE = 1
TRACKTWO = 2
UNASSIGNED = -1
def toColor(self):
track_colors = {
Tracks.TRACKONE: QColor.fromString("orange"),
Tracks.TRACKTWO: QColor.fromString("green"),
Tracks.UNASSIGNED: QColor.fromString("red")
}
return track_colors.get(self, QColor(128, 128, 128)) # Default to black if not found
@staticmethod
def fromValue(value):
for track in Tracks:
if track.value == value:
return track
return Tracks.UNASSIGNED

View File

@ -5,6 +5,7 @@ import pandas as pd
from PySide6.QtCore import QObject
class TrackingData(QObject):
def __init__(self, parent=None):
super().__init__(parent)
@ -58,10 +59,15 @@ class TrackingData(QObject):
self._start = start
self._stop = stop
self._selection_column = col
self._indices = np.where((self._data[col] >= self._start) & (self._data[col] < self._stop))[0]
col_indices = np.where((self._data[col] >= self._start) & (self._data[col] < self._stop))[0]
self._indices = self["index"][col_indices]
if len(col_indices) < 1:
logging.warning("TrackingData: Selection range is empty!")
def selectedData(self, col):
return self._data[col][self._indices]
def selectedData(self, col:str):
if col not in self.columns:
logging.error("TrackingData:selectedData: Invalid column name! %s", col)
return self[col][self._indices]
def setUserSelection(self, ids):
"""
@ -84,14 +90,20 @@ class TrackingData(QObject):
userFlag : bool
Should the "userlabeled" state of the detections be set to True or False?
"""
self._data["track"][self._user_selections] = track_id
self["track"][self._user_selections] = track_id
self.setAssignmentStatus(userFlag)
def setAssignmentStatus(self, isTrue: bool):
self._data["userlabeled"][self._user_selections] = isTrue
logging.debug("TrackingData:Re-setting assignment status of user selected data to %s", str(isTrue))
self["userlabeled"][self._user_selections] = isTrue
def revertAssignmentStatus(self):
self._data["userlabeled"][:] = False
logging.debug("TrackingData:Un-setting assignment status of all data!")
self["userlabeled"][:] = False
def revertTrackAssignments(self):
logging.debug("TrackingData: Reverting all track assignments!")
self["track"][:] = -1
def deleteDetections(self):
# from IPython import embed
@ -102,21 +114,21 @@ class TrackingData(QObject):
# embed()
pass
def assignTracks(self, tracks):
"""assignTracks _summary_
def assignTracks(self, tracks:np.ndarray):
"""assigns the given tracks to the user-selected detections. If the sizes of
provided tracks and the user selection do not match and error is logged and the tracks are not set.
Parameters
----------
tracks : _type_
_description_
tracks : np.ndarray
The track information.
Returns
-------
_type_
_description_
None
"""
if len(tracks) != self.numDetections:
logging.error("DataController: Size of passed tracks does not match data!")
logging.error("Trackingdata: Size of passed tracks does not match data!")
return
self._data["track"] = tracks
@ -142,11 +154,13 @@ class TrackingData(QObject):
and M is number of keypoints
"""
if selection:
return np.stack(self._data["keypoints"][self._start:self._stop, :, :]).astype(np.float32)
else:
return np.stack(self._data["keypoints"]).astype(np.float32)
if len(self._indices) < 1:
logging.info("TrackingData.coordinates returns empty array, not detections in range!")
return np.ndarray([])
return np.stack(self._data["keypoints"][self._indices]).astype(np.float32)
return np.stack(self._data["keypoints"]).astype(np.float32)
def keypointScores(self):
def keypointScores(self, selection=False):
"""
Returns the keypoint scores as a NumPy array of type float32.
@ -155,10 +169,15 @@ class TrackingData(QObject):
numpy.ndarray
A NumPy array of type float32 containing the keypoint scores of the shape (N, M)
with N the number of detections and M the number of keypoints.
"""
"""
if selection:
if len(self._indices) < 1:
logging.info("TrackingData.scores returns empty array, not detections in range!")
return np.ndarray([])
return np.stack(self._data["keypoint_score"][self._indices]).astype(np.float32)
return np.stack(self._data["keypoint_score"]).astype(np.float32)
def centerOfGravity(self, threshold=0.8):
def centerOfGravity(self, selection=False, threshold=0.8, nodes=[0,1,2]):
"""
Calculate the center of gravity of keypoints weighted by their scores. Ignores keypoints that have a score
less than threshold.
@ -166,16 +185,19 @@ class TrackingData(QObject):
Parameters:
-----------
threshold: float
keypoints with a score less than threshold are ignored
nodes with a score less than threshold are ignored
nodes: list
nodes/keypoints to consider for estimation. Defaults to [0,1,2]
Returns:
--------
np.ndarray:
A NumPy array of shape (N, 2) containing the center of gravity for each detection.
"""
scores = self.keypointScores()
scores = self.keypointScores(selection)
scores[scores < threshold] = 0.0
weighted_coords = self.coordinates() * scores[:, :, np.newaxis]
scores[:, np.setdiff1d(np.arange(scores.shape[1]), nodes)] = 0.0
weighted_coords = self.coordinates(selection=selection) * scores[:, :, np.newaxis]
sum_scores = np.sum(scores, axis=1, keepdims=True)
center_of_gravity = np.sum(weighted_coords, axis=1) / sum_scores
return center_of_gravity
@ -279,19 +301,19 @@ def main():
frames = data["frame"]
tracks = data["track"]
bendedness = data.bendedness()
positions = data.coordinates()[[160388, 160389]]
# positions = data.coordinates()[[160388, 160389]]
embed()
tracks = data["track"]
cogs = all_cogs[tracks==1]
all_dists = neighborDistances(cogs, 2, False)
plt.hist(all_dists[1:, 0], bins=1000)
print(np.percentile(all_dists[1:, 0], 99))
print(np.percentile(all_dists[1:, 0], 1))
plt.gca().set_xscale("log")
plt.gca().set_yscale("log")
# plt.hist(all_dists[1:, 0], bins=1000)
# print(np.percentile(all_dists[1:, 0], 99))
# print(np.percentile(all_dists[1:, 0], 1))
# plt.gca().set_xscale("log")
# plt.gca().set_yscale("log")
# plt.hist(all_dists[1:, 1], bins=100)
plt.show()
# plt.show()
# def compute_neighbor_distances(cogs, window=10):
# distances = []
# for i in range(len(cogs)):
@ -303,9 +325,6 @@ def main():
# return distances
# print("estimating neighorhood distances")
# neighbor_distances = compute_neighbor_distances(cogs)
embed()
if __name__ == "__main__":
main()

View File

@ -34,6 +34,9 @@ class ConsitencyDataLoader(QRunnable):
@Slot()
def run(self):
if self.data is None:
logging.error("ConsistencyTracker.DataLoader failed. No Data!")
return
self.positions = self.data.centerOfGravity()
self.orientations = self.data.orientation()
self.lengths = self.data.animalLength()
@ -464,9 +467,6 @@ class ConsistencyClassifier(QWidget):
self.setEnabled(False)
self._progressbar.setRange(0,0)
self._data = data
self._dataworker = ConsitencyDataLoader(self._data)
self._dataworker.signals.stopped.connect(self.data_processed)
self.threadpool.start(self._dataworker)
@Slot()
def data_processed(self):
@ -482,6 +482,7 @@ class ConsistencyClassifier(QWidget):
self._frames = self._dataworker.frames
self._tracks = self._dataworker.tracks
self._maxframes = np.max(self._frames)
# FIXME the following line causes an error when there are no detections in the range
min_frame = max([self._frames[self._tracks == 1][0], self._frames[self._tracks == 2][0]]) + 1
self._maxframeslabel.setText(str(self._maxframes))
self._startframe_spinner.setMinimum(min_frame)
@ -525,7 +526,9 @@ class ConsistencyClassifier(QWidget):
self.start()
def refresh(self):
self.setData(self._data)
self._dataworker = ConsitencyDataLoader(self._data)
self._dataworker.signals.stopped.connect(self.data_processed)
self.threadpool.start(self._dataworker)
def worker_progress(self, progress, processed, errors):
self._progressbar.setValue(progress)
@ -556,7 +559,8 @@ class ClassifierWidget(QTabWidget):
self._consistency_tracker = ConsistencyClassifier()
self.addTab(self._size_classifier, SizeClassifier.name)
self.addTab(self._consistency_tracker, ConsistencyClassifier.name)
self.tabBarClicked.connect(self.update)
# self.tabBarClicked.connect(self.update)
self.currentChanged.connect(self.tabChanged)
self._size_classifier.apply.connect(self._on_applySizeClassifier)
self._consistency_tracker.apply.connect(self._on_applyConsistencyTracker)
@ -576,12 +580,21 @@ class ClassifierWidget(QTabWidget):
def consistency_tracker(self):
return self._consistency_tracker
@Slot()
def tabChanged(self):
if isinstance(self.currentWidget(), ConsistencyClassifier):
self.consistency_tracker.refresh()
@Slot()
def update(self):
self.consistency_tracker.setData(self._data)
if isinstance(self.currentWidget(), ConsistencyClassifier):
self.consistency_tracker.refresh()
def setData(self, data:TrackingData):
self._data = data
self.consistency_tracker.setData(data)
coordinates = self._data.coordinates()
self._size_classifier.setCoordinates(coordinates)
def as_dict(df):
d = {c: df[c].values for c in df.columns}

View File

@ -1,13 +1,14 @@
import logging
import numpy as np
import pandas as pd
from PySide6.QtCore import Qt
from PySide6.QtWidgets import QWidget, QVBoxLayout, QSizePolicy, QLabel
from PySide6.QtWidgets import QGraphicsView, QGraphicsScene, QGraphicsItem, QGraphicsRectItem, QGraphicsLineItem
from PySide6.QtWidgets import QGraphicsView, QGraphicsScene, QGraphicsItem, QGraphicsRectItem, QGraphicsLineItem, QGraphicsEllipseItem
from PySide6.QtCore import Qt, QRectF, QRectF
from PySide6.QtGui import QBrush, QColor, QPen, QFont
from fixtracks.utils.signals import DetectionTimelineSignals
from fixtracks.utils.trackingdata import TrackingData
class Window(QGraphicsRectItem):
@ -40,24 +41,23 @@ class Window(QGraphicsRectItem):
self.signals.windowMoved.emit()
def setWindow(self, newx:float, newwidth:float):
def setWindow(self, newx: float, newwidth: float):
"""
Update the window to the specified range.
Parameters
----------
newx : float
The new x-coordinate of the window.
newwidth : float
The new width of the window.
Returns
-------
None
"""
"""
Update the window to the specified range.
Parameters
----------
newx : float
The new x-coordinate of the window.
newwidth : float
The new width of the window.
Returns
-------
None
"""
logging.debug("timeline.window: update window to range %.5f to %.5f", newx, newwidth)
self._width = newwidth
r = self.rect()
self.setRect(newx, r.y(), self._width, r.height())
self.update()
self.signals.windowMoved.emit()
def mouseMoveEvent(self, event):
@ -86,40 +86,44 @@ class Window(QGraphicsRectItem):
class DetectionTimeline(QWidget):
signals = DetectionTimelineSignals()
def __init__(self, detectiondata=None, trackone_id=1, tracktwo_id=2, parent=None):
def __init__(self, trackone_id=1, tracktwo_id=2, parent=None):
super().__init__(parent)
self._trackone = trackone_id
self._tracktwo = tracktwo_id
self._data = detectiondata
self._data = None
self._rangeStart = 0.0
self._rangeStop = 0.005
self._total_width = 2000
self._stepCount = 300
self._stepCount = 1000
self._bg_brush = QBrush(QColor(20, 20, 20, 255))
transparent_brush = QBrush(QColor(200, 200, 200, 64))
self._white_pen = QPen(QColor.fromString("white"))
self._white_pen.setWidth(0.1)
self._t1_pen = QPen(QColor.fromString("orange"))
self._t1_pen.setWidth(2)
self._t1_pen.setWidth(1)
self._t2_pen = QPen(QColor(0, 255, 0, 255))
self._t2_pen.setWidth(2)
self._t2_pen.setWidth(1)
self._other_pen = QPen(QColor.fromString("red"))
self._other_pen.setWidth(2)
axis_pen = QPen(QColor.fromString("white"))
axis_pen.setWidth(2)
self._other_pen.setWidth(1)
window_pen = QPen(QColor.fromString("white"))
window_pen.setWidth(2)
self._user_brush = QBrush(QColor.fromString("white"))
user_pen = QPen(QColor.fromString("white"))
user_pen.setWidth(2)
font = QFont()
font.setPointSize(15)
font.setBold(False)
font.setBold(True)
self._window = Window(0, 0, 100, 60, axis_pen, transparent_brush)
self._window = Window(0, 0, 100, 60, window_pen, transparent_brush)
self._window.signals.windowMoved.connect(self.on_windowMoved)
self._scene = QGraphicsScene(QRectF(0, 0, self._total_width, 65.))
self._scene = QGraphicsScene(QRectF(0, 0, self._total_width, 85.))
self._scene.setBackgroundBrush(self._bg_brush)
self._scene.addItem(self._window)
self._view = QGraphicsView()
# self._view.setRenderHints(QPainter.RenderHint.Antialiasing | QPainter.RenderHint.SmoothPixmapTransform);
# self._view.setRenderHints(QPainter.RenderHint.Antialiasing | QPainter.RenderHint.SmoothPixmapTransform)
self._view.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding)
self._view.setScene(self._scene)
self._view.fitInView(self._scene.sceneRect(), aspectRadioMode=Qt.AspectRatioMode.KeepAspectRatio)
@ -136,6 +140,10 @@ class DetectionTimeline(QWidget):
other_label.setFont(font)
other_label.setDefaultTextColor(self._other_pen.color())
other_label.setPos(200, 50)
user_label = self._scene.addText("user-labeled", font)
user_label.setFont(font)
user_label.setDefaultTextColor(user_pen.color())
user_label.setPos(350, 50)
self._position_label = QLabel("")
f = self._position_label.font()
@ -146,51 +154,48 @@ class DetectionTimeline(QWidget):
layout.addWidget(self._view)
layout.addWidget(self._position_label, Qt.AlignmentFlag.AlignRight)
self.setLayout(layout)
if self._data is not None:
self.draw_coverage()
# self.setMaximumHeight(100)
# self.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Fixed)
def setDetectionData(self, data):
self._data = data
def clear(self):
for i in self._scene.items():
if isinstance(i, QGraphicsLineItem):
if isinstance(i, (QGraphicsLineItem, QGraphicsEllipseItem)):
self._scene.removeItem(i)
def setData(self, data:TrackingData):
logging.debug("Timeline: setData!")
self._data = data
self.update()
def update(self):
self.clear()
self.draw_coverage()
def draw_coverage(self):
# FIXME this must be disentangled. timeline should not have to deal with two different ways of data storage
if isinstance(self._data, pd.DataFrame):
maxframe = np.max(self._data.frame.values)
bins = np.linspace(0, maxframe, self._stepCount)
pos = np.linspace(0, self._scene.width(), self._stepCount)
track1_frames = self._data.frame.values[self._data.track == self._trackone]
track2_frames = self._data.frame.values[self._data.track == self._tracktwo]
other_frames = self._data.frame.values[(self._data.track != self._trackone) &
(self._data.track != self._tracktwo)]
elif isinstance(self._data, dict):
logging.debug("Timeline: drawCoverage!")
if isinstance(self._data, TrackingData):
maxframe = np.max(self._data["frame"])
bins = np.linspace(0, maxframe, self._stepCount)
pos = np.linspace(0, self._scene.width(), self._stepCount)
pos = np.linspace(0, self._scene.width(), self._stepCount) # of the vertical dashes is this correct?
track1_frames = self._data["frame"][self._data["track"] == self._trackone]
track2_frames = self._data["frame"][self._data["track"] == self._tracktwo]
other_frames = self._data["frame"][(self._data["track"] != self._trackone) &
(self._data["track"] != self._tracktwo)]
userlabeled = self._data["frame"][self._data["userlabeled"]]
else:
print("Data is not trackingdata")
return
t1_coverage, _ = np.histogram(track1_frames, bins=bins)
t1_coverage = t1_coverage > 0
t2_coverage, _ = np.histogram(track2_frames, bins=bins)
t2_coverage = t2_coverage > 0
other_coverage, _ = np.histogram(other_frames, bins=bins)
other_coverage = other_coverage > 0
labeled_coverage, _ = np.histogram(userlabeled, bins=bins)
for i in range(len(t1_coverage)-1):
for i in range(len(bins)-1):
if t1_coverage[i]: self._scene.addLine(pos[i], 0, pos[i], 15., pen=self._t1_pen)
if t2_coverage[i]: self._scene.addLine(pos[i], 17, pos[i], 32., pen=self._t2_pen)
if other_coverage[i]: self._scene.addLine(pos[i], 34, pos[i], 49., pen=self._other_pen)
if other_coverage[i]: self._scene.addLine(pos[i], 34, pos[i], 49., pen=self._other_pen)
if labeled_coverage[i]: self._scene.addEllipse(pos[i]-2, 52, 4, 4, brush=self._user_brush)
def updatePositionLabel(self):
start = np.round(self._rangeStart * 100, 4)
@ -207,6 +212,7 @@ class DetectionTimeline(QWidget):
def fit_scene_to_view(self):
"""Scale the image to fit the QGraphicsView."""
logging.debug("Timeline: fit scene to view")
self._view.fitInView(self._scene.sceneRect(), Qt.KeepAspectRatio)
def resizeEvent(self, event):
@ -289,6 +295,11 @@ def main():
view.setWindowPos(0.0)
print(view.windowBounds())
def as_dict(df):
d = {c: df[c].values for c in df.columns}
d["index"] = df.index.values
return d
import pickle
import numpy as np
from PySide6.QtWidgets import QApplication, QPushButton, QHBoxLayout
@ -298,13 +309,18 @@ def main():
datafile = PACKAGE_ROOT / "data/merged_small.pkl"
with open(datafile, "rb") as f:
df = pickle.load(f)
data = TrackingData()
data.setData(as_dict(df))
data.setUserSelection(np.arange(0,100, 1))
data.setAssignmentStatus(True)
start_x = 0.1
app = QApplication([])
window = QWidget()
window.setMinimumSize(200, 75)
view = DetectionTimeline(df)
view = DetectionTimeline()
view.setData(data)
fwdBtn = QPushButton(">>")
fwdBtn.clicked.connect(lambda: fwd(0.5))
zeroBtn = QPushButton("0->|")

View File

@ -3,12 +3,12 @@ import numpy as np
from PySide6.QtWidgets import QWidget, QVBoxLayout, QSizePolicy, QGraphicsView, QGraphicsScene, QGraphicsEllipseItem, QGraphicsRectItem
from PySide6.QtCore import Qt, QPointF, QRectF, QPointF
from PySide6.QtGui import QPixmap, QBrush, QColor, QImage
from PySide6.QtGui import QPixmap, QBrush, QColor, QImage, QPen
from fixtracks.info import PACKAGE_ROOT
from fixtracks.utils.signals import DetectionSignals, DetectionViewSignals, DetectionSceneSignals
from ..utils.enums import DetectionData
from fixtracks.utils.enums import DetectionData, Tracks
from fixtracks.utils.trackingdata import TrackingData
class Detection(QGraphicsEllipseItem):
signals = DetectionSignals()
@ -79,6 +79,7 @@ class DetectionView(QWidget):
def __init__(self, parent=None):
super().__init__(parent)
self._img = None
self._data = None
self._pixmapitem = None
self._scene = DetectionScene()
# self.setRenderHint(QGraphicsView.RenderFlag.Ren Antialiasing)
@ -90,7 +91,6 @@ class DetectionView(QWidget):
self._minZoom = 0.1
self._maxZoom = 10
self._currentZoom = 1.0
lyt = QVBoxLayout()
lyt.addWidget(self._view)
self.setLayout(lyt)
@ -116,6 +116,9 @@ class DetectionView(QWidget):
self._view.setScene(self._scene)
self._view.fitInView(self._scene.sceneRect(), aspectRadioMode=Qt.AspectRatioMode.KeepAspectRatio)
def setData(self, data:TrackingData):
self._data = data
def clearDetections(self):
items = self._scene.items()
if items is not None:
@ -124,23 +127,34 @@ class DetectionView(QWidget):
self._scene.removeItem(it)
del it
def addDetections(self, coordinates:np.array, track_ids:np.array, detection_ids:np.array, frames: np.array,
keypoint:int, brush:QBrush):
def updateDetections(self, keypoint=-1):
self.clearDetections()
if self._data is None:
return
frames = self._data.selectedData("frame")
tracks = self._data.selectedData("track")
coordinates = self._data.coordinates(selection=True)
centercoordinates = self._data.centerOfGravity(selection=True)
userlabeled = self._data.selectedData("userlabeled")
indices = self._data.selectedData("index")
image_rect = self._pixmapitem.boundingRect() if self._pixmapitem is not None else QRectF(0,0,0,0)
num_detections = coordinates.shape[0]
for i in range(num_detections):
x = coordinates[i, keypoint, 0]
y = coordinates[i, keypoint, 1]
c = brush.color()
c.setAlpha(int(i * 255 / num_detections))
brush.setColor(c)
item = Detection(image_rect.left() + x, image_rect.top() + y, 20, 20, brush=brush)
item.setData(DetectionData.TRACK_ID.value, track_ids[i])
item.setData(DetectionData.ID.value, detection_ids[i])
for i, idx in enumerate(indices):
t = tracks[i]
c = Tracks.fromValue(t).toColor()
if keypoint >= 0:
x = coordinates[i, keypoint, 0]
y = coordinates[i, keypoint, 1]
else:
x = centercoordinates[i, 0]
y = centercoordinates[i, 1]
item = Detection(image_rect.left() + x, image_rect.top() + y, 20, 20, brush=QBrush(c))
item.setData(DetectionData.TRACK_ID.value, tracks[i])
item.setData(DetectionData.ID.value, idx)
item.setData(DetectionData.COORDINATES.value, coordinates[i, :, :])
item.setData(DetectionData.FRAME.value, frames[i])
item = self._scene.addItem(item)
# logging.debug("DetectionView: Number of items in scene: %i", len(self._scene.items()))
def fit_image_to_view(self):
"""Scale the image to fit the QGraphicsView."""
@ -159,7 +173,7 @@ class DetectionView(QWidget):
def main():
def items_selected(items):
print("items selected")
# FIXME The following code will no longer work...
import pickle
import numpy as np
from IPython import embed

View File

@ -117,6 +117,7 @@ class SelectionControls(QWidget):
deleteBtn.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding)
deleteBtn.setStyleSheet(pushBtnStyle("red"))
deleteBtn.setToolTip(f"DANGERZONE! Delete current selection of detections!")
deleteBtn.setEnabled(False)
deleteBtn.clicked.connect(self.on_Delete)
revertBtn = QPushButton("revert assignments")

View File

@ -2,8 +2,8 @@ import logging
import numpy as np
import pandas as pd
from PySide6.QtCore import Qt, QThreadPool, Signal, QSize, QObject
from PySide6.QtGui import QImage, QBrush, QColor, QFont
from PySide6.QtCore import Qt, QThreadPool, Signal
from PySide6.QtGui import QImage, QBrush, QColor
from PySide6.QtWidgets import QWidget, QVBoxLayout, QHBoxLayout, QLabel, QPushButton, QSizePolicy, QComboBox
from PySide6.QtWidgets import QSpinBox, QSpacerItem, QProgressBar, QSplitter, QGridLayout, QFileDialog, QGridLayout
@ -28,15 +28,11 @@ class FixTracks(QWidget):
self._threadpool = QThreadPool()
self._reader = None
self._image = None
self._clear_detections = True
self._currentWindowPos = 0 # in frames
self._currentWindowWidth = 0 # in frames
self._maxframes = 0
self._data = TrackingData()
self._brushes = {"assigned_left": QBrush(QColor.fromString("orange")),
"assigned_right": QBrush(QColor.fromString("green")),
"unassigned": QBrush(QColor.fromString("red"))
}
self._detectionView = DetectionView()
self._detectionView.signals.itemsSelected.connect(self.on_detectionsSelected)
self._skeleton = SkeletonWidget()
@ -60,7 +56,7 @@ class FixTracks(QWidget):
self._windowspinner.setSingleStep(50)
self._windowspinner.setValue(500)
self._windowspinner.valueChanged.connect(self.on_windowSizeChanged)
# self._timeline.setWindowWidth(0.01)
self._keypointcombo = QComboBox()
self._keypointcombo.currentIndexChanged.connect(self.on_keypointSelected)
@ -143,7 +139,7 @@ class FixTracks(QWidget):
def on_autoClassify(self, tracks):
self._data.setSelectionRange("index", 0, self._data.numDetections)
self._data.assignTracks(tracks)
self._timeline.setDetectionData(self._data.data)
self._timeline.update()
self.update()
def on_dataSelection(self):
@ -163,40 +159,15 @@ class FixTracks(QWidget):
self._detectionView.setImage(img)
def update(self):
def update_detectionView(df, name):
if len(df) == 0:
return
keypoint = self._keypointcombo.currentIndex()
coords = np.stack(df["keypoints"].values).astype(np.float32)[:, :,:]
tracks = df["track"].values.astype(int)
ids = df.index.values.astype(int)
frames = df["frame"].values.astype(int)
self._detectionView.addDetections(coords, tracks, ids, frames, keypoint, self._brushes[name])
start_frame = self._currentWindowPos
stop_frame = start_frame + self._currentWindowWidth
self._controls_widget.setWindow(start_frame, stop_frame)
logging.debug("Tracks:update: Updating View for detection range %i, %i frames", start_frame, stop_frame)
self._data.setSelectionRange("frame", start_frame, stop_frame)
frames = self._data.selectedData("frame")
tracks = self._data.selectedData("track")
keypoints = self._data.selectedData("keypoints")
index = self._data.selectedData("index")
df = pd.DataFrame({"frame": frames,
"track": tracks,
"keypoints": keypoints},
index=index)
assigned_left = df[(df.track == self.trackone_id)]
assigned_right = df[(df.track == self.tracktwo_id)]
unassigned = df[(df.track != self.trackone_id) & (df.track != self.tracktwo_id)]
if self._clear_detections:
self._detectionView.clearDetections()
update_detectionView(unassigned, "unassigned")
update_detectionView(assigned_left, "assigned_left")
update_detectionView(assigned_right, "assigned_right")
self._classifier.setData(self._data)
self._controls_widget.setWindow(start_frame, stop_frame)
kp = self._keypointcombo.currentText().lower()
kpi = -1 if "center" in kp else int(kp)
self._detectionView.updateDetections(kpi)
@property
def fileList(self):
@ -223,6 +194,7 @@ class FixTracks(QWidget):
def populateKeypointCombo(self, num_keypoints):
self._keypointcombo.clear()
self._keypointcombo.addItem("Center")
for i in range(num_keypoints):
self._keypointcombo.addItem(str(i))
self._keypointcombo.setCurrentIndex(0)
@ -238,17 +210,13 @@ class FixTracks(QWidget):
self._currentWindowWidth = self._windowspinner.value()
self._maxframes = self._data.max("frame")
self.populateKeypointCombo(self._data.numKeypoints())
self._timeline.setDetectionData(self._data.data)
self._timeline.setData(self._data)
self._timeline.setWindow(self._currentWindowPos / self._maxframes,
self._currentWindowWidth / self._maxframes)
coordinates = self._data.coordinates()
positions = self._data.centerOfGravity()
tracks = self._data["track"]
frames = self._data["frame"]
self._classifier.size_classifier.setCoordinates(coordinates)
self._classifier.consistency_tracker.setData(self._data)
self._detectionView.setData(self._data)
self._classifier.setData(self._data)
self.update()
logging.info("Finished loading data: %i frames, %i detections", self._maxframes, len(positions))
logging.info("Finished loading data: %i frames", self._maxframes)
def on_keypointSelected(self):
self.update()
@ -280,38 +248,43 @@ class FixTracks(QWidget):
def on_assignOne(self):
logging.debug("Assigning user selection to track One")
self._data.assignUserSelection(self.trackone_id)
self._timeline.setDetectionData(self._data.data)
self._timeline.update()
self.update()
def on_assignTwo(self):
logging.debug("Assigning user selection to track Two")
self._data.assignUserSelection(self.tracktwo_id)
self._timeline.setDetectionData(self._data.data)
self._timeline.update()
self.update()
def on_assignOther(self):
logging.debug("Assigning user selection to track Other")
self._data.assignUserSelection(self.trackother_id, False)
self._timeline.setDetectionData(self._data.data)
self._timeline.update()
self.update()
def on_setUserFlag(self):
self._data.setAssignmentStatus(True)
self._timeline.update()
self.update()
def on_unsetUserFlag(self):
logging.debug("Tracks:unsetUserFlag")
self._data.setAssignmentStatus(False)
self._timeline.update()
self.update()
def on_revertUserFlags(self):
logging.debug("Tracks:revert ALL UserFlags")
logging.debug("Tracks:revert ALL UserFlags and track assignments")
self._data.revertAssignmentStatus()
self._data.revertTrackAssignments()
self._timeline.update()
self.update()
def on_deleteDetection(self):
logging.debug("Tracks:delete detections")
logging.warning("Tracks:delete detections is currently not supported!")
# self._data.deleteDetections()
self._timeline.update()
self.update()
def on_windowChanged(self):
@ -354,7 +327,6 @@ class FixTracks(QWidget):
self.update()
def moveWindow(self, stepsize):
self._clear_detections = True
step = np.round(stepsize * (self._currentWindowWidth))
new_start_frame = self._currentWindowPos + step
self._timeline.setWindowPos(new_start_frame / self._maxframes)

65
main.py
View File

@ -1,65 +0,0 @@
"""
pyside6-rcc resources.qrc -o resources.py
"""
import sys
import platform
import logging
from PySide6.QtWidgets import QApplication
from PySide6.QtCore import QSettings
from PySide6.QtGui import QIcon, QPalette
from fixtracks import fixtracks, info
logging.basicConfig(level=logging.INFO, force=True)
def is_dark_mode(app: QApplication) -> bool:
palette = app.palette()
# Check the brightness of the window text and base colors
text_color = palette.color(QPalette.ColorRole.WindowText)
base_color = palette.color(QPalette.ColorRole.Base)
# Calculate brightness (0 for dark, 255 for bright)
def brightness(color):
return (color.red() * 299 + color.green() * 587 + color.blue() * 114) // 1000
return brightness(base_color) < brightness(text_color)
# import resources # needs to be imported somewhere in the project to be picked up by qt
if platform.system() == "Windows":
# from PySide6.QtWinExtras import QtWin
myappid = f"{info.organization_name}.{info.application_version}"
# QtWin.setCurrentProcessExplicitAppUserModelID(myappid)
settings = QSettings()
width = int(settings.value("app/width", 1024))
height = int(settings.value("app/height", 768))
x = int(settings.value("app/pos_x", 100))
y = int(settings.value("app/pos_y", 100))
app = QApplication(sys.argv)
app.setApplicationName(info.application_name)
app.setApplicationVersion(str(info.application_version))
app.setOrganizationDomain(info.organization_name)
# if platform.system() == 'Linux':
# icn = QIcon(":/icons/app_icon")
# app.setWindowIcon(icn)
# Create a Qt widget, which will be our window.
window = fixtracks.MainWindow(is_dark_mode(app))
window.setGeometry(100, 100, 1024, 768)
window.setWindowTitle("FixTracks")
window.setMinimumWidth(1024)
window.setMinimumHeight(768)
window.resize(width, height)
window.move(x, y)
window.show()
# Start the event loop.
app.exec()
pos = window.pos()
settings.setValue("app/width", window.width())
settings.setValue("app/height", window.height())
settings.setValue("app/pos_x", pos.x())
settings.setValue("app/pos_y", pos.y())

View File

@ -2,7 +2,7 @@
name = "fixtracks"
version = "0.1.0"
description = "A project to fix track metadata"
authors = ["Your Name <your.email@example.com>"]
authors = ["Your Name <jan.grewe@uni-tuebingen.de>"]
license = "MIT"
[tool.poetry.dependencies]