diff --git a/etrack/arena.py b/etrack/arena.py new file mode 100644 index 0000000..ef1ae25 --- /dev/null +++ b/etrack/arena.py @@ -0,0 +1,217 @@ +import numpy as np +import matplotlib.pyplot as plt +import matplotlib.patches as patches + +from skimage.draw import disk + +from IPython import embed + +from util import RegionShape, AnalysisType, Illumination + + +class Region(object): + + def __init__(self, origin, extent, inverted_y=True, name="", region_shape=RegionShape.Rectangular, parent=None) -> None: + assert len(origin) == 2 + self._origin = origin + self._extent = extent + self._inverted_y = inverted_y + self._name = name + self._shape_type = region_shape + self._check_extent(extent) + self._parent = parent + + @staticmethod + def circular_mask(width, height, center, radius): + assert center[1] + radius < width and center[1] - radius > 0 + assert center[0] + radius < height and center[0] - radius > 0 + + mask = np.zeros((height, width), dtype=np.uint8) + rr, cc = disk(reversed(center), radius) + mask[rr, cc] = 1 + + return mask + + @property + def _max_extent(self): + if self._shape_type == RegionShape.Rectangular: + max_extent = (self._origin[0] + self._extent[0], self._origin[1] + self._extent[1]) + else: + max_extent = (self._origin[0] + self._extent, self._origin[1] + self._extent) + return max_extent + + @property + def _min_extent(self): + if self._shape_type == RegionShape.Rectangular: + min_extent = self._origin + else: + min_extent = (self._origin[0] - self._extent, self._origin[1] - self._extent) + return min_extent + + @property + def position(self): + """Returns the position and extent of the region as 4-tuple, (x, y, width, height) + """ + x = self._min_extent[0] + y = self._min_extent[1] + width = self._max_extent[0] - self._min_extent[0] + height = self._max_extent[1] - self._min_extent[1] + return x, y, width, height + + def _check_extent(self, ext): + """Checks whether the extent matches the shape. i.e. if the shape is Rectangular, extent must be a length 2 list, tuple, otherwise, if the region is circular, extent must be a single numerical value. + + Parameters + ---------- + ext : tuple, or numeric scalar + """ + if self._shape_type == RegionShape.Rectangular: + if not isinstance(ext, (list, tuple, np.ndarray)) and len(ext) != 2: + raise ValueError("Extent must be a length 2 list or tuple for rectangular regions!") + elif self._shape_type == RegionShape.Circular: + if not isinstance(ext, (int, float)): + raise ValueError("Extent must be a numerical scalar for circular regions!") + else: + raise ValueError(f"Invalid ShapeType, {self._shape_type}!") + + def fits(self, other) -> bool: + """ + Returns true if the other region fits inside this region! + """ + assert isinstance(other, Region) + does_fit = all((other._min_extent[0] >= self._min_extent[0], other._min_extent[1] >= self._min_extent[1], + other._max_extent[0] <= self._max_extent[0], other._max_extent[1] <= self._max_extent[1])) + return does_fit + + @property + def is_child(self): + return self._parent is not None + + def points_in_region(self, x, y, analysis_type=AnalysisType.Full): + """returns the indices of the points specified by 'x' and 'y' that fall into this region. + + Parameters + ---------- + x : np.ndarray + the x positions + y : np.ndarray + the y positions + analysis_type : AnalysisType, optional + defines how the positions are evaluated, by default AnalysisType.Full + FIXME: some of this can probably be solved using linear algebra, what with multiple exact same points? + """ + if self._shape_type == RegionShape.Rectangular or (self._shape_type == RegionShape.Circular and analysis_type != AnalysisType.Full): + x_indices = np.where((x >= self._min_extent[0]) & (x <= self._max_extent[0] ))[0] + y_indices = np.where((y >= self._min_extent[1]) & (y <= self._max_extent[1] ))[0] + if analysis_type == AnalysisType.Full: + indices = np.array(list(set(x_indices).intersection(set(y_indices))), dtype=int) + elif analysis_type == AnalysisType.CollapseX: + indices = np.asarray(x_indices, dtype=int) + else: + indices = np.asarray(y_indices, dtype=int) + else: + if self.is_child: + mask = self.circular_mask(self._parent.position[2], self._parent.position[3], self._origin, self._extent) + else: + mask = self.circular_mask(self.position[2], self.position[3], self._origin, self._extent) + img = np.zeros_like(mask) + img[np.asarray(y, dtype=int), np.asarray(x, dtype=int)] = 1 + temp = np.where(img & mask) + indices = [] + for i, j in zip(list(temp[1]), list(temp[0])): + matches = np.where((x == i) & (y == j)) + if len(matches[0]) == 0: + continue + indices.append(matches[0][0]) + indices = np.array(indices) + return indices + + + def patch(self, **kwargs): + if "fc" not in kwargs: + kwargs["fc"] = None + kwargs["fill"] = False + if self._shape_type == RegionShape.Rectangular: + w = self.position[2] + h = self.position[3] + return patches.Rectangle(self._origin, w, h, **kwargs) + else: + return patches.Circle(self._origin, self._extent, **kwargs) + + def __repr__(self): + return f"Region: '{self._name}' of {self._shape_type} shape." + + +class Arena(Region): + + def __init__(self, origin, extent, inverted_y=True, name="", arena_shape=RegionShape.Rectangular, + illumination=Illumination.Backlight) -> None: + super().__init__(origin, extent, inverted_y, name, arena_shape) + self._illumination = illumination + self.regions = [] + + def add_region(self, origin, extent, name="", shape_type=RegionShape.Rectangular, region=None): + if region is None: + region = Region(origin, extent, name=name, region_shape=shape_type, parent=self) + else: + region._parent = self + if self.fits(region): + self.regions.append(region) + else: + raise Warning(f"Region {region} fits not! Not added to the list of regions!") + + def __repr__(self): + return f"Arena: '{self._name}' of {self._shape_type} shape." + + def plot(self, axis=None): + if axis is None: + fig = plt.figure() + axis = fig.add_subplot(111) + axis.add_patch(self.patch()) + axis.set_xlim([self._origin[0], self._max_extent[0]]) + axis.set_ylim([self._origin[1], self._max_extent[1]]) + for r in self.regions: + axis.add_patch(r.patch()) + return axis + +if __name__ == "__main__": + a = Arena((0, 0), (1024, 768), name="arena", arena_shape=RegionShape.Rectangular) + a.add_region((0, 0), (100, 300), name="small rect1") + a.add_region((150, 0), (100, 300), name="small rect2") + a.add_region((300, 0), (100, 300), name="small rect3") + a.add_region((600, 400), 150, name="circ", shape_type=RegionShape.Circular) + axis = a.plot() + x = np.linspace(a.position[0], a.position[0] + a.position[2] - 1, 100, dtype=int) + y = np.asarray((np.sin(x*0.01) + 1) * a.position[3] / 2 + a.position[1] -1, dtype=int) + #y = np.linspace(a.position[1], a.position[1] + a.position[3] - 1, 100, dtype=int) + axis.scatter(x, y, c="k", s=2) + + ind = a.regions[3].points_in_region(x, y) + if len(ind) > 0: + axis.scatter(x[ind], y[ind], label="circ full") + + ind = a.regions[3].points_in_region(x, y, AnalysisType.CollapseX) + if len(ind) > 0: + axis.scatter(x[ind], y[ind] - 10, label="circ collapseX") + + ind = a.regions[3].points_in_region(x, y, AnalysisType.CollapseY) + if len(ind) > 0: + axis.scatter(x[ind], y[ind] + 10, label="circ collapseY") + + + ind = a.regions[0].points_in_region(x, y, AnalysisType.CollapseX) + if len(ind) > 0: + axis.scatter(x[ind], y[ind]-10, label="rect collapseX") + + ind = a.regions[1].points_in_region(x, y, AnalysisType.CollapseY) + if len(ind) > 0: + axis.scatter(x[ind], y[ind] + 10, label="rect collapseY") + + ind = a.regions[2].points_in_region(x, y, AnalysisType.Full) + if len(ind) > 0: + axis.scatter(x[ind], y[ind]+20, label="rect full") + axis.legend() + plt.show() + + a.plot() + plt.show() \ No newline at end of file diff --git a/etrack/tracking_result.py b/etrack/tracking_result.py index aab0128..30b4b2b 100644 --- a/etrack/tracking_result.py +++ b/etrack/tracking_result.py @@ -185,13 +185,13 @@ if __name__ == '__main__': x3 = np.interp(time, time2, x2) y3 = np.interp(time, time2, y2) - - fig, axes = plt.subplots(3,1, sharex=True) - axes[0].plot(time, x) - axes[0].plot(time, x3) + fig, axes = plt.subplots(3,1, sharex=True) + axes[0].plot(time, x) + axes[0].plot(time, x3) axes[1].plot(time, y) - axes[1].plot(time, y3) - axes[2].plot(time, l) + axes[1].plot(time, y3) + + axes[2].plot(time, l) plt.show() embed() \ No newline at end of file diff --git a/etrack/util.py b/etrack/util.py new file mode 100644 index 0000000..d4a29ba --- /dev/null +++ b/etrack/util.py @@ -0,0 +1,22 @@ +from enum import Enum + +class Illumination(Enum): + Backlight = 0 + Incident = 1 + + +class RegionShape(Enum): + Circular = 0 + Rectangular = 1 + + def __str__(self) -> str: + return self.name + + +class AnalysisType(Enum): + Full = 0 + CollapseX = 1 + CollapseY = 2 + + def __str__(self) -> str: + return self.name