From ae277ce8fbb9f083399611439b1b89fb42f66359 Mon Sep 17 00:00:00 2001
From: Jan Grewe <jan.grewe@g-node.org>
Date: Wed, 31 Aug 2022 16:56:26 +0200
Subject: [PATCH] arena abstraction

---
 etrack/arena.py           | 217 ++++++++++++++++++++++++++++++++++++++
 etrack/tracking_result.py |  12 +--
 etrack/util.py            |  22 ++++
 3 files changed, 245 insertions(+), 6 deletions(-)
 create mode 100644 etrack/arena.py
 create mode 100644 etrack/util.py

diff --git a/etrack/arena.py b/etrack/arena.py
new file mode 100644
index 0000000..ef1ae25
--- /dev/null
+++ b/etrack/arena.py
@@ -0,0 +1,217 @@
+import numpy as np
+import matplotlib.pyplot as plt
+import matplotlib.patches as patches
+
+from skimage.draw import disk
+
+from IPython import embed
+
+from util import RegionShape, AnalysisType, Illumination
+
+
+class Region(object):
+
+    def __init__(self, origin, extent, inverted_y=True, name="", region_shape=RegionShape.Rectangular, parent=None) -> None:
+        assert len(origin) == 2
+        self._origin = origin
+        self._extent = extent
+        self._inverted_y = inverted_y
+        self._name = name
+        self._shape_type = region_shape
+        self._check_extent(extent)
+        self._parent = parent
+
+    @staticmethod
+    def circular_mask(width, height, center, radius):
+        assert center[1] + radius < width and center[1] - radius > 0
+        assert center[0] + radius < height and center[0] - radius > 0
+
+        mask = np.zeros((height, width), dtype=np.uint8)
+        rr, cc = disk(reversed(center), radius)
+        mask[rr, cc] = 1
+
+        return mask
+
+    @property
+    def _max_extent(self):
+        if self._shape_type == RegionShape.Rectangular:
+            max_extent = (self._origin[0] + self._extent[0], self._origin[1] + self._extent[1])
+        else:
+            max_extent = (self._origin[0] + self._extent, self._origin[1] + self._extent)
+        return max_extent
+
+    @property
+    def _min_extent(self):
+        if self._shape_type == RegionShape.Rectangular:
+            min_extent = self._origin
+        else:
+            min_extent = (self._origin[0] - self._extent, self._origin[1] - self._extent)
+        return min_extent
+
+    @property
+    def position(self):
+        """Returns the position and extent of the region as 4-tuple, (x, y, width, height)
+        """
+        x = self._min_extent[0]
+        y = self._min_extent[1]
+        width = self._max_extent[0] - self._min_extent[0]
+        height = self._max_extent[1] - self._min_extent[1]
+        return x, y, width, height
+
+    def _check_extent(self, ext):
+        """Checks whether the extent matches the shape. i.e. if the shape is Rectangular, extent must be a length 2 list, tuple, otherwise, if the region is circular, extent must be a single numerical value.
+
+        Parameters
+        ----------
+        ext : tuple, or numeric scalar
+        """
+        if self._shape_type == RegionShape.Rectangular:
+            if not isinstance(ext, (list, tuple, np.ndarray)) and len(ext) != 2:
+                raise ValueError("Extent must be a length 2 list or tuple for rectangular regions!")
+        elif self._shape_type == RegionShape.Circular:
+            if not isinstance(ext, (int, float)):
+                raise ValueError("Extent must be a numerical scalar for circular regions!")
+        else:
+            raise ValueError(f"Invalid ShapeType, {self._shape_type}!")
+
+    def fits(self, other) -> bool:
+        """
+            Returns true if the other region fits inside this region!
+        """
+        assert isinstance(other, Region)
+        does_fit = all((other._min_extent[0] >= self._min_extent[0], other._min_extent[1] >= self._min_extent[1], 
+                        other._max_extent[0] <= self._max_extent[0], other._max_extent[1] <= self._max_extent[1]))
+        return does_fit
+
+    @property
+    def is_child(self):
+        return self._parent is not None
+
+    def points_in_region(self, x, y, analysis_type=AnalysisType.Full):
+        """returns the indices of the points specified by 'x' and 'y' that fall into this region.
+
+        Parameters
+        ----------
+        x : np.ndarray
+            the x positions
+        y : np.ndarray
+            the y positions
+        analysis_type : AnalysisType, optional
+            defines how the positions are evaluated, by default AnalysisType.Full
+            FIXME: some of this can probably be solved using linear algebra, what with multiple exact same points?
+        """
+        if self._shape_type == RegionShape.Rectangular or (self._shape_type == RegionShape.Circular and analysis_type != AnalysisType.Full):
+            x_indices = np.where((x >= self._min_extent[0]) & (x <= self._max_extent[0] ))[0]
+            y_indices = np.where((y >= self._min_extent[1]) & (y <= self._max_extent[1] ))[0]
+            if analysis_type == AnalysisType.Full:
+                indices = np.array(list(set(x_indices).intersection(set(y_indices))), dtype=int)
+            elif analysis_type == AnalysisType.CollapseX:
+                indices = np.asarray(x_indices, dtype=int)
+            else:
+                indices = np.asarray(y_indices, dtype=int)
+        else:
+            if self.is_child:
+                mask = self.circular_mask(self._parent.position[2], self._parent.position[3], self._origin, self._extent)
+            else:
+                mask = self.circular_mask(self.position[2], self.position[3], self._origin, self._extent)
+            img = np.zeros_like(mask)
+            img[np.asarray(y, dtype=int), np.asarray(x, dtype=int)] = 1
+            temp = np.where(img & mask)
+            indices = []
+            for i, j in zip(list(temp[1]), list(temp[0])):
+                matches = np.where((x == i) & (y == j))
+                if len(matches[0]) == 0:
+                    continue
+                indices.append(matches[0][0])
+            indices = np.array(indices)
+        return indices
+
+
+    def patch(self, **kwargs):
+        if "fc" not in kwargs:
+            kwargs["fc"] = None
+            kwargs["fill"] = False
+        if self._shape_type == RegionShape.Rectangular:
+            w = self.position[2]
+            h = self.position[3]
+            return patches.Rectangle(self._origin, w, h, **kwargs)
+        else:
+            return patches.Circle(self._origin, self._extent, **kwargs)
+
+    def __repr__(self):
+        return f"Region: '{self._name}' of {self._shape_type} shape."
+
+
+class Arena(Region):
+
+    def __init__(self, origin, extent, inverted_y=True, name="", arena_shape=RegionShape.Rectangular,
+                 illumination=Illumination.Backlight) -> None:
+        super().__init__(origin, extent, inverted_y, name, arena_shape)
+        self._illumination = illumination
+        self.regions = []
+
+    def add_region(self, origin, extent, name="", shape_type=RegionShape.Rectangular, region=None):
+        if region is None:
+            region = Region(origin, extent, name=name, region_shape=shape_type, parent=self)
+        else:
+            region._parent = self
+        if self.fits(region):
+            self.regions.append(region)
+        else:
+            raise Warning(f"Region {region} fits not! Not added to the list of regions!")
+
+    def __repr__(self):
+        return f"Arena: '{self._name}' of {self._shape_type} shape."
+
+    def plot(self, axis=None):
+        if axis is None:
+            fig = plt.figure()
+            axis = fig.add_subplot(111)
+        axis.add_patch(self.patch())
+        axis.set_xlim([self._origin[0], self._max_extent[0]])
+        axis.set_ylim([self._origin[1], self._max_extent[1]])
+        for r in self.regions:
+            axis.add_patch(r.patch())
+        return axis
+
+if __name__ == "__main__":
+    a = Arena((0, 0), (1024, 768), name="arena", arena_shape=RegionShape.Rectangular)
+    a.add_region((0, 0), (100, 300), name="small rect1")
+    a.add_region((150, 0), (100, 300), name="small rect2")
+    a.add_region((300, 0), (100, 300), name="small rect3")
+    a.add_region((600, 400), 150, name="circ", shape_type=RegionShape.Circular)
+    axis = a.plot()
+    x = np.linspace(a.position[0], a.position[0] + a.position[2] - 1, 100, dtype=int)
+    y = np.asarray((np.sin(x*0.01) + 1) * a.position[3] / 2 + a.position[1] -1, dtype=int)
+    #y = np.linspace(a.position[1], a.position[1] + a.position[3] - 1, 100, dtype=int)
+    axis.scatter(x, y, c="k", s=2)
+
+    ind = a.regions[3].points_in_region(x, y)
+    if len(ind) > 0:
+        axis.scatter(x[ind], y[ind], label="circ full")
+
+    ind = a.regions[3].points_in_region(x, y, AnalysisType.CollapseX)
+    if len(ind) > 0:
+        axis.scatter(x[ind], y[ind] - 10, label="circ collapseX")
+
+    ind = a.regions[3].points_in_region(x, y, AnalysisType.CollapseY)   
+    if len(ind) > 0:
+        axis.scatter(x[ind], y[ind] + 10, label="circ collapseY")
+
+
+    ind = a.regions[0].points_in_region(x, y, AnalysisType.CollapseX)
+    if len(ind) > 0:
+        axis.scatter(x[ind], y[ind]-10, label="rect collapseX")
+
+    ind = a.regions[1].points_in_region(x, y, AnalysisType.CollapseY)
+    if len(ind) > 0:
+        axis.scatter(x[ind], y[ind] + 10, label="rect collapseY")
+
+    ind = a.regions[2].points_in_region(x, y, AnalysisType.Full)
+    if len(ind) > 0:
+        axis.scatter(x[ind], y[ind]+20, label="rect full")
+    axis.legend()
+    plt.show()
+
+    a.plot()
+    plt.show()
\ No newline at end of file
diff --git a/etrack/tracking_result.py b/etrack/tracking_result.py
index aab0128..30b4b2b 100644
--- a/etrack/tracking_result.py
+++ b/etrack/tracking_result.py
@@ -185,13 +185,13 @@ if __name__ == '__main__':
     x3 = np.interp(time, time2, x2)
     y3 = np.interp(time, time2, y2)
 
-
-    fig, axes = plt.subplots(3,1, sharex=True)                                                                                                                                                  
-    axes[0].plot(time, x)    
-    axes[0].plot(time, x3)                                                                                                                                                                   
+    fig, axes = plt.subplots(3,1, sharex=True)
+    axes[0].plot(time, x)
+    axes[0].plot(time, x3)
     axes[1].plot(time, y)
-    axes[1].plot(time, y3)                                                                                                                                                                       
-    axes[2].plot(time, l)                            
+    axes[1].plot(time, y3)
+
+    axes[2].plot(time, l)
     plt.show()
 
     embed()
\ No newline at end of file
diff --git a/etrack/util.py b/etrack/util.py
new file mode 100644
index 0000000..d4a29ba
--- /dev/null
+++ b/etrack/util.py
@@ -0,0 +1,22 @@
+from enum import Enum
+
+class Illumination(Enum):
+    Backlight = 0
+    Incident = 1
+
+
+class RegionShape(Enum):
+    Circular = 0
+    Rectangular = 1
+
+    def __str__(self) -> str:
+        return self.name
+
+
+class AnalysisType(Enum):
+    Full = 0
+    CollapseX = 1
+    CollapseY = 2
+
+    def __str__(self) -> str:
+        return self.name