From 701cda106979c43da116e31f4dc17df016381bd4 Mon Sep 17 00:00:00 2001 From: Jan Grewe Date: Mon, 12 Sep 2022 16:57:26 +0200 Subject: [PATCH] [arena] add time in region function --- etrack/arena.py | 59 +++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 55 insertions(+), 4 deletions(-) diff --git a/etrack/arena.py b/etrack/arena.py index 6f270db..44997fa 100644 --- a/etrack/arena.py +++ b/etrack/arena.py @@ -99,13 +99,15 @@ class Region(object): FIXME: some of this can probably be solved using linear algebra, what with multiple exact same points? """ if self._shape_type == RegionShape.Rectangular or (self._shape_type == RegionShape.Circular and analysis_type != AnalysisType.Full): - x_indices = np.where((x >= self._min_extent[0]) & (x <= self._max_extent[0] ))[0] - y_indices = np.where((y >= self._min_extent[1]) & (y <= self._max_extent[1] ))[0] if analysis_type == AnalysisType.Full: - indices = np.array(list(set(x_indices).intersection(set(y_indices))), dtype=int) + indices = np.where(((y >= self._min_extent[1]) & (y <= self._max_extent[1])) & + ((x >= self._min_extent[0]) & (x <= self._max_extent[0])))[0] + indices = np.array(indices, dtype=int) elif analysis_type == AnalysisType.CollapseX: + x_indices = np.where((x >= self._min_extent[0]) & (x <= self._max_extent[0] ))[0] indices = np.asarray(x_indices, dtype=int) else: + y_indices = np.where((y >= self._min_extent[1]) & (y <= self._max_extent[1] ))[0] indices = np.asarray(y_indices, dtype=int) else: if self.is_child: @@ -124,6 +126,49 @@ class Region(object): indices = np.array(indices) return indices + def time_in_region(self, x, y, time, analysis_type=AnalysisType.Full): + """_summary_ + + Parameters + ---------- + x : _type_ + _description_ + y : _type_ + _description_ + time : _type_ + _description_ + analysis_type : _type_, optional + _description_, by default AnalysisType.Full + + Returns + ------- + _type_ + _description_ + """ + indices = self.points_in_region(x, y, analysis_type) + if len(indices) == 0: + return np.array([]), np.array([]) + + diffs = np.diff(indices) + if len(diffs) == sum(diffs): + entering = [time[indices[0]]] + leaving = [time[indices[-1]]] + else: + entering = [] + leaving = [] + jumps = np.where(diffs > 1)[0] + start = time[indices[0]] + for i in range(len(jumps)): + end = time[indices[jumps[i]]] + entering.append(start) + leaving.append(end) + start = time[indices[jumps[i] + 1]] + + end = time[indices[-1]] + entering.append(start) + leaving.append(end) + return np.array(entering), np.array(leaving) + def patch(self, **kwargs): if "fc" not in kwargs: kwargs["fc"] = None @@ -179,7 +224,7 @@ class Arena(Region): def region_vector(self, x, y): """Returns a vector that contains the region names within which the agent was found. - + FIXME: This does not work well with overlapping regions!@! Parameters ---------- x : np.array @@ -198,6 +243,12 @@ class Arena(Region): rv[indices] = r return rv + def in_region(self, x, y): + tmp = {} + for r in self.regions: + indices = self.regions[r].points_in_region(x, y) + tmp[r] = indices + return tmp if __name__ == "__main__": a = Arena((0, 0), (1024, 768), name="arena", arena_shape=RegionShape.Rectangular)