efish_tracking/src/etrack/arena.py
2024-06-01 12:19:43 +02:00

525 lines
18 KiB
Python

import logging
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from skimage.draw import disk
from .util import RegionShape, AnalysisType, Illumination
class Region(object):
"""
Class representing a region (of interest). Regions can be either circular or rectangular.
A Region can have a parent, i.e. it is contained inside a parent region. It can also have children.
Coordinates are given in absolute coordinates. The extent is treated depending on the shape. In case of a circular
shape, it is the radius and the origin is the center of the circle. Otherwise the origin is the bottom, or top-left corner, depending on the y-axis orientation, if inverted, then it is top-left. FIXME: check this
"""
def __init__(self, origin, extent, inverted_y=True, name="", region_shape=RegionShape.Rectangular, parent=None) -> None:
"""Region constructor.
Parameters
----------
origin : 2-tuple
x, and y coordinates
extent : scalar or 2-tuple, scalar only allowed to circular regions, 2-tuple for rectangular.
inverted_y : bool, optional
_description_, by default True
name : str, optional
_description_, by default ""
region_shape : _type_, optional
_description_, by default RegionShape.Rectangular
parent : _type_, optional
_description_, by default None
Returns
-------
_type_
_description_
Raises
------
ValueError
Raises Value error when origin or extent are invalid
"""
logging.debug(
f"etrack.Region: Create {str(region_shape)} region {name} with props origin {origin}, extent {extent} and parent {parent}"
)
if len(origin) != 2:
raise ValueError("Region: origin must be 2-tuple!")
self._parent = parent
self._name = name
self._shape_type = region_shape
self._origin = origin
self._check_extent(extent)
self._extent = extent
self._inverted_y = inverted_y
@staticmethod
def circular_mask(width, height, center, radius):
assert center[1] + radius < width and center[1] - radius > 0
assert center[0] + radius < height and center[0] - radius > 0
mask = np.zeros((height, width), dtype=np.uint8)
rr, cc = disk(reversed(center), radius)
mask[rr, cc] = 1
return mask
@property
def name(self):
return self._name
@property
def inverted_y(self):
return self._inverted_y
@property
def _max_extent(self):
if self._shape_type == RegionShape.Rectangular:
max_extent = (
self._origin[0] + self._extent[0],
self._origin[1] + self._extent[1],
)
else:
max_extent = (
self._origin[0] + self._extent,
self._origin[1] + self._extent,
)
return np.asarray(max_extent)
@property
def _min_extent(self):
if self._shape_type == RegionShape.Rectangular:
min_extent = self._origin
else:
min_extent = (
self._origin[0] - self._extent,
self._origin[1] - self._extent,
)
return np.asarray(min_extent)
@property
def xmax(self):
return self._max_extent[0]
@property
def xmin(self):
return self._min_extent[0]
@property
def ymin(self):
return self._min_extent[1]
@property
def ymax(self):
return self._max_extent[1]
@property
def position(self):
"""
Get the position of the arena.
Returns
-------
tuple
A tuple containing the x-coordinate, y-coordinate, width, and height of the arena.
"""
x = self._min_extent[0]
y = self._min_extent[1]
width = self._max_extent[0] - self._min_extent[0]
height = self._max_extent[1] - self._min_extent[1]
return x, y, width, height
def _check_extent(self, ext):
"""Checks whether the extent matches the shape. i.e. if the shape is Rectangular, extent must be a length 2 list, tuple, otherwise, if the region is circular, extent must be a single numerical value.
Parameters
----------
ext : tuple, or numeric scalar
"""
if self._shape_type == RegionShape.Rectangular:
if not isinstance(ext, (list, tuple, np.ndarray)) and len(ext) != 2:
raise ValueError(
"Extent must be a length 2 list or tuple for rectangular regions!"
)
elif self._shape_type == RegionShape.Circular:
if not isinstance(ext, (int, float)):
raise ValueError(
"Extent must be a numerical scalar for circular regions!"
)
else:
raise ValueError(f"Invalid ShapeType, {self._shape_type}!")
def fits(self, other) -> bool:
"""
Checks if the given region fits into the current region.
Args:
other (Region): The region to check if it fits.
Returns:
bool: True if the given region fits into the current region, False otherwise.
"""
assert isinstance(other, Region)
does_fit = all(
(
other._min_extent[0] >= self._min_extent[0],
other._min_extent[1] >= self._min_extent[1],
other._max_extent[0] <= self._max_extent[0],
other._max_extent[1] <= self._max_extent[1],
)
)
if not does_fit:
m = (
f"Region {other.name} does not fit into {self.name}. "
f"min x: {other._min_extent[0] >= self._min_extent[0]},",
f"min y: {other._min_extent[1] >= self._min_extent[1]},",
f"max x: {other._max_extent[0] <= self._max_extent[0]},",
f"max y: {other._max_extent[1] <= self._max_extent[1]}",
)
logging.debug(m)
return does_fit
@property
def is_child(self):
"""
Check if the current instance is a child.
Returns:
bool: True if the instance has a parent, False otherwise.
"""
return self._parent is not None
def points_in_region(self, x, y, analysis_type=AnalysisType.Full):
"""Returns the indices of the points specified by 'x' and 'y' that fall into this region.
Parameters
----------
x : np.ndarray
the x positions
y : np.ndarray
the y positions
analysis_type : AnalysisType, optional
defines how the positions are evaluated, by default AnalysisType.Full
FIXME: some of this can probably be solved using linear algebra, what with multiple exact same points?
"""
if self._shape_type == RegionShape.Rectangular or (
self._shape_type == RegionShape.Circular
and analysis_type != AnalysisType.Full
):
if analysis_type == AnalysisType.Full:
indices = np.where(
((y >= self._min_extent[1]) & (y <= self._max_extent[1]))
& ((x >= self._min_extent[0]) & (x <= self._max_extent[0]))
)[0]
indices = np.array(indices, dtype=int)
elif analysis_type == AnalysisType.CollapseX:
x_indices = np.where(
(x >= self._min_extent[0]) & (x <= self._max_extent[0])
)[0]
indices = np.asarray(x_indices, dtype=int)
else:
y_indices = np.where(
(y >= self._min_extent[1]) & (y <= self._max_extent[1])
)[0]
indices = np.asarray(y_indices, dtype=int)
else:
if self.is_child:
mask = self.circular_mask(
self._parent.position[2],
self._parent.position[3],
self._origin,
self._extent,
)
else:
mask = self.circular_mask(
self.position[2], self.position[3], self._origin, self._extent
)
img = np.zeros_like(mask)
img[np.asarray(y, dtype=int), np.asarray(x, dtype=int)] = 1
temp = np.where(img & mask)
indices = []
for i, j in zip(list(temp[1]), list(temp[0])):
matches = np.where((x == i) & (y == j))
if len(matches[0]) == 0:
continue
indices.append(matches[0][0])
indices = np.array(indices)
return indices
def time_in_region(self, x, y, time, analysis_type=AnalysisType.Full):
"""Returns the entering and leaving times at which the animal entered
and left a region. In case the animal was not observed after entering
this region (for example when hidden in a tube) the leaving time is
the maximum time entry.
Whether the full position, or only the x- or y-position should be considered
is controlled with the analysis_type parameter.
Parameters
----------
x : np.ndarray
The animal's x-positions
y : np.ndarray
the animal's y-positions
time : np.ndarray
the time array
analysis_type : AnalysisType, optional
The type of analysis, by default AnalysisType.Full
Returns
-------
np.ndarray
The entering times
np.ndarray
The leaving times
"""
indices = self.points_in_region(x, y, analysis_type)
if len(indices) == 0:
return np.array([]), np.array([])
diffs = np.diff(indices)
if len(diffs) == sum(diffs):
entering = [time[indices[0]]]
leaving = [time[indices[-1]]]
else:
entering = []
leaving = []
jumps = np.where(diffs > 1)[0]
start = time[indices[0]]
for i in range(len(jumps)):
end = time[indices[jumps[i]]]
entering.append(start)
leaving.append(end)
start = time[indices[jumps[i] + 1]]
end = time[indices[-1]]
entering.append(start)
leaving.append(end)
return np.array(entering), np.array(leaving)
def patch(self, **kwargs):
"""
Create and return a matplotlib patch object based on the shape type of the arena.
Parameters:
- kwargs: Additional keyword arguments to customize the patch object.
Returns:
- A matplotlib patch object representing the arena shape.
If the 'fc' (facecolor) keyword argument is not provided, it will default to None.
If the 'fill' keyword argument is not provided, it will default to False.
For rectangular arenas, the patch object will be a Rectangle with width and height
based on the arena's position.
For circular arenas, the patch object will be a Circle with radius based on the
arena's extent.
Example usage:
```
arena = Arena()
patch = arena.patch(fc='blue', fill=True)
ax.add_patch(patch)
```
"""
if "fc" not in kwargs:
kwargs["fc"] = None
kwargs["fill"] = False
if self._shape_type == RegionShape.Rectangular:
w = self.position[2]
h = self.position[3]
return patches.Rectangle(self._origin, w, h, **kwargs)
else:
return patches.Circle(self._origin, self._extent, **kwargs)
def __repr__(self):
return f"Region: '{self._name}' of {self._shape_type} shape."
class Arena(Region):
"""
Class to represent the experimental arena. Arena is derived from Region and can be either rectangular or circular.
An arena can not have a parent.
See Region for more details.
"""
def __init__(self, origin, extent, inverted_y=True, name="", arena_shape=RegionShape.Rectangular,
illumination=Illumination.Backlight) -> None:
""" Construct a new Area with a given origin and extent.
Returns
-------
_type_
_description_
Raises
------
ValueError
_description_
"""
super().__init__(origin, extent, inverted_y, name, arena_shape)
self._illumination = illumination
self.regions = {}
def add_region(
self, name, origin, extent, shape_type=RegionShape.Rectangular, region=None
):
if name is None or name in self.regions.keys():
raise ValueError(
"Region name '{name}' is invalid. The name must not be None and must be unique among the regions."
)
if region is None:
region = Region(
origin, extent, name=name, region_shape=shape_type, parent=self
)
else:
region._parent = self
doesfit = self.fits(region)
if not doesfit:
logging.warn(
f"Warning! Region {region.name} with size {region.position} does fit into {self.name} with size {self.position}!"
)
self.regions[name] = region
def remove_region(self, name):
"""
Remove a region from the arena.
Parameter:
name : str
The name of the region to remove.
Returns:
None
"""
if name in self.regions:
self.regions.pop(name)
def __repr__(self):
return f"Arena: '{self._name}' of {self._shape_type} shape."
def plot(self, axis=None):
"""
Plots the arena on the given axis.
Parameters
----------
- axis (matplotlib.axes.Axes, optional): The axis on which to plot the arena. If not provided, a new figure and axis will be created.
Returns
-------
- matplotlib.axes.Axes: The axis on which the arena is plotted.
"""
if axis is None:
fig = plt.figure()
axis = fig.add_subplot(111)
axis.add_patch(self.patch())
axis.set_xlim([self._origin[0], self._max_extent[0]])
if self.inverted_y:
axis.set_ylim([self._max_extent[1], self._origin[1]])
else:
axis.set_ylim([self._origin[1], self._max_extent[1]])
for r in self.regions:
axis.add_patch(self.regions[r].patch())
return axis
def region_vector(self, x, y):
"""Returns a vector that contains the region names within which the agent was found.
FIXME: This does not work well with overlapping regions!@!
Parameters
----------
x : np.array
the x-positions
y : np.ndarray
the y-positions
Returns
-------
np.array
vector of the same size as x and y. Each entry is the region to which the position is assigned to. If the point is not assigned to a region, the entry will be empty.
"""
if not isinstance(x, np.ndarray):
x = np.asarray(x)
if not isinstance(y, np.ndarray):
y = np.asarray(y)
rv = np.empty(x.shape, dtype=str)
for r in self.regions:
indices = self.regions[r].points_in_region(x, y)
rv[indices] = r
return rv
def in_region(self, x, y):
"""
Determines if the given coordinates (x, y) are within any of the defined regions in the arena.
Parameters
----------
x : float
The x-coordinate of the point to check.
y : float
The y-coordinate of the point to check.
Returns
-------
dict:
A dictionary containing the region names as keys and a list of indices of points within each region as values.
"""
tmp = {}
for r in self.regions:
print(r)
indices = self.regions[r].points_in_region(x, y)
tmp[r] = indices
return tmp
def __getitem__(self, key):
if isinstance(key, (str)):
return self.regions[key]
else:
return self.regions[self.regions.keys()[key]]
if __name__ == "__main__":
a = Arena((0, 0), (1024, 768), name="arena", arena_shape=RegionShape.Rectangular)
a.add_region("small rect1", (0, 0), (100, 300))
a.add_region("small rect2", (150, 0), (100, 300))
a.add_region("small rect3", (300, 0), (100, 300))
a.add_region("circ", (600, 400), 150, shape_type=RegionShape.Circular)
axis = a.plot()
x = np.linspace(a.position[0], a.position[0] + a.position[2] - 1, 100, dtype=int)
y = np.asarray(
(np.sin(x * 0.01) + 1) * a.position[3] / 2 + a.position[1] - 1, dtype=int
)
# y = np.linspace(a.position[1], a.position[1] + a.position[3] - 1, 100, dtype=int)
axis.scatter(x, y, c="k", s=2)
ind = a.regions[3].points_in_region(x, y)
if len(ind) > 0:
axis.scatter(x[ind], y[ind], label="circ full")
ind = a.regions[3].points_in_region(x, y, AnalysisType.CollapseX)
if len(ind) > 0:
axis.scatter(x[ind], y[ind] - 10, label="circ collapseX")
ind = a.regions[3].points_in_region(x, y, AnalysisType.CollapseY)
if len(ind) > 0:
axis.scatter(x[ind], y[ind] + 10, label="circ collapseY")
ind = a.regions[0].points_in_region(x, y, AnalysisType.CollapseX)
if len(ind) > 0:
axis.scatter(x[ind], y[ind] - 10, label="rect collapseX")
ind = a.regions[1].points_in_region(x, y, AnalysisType.CollapseY)
if len(ind) > 0:
axis.scatter(x[ind], y[ind] + 10, label="rect collapseY")
ind = a.regions[2].points_in_region(x, y, AnalysisType.Full)
if len(ind) > 0:
axis.scatter(x[ind], y[ind] + 20, label="rect full")
axis.legend()
plt.show()
a.plot()
plt.show()