From 3e1cbe4b9bf2bdf785b9062f86d16d483f1c38da Mon Sep 17 00:00:00 2001
From: Jan Grewe <jan.grewe@g-node.org>
Date: Mon, 12 Sep 2022 16:56:43 +0200
Subject: [PATCH] [trackingdata] add fps, some docs, change interpolate

---
 etrack/tracking_data.py | 33 ++++++++++++++++++++-------------
 1 file changed, 20 insertions(+), 13 deletions(-)

diff --git a/etrack/tracking_data.py b/etrack/tracking_data.py
index c231f1a..abcd783 100644
--- a/etrack/tracking_data.py
+++ b/etrack/tracking_data.py
@@ -23,6 +23,7 @@ class TrackingData(object):
         self._threshold = quality_threshold
         self._position_limits = position_limits
         self._time_limits = temporal_limits
+        self._fps = fps
 
     @property
     def original_positions(self):
@@ -32,16 +33,17 @@ class TrackingData(object):
     def original_quality(self):
         return self._orgquality
 
-    def interpolate(self, store=True, min_count=10):
+    def interpolate(self, start_time=None, end_time=None, min_count=5):
         if len(self._x) < min_count:
-            print(f"{self._node} data has less than {min_count} data points with sufficient quality!")
-            return None
-        x = np.interp(self._orgtime, self._time, self._x)
-        y = np.interp(self._orgtime, self._time, self._y)
-        if store:
-            self._x = x
-            self._y = y
-            self._time = self._orgtime.copy()
+            print(f"{self._node} data has less than {min_count} data points with sufficient quality ({len(self._x)})!")
+            return None, None, None
+        start = self._time[0] if start_time is None else start_time
+        end = self._time[-1] if end_time is None else end_time
+        time = np.arange(start, end, 1./self._fps)
+        x = np.interp(time, self._time, self._x)
+        y = np.interp(time, self._time, self._y)
+
+        return x, y, time
 
     @property
     def quality_threshold(self):
@@ -96,9 +98,14 @@ class TrackingData(object):
             raise ValueError(f"The new_limits vector must be a 2-tuple of the form (start, end). ")
         self._time_limits = new_limits
 
-    def filter_tracks(self):
-        """Applies the filters to the tracking data. All filters will be applied squentially, i.e. an AND connection.
+    def filter_tracks(self, align_time=True):
+        """Applies the filters to the tracking data. All filters will be applied sequentially, i.e. an AND connection.
         To change the filter settings use the setters for 'quality_threshold', 'temporal_limits', 'position_limits'. Setting them to None disables the respective filter discarding the setting.
+
+        Parameters
+        ----------
+        align_time: bool
+            Controls whether the time vector is aligned to the first time point at which the agent is within the positional_limits. Default = True
         """
         self._x = self._orgx.copy()
         self._y = self._orgy.copy()
@@ -112,9 +119,9 @@ class TrackingData(object):
                                (self._y >= self.position_limits[1]) & (self._y < y_max))
             self._x = self._x[indices]
             self._y = self._y[indices]
-            self._time = self._time[indices]
+            self._time = self._time[indices] - self._time[0] if align_time else 0.0
             self._quality = self._quality[indices]
-        
+
         if self.temporal_limits is not None:
             indices = np.where((self._time >= self.temporal_limits[0]) &
                                (self._time < self.temporal_limits[1]))