From 469a35724d4c89280ee5d2ee9b7d6cd781a1a1ba Mon Sep 17 00:00:00 2001
From: Jan Grewe <jan.grewe@g-node.org>
Date: Fri, 10 Feb 2023 18:45:57 +0100
Subject: [PATCH] logging

---
 etrack/arena.py         | 159 +++++++++++++++++++++++++++++++---------
 etrack/tracking_data.py |  64 +++++++++++-----
 2 files changed, 170 insertions(+), 53 deletions(-)

diff --git a/etrack/arena.py b/etrack/arena.py
index 52061af..86b7233 100644
--- a/etrack/arena.py
+++ b/etrack/arena.py
@@ -1,3 +1,4 @@
+import logging
 import numpy as np
 import matplotlib.pyplot as plt
 import matplotlib.patches as patches
@@ -7,9 +8,20 @@ from skimage.draw import disk
 from .util import RegionShape, AnalysisType, Illumination
 from IPython import embed
 
-class Region(object):
 
-    def __init__(self, origin, extent, inverted_y=True, name="", region_shape=RegionShape.Rectangular, parent=None) -> None:
+class Region(object):
+    def __init__(
+        self,
+        origin,
+        extent,
+        inverted_y=True,
+        name="",
+        region_shape=RegionShape.Rectangular,
+        parent=None,
+    ) -> None:
+        logging.debug(
+            f"etrack.Region: Create {str(region_shape)} region {name} with props origin {origin}, extent {extent} and parent {parent}"
+        )
         assert len(origin) == 2
         self._origin = origin
         self._extent = extent
@@ -41,9 +53,15 @@ class Region(object):
     @property
     def _max_extent(self):
         if self._shape_type == RegionShape.Rectangular:
-            max_extent = (self._origin[0] + self._extent[0], self._origin[1] + self._extent[1])
+            max_extent = (
+                self._origin[0] + self._extent[0],
+                self._origin[1] + self._extent[1],
+            )
         else:
-            max_extent = (self._origin[0] + self._extent, self._origin[1] + self._extent)
+            max_extent = (
+                self._origin[0] + self._extent,
+                self._origin[1] + self._extent,
+            )
         return max_extent
 
     @property
@@ -51,13 +69,31 @@ class Region(object):
         if self._shape_type == RegionShape.Rectangular:
             min_extent = self._origin
         else:
-            min_extent = (self._origin[0] - self._extent, self._origin[1] - self._extent)
+            min_extent = (
+                self._origin[0] - self._extent,
+                self._origin[1] - self._extent,
+            )
         return min_extent
 
+    @property
+    def xmax(self):
+        return self._max_extent[0]
+
+    @property
+    def xmin(self):
+        return self._min_extent[0]
+
+    @property
+    def ymin(self):
+        return self._min_extent[1]
+
+    @property
+    def ymax(self):
+        return self._max_extent[1]
+
     @property
     def position(self):
-        """Returns the position and extent of the region as 4-tuple, (x, y, width, height)
-        """
+        """Returns the position and extent of the region as 4-tuple, (x, y, width, height)"""
         x = self._min_extent[0]
         y = self._min_extent[1]
         width = self._max_extent[0] - self._min_extent[0]
@@ -73,20 +109,39 @@ class Region(object):
         """
         if self._shape_type == RegionShape.Rectangular:
             if not isinstance(ext, (list, tuple, np.ndarray)) and len(ext) != 2:
-                raise ValueError("Extent must be a length 2 list or tuple for rectangular regions!")
+                raise ValueError(
+                    "Extent must be a length 2 list or tuple for rectangular regions!"
+                )
         elif self._shape_type == RegionShape.Circular:
             if not isinstance(ext, (int, float)):
-                raise ValueError("Extent must be a numerical scalar for circular regions!")
+                raise ValueError(
+                    "Extent must be a numerical scalar for circular regions!"
+                )
         else:
             raise ValueError(f"Invalid ShapeType, {self._shape_type}!")
 
     def fits(self, other) -> bool:
         """
-            Returns true if the other region fits inside this region!
+        Returns true if the other region fits inside this region!
         """
         assert isinstance(other, Region)
-        does_fit = all((other._min_extent[0] >= self._min_extent[0], other._min_extent[1] >= self._min_extent[1], 
-                        other._max_extent[0] <= self._max_extent[0], other._max_extent[1] <= self._max_extent[1]))
+        does_fit = all(
+            (
+                other._min_extent[0] >= self._min_extent[0],
+                other._min_extent[1] >= self._min_extent[1],
+                other._max_extent[0] <= self._max_extent[0],
+                other._max_extent[1] <= self._max_extent[1],
+            )
+        )
+        if not does_fit:
+            m = (
+                f"Region {other.name} does not fit into {self.name}. "
+                f"min x: {other._min_extent[0] >= self._min_extent[0]},",
+                f"min y: {other._min_extent[1] >= self._min_extent[1]},",
+                f"max x: {other._max_extent[0] <= self._max_extent[0]},",
+                f"max y: {other._max_extent[1] <= self._max_extent[1]}",
+            )
+            logging.debug(m)
         return does_fit
 
     @property
@@ -106,22 +161,38 @@ class Region(object):
             defines how the positions are evaluated, by default AnalysisType.Full
             FIXME: some of this can probably be solved using linear algebra, what with multiple exact same points?
         """
-        if self._shape_type == RegionShape.Rectangular or (self._shape_type == RegionShape.Circular and analysis_type != AnalysisType.Full):
+        if self._shape_type == RegionShape.Rectangular or (
+            self._shape_type == RegionShape.Circular
+            and analysis_type != AnalysisType.Full
+        ):
             if analysis_type == AnalysisType.Full:
-                indices = np.where(((y >= self._min_extent[1]) & (y <= self._max_extent[1])) & 
-                                   ((x >= self._min_extent[0]) & (x <= self._max_extent[0])))[0]
+                indices = np.where(
+                    ((y >= self._min_extent[1]) & (y <= self._max_extent[1]))
+                    & ((x >= self._min_extent[0]) & (x <= self._max_extent[0]))
+                )[0]
                 indices = np.array(indices, dtype=int)
             elif analysis_type == AnalysisType.CollapseX:
-                x_indices = np.where((x >= self._min_extent[0]) & (x <= self._max_extent[0] ))[0]
+                x_indices = np.where(
+                    (x >= self._min_extent[0]) & (x <= self._max_extent[0])
+                )[0]
                 indices = np.asarray(x_indices, dtype=int)
             else:
-                y_indices = np.where((y >= self._min_extent[1]) & (y <= self._max_extent[1] ))[0]
+                y_indices = np.where(
+                    (y >= self._min_extent[1]) & (y <= self._max_extent[1])
+                )[0]
                 indices = np.asarray(y_indices, dtype=int)
         else:
             if self.is_child:
-                mask = self.circular_mask(self._parent.position[2], self._parent.position[3], self._origin, self._extent)
+                mask = self.circular_mask(
+                    self._parent.position[2],
+                    self._parent.position[3],
+                    self._origin,
+                    self._extent,
+                )
             else:
-                mask = self.circular_mask(self.position[2], self.position[3], self._origin, self._extent)
+                mask = self.circular_mask(
+                    self.position[2], self.position[3], self._origin, self._extent
+                )
             img = np.zeros_like(mask)
             img[np.asarray(y, dtype=int), np.asarray(x, dtype=int)] = 1
             temp = np.where(img & mask)
@@ -156,7 +227,7 @@ class Region(object):
         indices = self.points_in_region(x, y, analysis_type)
         if len(indices) == 0:
             return np.array([]), np.array([])
-  
+
         diffs = np.diff(indices)
         if len(diffs) == sum(diffs):
             entering = [time[indices[0]]]
@@ -164,7 +235,7 @@ class Region(object):
         else:
             entering = []
             leaving = []
-            jumps  = np.where(diffs > 1)[0]
+            jumps = np.where(diffs > 1)[0]
             start = time[indices[0]]
             for i in range(len(jumps)):
                 end = time[indices[jumps[i]]]
@@ -193,22 +264,37 @@ class Region(object):
 
 
 class Arena(Region):
-
-    def __init__(self, origin, extent, inverted_y=True, name="", arena_shape=RegionShape.Rectangular,
-                 illumination=Illumination.Backlight) -> None:
+    def __init__(
+        self,
+        origin,
+        extent,
+        inverted_y=True,
+        name="",
+        arena_shape=RegionShape.Rectangular,
+        illumination=Illumination.Backlight,
+    ) -> None:
         super().__init__(origin, extent, inverted_y, name, arena_shape)
         self._illumination = illumination
         self.regions = {}
 
-    def add_region(self, name, origin, extent, shape_type=RegionShape.Rectangular, region=None):
+    def add_region(
+        self, name, origin, extent, shape_type=RegionShape.Rectangular, region=None
+    ):
         if name is None or name in self.regions.keys():
-            raise ValueError("Region name '{name}' is invalid. The name must not be None and must be unique among the regions.") 
+            raise ValueError(
+                "Region name '{name}' is invalid. The name must not be None and must be unique among the regions."
+            )
         if region is None:
-            region = Region(origin, extent, name=name, region_shape=shape_type, parent=self)
+            region = Region(
+                origin, extent, name=name, region_shape=shape_type, parent=self
+            )
         else:
             region._parent = self
-        if ~self.fits(region):
-            print(f"Warning! Region {region.name} with size {region.position} does fit into {self.name} with size {self.position}!")
+        doesfit = self.fits(region)
+        if not doesfit:
+            logging.warn(
+                f"Warning! Region {region.name} with size {region.position} does fit into {self.name} with size {self.position}!"
+            )
         self.regions[name] = region
 
     def remove_region(self, name):
@@ -276,8 +362,10 @@ if __name__ == "__main__":
     a.add_region("circ", (600, 400), 150, shape_type=RegionShape.Circular)
     axis = a.plot()
     x = np.linspace(a.position[0], a.position[0] + a.position[2] - 1, 100, dtype=int)
-    y = np.asarray((np.sin(x*0.01) + 1) * a.position[3] / 2 + a.position[1] -1, dtype=int)
-    #y = np.linspace(a.position[1], a.position[1] + a.position[3] - 1, 100, dtype=int)
+    y = np.asarray(
+        (np.sin(x * 0.01) + 1) * a.position[3] / 2 + a.position[1] - 1, dtype=int
+    )
+    # y = np.linspace(a.position[1], a.position[1] + a.position[3] - 1, 100, dtype=int)
     axis.scatter(x, y, c="k", s=2)
 
     ind = a.regions[3].points_in_region(x, y)
@@ -288,14 +376,13 @@ if __name__ == "__main__":
     if len(ind) > 0:
         axis.scatter(x[ind], y[ind] - 10, label="circ collapseX")
 
-    ind = a.regions[3].points_in_region(x, y, AnalysisType.CollapseY)   
+    ind = a.regions[3].points_in_region(x, y, AnalysisType.CollapseY)
     if len(ind) > 0:
         axis.scatter(x[ind], y[ind] + 10, label="circ collapseY")
 
-
     ind = a.regions[0].points_in_region(x, y, AnalysisType.CollapseX)
     if len(ind) > 0:
-        axis.scatter(x[ind], y[ind]-10, label="rect collapseX")
+        axis.scatter(x[ind], y[ind] - 10, label="rect collapseX")
 
     ind = a.regions[1].points_in_region(x, y, AnalysisType.CollapseY)
     if len(ind) > 0:
@@ -303,9 +390,9 @@ if __name__ == "__main__":
 
     ind = a.regions[2].points_in_region(x, y, AnalysisType.Full)
     if len(ind) > 0:
-        axis.scatter(x[ind], y[ind]+20, label="rect full")
+        axis.scatter(x[ind], y[ind] + 20, label="rect full")
     axis.legend()
     plt.show()
 
     a.plot()
-    plt.show()
\ No newline at end of file
+    plt.show()
diff --git a/etrack/tracking_data.py b/etrack/tracking_data.py
index abcd783..a945975 100644
--- a/etrack/tracking_data.py
+++ b/etrack/tracking_data.py
@@ -8,9 +8,21 @@ class TrackingData(object):
     Using the 'quality_threshold', 'temporal_limits', or the 'position_limits' data can be filtered (see filter_tracks function).
     The 'interpolate' function allows to fill up the gaps that be result from filtering with linearly interpolated data points.
 
-    More may follow... 
+    More may follow...
     """
-    def __init__(self, x, y, time, quality, node="", fps=None, quality_threshold=None, temporal_limits=None, position_limits=None) -> None:
+
+    def __init__(
+        self,
+        x,
+        y,
+        time,
+        quality,
+        node="",
+        fps=None,
+        quality_threshold=None,
+        temporal_limits=None,
+        position_limits=None,
+    ) -> None:
         self._orgx = x
         self._orgy = y
         self._orgtime = time
@@ -35,11 +47,13 @@ class TrackingData(object):
 
     def interpolate(self, start_time=None, end_time=None, min_count=5):
         if len(self._x) < min_count:
-            print(f"{self._node} data has less than {min_count} data points with sufficient quality ({len(self._x)})!")
+            print(
+                f"{self._node} data has less than {min_count} data points with sufficient quality ({len(self._x)})!"
+            )
             return None, None, None
         start = self._time[0] if start_time is None else start_time
         end = self._time[-1] if end_time is None else end_time
-        time = np.arange(start, end, 1./self._fps)
+        time = np.arange(start, end, 1.0 / self._fps)
         x = np.interp(time, self._time, self._x)
         y = np.interp(time, self._time, self._y)
 
@@ -56,7 +70,7 @@ class TrackingData(object):
         Parameters
         ----------
         new_threshold : float
-            
+
         """
         self._threshold = new_threshold
 
@@ -77,8 +91,12 @@ class TrackingData(object):
         ------
         ValueError, if new_value is not a 4-tuple
         """
-        if new_limits is not None and not (isinstance(new_limits, (tuple, list)) and len(new_limits) == 4):
-            raise ValueError(f"The new_limits vector must be a 4-tuple of the form (x, y, width, height)")
+        if new_limits is not None and not (
+            isinstance(new_limits, (tuple, list)) and len(new_limits) == 4
+        ):
+            raise ValueError(
+                f"The new_limits vector must be a 4-tuple of the form (x, y, width, height)"
+            )
         self._position_limits = new_limits
 
     @property
@@ -94,8 +112,12 @@ class TrackingData(object):
         new_limits : 2-tuple
             The new limits in the form (start, end) given in seconds.
         """
-        if new_limits is not None and not (isinstance(new_limits, (tuple, list)) and len(new_limits) == 2):
-            raise ValueError(f"The new_limits vector must be a 2-tuple of the form (start, end). ")
+        if new_limits is not None and not (
+            isinstance(new_limits, (tuple, list)) and len(new_limits) == 2
+        ):
+            raise ValueError(
+                f"The new_limits vector must be a 2-tuple of the form (start, end). "
+            )
         self._time_limits = new_limits
 
     def filter_tracks(self, align_time=True):
@@ -115,16 +137,22 @@ class TrackingData(object):
         if self.position_limits is not None:
             x_max = self.position_limits[0] + self.position_limits[2]
             y_max = self.position_limits[1] + self.position_limits[3]
-            indices = np.where((self._x >= self.position_limits[0]) & (self._x < x_max) &
-                               (self._y >= self.position_limits[1]) & (self._y < y_max))
+            indices = np.where(
+                (self._x >= self.position_limits[0])
+                & (self._x < x_max)
+                & (self._y >= self.position_limits[1])
+                & (self._y < y_max)
+            )
             self._x = self._x[indices]
             self._y = self._y[indices]
             self._time = self._time[indices] - self._time[0] if align_time else 0.0
             self._quality = self._quality[indices]
 
         if self.temporal_limits is not None:
-            indices = np.where((self._time >= self.temporal_limits[0]) &
-                               (self._time < self.temporal_limits[1]))
+            indices = np.where(
+                (self._time >= self.temporal_limits[0])
+                & (self._time < self.temporal_limits[1])
+            )
             self._x = self._x[indices]
             self._y = self._y[indices]
             self._time = self._time[indices]
@@ -138,7 +166,7 @@ class TrackingData(object):
             self._quality = self._quality[indices]
 
     def positions(self):
-        """Returns the filtered data (if filters have been applied). 
+        """Returns the filtered data (if filters have been applied).
 
         Returns
         -------
@@ -154,7 +182,7 @@ class TrackingData(object):
         return self._x, self._y, self._time, self._quality
 
     def speed(self):
-        """ Returns the agent's speed as a function of time and position. The speed estimation is associated to the time/position between two sample points.
+        """Returns the agent's speed as a function of time and position. The speed estimation is associated to the time/position between two sample points.
 
         Returns
         -------
@@ -165,7 +193,9 @@ class TrackingData(object):
         tuple of np.ndarray
             The position
         """
-        speed = np.sqrt(np.diff(self._x)**2 + np.diff(self._y)**2) / np.diff(self._time)
+        speed = np.sqrt(np.diff(self._x) ** 2 + np.diff(self._y) ** 2) / np.diff(
+            self._time
+        )
         t = self._time[:-1] + np.diff(self._time) / 2
         x = self._x[:-1] + np.diff(self._x) / 2
         y = self._y[:-1] + np.diff(self._y) / 2
@@ -174,4 +204,4 @@ class TrackingData(object):
 
     def __repr__(self) -> str:
         s = f"Tracking data of node '{self._node}'!"
-        return s
\ No newline at end of file
+        return s