This commit is contained in:
Jan Grewe 2023-02-10 18:45:57 +01:00
parent 2bba750e1f
commit 469a35724d
2 changed files with 170 additions and 53 deletions

View File

@ -1,3 +1,4 @@
import logging
import numpy as np import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import matplotlib.patches as patches import matplotlib.patches as patches
@ -7,9 +8,20 @@ from skimage.draw import disk
from .util import RegionShape, AnalysisType, Illumination from .util import RegionShape, AnalysisType, Illumination
from IPython import embed from IPython import embed
class Region(object):
def __init__(self, origin, extent, inverted_y=True, name="", region_shape=RegionShape.Rectangular, parent=None) -> None: class Region(object):
def __init__(
self,
origin,
extent,
inverted_y=True,
name="",
region_shape=RegionShape.Rectangular,
parent=None,
) -> None:
logging.debug(
f"etrack.Region: Create {str(region_shape)} region {name} with props origin {origin}, extent {extent} and parent {parent}"
)
assert len(origin) == 2 assert len(origin) == 2
self._origin = origin self._origin = origin
self._extent = extent self._extent = extent
@ -41,9 +53,15 @@ class Region(object):
@property @property
def _max_extent(self): def _max_extent(self):
if self._shape_type == RegionShape.Rectangular: if self._shape_type == RegionShape.Rectangular:
max_extent = (self._origin[0] + self._extent[0], self._origin[1] + self._extent[1]) max_extent = (
self._origin[0] + self._extent[0],
self._origin[1] + self._extent[1],
)
else: else:
max_extent = (self._origin[0] + self._extent, self._origin[1] + self._extent) max_extent = (
self._origin[0] + self._extent,
self._origin[1] + self._extent,
)
return max_extent return max_extent
@property @property
@ -51,13 +69,31 @@ class Region(object):
if self._shape_type == RegionShape.Rectangular: if self._shape_type == RegionShape.Rectangular:
min_extent = self._origin min_extent = self._origin
else: else:
min_extent = (self._origin[0] - self._extent, self._origin[1] - self._extent) min_extent = (
self._origin[0] - self._extent,
self._origin[1] - self._extent,
)
return min_extent return min_extent
@property
def xmax(self):
return self._max_extent[0]
@property
def xmin(self):
return self._min_extent[0]
@property
def ymin(self):
return self._min_extent[1]
@property
def ymax(self):
return self._max_extent[1]
@property @property
def position(self): def position(self):
"""Returns the position and extent of the region as 4-tuple, (x, y, width, height) """Returns the position and extent of the region as 4-tuple, (x, y, width, height)"""
"""
x = self._min_extent[0] x = self._min_extent[0]
y = self._min_extent[1] y = self._min_extent[1]
width = self._max_extent[0] - self._min_extent[0] width = self._max_extent[0] - self._min_extent[0]
@ -73,20 +109,39 @@ class Region(object):
""" """
if self._shape_type == RegionShape.Rectangular: if self._shape_type == RegionShape.Rectangular:
if not isinstance(ext, (list, tuple, np.ndarray)) and len(ext) != 2: if not isinstance(ext, (list, tuple, np.ndarray)) and len(ext) != 2:
raise ValueError("Extent must be a length 2 list or tuple for rectangular regions!") raise ValueError(
"Extent must be a length 2 list or tuple for rectangular regions!"
)
elif self._shape_type == RegionShape.Circular: elif self._shape_type == RegionShape.Circular:
if not isinstance(ext, (int, float)): if not isinstance(ext, (int, float)):
raise ValueError("Extent must be a numerical scalar for circular regions!") raise ValueError(
"Extent must be a numerical scalar for circular regions!"
)
else: else:
raise ValueError(f"Invalid ShapeType, {self._shape_type}!") raise ValueError(f"Invalid ShapeType, {self._shape_type}!")
def fits(self, other) -> bool: def fits(self, other) -> bool:
""" """
Returns true if the other region fits inside this region! Returns true if the other region fits inside this region!
""" """
assert isinstance(other, Region) assert isinstance(other, Region)
does_fit = all((other._min_extent[0] >= self._min_extent[0], other._min_extent[1] >= self._min_extent[1], does_fit = all(
other._max_extent[0] <= self._max_extent[0], other._max_extent[1] <= self._max_extent[1])) (
other._min_extent[0] >= self._min_extent[0],
other._min_extent[1] >= self._min_extent[1],
other._max_extent[0] <= self._max_extent[0],
other._max_extent[1] <= self._max_extent[1],
)
)
if not does_fit:
m = (
f"Region {other.name} does not fit into {self.name}. "
f"min x: {other._min_extent[0] >= self._min_extent[0]},",
f"min y: {other._min_extent[1] >= self._min_extent[1]},",
f"max x: {other._max_extent[0] <= self._max_extent[0]},",
f"max y: {other._max_extent[1] <= self._max_extent[1]}",
)
logging.debug(m)
return does_fit return does_fit
@property @property
@ -106,22 +161,38 @@ class Region(object):
defines how the positions are evaluated, by default AnalysisType.Full defines how the positions are evaluated, by default AnalysisType.Full
FIXME: some of this can probably be solved using linear algebra, what with multiple exact same points? FIXME: some of this can probably be solved using linear algebra, what with multiple exact same points?
""" """
if self._shape_type == RegionShape.Rectangular or (self._shape_type == RegionShape.Circular and analysis_type != AnalysisType.Full): if self._shape_type == RegionShape.Rectangular or (
self._shape_type == RegionShape.Circular
and analysis_type != AnalysisType.Full
):
if analysis_type == AnalysisType.Full: if analysis_type == AnalysisType.Full:
indices = np.where(((y >= self._min_extent[1]) & (y <= self._max_extent[1])) & indices = np.where(
((x >= self._min_extent[0]) & (x <= self._max_extent[0])))[0] ((y >= self._min_extent[1]) & (y <= self._max_extent[1]))
& ((x >= self._min_extent[0]) & (x <= self._max_extent[0]))
)[0]
indices = np.array(indices, dtype=int) indices = np.array(indices, dtype=int)
elif analysis_type == AnalysisType.CollapseX: elif analysis_type == AnalysisType.CollapseX:
x_indices = np.where((x >= self._min_extent[0]) & (x <= self._max_extent[0] ))[0] x_indices = np.where(
(x >= self._min_extent[0]) & (x <= self._max_extent[0])
)[0]
indices = np.asarray(x_indices, dtype=int) indices = np.asarray(x_indices, dtype=int)
else: else:
y_indices = np.where((y >= self._min_extent[1]) & (y <= self._max_extent[1] ))[0] y_indices = np.where(
(y >= self._min_extent[1]) & (y <= self._max_extent[1])
)[0]
indices = np.asarray(y_indices, dtype=int) indices = np.asarray(y_indices, dtype=int)
else: else:
if self.is_child: if self.is_child:
mask = self.circular_mask(self._parent.position[2], self._parent.position[3], self._origin, self._extent) mask = self.circular_mask(
self._parent.position[2],
self._parent.position[3],
self._origin,
self._extent,
)
else: else:
mask = self.circular_mask(self.position[2], self.position[3], self._origin, self._extent) mask = self.circular_mask(
self.position[2], self.position[3], self._origin, self._extent
)
img = np.zeros_like(mask) img = np.zeros_like(mask)
img[np.asarray(y, dtype=int), np.asarray(x, dtype=int)] = 1 img[np.asarray(y, dtype=int), np.asarray(x, dtype=int)] = 1
temp = np.where(img & mask) temp = np.where(img & mask)
@ -156,7 +227,7 @@ class Region(object):
indices = self.points_in_region(x, y, analysis_type) indices = self.points_in_region(x, y, analysis_type)
if len(indices) == 0: if len(indices) == 0:
return np.array([]), np.array([]) return np.array([]), np.array([])
diffs = np.diff(indices) diffs = np.diff(indices)
if len(diffs) == sum(diffs): if len(diffs) == sum(diffs):
entering = [time[indices[0]]] entering = [time[indices[0]]]
@ -164,7 +235,7 @@ class Region(object):
else: else:
entering = [] entering = []
leaving = [] leaving = []
jumps = np.where(diffs > 1)[0] jumps = np.where(diffs > 1)[0]
start = time[indices[0]] start = time[indices[0]]
for i in range(len(jumps)): for i in range(len(jumps)):
end = time[indices[jumps[i]]] end = time[indices[jumps[i]]]
@ -193,22 +264,37 @@ class Region(object):
class Arena(Region): class Arena(Region):
def __init__(
def __init__(self, origin, extent, inverted_y=True, name="", arena_shape=RegionShape.Rectangular, self,
illumination=Illumination.Backlight) -> None: origin,
extent,
inverted_y=True,
name="",
arena_shape=RegionShape.Rectangular,
illumination=Illumination.Backlight,
) -> None:
super().__init__(origin, extent, inverted_y, name, arena_shape) super().__init__(origin, extent, inverted_y, name, arena_shape)
self._illumination = illumination self._illumination = illumination
self.regions = {} self.regions = {}
def add_region(self, name, origin, extent, shape_type=RegionShape.Rectangular, region=None): def add_region(
self, name, origin, extent, shape_type=RegionShape.Rectangular, region=None
):
if name is None or name in self.regions.keys(): if name is None or name in self.regions.keys():
raise ValueError("Region name '{name}' is invalid. The name must not be None and must be unique among the regions.") raise ValueError(
"Region name '{name}' is invalid. The name must not be None and must be unique among the regions."
)
if region is None: if region is None:
region = Region(origin, extent, name=name, region_shape=shape_type, parent=self) region = Region(
origin, extent, name=name, region_shape=shape_type, parent=self
)
else: else:
region._parent = self region._parent = self
if ~self.fits(region): doesfit = self.fits(region)
print(f"Warning! Region {region.name} with size {region.position} does fit into {self.name} with size {self.position}!") if not doesfit:
logging.warn(
f"Warning! Region {region.name} with size {region.position} does fit into {self.name} with size {self.position}!"
)
self.regions[name] = region self.regions[name] = region
def remove_region(self, name): def remove_region(self, name):
@ -276,8 +362,10 @@ if __name__ == "__main__":
a.add_region("circ", (600, 400), 150, shape_type=RegionShape.Circular) a.add_region("circ", (600, 400), 150, shape_type=RegionShape.Circular)
axis = a.plot() axis = a.plot()
x = np.linspace(a.position[0], a.position[0] + a.position[2] - 1, 100, dtype=int) x = np.linspace(a.position[0], a.position[0] + a.position[2] - 1, 100, dtype=int)
y = np.asarray((np.sin(x*0.01) + 1) * a.position[3] / 2 + a.position[1] -1, dtype=int) y = np.asarray(
#y = np.linspace(a.position[1], a.position[1] + a.position[3] - 1, 100, dtype=int) (np.sin(x * 0.01) + 1) * a.position[3] / 2 + a.position[1] - 1, dtype=int
)
# y = np.linspace(a.position[1], a.position[1] + a.position[3] - 1, 100, dtype=int)
axis.scatter(x, y, c="k", s=2) axis.scatter(x, y, c="k", s=2)
ind = a.regions[3].points_in_region(x, y) ind = a.regions[3].points_in_region(x, y)
@ -288,14 +376,13 @@ if __name__ == "__main__":
if len(ind) > 0: if len(ind) > 0:
axis.scatter(x[ind], y[ind] - 10, label="circ collapseX") axis.scatter(x[ind], y[ind] - 10, label="circ collapseX")
ind = a.regions[3].points_in_region(x, y, AnalysisType.CollapseY) ind = a.regions[3].points_in_region(x, y, AnalysisType.CollapseY)
if len(ind) > 0: if len(ind) > 0:
axis.scatter(x[ind], y[ind] + 10, label="circ collapseY") axis.scatter(x[ind], y[ind] + 10, label="circ collapseY")
ind = a.regions[0].points_in_region(x, y, AnalysisType.CollapseX) ind = a.regions[0].points_in_region(x, y, AnalysisType.CollapseX)
if len(ind) > 0: if len(ind) > 0:
axis.scatter(x[ind], y[ind]-10, label="rect collapseX") axis.scatter(x[ind], y[ind] - 10, label="rect collapseX")
ind = a.regions[1].points_in_region(x, y, AnalysisType.CollapseY) ind = a.regions[1].points_in_region(x, y, AnalysisType.CollapseY)
if len(ind) > 0: if len(ind) > 0:
@ -303,9 +390,9 @@ if __name__ == "__main__":
ind = a.regions[2].points_in_region(x, y, AnalysisType.Full) ind = a.regions[2].points_in_region(x, y, AnalysisType.Full)
if len(ind) > 0: if len(ind) > 0:
axis.scatter(x[ind], y[ind]+20, label="rect full") axis.scatter(x[ind], y[ind] + 20, label="rect full")
axis.legend() axis.legend()
plt.show() plt.show()
a.plot() a.plot()
plt.show() plt.show()

View File

@ -8,9 +8,21 @@ class TrackingData(object):
Using the 'quality_threshold', 'temporal_limits', or the 'position_limits' data can be filtered (see filter_tracks function). Using the 'quality_threshold', 'temporal_limits', or the 'position_limits' data can be filtered (see filter_tracks function).
The 'interpolate' function allows to fill up the gaps that be result from filtering with linearly interpolated data points. The 'interpolate' function allows to fill up the gaps that be result from filtering with linearly interpolated data points.
More may follow... More may follow...
""" """
def __init__(self, x, y, time, quality, node="", fps=None, quality_threshold=None, temporal_limits=None, position_limits=None) -> None:
def __init__(
self,
x,
y,
time,
quality,
node="",
fps=None,
quality_threshold=None,
temporal_limits=None,
position_limits=None,
) -> None:
self._orgx = x self._orgx = x
self._orgy = y self._orgy = y
self._orgtime = time self._orgtime = time
@ -35,11 +47,13 @@ class TrackingData(object):
def interpolate(self, start_time=None, end_time=None, min_count=5): def interpolate(self, start_time=None, end_time=None, min_count=5):
if len(self._x) < min_count: if len(self._x) < min_count:
print(f"{self._node} data has less than {min_count} data points with sufficient quality ({len(self._x)})!") print(
f"{self._node} data has less than {min_count} data points with sufficient quality ({len(self._x)})!"
)
return None, None, None return None, None, None
start = self._time[0] if start_time is None else start_time start = self._time[0] if start_time is None else start_time
end = self._time[-1] if end_time is None else end_time end = self._time[-1] if end_time is None else end_time
time = np.arange(start, end, 1./self._fps) time = np.arange(start, end, 1.0 / self._fps)
x = np.interp(time, self._time, self._x) x = np.interp(time, self._time, self._x)
y = np.interp(time, self._time, self._y) y = np.interp(time, self._time, self._y)
@ -56,7 +70,7 @@ class TrackingData(object):
Parameters Parameters
---------- ----------
new_threshold : float new_threshold : float
""" """
self._threshold = new_threshold self._threshold = new_threshold
@ -77,8 +91,12 @@ class TrackingData(object):
------ ------
ValueError, if new_value is not a 4-tuple ValueError, if new_value is not a 4-tuple
""" """
if new_limits is not None and not (isinstance(new_limits, (tuple, list)) and len(new_limits) == 4): if new_limits is not None and not (
raise ValueError(f"The new_limits vector must be a 4-tuple of the form (x, y, width, height)") isinstance(new_limits, (tuple, list)) and len(new_limits) == 4
):
raise ValueError(
f"The new_limits vector must be a 4-tuple of the form (x, y, width, height)"
)
self._position_limits = new_limits self._position_limits = new_limits
@property @property
@ -94,8 +112,12 @@ class TrackingData(object):
new_limits : 2-tuple new_limits : 2-tuple
The new limits in the form (start, end) given in seconds. The new limits in the form (start, end) given in seconds.
""" """
if new_limits is not None and not (isinstance(new_limits, (tuple, list)) and len(new_limits) == 2): if new_limits is not None and not (
raise ValueError(f"The new_limits vector must be a 2-tuple of the form (start, end). ") isinstance(new_limits, (tuple, list)) and len(new_limits) == 2
):
raise ValueError(
f"The new_limits vector must be a 2-tuple of the form (start, end). "
)
self._time_limits = new_limits self._time_limits = new_limits
def filter_tracks(self, align_time=True): def filter_tracks(self, align_time=True):
@ -115,16 +137,22 @@ class TrackingData(object):
if self.position_limits is not None: if self.position_limits is not None:
x_max = self.position_limits[0] + self.position_limits[2] x_max = self.position_limits[0] + self.position_limits[2]
y_max = self.position_limits[1] + self.position_limits[3] y_max = self.position_limits[1] + self.position_limits[3]
indices = np.where((self._x >= self.position_limits[0]) & (self._x < x_max) & indices = np.where(
(self._y >= self.position_limits[1]) & (self._y < y_max)) (self._x >= self.position_limits[0])
& (self._x < x_max)
& (self._y >= self.position_limits[1])
& (self._y < y_max)
)
self._x = self._x[indices] self._x = self._x[indices]
self._y = self._y[indices] self._y = self._y[indices]
self._time = self._time[indices] - self._time[0] if align_time else 0.0 self._time = self._time[indices] - self._time[0] if align_time else 0.0
self._quality = self._quality[indices] self._quality = self._quality[indices]
if self.temporal_limits is not None: if self.temporal_limits is not None:
indices = np.where((self._time >= self.temporal_limits[0]) & indices = np.where(
(self._time < self.temporal_limits[1])) (self._time >= self.temporal_limits[0])
& (self._time < self.temporal_limits[1])
)
self._x = self._x[indices] self._x = self._x[indices]
self._y = self._y[indices] self._y = self._y[indices]
self._time = self._time[indices] self._time = self._time[indices]
@ -138,7 +166,7 @@ class TrackingData(object):
self._quality = self._quality[indices] self._quality = self._quality[indices]
def positions(self): def positions(self):
"""Returns the filtered data (if filters have been applied). """Returns the filtered data (if filters have been applied).
Returns Returns
------- -------
@ -154,7 +182,7 @@ class TrackingData(object):
return self._x, self._y, self._time, self._quality return self._x, self._y, self._time, self._quality
def speed(self): def speed(self):
""" Returns the agent's speed as a function of time and position. The speed estimation is associated to the time/position between two sample points. """Returns the agent's speed as a function of time and position. The speed estimation is associated to the time/position between two sample points.
Returns Returns
------- -------
@ -165,7 +193,9 @@ class TrackingData(object):
tuple of np.ndarray tuple of np.ndarray
The position The position
""" """
speed = np.sqrt(np.diff(self._x)**2 + np.diff(self._y)**2) / np.diff(self._time) speed = np.sqrt(np.diff(self._x) ** 2 + np.diff(self._y) ** 2) / np.diff(
self._time
)
t = self._time[:-1] + np.diff(self._time) / 2 t = self._time[:-1] + np.diff(self._time) / 2
x = self._x[:-1] + np.diff(self._x) / 2 x = self._x[:-1] + np.diff(self._x) / 2
y = self._y[:-1] + np.diff(self._y) / 2 y = self._y[:-1] + np.diff(self._y) / 2
@ -174,4 +204,4 @@ class TrackingData(object):
def __repr__(self) -> str: def __repr__(self) -> str:
s = f"Tracking data of node '{self._node}'!" s = f"Tracking data of node '{self._node}'!"
return s return s