[arena] add time in region function

This commit is contained in:
Jan Grewe 2022-09-12 16:57:26 +02:00
parent 3e1cbe4b9b
commit 701cda1069

View File

@ -99,13 +99,15 @@ class Region(object):
FIXME: some of this can probably be solved using linear algebra, what with multiple exact same points?
"""
if self._shape_type == RegionShape.Rectangular or (self._shape_type == RegionShape.Circular and analysis_type != AnalysisType.Full):
x_indices = np.where((x >= self._min_extent[0]) & (x <= self._max_extent[0] ))[0]
y_indices = np.where((y >= self._min_extent[1]) & (y <= self._max_extent[1] ))[0]
if analysis_type == AnalysisType.Full:
indices = np.array(list(set(x_indices).intersection(set(y_indices))), dtype=int)
indices = np.where(((y >= self._min_extent[1]) & (y <= self._max_extent[1])) &
((x >= self._min_extent[0]) & (x <= self._max_extent[0])))[0]
indices = np.array(indices, dtype=int)
elif analysis_type == AnalysisType.CollapseX:
x_indices = np.where((x >= self._min_extent[0]) & (x <= self._max_extent[0] ))[0]
indices = np.asarray(x_indices, dtype=int)
else:
y_indices = np.where((y >= self._min_extent[1]) & (y <= self._max_extent[1] ))[0]
indices = np.asarray(y_indices, dtype=int)
else:
if self.is_child:
@ -124,6 +126,49 @@ class Region(object):
indices = np.array(indices)
return indices
def time_in_region(self, x, y, time, analysis_type=AnalysisType.Full):
"""_summary_
Parameters
----------
x : _type_
_description_
y : _type_
_description_
time : _type_
_description_
analysis_type : _type_, optional
_description_, by default AnalysisType.Full
Returns
-------
_type_
_description_
"""
indices = self.points_in_region(x, y, analysis_type)
if len(indices) == 0:
return np.array([]), np.array([])
diffs = np.diff(indices)
if len(diffs) == sum(diffs):
entering = [time[indices[0]]]
leaving = [time[indices[-1]]]
else:
entering = []
leaving = []
jumps = np.where(diffs > 1)[0]
start = time[indices[0]]
for i in range(len(jumps)):
end = time[indices[jumps[i]]]
entering.append(start)
leaving.append(end)
start = time[indices[jumps[i] + 1]]
end = time[indices[-1]]
entering.append(start)
leaving.append(end)
return np.array(entering), np.array(leaving)
def patch(self, **kwargs):
if "fc" not in kwargs:
kwargs["fc"] = None
@ -179,7 +224,7 @@ class Arena(Region):
def region_vector(self, x, y):
"""Returns a vector that contains the region names within which the agent was found.
FIXME: This does not work well with overlapping regions!@!
Parameters
----------
x : np.array
@ -198,6 +243,12 @@ class Arena(Region):
rv[indices] = r
return rv
def in_region(self, x, y):
tmp = {}
for r in self.regions:
indices = self.regions[r].points_in_region(x, y)
tmp[r] = indices
return tmp
if __name__ == "__main__":
a = Arena((0, 0), (1024, 768), name="arena", arena_shape=RegionShape.Rectangular)