From 64e75ba4b01f107aaecb0855b36fe987569005e4 Mon Sep 17 00:00:00 2001
From: Jan Grewe <jan.grewe@g-node.org>
Date: Fri, 21 Feb 2025 16:24:18 +0100
Subject: [PATCH] [tracks] delegate functionality to widgets, cleanup

---
 fixtracks/widgets/tracks.py | 68 ++++++++++++-------------------------
 1 file changed, 22 insertions(+), 46 deletions(-)

diff --git a/fixtracks/widgets/tracks.py b/fixtracks/widgets/tracks.py
index 02ca094..4c4a524 100644
--- a/fixtracks/widgets/tracks.py
+++ b/fixtracks/widgets/tracks.py
@@ -2,8 +2,8 @@ import logging
 import numpy as np
 import pandas as pd
 
-from PySide6.QtCore import Qt, QThreadPool, Signal, QSize, QObject
-from PySide6.QtGui import QImage, QBrush, QColor, QFont
+from PySide6.QtCore import Qt, QThreadPool, Signal
+from PySide6.QtGui import QImage, QBrush, QColor
 from PySide6.QtWidgets import QWidget, QVBoxLayout, QHBoxLayout, QLabel, QPushButton, QSizePolicy, QComboBox
 from PySide6.QtWidgets import QSpinBox, QSpacerItem, QProgressBar, QSplitter, QGridLayout, QFileDialog, QGridLayout
 
@@ -28,15 +28,11 @@ class FixTracks(QWidget):
         self._threadpool = QThreadPool()
         self._reader = None
         self._image = None
-        self._clear_detections = True
         self._currentWindowPos = 0  # in frames
         self._currentWindowWidth = 0  # in frames
         self._maxframes = 0
         self._data = TrackingData()
-        self._brushes = {"assigned_left": QBrush(QColor.fromString("orange")),
-                         "assigned_right": QBrush(QColor.fromString("green")),
-                         "unassigned": QBrush(QColor.fromString("red"))
-        }
+
         self._detectionView = DetectionView()
         self._detectionView.signals.itemsSelected.connect(self.on_detectionsSelected)
         self._skeleton = SkeletonWidget()
@@ -60,7 +56,7 @@ class FixTracks(QWidget):
         self._windowspinner.setSingleStep(50)
         self._windowspinner.setValue(500)
         self._windowspinner.valueChanged.connect(self.on_windowSizeChanged)
-        # self._timeline.setWindowWidth(0.01)
+
         self._keypointcombo = QComboBox()
         self._keypointcombo.currentIndexChanged.connect(self.on_keypointSelected)
 
@@ -143,6 +139,7 @@ class FixTracks(QWidget):
     def on_autoClassify(self, tracks):
         self._data.setSelectionRange("index", 0, self._data.numDetections)
         self._data.assignTracks(tracks)
+        self._timeline.update()
         self.update()
 
     def on_dataSelection(self):
@@ -162,41 +159,15 @@ class FixTracks(QWidget):
         self._detectionView.setImage(img)
 
     def update(self):
-        def update_detectionView(df, name):
-            if len(df) == 0:
-                return
-            keypoint = self._keypointcombo.currentIndex()
-            coords = np.stack(df["keypoints"].values).astype(np.float32)[:, :,:]
-            tracks = df["track"].values.astype(int)
-            ids = df.index.values.astype(int)
-            frames = df["frame"].values.astype(int)
-            self._detectionView.addDetections(coords, tracks, ids, frames, keypoint, self._brushes[name])
-        
-        self._timeline.setData(self._data)
         start_frame = self._currentWindowPos
         stop_frame = start_frame + self._currentWindowWidth
-        self._controls_widget.setWindow(start_frame, stop_frame)
         logging.debug("Tracks:update: Updating View for detection range %i, %i frames", start_frame, stop_frame)
         self._data.setSelectionRange("frame", start_frame, stop_frame)
-        frames = self._data.selectedData("frame")
-        tracks = self._data.selectedData("track")
-        keypoints = self._data.selectedData("keypoints")
-        index = self._data.selectedData("index")
-
-        df = pd.DataFrame({"frame": frames,
-                           "track": tracks,
-                           "keypoints": keypoints},
-                            index=index)
-        assigned_left = df[(df.track == self.trackone_id)]
-        assigned_right = df[(df.track == self.tracktwo_id)]
-        unassigned = df[(df.track != self.trackone_id) & (df.track != self.tracktwo_id)]
-
-        if self._clear_detections:
-            self._detectionView.clearDetections()
-        update_detectionView(unassigned, "unassigned")
-        update_detectionView(assigned_left, "assigned_left")
-        update_detectionView(assigned_right, "assigned_right")
-        self._classifier.setData(self._data)
+
+        self._controls_widget.setWindow(start_frame, stop_frame)
+        kp = self._keypointcombo.currentText().lower()
+        kpi = -1 if "center" in kp else int(kp)
+        self._detectionView.updateDetections(kpi)
 
     @property
     def fileList(self):
@@ -223,6 +194,7 @@ class FixTracks(QWidget):
 
     def populateKeypointCombo(self, num_keypoints):
         self._keypointcombo.clear()
+        self._keypointcombo.addItem("Center")
         for i in range(num_keypoints):
             self._keypointcombo.addItem(str(i))
         self._keypointcombo.setCurrentIndex(0)
@@ -241,12 +213,10 @@ class FixTracks(QWidget):
             self._timeline.setData(self._data)
             self._timeline.setWindow(self._currentWindowPos / self._maxframes,
                                      self._currentWindowWidth / self._maxframes)
-            coordinates = self._data.coordinates()
-            positions = self._data.centerOfGravity()
-            self._classifier.size_classifier.setCoordinates(coordinates)
-            self._classifier.consistency_tracker.setData(self._data)
+            self._detectionView.setData(self._data)
+            self._classifier.setData(self._data)
             self.update()
-            logging.info("Finished loading data: %i frames, %i detections", self._maxframes, len(positions))
+            logging.info("Finished loading data: %i frames", self._maxframes)
 
     def on_keypointSelected(self):
         self.update()
@@ -278,36 +248,43 @@ class FixTracks(QWidget):
     def on_assignOne(self):
         logging.debug("Assigning user selection to track One")
         self._data.assignUserSelection(self.trackone_id)
+        self._timeline.update()
         self.update()
 
     def on_assignTwo(self):
         logging.debug("Assigning user selection to track Two")
         self._data.assignUserSelection(self.tracktwo_id)
+        self._timeline.update()
         self.update()
 
     def on_assignOther(self):
         logging.debug("Assigning user selection to track Other")
         self._data.assignUserSelection(self.trackother_id, False)
+        self._timeline.update()
         self.update()
 
     def on_setUserFlag(self):
         self._data.setAssignmentStatus(True)
+        self._timeline.update()
         self.update()
 
     def on_unsetUserFlag(self):
         logging.debug("Tracks:unsetUserFlag")
         self._data.setAssignmentStatus(False)
+        self._timeline.update()
         self.update()
 
     def on_revertUserFlags(self):
         logging.debug("Tracks:revert ALL UserFlags and track assignments")
         self._data.revertAssignmentStatus()
         self._data.revertTrackAssignments()
+        self._timeline.update()
         self.update()
 
     def on_deleteDetection(self):
-        logging.debug("Tracks:delete detections")
+        logging.warning("Tracks:delete detections is currently not supported!")
         # self._data.deleteDetections()
+        self._timeline.update()
         self.update()
 
     def on_windowChanged(self):
@@ -350,7 +327,6 @@ class FixTracks(QWidget):
         self.update()
 
     def moveWindow(self, stepsize):
-        self._clear_detections = True
         step = np.round(stepsize * (self._currentWindowWidth))
         new_start_frame = self._currentWindowPos + step
         self._timeline.setWindowPos(new_start_frame / self._maxframes)