399 lines
13 KiB
Python
399 lines
13 KiB
Python
import logging
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
import matplotlib.patches as patches
|
|
|
|
from skimage.draw import disk
|
|
|
|
from .util import RegionShape, AnalysisType, Illumination
|
|
from IPython import embed
|
|
|
|
|
|
class Region(object):
|
|
def __init__(
|
|
self,
|
|
origin,
|
|
extent,
|
|
inverted_y=True,
|
|
name="",
|
|
region_shape=RegionShape.Rectangular,
|
|
parent=None,
|
|
) -> None:
|
|
logging.debug(
|
|
f"etrack.Region: Create {str(region_shape)} region {name} with props origin {origin}, extent {extent} and parent {parent}"
|
|
)
|
|
assert len(origin) == 2
|
|
self._origin = origin
|
|
self._extent = extent
|
|
self._inverted_y = inverted_y
|
|
self._name = name
|
|
self._shape_type = region_shape
|
|
self._check_extent(extent)
|
|
self._parent = parent
|
|
|
|
@staticmethod
|
|
def circular_mask(width, height, center, radius):
|
|
assert center[1] + radius < width and center[1] - radius > 0
|
|
assert center[0] + radius < height and center[0] - radius > 0
|
|
|
|
mask = np.zeros((height, width), dtype=np.uint8)
|
|
rr, cc = disk(reversed(center), radius)
|
|
mask[rr, cc] = 1
|
|
|
|
return mask
|
|
|
|
@property
|
|
def name(self):
|
|
return self._name
|
|
|
|
@property
|
|
def inverted_y(self):
|
|
return self._inverted_y
|
|
|
|
@property
|
|
def _max_extent(self):
|
|
if self._shape_type == RegionShape.Rectangular:
|
|
max_extent = (
|
|
self._origin[0] + self._extent[0],
|
|
self._origin[1] + self._extent[1],
|
|
)
|
|
else:
|
|
max_extent = (
|
|
self._origin[0] + self._extent,
|
|
self._origin[1] + self._extent,
|
|
)
|
|
return max_extent
|
|
|
|
@property
|
|
def _min_extent(self):
|
|
if self._shape_type == RegionShape.Rectangular:
|
|
min_extent = self._origin
|
|
else:
|
|
min_extent = (
|
|
self._origin[0] - self._extent,
|
|
self._origin[1] - self._extent,
|
|
)
|
|
return min_extent
|
|
|
|
@property
|
|
def xmax(self):
|
|
return self._max_extent[0]
|
|
|
|
@property
|
|
def xmin(self):
|
|
return self._min_extent[0]
|
|
|
|
@property
|
|
def ymin(self):
|
|
return self._min_extent[1]
|
|
|
|
@property
|
|
def ymax(self):
|
|
return self._max_extent[1]
|
|
|
|
@property
|
|
def position(self):
|
|
"""Returns the position and extent of the region as 4-tuple, (x, y, width, height)"""
|
|
x = self._min_extent[0]
|
|
y = self._min_extent[1]
|
|
width = self._max_extent[0] - self._min_extent[0]
|
|
height = self._max_extent[1] - self._min_extent[1]
|
|
return x, y, width, height
|
|
|
|
def _check_extent(self, ext):
|
|
"""Checks whether the extent matches the shape. i.e. if the shape is Rectangular, extent must be a length 2 list, tuple, otherwise, if the region is circular, extent must be a single numerical value.
|
|
|
|
Parameters
|
|
----------
|
|
ext : tuple, or numeric scalar
|
|
"""
|
|
if self._shape_type == RegionShape.Rectangular:
|
|
if not isinstance(ext, (list, tuple, np.ndarray)) and len(ext) != 2:
|
|
raise ValueError(
|
|
"Extent must be a length 2 list or tuple for rectangular regions!"
|
|
)
|
|
elif self._shape_type == RegionShape.Circular:
|
|
if not isinstance(ext, (int, float)):
|
|
raise ValueError(
|
|
"Extent must be a numerical scalar for circular regions!"
|
|
)
|
|
else:
|
|
raise ValueError(f"Invalid ShapeType, {self._shape_type}!")
|
|
|
|
def fits(self, other) -> bool:
|
|
"""
|
|
Returns true if the other region fits inside this region!
|
|
"""
|
|
assert isinstance(other, Region)
|
|
does_fit = all(
|
|
(
|
|
other._min_extent[0] >= self._min_extent[0],
|
|
other._min_extent[1] >= self._min_extent[1],
|
|
other._max_extent[0] <= self._max_extent[0],
|
|
other._max_extent[1] <= self._max_extent[1],
|
|
)
|
|
)
|
|
if not does_fit:
|
|
m = (
|
|
f"Region {other.name} does not fit into {self.name}. "
|
|
f"min x: {other._min_extent[0] >= self._min_extent[0]},",
|
|
f"min y: {other._min_extent[1] >= self._min_extent[1]},",
|
|
f"max x: {other._max_extent[0] <= self._max_extent[0]},",
|
|
f"max y: {other._max_extent[1] <= self._max_extent[1]}",
|
|
)
|
|
logging.debug(m)
|
|
return does_fit
|
|
|
|
@property
|
|
def is_child(self):
|
|
return self._parent is not None
|
|
|
|
def points_in_region(self, x, y, analysis_type=AnalysisType.Full):
|
|
"""returns the indices of the points specified by 'x' and 'y' that fall into this region.
|
|
|
|
Parameters
|
|
----------
|
|
x : np.ndarray
|
|
the x positions
|
|
y : np.ndarray
|
|
the y positions
|
|
analysis_type : AnalysisType, optional
|
|
defines how the positions are evaluated, by default AnalysisType.Full
|
|
FIXME: some of this can probably be solved using linear algebra, what with multiple exact same points?
|
|
"""
|
|
if self._shape_type == RegionShape.Rectangular or (
|
|
self._shape_type == RegionShape.Circular
|
|
and analysis_type != AnalysisType.Full
|
|
):
|
|
if analysis_type == AnalysisType.Full:
|
|
indices = np.where(
|
|
((y >= self._min_extent[1]) & (y <= self._max_extent[1]))
|
|
& ((x >= self._min_extent[0]) & (x <= self._max_extent[0]))
|
|
)[0]
|
|
indices = np.array(indices, dtype=int)
|
|
elif analysis_type == AnalysisType.CollapseX:
|
|
x_indices = np.where(
|
|
(x >= self._min_extent[0]) & (x <= self._max_extent[0])
|
|
)[0]
|
|
indices = np.asarray(x_indices, dtype=int)
|
|
else:
|
|
y_indices = np.where(
|
|
(y >= self._min_extent[1]) & (y <= self._max_extent[1])
|
|
)[0]
|
|
indices = np.asarray(y_indices, dtype=int)
|
|
else:
|
|
if self.is_child:
|
|
mask = self.circular_mask(
|
|
self._parent.position[2],
|
|
self._parent.position[3],
|
|
self._origin,
|
|
self._extent,
|
|
)
|
|
else:
|
|
mask = self.circular_mask(
|
|
self.position[2], self.position[3], self._origin, self._extent
|
|
)
|
|
img = np.zeros_like(mask)
|
|
img[np.asarray(y, dtype=int), np.asarray(x, dtype=int)] = 1
|
|
temp = np.where(img & mask)
|
|
indices = []
|
|
for i, j in zip(list(temp[1]), list(temp[0])):
|
|
matches = np.where((x == i) & (y == j))
|
|
if len(matches[0]) == 0:
|
|
continue
|
|
indices.append(matches[0][0])
|
|
indices = np.array(indices)
|
|
return indices
|
|
|
|
def time_in_region(self, x, y, time, analysis_type=AnalysisType.Full):
|
|
"""Returns the entering and leaving times at which the animal entered and left a region. In case the animal was not observed after entering this region (for example when hidden in a tube) the leaving time is the maximum time entry.
|
|
|
|
Parameters
|
|
----------
|
|
x : _type_
|
|
_description_
|
|
y : _type_
|
|
_description_
|
|
time : _type_
|
|
_description_
|
|
analysis_type : _type_, optional
|
|
_description_, by default AnalysisType.Full
|
|
|
|
Returns
|
|
-------
|
|
_type_
|
|
_description_
|
|
"""
|
|
indices = self.points_in_region(x, y, analysis_type)
|
|
if len(indices) == 0:
|
|
return np.array([]), np.array([])
|
|
|
|
diffs = np.diff(indices)
|
|
if len(diffs) == sum(diffs):
|
|
entering = [time[indices[0]]]
|
|
leaving = [time[indices[-1]]]
|
|
else:
|
|
entering = []
|
|
leaving = []
|
|
jumps = np.where(diffs > 1)[0]
|
|
start = time[indices[0]]
|
|
for i in range(len(jumps)):
|
|
end = time[indices[jumps[i]]]
|
|
entering.append(start)
|
|
leaving.append(end)
|
|
start = time[indices[jumps[i] + 1]]
|
|
|
|
end = time[indices[-1]]
|
|
entering.append(start)
|
|
leaving.append(end)
|
|
return np.array(entering), np.array(leaving)
|
|
|
|
def patch(self, **kwargs):
|
|
if "fc" not in kwargs:
|
|
kwargs["fc"] = None
|
|
kwargs["fill"] = False
|
|
if self._shape_type == RegionShape.Rectangular:
|
|
w = self.position[2]
|
|
h = self.position[3]
|
|
return patches.Rectangle(self._origin, w, h, **kwargs)
|
|
else:
|
|
return patches.Circle(self._origin, self._extent, **kwargs)
|
|
|
|
def __repr__(self):
|
|
return f"Region: '{self._name}' of {self._shape_type} shape."
|
|
|
|
|
|
class Arena(Region):
|
|
def __init__(
|
|
self,
|
|
origin,
|
|
extent,
|
|
inverted_y=True,
|
|
name="",
|
|
arena_shape=RegionShape.Rectangular,
|
|
illumination=Illumination.Backlight,
|
|
) -> None:
|
|
super().__init__(origin, extent, inverted_y, name, arena_shape)
|
|
self._illumination = illumination
|
|
self.regions = {}
|
|
|
|
def add_region(
|
|
self, name, origin, extent, shape_type=RegionShape.Rectangular, region=None
|
|
):
|
|
if name is None or name in self.regions.keys():
|
|
raise ValueError(
|
|
"Region name '{name}' is invalid. The name must not be None and must be unique among the regions."
|
|
)
|
|
if region is None:
|
|
region = Region(
|
|
origin, extent, name=name, region_shape=shape_type, parent=self
|
|
)
|
|
else:
|
|
region._parent = self
|
|
doesfit = self.fits(region)
|
|
if not doesfit:
|
|
logging.warn(
|
|
f"Warning! Region {region.name} with size {region.position} does fit into {self.name} with size {self.position}!"
|
|
)
|
|
self.regions[name] = region
|
|
|
|
def remove_region(self, name):
|
|
if name in self.regions:
|
|
self.regions.pop(name)
|
|
|
|
def __repr__(self):
|
|
return f"Arena: '{self._name}' of {self._shape_type} shape."
|
|
|
|
def plot(self, axis=None):
|
|
if axis is None:
|
|
fig = plt.figure()
|
|
axis = fig.add_subplot(111)
|
|
axis.add_patch(self.patch())
|
|
axis.set_xlim([self._origin[0], self._max_extent[0]])
|
|
|
|
if self.inverted_y:
|
|
axis.set_ylim([self._max_extent[1], self._origin[1]])
|
|
else:
|
|
axis.set_ylim([self._origin[1], self._max_extent[1]])
|
|
for r in self.regions:
|
|
axis.add_patch(self.regions[r].patch())
|
|
return axis
|
|
|
|
def region_vector(self, x, y):
|
|
"""Returns a vector that contains the region names within which the agent was found.
|
|
FIXME: This does not work well with overlapping regions!@!
|
|
Parameters
|
|
----------
|
|
x : np.array
|
|
the x-positions
|
|
y : np.ndarray
|
|
the y-positions
|
|
|
|
Returns
|
|
-------
|
|
np.array
|
|
vector of the same size as x and y. Each entry is the region to which the position is assinged to. If the point is not assigned to a region, the entry will be empty.
|
|
"""
|
|
rv = np.empty(x.shape, dtype=str)
|
|
for r in self.regions:
|
|
indices = self.regions[r].points_in_region(x, y)
|
|
rv[indices] = r
|
|
return rv
|
|
|
|
def in_region(self, x, y):
|
|
tmp = {}
|
|
for r in self.regions:
|
|
indices = self.regions[r].points_in_region(x, y)
|
|
tmp[r] = indices
|
|
return tmp
|
|
|
|
def __getitem__(self, key):
|
|
if isinstance(key, (str)):
|
|
return self.regions[key]
|
|
else:
|
|
return self.regions[self.regions.keys()[key]]
|
|
|
|
|
|
if __name__ == "__main__":
|
|
a = Arena((0, 0), (1024, 768), name="arena", arena_shape=RegionShape.Rectangular)
|
|
a.add_region("small rect1", (0, 0), (100, 300))
|
|
a.add_region("small rect2", (150, 0), (100, 300))
|
|
a.add_region("small rect3", (300, 0), (100, 300))
|
|
a.add_region("circ", (600, 400), 150, shape_type=RegionShape.Circular)
|
|
axis = a.plot()
|
|
x = np.linspace(a.position[0], a.position[0] + a.position[2] - 1, 100, dtype=int)
|
|
y = np.asarray(
|
|
(np.sin(x * 0.01) + 1) * a.position[3] / 2 + a.position[1] - 1, dtype=int
|
|
)
|
|
# y = np.linspace(a.position[1], a.position[1] + a.position[3] - 1, 100, dtype=int)
|
|
axis.scatter(x, y, c="k", s=2)
|
|
|
|
ind = a.regions[3].points_in_region(x, y)
|
|
if len(ind) > 0:
|
|
axis.scatter(x[ind], y[ind], label="circ full")
|
|
|
|
ind = a.regions[3].points_in_region(x, y, AnalysisType.CollapseX)
|
|
if len(ind) > 0:
|
|
axis.scatter(x[ind], y[ind] - 10, label="circ collapseX")
|
|
|
|
ind = a.regions[3].points_in_region(x, y, AnalysisType.CollapseY)
|
|
if len(ind) > 0:
|
|
axis.scatter(x[ind], y[ind] + 10, label="circ collapseY")
|
|
|
|
ind = a.regions[0].points_in_region(x, y, AnalysisType.CollapseX)
|
|
if len(ind) > 0:
|
|
axis.scatter(x[ind], y[ind] - 10, label="rect collapseX")
|
|
|
|
ind = a.regions[1].points_in_region(x, y, AnalysisType.CollapseY)
|
|
if len(ind) > 0:
|
|
axis.scatter(x[ind], y[ind] + 10, label="rect collapseY")
|
|
|
|
ind = a.regions[2].points_in_region(x, y, AnalysisType.Full)
|
|
if len(ind) > 0:
|
|
axis.scatter(x[ind], y[ind] + 20, label="rect full")
|
|
axis.legend()
|
|
plt.show()
|
|
|
|
a.plot()
|
|
plt.show()
|