From 9046e70592e84ccedda9f5b9125763a00f749124 Mon Sep 17 00:00:00 2001 From: Jan Grewe Date: Tue, 6 Sep 2022 15:23:55 +0200 Subject: [PATCH] [arena] cleanup, change regions to dict, and ... add function to assign regions to positions --- etrack/arena.py | 49 +++++++++++++++++++++++++++++++++++++------------ 1 file changed, 37 insertions(+), 12 deletions(-) diff --git a/etrack/arena.py b/etrack/arena.py index ef1ae25..6f270db 100644 --- a/etrack/arena.py +++ b/etrack/arena.py @@ -4,9 +4,7 @@ import matplotlib.patches as patches from skimage.draw import disk -from IPython import embed - -from util import RegionShape, AnalysisType, Illumination +from .util import RegionShape, AnalysisType, Illumination class Region(object): @@ -126,7 +124,6 @@ class Region(object): indices = np.array(indices) return indices - def patch(self, **kwargs): if "fc" not in kwargs: kwargs["fc"] = None @@ -148,17 +145,23 @@ class Arena(Region): illumination=Illumination.Backlight) -> None: super().__init__(origin, extent, inverted_y, name, arena_shape) self._illumination = illumination - self.regions = [] + self.regions = {} - def add_region(self, origin, extent, name="", shape_type=RegionShape.Rectangular, region=None): + def add_region(self, name, origin, extent, shape_type=RegionShape.Rectangular, region=None): + if name is None or name in self.regions.keys(): + raise ValueError("Region name '{name}' is invalid. The name must not be None and must be unique among the regions.") if region is None: region = Region(origin, extent, name=name, region_shape=shape_type, parent=self) else: region._parent = self if self.fits(region): - self.regions.append(region) + self.regions[name] = region else: - raise Warning(f"Region {region} fits not! Not added to the list of regions!") + Warning(f"Region {region} fits not! Not added to the list of regions!") + + def remove_region(self, name): + if name in self.regions: + self.regions.pop(name) def __repr__(self): return f"Arena: '{self._name}' of {self._shape_type} shape." @@ -174,12 +177,34 @@ class Arena(Region): axis.add_patch(r.patch()) return axis + def region_vector(self, x, y): + """Returns a vector that contains the region names within which the agent was found. + + Parameters + ---------- + x : np.array + the x-positions + y : np.ndarray + the y-positions + + Returns + ------- + np.array + vector of the same size as x and y. Each entry is the region to which the position is assinged to. If the point is not assigned to a region, the entry will be empty. + """ + rv = np.empty(x.shape, dtype=str) + for r in self.regions: + indices = self.regions[r].points_in_region(x, y) + rv[indices] = r + return rv + + if __name__ == "__main__": a = Arena((0, 0), (1024, 768), name="arena", arena_shape=RegionShape.Rectangular) - a.add_region((0, 0), (100, 300), name="small rect1") - a.add_region((150, 0), (100, 300), name="small rect2") - a.add_region((300, 0), (100, 300), name="small rect3") - a.add_region((600, 400), 150, name="circ", shape_type=RegionShape.Circular) + a.add_region("small rect1", (0, 0), (100, 300)) + a.add_region("small rect2", (150, 0), (100, 300)) + a.add_region("small rect3", (300, 0), (100, 300)) + a.add_region("circ", (600, 400), 150, shape_type=RegionShape.Circular) axis = a.plot() x = np.linspace(a.position[0], a.position[0] + a.position[2] - 1, 100, dtype=int) y = np.asarray((np.sin(x*0.01) + 1) * a.position[3] / 2 + a.position[1] -1, dtype=int)