[arena] cleanup, change regions to dict, and ...

add function to assign regions to positions
This commit is contained in:
Jan Grewe 2022-09-06 15:23:55 +02:00
parent 6487cb07ff
commit 9046e70592

View File

@ -4,9 +4,7 @@ import matplotlib.patches as patches
from skimage.draw import disk
from IPython import embed
from util import RegionShape, AnalysisType, Illumination
from .util import RegionShape, AnalysisType, Illumination
class Region(object):
@ -126,7 +124,6 @@ class Region(object):
indices = np.array(indices)
return indices
def patch(self, **kwargs):
if "fc" not in kwargs:
kwargs["fc"] = None
@ -148,17 +145,23 @@ class Arena(Region):
illumination=Illumination.Backlight) -> None:
super().__init__(origin, extent, inverted_y, name, arena_shape)
self._illumination = illumination
self.regions = []
self.regions = {}
def add_region(self, origin, extent, name="", shape_type=RegionShape.Rectangular, region=None):
def add_region(self, name, origin, extent, shape_type=RegionShape.Rectangular, region=None):
if name is None or name in self.regions.keys():
raise ValueError("Region name '{name}' is invalid. The name must not be None and must be unique among the regions.")
if region is None:
region = Region(origin, extent, name=name, region_shape=shape_type, parent=self)
else:
region._parent = self
if self.fits(region):
self.regions.append(region)
self.regions[name] = region
else:
raise Warning(f"Region {region} fits not! Not added to the list of regions!")
Warning(f"Region {region} fits not! Not added to the list of regions!")
def remove_region(self, name):
if name in self.regions:
self.regions.pop(name)
def __repr__(self):
return f"Arena: '{self._name}' of {self._shape_type} shape."
@ -174,12 +177,34 @@ class Arena(Region):
axis.add_patch(r.patch())
return axis
def region_vector(self, x, y):
"""Returns a vector that contains the region names within which the agent was found.
Parameters
----------
x : np.array
the x-positions
y : np.ndarray
the y-positions
Returns
-------
np.array
vector of the same size as x and y. Each entry is the region to which the position is assinged to. If the point is not assigned to a region, the entry will be empty.
"""
rv = np.empty(x.shape, dtype=str)
for r in self.regions:
indices = self.regions[r].points_in_region(x, y)
rv[indices] = r
return rv
if __name__ == "__main__":
a = Arena((0, 0), (1024, 768), name="arena", arena_shape=RegionShape.Rectangular)
a.add_region((0, 0), (100, 300), name="small rect1")
a.add_region((150, 0), (100, 300), name="small rect2")
a.add_region((300, 0), (100, 300), name="small rect3")
a.add_region((600, 400), 150, name="circ", shape_type=RegionShape.Circular)
a.add_region("small rect1", (0, 0), (100, 300))
a.add_region("small rect2", (150, 0), (100, 300))
a.add_region("small rect3", (300, 0), (100, 300))
a.add_region("circ", (600, 400), 150, shape_type=RegionShape.Circular)
axis = a.plot()
x = np.linspace(a.position[0], a.position[0] + a.position[2] - 1, 100, dtype=int)
y = np.asarray((np.sin(x*0.01) + 1) * a.position[3] / 2 + a.position[1] -1, dtype=int)