From e33528392c5ba183614a003cd41e43aeda347763 Mon Sep 17 00:00:00 2001
From: Jan Grewe <jan.grewe@g-node.org>
Date: Fri, 21 Feb 2025 16:20:47 +0100
Subject: [PATCH] [classifier] separate setting of data and refresh

---
 fixtracks/widgets/classifier.py | 25 +++++++++++++++++++------
 1 file changed, 19 insertions(+), 6 deletions(-)

diff --git a/fixtracks/widgets/classifier.py b/fixtracks/widgets/classifier.py
index bb4e7f1..edee355 100644
--- a/fixtracks/widgets/classifier.py
+++ b/fixtracks/widgets/classifier.py
@@ -34,6 +34,9 @@ class ConsitencyDataLoader(QRunnable):
 
     @Slot()
     def run(self):
+        if self.data is None:
+            logging.error("ConsistencyTracker.DataLoader failed. No Data!")
+            return
         self.positions = self.data.centerOfGravity()
         self.orientations = self.data.orientation()
         self.lengths = self.data.animalLength()
@@ -464,9 +467,6 @@ class ConsistencyClassifier(QWidget):
         self.setEnabled(False)
         self._progressbar.setRange(0,0)
         self._data = data
-        self._dataworker = ConsitencyDataLoader(self._data)
-        self._dataworker.signals.stopped.connect(self.data_processed)
-        self.threadpool.start(self._dataworker)
 
     @Slot()
     def data_processed(self):
@@ -482,6 +482,7 @@ class ConsistencyClassifier(QWidget):
             self._frames = self._dataworker.frames
             self._tracks = self._dataworker.tracks
             self._maxframes = np.max(self._frames)
+            # FIXME the following line causes an error when there are no detections in the range
             min_frame = max([self._frames[self._tracks == 1][0], self._frames[self._tracks == 2][0]]) + 1
             self._maxframeslabel.setText(str(self._maxframes))
             self._startframe_spinner.setMinimum(min_frame)
@@ -525,7 +526,9 @@ class ConsistencyClassifier(QWidget):
         self.start()
 
     def refresh(self):
-        self.setData(self._data)
+        self._dataworker = ConsitencyDataLoader(self._data)
+        self._dataworker.signals.stopped.connect(self.data_processed)
+        self.threadpool.start(self._dataworker)
 
     def worker_progress(self, progress, processed, errors):
         self._progressbar.setValue(progress)
@@ -556,7 +559,8 @@ class ClassifierWidget(QTabWidget):
         self._consistency_tracker = ConsistencyClassifier()
         self.addTab(self._size_classifier, SizeClassifier.name)
         self.addTab(self._consistency_tracker, ConsistencyClassifier.name)
-        self.tabBarClicked.connect(self.update)
+        # self.tabBarClicked.connect(self.update)
+        self.currentChanged.connect(self.tabChanged)
         self._size_classifier.apply.connect(self._on_applySizeClassifier)
         self._consistency_tracker.apply.connect(self._on_applyConsistencyTracker)
 
@@ -576,12 +580,21 @@ class ClassifierWidget(QTabWidget):
     def consistency_tracker(self):
         return self._consistency_tracker
 
+    @Slot()
+    def tabChanged(self):
+        if isinstance(self.currentWidget(), ConsistencyClassifier):
+            self.consistency_tracker.refresh()
+
     @Slot()
     def update(self):
-        self.consistency_tracker.setData(self._data)
+        if isinstance(self.currentWidget(), ConsistencyClassifier):
+            self.consistency_tracker.refresh()
 
     def setData(self, data:TrackingData):
         self._data = data
+        self.consistency_tracker.setData(data)
+        coordinates = self._data.coordinates()
+        self._size_classifier.setCoordinates(coordinates)
 
 def as_dict(df):
     d = {c: df[c].values for c in df.columns}