arena abstraction
This commit is contained in:
parent
e854ab591f
commit
ae277ce8fb
217
etrack/arena.py
Normal file
217
etrack/arena.py
Normal file
@ -0,0 +1,217 @@
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.patches as patches
|
||||
|
||||
from skimage.draw import disk
|
||||
|
||||
from IPython import embed
|
||||
|
||||
from util import RegionShape, AnalysisType, Illumination
|
||||
|
||||
|
||||
class Region(object):
|
||||
|
||||
def __init__(self, origin, extent, inverted_y=True, name="", region_shape=RegionShape.Rectangular, parent=None) -> None:
|
||||
assert len(origin) == 2
|
||||
self._origin = origin
|
||||
self._extent = extent
|
||||
self._inverted_y = inverted_y
|
||||
self._name = name
|
||||
self._shape_type = region_shape
|
||||
self._check_extent(extent)
|
||||
self._parent = parent
|
||||
|
||||
@staticmethod
|
||||
def circular_mask(width, height, center, radius):
|
||||
assert center[1] + radius < width and center[1] - radius > 0
|
||||
assert center[0] + radius < height and center[0] - radius > 0
|
||||
|
||||
mask = np.zeros((height, width), dtype=np.uint8)
|
||||
rr, cc = disk(reversed(center), radius)
|
||||
mask[rr, cc] = 1
|
||||
|
||||
return mask
|
||||
|
||||
@property
|
||||
def _max_extent(self):
|
||||
if self._shape_type == RegionShape.Rectangular:
|
||||
max_extent = (self._origin[0] + self._extent[0], self._origin[1] + self._extent[1])
|
||||
else:
|
||||
max_extent = (self._origin[0] + self._extent, self._origin[1] + self._extent)
|
||||
return max_extent
|
||||
|
||||
@property
|
||||
def _min_extent(self):
|
||||
if self._shape_type == RegionShape.Rectangular:
|
||||
min_extent = self._origin
|
||||
else:
|
||||
min_extent = (self._origin[0] - self._extent, self._origin[1] - self._extent)
|
||||
return min_extent
|
||||
|
||||
@property
|
||||
def position(self):
|
||||
"""Returns the position and extent of the region as 4-tuple, (x, y, width, height)
|
||||
"""
|
||||
x = self._min_extent[0]
|
||||
y = self._min_extent[1]
|
||||
width = self._max_extent[0] - self._min_extent[0]
|
||||
height = self._max_extent[1] - self._min_extent[1]
|
||||
return x, y, width, height
|
||||
|
||||
def _check_extent(self, ext):
|
||||
"""Checks whether the extent matches the shape. i.e. if the shape is Rectangular, extent must be a length 2 list, tuple, otherwise, if the region is circular, extent must be a single numerical value.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ext : tuple, or numeric scalar
|
||||
"""
|
||||
if self._shape_type == RegionShape.Rectangular:
|
||||
if not isinstance(ext, (list, tuple, np.ndarray)) and len(ext) != 2:
|
||||
raise ValueError("Extent must be a length 2 list or tuple for rectangular regions!")
|
||||
elif self._shape_type == RegionShape.Circular:
|
||||
if not isinstance(ext, (int, float)):
|
||||
raise ValueError("Extent must be a numerical scalar for circular regions!")
|
||||
else:
|
||||
raise ValueError(f"Invalid ShapeType, {self._shape_type}!")
|
||||
|
||||
def fits(self, other) -> bool:
|
||||
"""
|
||||
Returns true if the other region fits inside this region!
|
||||
"""
|
||||
assert isinstance(other, Region)
|
||||
does_fit = all((other._min_extent[0] >= self._min_extent[0], other._min_extent[1] >= self._min_extent[1],
|
||||
other._max_extent[0] <= self._max_extent[0], other._max_extent[1] <= self._max_extent[1]))
|
||||
return does_fit
|
||||
|
||||
@property
|
||||
def is_child(self):
|
||||
return self._parent is not None
|
||||
|
||||
def points_in_region(self, x, y, analysis_type=AnalysisType.Full):
|
||||
"""returns the indices of the points specified by 'x' and 'y' that fall into this region.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : np.ndarray
|
||||
the x positions
|
||||
y : np.ndarray
|
||||
the y positions
|
||||
analysis_type : AnalysisType, optional
|
||||
defines how the positions are evaluated, by default AnalysisType.Full
|
||||
FIXME: some of this can probably be solved using linear algebra, what with multiple exact same points?
|
||||
"""
|
||||
if self._shape_type == RegionShape.Rectangular or (self._shape_type == RegionShape.Circular and analysis_type != AnalysisType.Full):
|
||||
x_indices = np.where((x >= self._min_extent[0]) & (x <= self._max_extent[0] ))[0]
|
||||
y_indices = np.where((y >= self._min_extent[1]) & (y <= self._max_extent[1] ))[0]
|
||||
if analysis_type == AnalysisType.Full:
|
||||
indices = np.array(list(set(x_indices).intersection(set(y_indices))), dtype=int)
|
||||
elif analysis_type == AnalysisType.CollapseX:
|
||||
indices = np.asarray(x_indices, dtype=int)
|
||||
else:
|
||||
indices = np.asarray(y_indices, dtype=int)
|
||||
else:
|
||||
if self.is_child:
|
||||
mask = self.circular_mask(self._parent.position[2], self._parent.position[3], self._origin, self._extent)
|
||||
else:
|
||||
mask = self.circular_mask(self.position[2], self.position[3], self._origin, self._extent)
|
||||
img = np.zeros_like(mask)
|
||||
img[np.asarray(y, dtype=int), np.asarray(x, dtype=int)] = 1
|
||||
temp = np.where(img & mask)
|
||||
indices = []
|
||||
for i, j in zip(list(temp[1]), list(temp[0])):
|
||||
matches = np.where((x == i) & (y == j))
|
||||
if len(matches[0]) == 0:
|
||||
continue
|
||||
indices.append(matches[0][0])
|
||||
indices = np.array(indices)
|
||||
return indices
|
||||
|
||||
|
||||
def patch(self, **kwargs):
|
||||
if "fc" not in kwargs:
|
||||
kwargs["fc"] = None
|
||||
kwargs["fill"] = False
|
||||
if self._shape_type == RegionShape.Rectangular:
|
||||
w = self.position[2]
|
||||
h = self.position[3]
|
||||
return patches.Rectangle(self._origin, w, h, **kwargs)
|
||||
else:
|
||||
return patches.Circle(self._origin, self._extent, **kwargs)
|
||||
|
||||
def __repr__(self):
|
||||
return f"Region: '{self._name}' of {self._shape_type} shape."
|
||||
|
||||
|
||||
class Arena(Region):
|
||||
|
||||
def __init__(self, origin, extent, inverted_y=True, name="", arena_shape=RegionShape.Rectangular,
|
||||
illumination=Illumination.Backlight) -> None:
|
||||
super().__init__(origin, extent, inverted_y, name, arena_shape)
|
||||
self._illumination = illumination
|
||||
self.regions = []
|
||||
|
||||
def add_region(self, origin, extent, name="", shape_type=RegionShape.Rectangular, region=None):
|
||||
if region is None:
|
||||
region = Region(origin, extent, name=name, region_shape=shape_type, parent=self)
|
||||
else:
|
||||
region._parent = self
|
||||
if self.fits(region):
|
||||
self.regions.append(region)
|
||||
else:
|
||||
raise Warning(f"Region {region} fits not! Not added to the list of regions!")
|
||||
|
||||
def __repr__(self):
|
||||
return f"Arena: '{self._name}' of {self._shape_type} shape."
|
||||
|
||||
def plot(self, axis=None):
|
||||
if axis is None:
|
||||
fig = plt.figure()
|
||||
axis = fig.add_subplot(111)
|
||||
axis.add_patch(self.patch())
|
||||
axis.set_xlim([self._origin[0], self._max_extent[0]])
|
||||
axis.set_ylim([self._origin[1], self._max_extent[1]])
|
||||
for r in self.regions:
|
||||
axis.add_patch(r.patch())
|
||||
return axis
|
||||
|
||||
if __name__ == "__main__":
|
||||
a = Arena((0, 0), (1024, 768), name="arena", arena_shape=RegionShape.Rectangular)
|
||||
a.add_region((0, 0), (100, 300), name="small rect1")
|
||||
a.add_region((150, 0), (100, 300), name="small rect2")
|
||||
a.add_region((300, 0), (100, 300), name="small rect3")
|
||||
a.add_region((600, 400), 150, name="circ", shape_type=RegionShape.Circular)
|
||||
axis = a.plot()
|
||||
x = np.linspace(a.position[0], a.position[0] + a.position[2] - 1, 100, dtype=int)
|
||||
y = np.asarray((np.sin(x*0.01) + 1) * a.position[3] / 2 + a.position[1] -1, dtype=int)
|
||||
#y = np.linspace(a.position[1], a.position[1] + a.position[3] - 1, 100, dtype=int)
|
||||
axis.scatter(x, y, c="k", s=2)
|
||||
|
||||
ind = a.regions[3].points_in_region(x, y)
|
||||
if len(ind) > 0:
|
||||
axis.scatter(x[ind], y[ind], label="circ full")
|
||||
|
||||
ind = a.regions[3].points_in_region(x, y, AnalysisType.CollapseX)
|
||||
if len(ind) > 0:
|
||||
axis.scatter(x[ind], y[ind] - 10, label="circ collapseX")
|
||||
|
||||
ind = a.regions[3].points_in_region(x, y, AnalysisType.CollapseY)
|
||||
if len(ind) > 0:
|
||||
axis.scatter(x[ind], y[ind] + 10, label="circ collapseY")
|
||||
|
||||
|
||||
ind = a.regions[0].points_in_region(x, y, AnalysisType.CollapseX)
|
||||
if len(ind) > 0:
|
||||
axis.scatter(x[ind], y[ind]-10, label="rect collapseX")
|
||||
|
||||
ind = a.regions[1].points_in_region(x, y, AnalysisType.CollapseY)
|
||||
if len(ind) > 0:
|
||||
axis.scatter(x[ind], y[ind] + 10, label="rect collapseY")
|
||||
|
||||
ind = a.regions[2].points_in_region(x, y, AnalysisType.Full)
|
||||
if len(ind) > 0:
|
||||
axis.scatter(x[ind], y[ind]+20, label="rect full")
|
||||
axis.legend()
|
||||
plt.show()
|
||||
|
||||
a.plot()
|
||||
plt.show()
|
@ -185,13 +185,13 @@ if __name__ == '__main__':
|
||||
x3 = np.interp(time, time2, x2)
|
||||
y3 = np.interp(time, time2, y2)
|
||||
|
||||
|
||||
fig, axes = plt.subplots(3,1, sharex=True)
|
||||
axes[0].plot(time, x)
|
||||
axes[0].plot(time, x3)
|
||||
fig, axes = plt.subplots(3,1, sharex=True)
|
||||
axes[0].plot(time, x)
|
||||
axes[0].plot(time, x3)
|
||||
axes[1].plot(time, y)
|
||||
axes[1].plot(time, y3)
|
||||
axes[2].plot(time, l)
|
||||
axes[1].plot(time, y3)
|
||||
|
||||
axes[2].plot(time, l)
|
||||
plt.show()
|
||||
|
||||
embed()
|
22
etrack/util.py
Normal file
22
etrack/util.py
Normal file
@ -0,0 +1,22 @@
|
||||
from enum import Enum
|
||||
|
||||
class Illumination(Enum):
|
||||
Backlight = 0
|
||||
Incident = 1
|
||||
|
||||
|
||||
class RegionShape(Enum):
|
||||
Circular = 0
|
||||
Rectangular = 1
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.name
|
||||
|
||||
|
||||
class AnalysisType(Enum):
|
||||
Full = 0
|
||||
CollapseX = 1
|
||||
CollapseY = 2
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.name
|
Loading…
Reference in New Issue
Block a user