efish_tracking/etrack/arena.py
2023-02-10 18:45:57 +01:00

399 lines
13 KiB
Python

import logging
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from skimage.draw import disk
from .util import RegionShape, AnalysisType, Illumination
from IPython import embed
class Region(object):
def __init__(
self,
origin,
extent,
inverted_y=True,
name="",
region_shape=RegionShape.Rectangular,
parent=None,
) -> None:
logging.debug(
f"etrack.Region: Create {str(region_shape)} region {name} with props origin {origin}, extent {extent} and parent {parent}"
)
assert len(origin) == 2
self._origin = origin
self._extent = extent
self._inverted_y = inverted_y
self._name = name
self._shape_type = region_shape
self._check_extent(extent)
self._parent = parent
@staticmethod
def circular_mask(width, height, center, radius):
assert center[1] + radius < width and center[1] - radius > 0
assert center[0] + radius < height and center[0] - radius > 0
mask = np.zeros((height, width), dtype=np.uint8)
rr, cc = disk(reversed(center), radius)
mask[rr, cc] = 1
return mask
@property
def name(self):
return self._name
@property
def inverted_y(self):
return self._inverted_y
@property
def _max_extent(self):
if self._shape_type == RegionShape.Rectangular:
max_extent = (
self._origin[0] + self._extent[0],
self._origin[1] + self._extent[1],
)
else:
max_extent = (
self._origin[0] + self._extent,
self._origin[1] + self._extent,
)
return max_extent
@property
def _min_extent(self):
if self._shape_type == RegionShape.Rectangular:
min_extent = self._origin
else:
min_extent = (
self._origin[0] - self._extent,
self._origin[1] - self._extent,
)
return min_extent
@property
def xmax(self):
return self._max_extent[0]
@property
def xmin(self):
return self._min_extent[0]
@property
def ymin(self):
return self._min_extent[1]
@property
def ymax(self):
return self._max_extent[1]
@property
def position(self):
"""Returns the position and extent of the region as 4-tuple, (x, y, width, height)"""
x = self._min_extent[0]
y = self._min_extent[1]
width = self._max_extent[0] - self._min_extent[0]
height = self._max_extent[1] - self._min_extent[1]
return x, y, width, height
def _check_extent(self, ext):
"""Checks whether the extent matches the shape. i.e. if the shape is Rectangular, extent must be a length 2 list, tuple, otherwise, if the region is circular, extent must be a single numerical value.
Parameters
----------
ext : tuple, or numeric scalar
"""
if self._shape_type == RegionShape.Rectangular:
if not isinstance(ext, (list, tuple, np.ndarray)) and len(ext) != 2:
raise ValueError(
"Extent must be a length 2 list or tuple for rectangular regions!"
)
elif self._shape_type == RegionShape.Circular:
if not isinstance(ext, (int, float)):
raise ValueError(
"Extent must be a numerical scalar for circular regions!"
)
else:
raise ValueError(f"Invalid ShapeType, {self._shape_type}!")
def fits(self, other) -> bool:
"""
Returns true if the other region fits inside this region!
"""
assert isinstance(other, Region)
does_fit = all(
(
other._min_extent[0] >= self._min_extent[0],
other._min_extent[1] >= self._min_extent[1],
other._max_extent[0] <= self._max_extent[0],
other._max_extent[1] <= self._max_extent[1],
)
)
if not does_fit:
m = (
f"Region {other.name} does not fit into {self.name}. "
f"min x: {other._min_extent[0] >= self._min_extent[0]},",
f"min y: {other._min_extent[1] >= self._min_extent[1]},",
f"max x: {other._max_extent[0] <= self._max_extent[0]},",
f"max y: {other._max_extent[1] <= self._max_extent[1]}",
)
logging.debug(m)
return does_fit
@property
def is_child(self):
return self._parent is not None
def points_in_region(self, x, y, analysis_type=AnalysisType.Full):
"""returns the indices of the points specified by 'x' and 'y' that fall into this region.
Parameters
----------
x : np.ndarray
the x positions
y : np.ndarray
the y positions
analysis_type : AnalysisType, optional
defines how the positions are evaluated, by default AnalysisType.Full
FIXME: some of this can probably be solved using linear algebra, what with multiple exact same points?
"""
if self._shape_type == RegionShape.Rectangular or (
self._shape_type == RegionShape.Circular
and analysis_type != AnalysisType.Full
):
if analysis_type == AnalysisType.Full:
indices = np.where(
((y >= self._min_extent[1]) & (y <= self._max_extent[1]))
& ((x >= self._min_extent[0]) & (x <= self._max_extent[0]))
)[0]
indices = np.array(indices, dtype=int)
elif analysis_type == AnalysisType.CollapseX:
x_indices = np.where(
(x >= self._min_extent[0]) & (x <= self._max_extent[0])
)[0]
indices = np.asarray(x_indices, dtype=int)
else:
y_indices = np.where(
(y >= self._min_extent[1]) & (y <= self._max_extent[1])
)[0]
indices = np.asarray(y_indices, dtype=int)
else:
if self.is_child:
mask = self.circular_mask(
self._parent.position[2],
self._parent.position[3],
self._origin,
self._extent,
)
else:
mask = self.circular_mask(
self.position[2], self.position[3], self._origin, self._extent
)
img = np.zeros_like(mask)
img[np.asarray(y, dtype=int), np.asarray(x, dtype=int)] = 1
temp = np.where(img & mask)
indices = []
for i, j in zip(list(temp[1]), list(temp[0])):
matches = np.where((x == i) & (y == j))
if len(matches[0]) == 0:
continue
indices.append(matches[0][0])
indices = np.array(indices)
return indices
def time_in_region(self, x, y, time, analysis_type=AnalysisType.Full):
"""Returns the entering and leaving times at which the animal entered and left a region. In case the animal was not observed after entering this region (for example when hidden in a tube) the leaving time is the maximum time entry.
Parameters
----------
x : _type_
_description_
y : _type_
_description_
time : _type_
_description_
analysis_type : _type_, optional
_description_, by default AnalysisType.Full
Returns
-------
_type_
_description_
"""
indices = self.points_in_region(x, y, analysis_type)
if len(indices) == 0:
return np.array([]), np.array([])
diffs = np.diff(indices)
if len(diffs) == sum(diffs):
entering = [time[indices[0]]]
leaving = [time[indices[-1]]]
else:
entering = []
leaving = []
jumps = np.where(diffs > 1)[0]
start = time[indices[0]]
for i in range(len(jumps)):
end = time[indices[jumps[i]]]
entering.append(start)
leaving.append(end)
start = time[indices[jumps[i] + 1]]
end = time[indices[-1]]
entering.append(start)
leaving.append(end)
return np.array(entering), np.array(leaving)
def patch(self, **kwargs):
if "fc" not in kwargs:
kwargs["fc"] = None
kwargs["fill"] = False
if self._shape_type == RegionShape.Rectangular:
w = self.position[2]
h = self.position[3]
return patches.Rectangle(self._origin, w, h, **kwargs)
else:
return patches.Circle(self._origin, self._extent, **kwargs)
def __repr__(self):
return f"Region: '{self._name}' of {self._shape_type} shape."
class Arena(Region):
def __init__(
self,
origin,
extent,
inverted_y=True,
name="",
arena_shape=RegionShape.Rectangular,
illumination=Illumination.Backlight,
) -> None:
super().__init__(origin, extent, inverted_y, name, arena_shape)
self._illumination = illumination
self.regions = {}
def add_region(
self, name, origin, extent, shape_type=RegionShape.Rectangular, region=None
):
if name is None or name in self.regions.keys():
raise ValueError(
"Region name '{name}' is invalid. The name must not be None and must be unique among the regions."
)
if region is None:
region = Region(
origin, extent, name=name, region_shape=shape_type, parent=self
)
else:
region._parent = self
doesfit = self.fits(region)
if not doesfit:
logging.warn(
f"Warning! Region {region.name} with size {region.position} does fit into {self.name} with size {self.position}!"
)
self.regions[name] = region
def remove_region(self, name):
if name in self.regions:
self.regions.pop(name)
def __repr__(self):
return f"Arena: '{self._name}' of {self._shape_type} shape."
def plot(self, axis=None):
if axis is None:
fig = plt.figure()
axis = fig.add_subplot(111)
axis.add_patch(self.patch())
axis.set_xlim([self._origin[0], self._max_extent[0]])
if self.inverted_y:
axis.set_ylim([self._max_extent[1], self._origin[1]])
else:
axis.set_ylim([self._origin[1], self._max_extent[1]])
for r in self.regions:
axis.add_patch(self.regions[r].patch())
return axis
def region_vector(self, x, y):
"""Returns a vector that contains the region names within which the agent was found.
FIXME: This does not work well with overlapping regions!@!
Parameters
----------
x : np.array
the x-positions
y : np.ndarray
the y-positions
Returns
-------
np.array
vector of the same size as x and y. Each entry is the region to which the position is assinged to. If the point is not assigned to a region, the entry will be empty.
"""
rv = np.empty(x.shape, dtype=str)
for r in self.regions:
indices = self.regions[r].points_in_region(x, y)
rv[indices] = r
return rv
def in_region(self, x, y):
tmp = {}
for r in self.regions:
indices = self.regions[r].points_in_region(x, y)
tmp[r] = indices
return tmp
def __getitem__(self, key):
if isinstance(key, (str)):
return self.regions[key]
else:
return self.regions[self.regions.keys()[key]]
if __name__ == "__main__":
a = Arena((0, 0), (1024, 768), name="arena", arena_shape=RegionShape.Rectangular)
a.add_region("small rect1", (0, 0), (100, 300))
a.add_region("small rect2", (150, 0), (100, 300))
a.add_region("small rect3", (300, 0), (100, 300))
a.add_region("circ", (600, 400), 150, shape_type=RegionShape.Circular)
axis = a.plot()
x = np.linspace(a.position[0], a.position[0] + a.position[2] - 1, 100, dtype=int)
y = np.asarray(
(np.sin(x * 0.01) + 1) * a.position[3] / 2 + a.position[1] - 1, dtype=int
)
# y = np.linspace(a.position[1], a.position[1] + a.position[3] - 1, 100, dtype=int)
axis.scatter(x, y, c="k", s=2)
ind = a.regions[3].points_in_region(x, y)
if len(ind) > 0:
axis.scatter(x[ind], y[ind], label="circ full")
ind = a.regions[3].points_in_region(x, y, AnalysisType.CollapseX)
if len(ind) > 0:
axis.scatter(x[ind], y[ind] - 10, label="circ collapseX")
ind = a.regions[3].points_in_region(x, y, AnalysisType.CollapseY)
if len(ind) > 0:
axis.scatter(x[ind], y[ind] + 10, label="circ collapseY")
ind = a.regions[0].points_in_region(x, y, AnalysisType.CollapseX)
if len(ind) > 0:
axis.scatter(x[ind], y[ind] - 10, label="rect collapseX")
ind = a.regions[1].points_in_region(x, y, AnalysisType.CollapseY)
if len(ind) > 0:
axis.scatter(x[ind], y[ind] + 10, label="rect collapseY")
ind = a.regions[2].points_in_region(x, y, AnalysisType.Full)
if len(ind) > 0:
axis.scatter(x[ind], y[ind] + 20, label="rect full")
axis.legend()
plt.show()
a.plot()
plt.show()