diff --git a/etrack/arena.py b/etrack/arena.py index 52061af..86b7233 100644 --- a/etrack/arena.py +++ b/etrack/arena.py @@ -1,3 +1,4 @@ +import logging import numpy as np import matplotlib.pyplot as plt import matplotlib.patches as patches @@ -7,9 +8,20 @@ from skimage.draw import disk from .util import RegionShape, AnalysisType, Illumination from IPython import embed -class Region(object): - def __init__(self, origin, extent, inverted_y=True, name="", region_shape=RegionShape.Rectangular, parent=None) -> None: +class Region(object): + def __init__( + self, + origin, + extent, + inverted_y=True, + name="", + region_shape=RegionShape.Rectangular, + parent=None, + ) -> None: + logging.debug( + f"etrack.Region: Create {str(region_shape)} region {name} with props origin {origin}, extent {extent} and parent {parent}" + ) assert len(origin) == 2 self._origin = origin self._extent = extent @@ -41,9 +53,15 @@ class Region(object): @property def _max_extent(self): if self._shape_type == RegionShape.Rectangular: - max_extent = (self._origin[0] + self._extent[0], self._origin[1] + self._extent[1]) + max_extent = ( + self._origin[0] + self._extent[0], + self._origin[1] + self._extent[1], + ) else: - max_extent = (self._origin[0] + self._extent, self._origin[1] + self._extent) + max_extent = ( + self._origin[0] + self._extent, + self._origin[1] + self._extent, + ) return max_extent @property @@ -51,13 +69,31 @@ class Region(object): if self._shape_type == RegionShape.Rectangular: min_extent = self._origin else: - min_extent = (self._origin[0] - self._extent, self._origin[1] - self._extent) + min_extent = ( + self._origin[0] - self._extent, + self._origin[1] - self._extent, + ) return min_extent + @property + def xmax(self): + return self._max_extent[0] + + @property + def xmin(self): + return self._min_extent[0] + + @property + def ymin(self): + return self._min_extent[1] + + @property + def ymax(self): + return self._max_extent[1] + @property def position(self): - """Returns the position and extent of the region as 4-tuple, (x, y, width, height) - """ + """Returns the position and extent of the region as 4-tuple, (x, y, width, height)""" x = self._min_extent[0] y = self._min_extent[1] width = self._max_extent[0] - self._min_extent[0] @@ -73,20 +109,39 @@ class Region(object): """ if self._shape_type == RegionShape.Rectangular: if not isinstance(ext, (list, tuple, np.ndarray)) and len(ext) != 2: - raise ValueError("Extent must be a length 2 list or tuple for rectangular regions!") + raise ValueError( + "Extent must be a length 2 list or tuple for rectangular regions!" + ) elif self._shape_type == RegionShape.Circular: if not isinstance(ext, (int, float)): - raise ValueError("Extent must be a numerical scalar for circular regions!") + raise ValueError( + "Extent must be a numerical scalar for circular regions!" + ) else: raise ValueError(f"Invalid ShapeType, {self._shape_type}!") def fits(self, other) -> bool: """ - Returns true if the other region fits inside this region! + Returns true if the other region fits inside this region! """ assert isinstance(other, Region) - does_fit = all((other._min_extent[0] >= self._min_extent[0], other._min_extent[1] >= self._min_extent[1], - other._max_extent[0] <= self._max_extent[0], other._max_extent[1] <= self._max_extent[1])) + does_fit = all( + ( + other._min_extent[0] >= self._min_extent[0], + other._min_extent[1] >= self._min_extent[1], + other._max_extent[0] <= self._max_extent[0], + other._max_extent[1] <= self._max_extent[1], + ) + ) + if not does_fit: + m = ( + f"Region {other.name} does not fit into {self.name}. " + f"min x: {other._min_extent[0] >= self._min_extent[0]},", + f"min y: {other._min_extent[1] >= self._min_extent[1]},", + f"max x: {other._max_extent[0] <= self._max_extent[0]},", + f"max y: {other._max_extent[1] <= self._max_extent[1]}", + ) + logging.debug(m) return does_fit @property @@ -106,22 +161,38 @@ class Region(object): defines how the positions are evaluated, by default AnalysisType.Full FIXME: some of this can probably be solved using linear algebra, what with multiple exact same points? """ - if self._shape_type == RegionShape.Rectangular or (self._shape_type == RegionShape.Circular and analysis_type != AnalysisType.Full): + if self._shape_type == RegionShape.Rectangular or ( + self._shape_type == RegionShape.Circular + and analysis_type != AnalysisType.Full + ): if analysis_type == AnalysisType.Full: - indices = np.where(((y >= self._min_extent[1]) & (y <= self._max_extent[1])) & - ((x >= self._min_extent[0]) & (x <= self._max_extent[0])))[0] + indices = np.where( + ((y >= self._min_extent[1]) & (y <= self._max_extent[1])) + & ((x >= self._min_extent[0]) & (x <= self._max_extent[0])) + )[0] indices = np.array(indices, dtype=int) elif analysis_type == AnalysisType.CollapseX: - x_indices = np.where((x >= self._min_extent[0]) & (x <= self._max_extent[0] ))[0] + x_indices = np.where( + (x >= self._min_extent[0]) & (x <= self._max_extent[0]) + )[0] indices = np.asarray(x_indices, dtype=int) else: - y_indices = np.where((y >= self._min_extent[1]) & (y <= self._max_extent[1] ))[0] + y_indices = np.where( + (y >= self._min_extent[1]) & (y <= self._max_extent[1]) + )[0] indices = np.asarray(y_indices, dtype=int) else: if self.is_child: - mask = self.circular_mask(self._parent.position[2], self._parent.position[3], self._origin, self._extent) + mask = self.circular_mask( + self._parent.position[2], + self._parent.position[3], + self._origin, + self._extent, + ) else: - mask = self.circular_mask(self.position[2], self.position[3], self._origin, self._extent) + mask = self.circular_mask( + self.position[2], self.position[3], self._origin, self._extent + ) img = np.zeros_like(mask) img[np.asarray(y, dtype=int), np.asarray(x, dtype=int)] = 1 temp = np.where(img & mask) @@ -156,7 +227,7 @@ class Region(object): indices = self.points_in_region(x, y, analysis_type) if len(indices) == 0: return np.array([]), np.array([]) - + diffs = np.diff(indices) if len(diffs) == sum(diffs): entering = [time[indices[0]]] @@ -164,7 +235,7 @@ class Region(object): else: entering = [] leaving = [] - jumps = np.where(diffs > 1)[0] + jumps = np.where(diffs > 1)[0] start = time[indices[0]] for i in range(len(jumps)): end = time[indices[jumps[i]]] @@ -193,22 +264,37 @@ class Region(object): class Arena(Region): - - def __init__(self, origin, extent, inverted_y=True, name="", arena_shape=RegionShape.Rectangular, - illumination=Illumination.Backlight) -> None: + def __init__( + self, + origin, + extent, + inverted_y=True, + name="", + arena_shape=RegionShape.Rectangular, + illumination=Illumination.Backlight, + ) -> None: super().__init__(origin, extent, inverted_y, name, arena_shape) self._illumination = illumination self.regions = {} - def add_region(self, name, origin, extent, shape_type=RegionShape.Rectangular, region=None): + def add_region( + self, name, origin, extent, shape_type=RegionShape.Rectangular, region=None + ): if name is None or name in self.regions.keys(): - raise ValueError("Region name '{name}' is invalid. The name must not be None and must be unique among the regions.") + raise ValueError( + "Region name '{name}' is invalid. The name must not be None and must be unique among the regions." + ) if region is None: - region = Region(origin, extent, name=name, region_shape=shape_type, parent=self) + region = Region( + origin, extent, name=name, region_shape=shape_type, parent=self + ) else: region._parent = self - if ~self.fits(region): - print(f"Warning! Region {region.name} with size {region.position} does fit into {self.name} with size {self.position}!") + doesfit = self.fits(region) + if not doesfit: + logging.warn( + f"Warning! Region {region.name} with size {region.position} does fit into {self.name} with size {self.position}!" + ) self.regions[name] = region def remove_region(self, name): @@ -276,8 +362,10 @@ if __name__ == "__main__": a.add_region("circ", (600, 400), 150, shape_type=RegionShape.Circular) axis = a.plot() x = np.linspace(a.position[0], a.position[0] + a.position[2] - 1, 100, dtype=int) - y = np.asarray((np.sin(x*0.01) + 1) * a.position[3] / 2 + a.position[1] -1, dtype=int) - #y = np.linspace(a.position[1], a.position[1] + a.position[3] - 1, 100, dtype=int) + y = np.asarray( + (np.sin(x * 0.01) + 1) * a.position[3] / 2 + a.position[1] - 1, dtype=int + ) + # y = np.linspace(a.position[1], a.position[1] + a.position[3] - 1, 100, dtype=int) axis.scatter(x, y, c="k", s=2) ind = a.regions[3].points_in_region(x, y) @@ -288,14 +376,13 @@ if __name__ == "__main__": if len(ind) > 0: axis.scatter(x[ind], y[ind] - 10, label="circ collapseX") - ind = a.regions[3].points_in_region(x, y, AnalysisType.CollapseY) + ind = a.regions[3].points_in_region(x, y, AnalysisType.CollapseY) if len(ind) > 0: axis.scatter(x[ind], y[ind] + 10, label="circ collapseY") - ind = a.regions[0].points_in_region(x, y, AnalysisType.CollapseX) if len(ind) > 0: - axis.scatter(x[ind], y[ind]-10, label="rect collapseX") + axis.scatter(x[ind], y[ind] - 10, label="rect collapseX") ind = a.regions[1].points_in_region(x, y, AnalysisType.CollapseY) if len(ind) > 0: @@ -303,9 +390,9 @@ if __name__ == "__main__": ind = a.regions[2].points_in_region(x, y, AnalysisType.Full) if len(ind) > 0: - axis.scatter(x[ind], y[ind]+20, label="rect full") + axis.scatter(x[ind], y[ind] + 20, label="rect full") axis.legend() plt.show() a.plot() - plt.show() \ No newline at end of file + plt.show() diff --git a/etrack/tracking_data.py b/etrack/tracking_data.py index abcd783..a945975 100644 --- a/etrack/tracking_data.py +++ b/etrack/tracking_data.py @@ -8,9 +8,21 @@ class TrackingData(object): Using the 'quality_threshold', 'temporal_limits', or the 'position_limits' data can be filtered (see filter_tracks function). The 'interpolate' function allows to fill up the gaps that be result from filtering with linearly interpolated data points. - More may follow... + More may follow... """ - def __init__(self, x, y, time, quality, node="", fps=None, quality_threshold=None, temporal_limits=None, position_limits=None) -> None: + + def __init__( + self, + x, + y, + time, + quality, + node="", + fps=None, + quality_threshold=None, + temporal_limits=None, + position_limits=None, + ) -> None: self._orgx = x self._orgy = y self._orgtime = time @@ -35,11 +47,13 @@ class TrackingData(object): def interpolate(self, start_time=None, end_time=None, min_count=5): if len(self._x) < min_count: - print(f"{self._node} data has less than {min_count} data points with sufficient quality ({len(self._x)})!") + print( + f"{self._node} data has less than {min_count} data points with sufficient quality ({len(self._x)})!" + ) return None, None, None start = self._time[0] if start_time is None else start_time end = self._time[-1] if end_time is None else end_time - time = np.arange(start, end, 1./self._fps) + time = np.arange(start, end, 1.0 / self._fps) x = np.interp(time, self._time, self._x) y = np.interp(time, self._time, self._y) @@ -56,7 +70,7 @@ class TrackingData(object): Parameters ---------- new_threshold : float - + """ self._threshold = new_threshold @@ -77,8 +91,12 @@ class TrackingData(object): ------ ValueError, if new_value is not a 4-tuple """ - if new_limits is not None and not (isinstance(new_limits, (tuple, list)) and len(new_limits) == 4): - raise ValueError(f"The new_limits vector must be a 4-tuple of the form (x, y, width, height)") + if new_limits is not None and not ( + isinstance(new_limits, (tuple, list)) and len(new_limits) == 4 + ): + raise ValueError( + f"The new_limits vector must be a 4-tuple of the form (x, y, width, height)" + ) self._position_limits = new_limits @property @@ -94,8 +112,12 @@ class TrackingData(object): new_limits : 2-tuple The new limits in the form (start, end) given in seconds. """ - if new_limits is not None and not (isinstance(new_limits, (tuple, list)) and len(new_limits) == 2): - raise ValueError(f"The new_limits vector must be a 2-tuple of the form (start, end). ") + if new_limits is not None and not ( + isinstance(new_limits, (tuple, list)) and len(new_limits) == 2 + ): + raise ValueError( + f"The new_limits vector must be a 2-tuple of the form (start, end). " + ) self._time_limits = new_limits def filter_tracks(self, align_time=True): @@ -115,16 +137,22 @@ class TrackingData(object): if self.position_limits is not None: x_max = self.position_limits[0] + self.position_limits[2] y_max = self.position_limits[1] + self.position_limits[3] - indices = np.where((self._x >= self.position_limits[0]) & (self._x < x_max) & - (self._y >= self.position_limits[1]) & (self._y < y_max)) + indices = np.where( + (self._x >= self.position_limits[0]) + & (self._x < x_max) + & (self._y >= self.position_limits[1]) + & (self._y < y_max) + ) self._x = self._x[indices] self._y = self._y[indices] self._time = self._time[indices] - self._time[0] if align_time else 0.0 self._quality = self._quality[indices] if self.temporal_limits is not None: - indices = np.where((self._time >= self.temporal_limits[0]) & - (self._time < self.temporal_limits[1])) + indices = np.where( + (self._time >= self.temporal_limits[0]) + & (self._time < self.temporal_limits[1]) + ) self._x = self._x[indices] self._y = self._y[indices] self._time = self._time[indices] @@ -138,7 +166,7 @@ class TrackingData(object): self._quality = self._quality[indices] def positions(self): - """Returns the filtered data (if filters have been applied). + """Returns the filtered data (if filters have been applied). Returns ------- @@ -154,7 +182,7 @@ class TrackingData(object): return self._x, self._y, self._time, self._quality def speed(self): - """ Returns the agent's speed as a function of time and position. The speed estimation is associated to the time/position between two sample points. + """Returns the agent's speed as a function of time and position. The speed estimation is associated to the time/position between two sample points. Returns ------- @@ -165,7 +193,9 @@ class TrackingData(object): tuple of np.ndarray The position """ - speed = np.sqrt(np.diff(self._x)**2 + np.diff(self._y)**2) / np.diff(self._time) + speed = np.sqrt(np.diff(self._x) ** 2 + np.diff(self._y) ** 2) / np.diff( + self._time + ) t = self._time[:-1] + np.diff(self._time) / 2 x = self._x[:-1] + np.diff(self._x) / 2 y = self._y[:-1] + np.diff(self._y) / 2 @@ -174,4 +204,4 @@ class TrackingData(object): def __repr__(self) -> str: s = f"Tracking data of node '{self._node}'!" - return s \ No newline at end of file + return s