[arena] add time in region function
This commit is contained in:
parent
3e1cbe4b9b
commit
701cda1069
@ -99,13 +99,15 @@ class Region(object):
|
||||
FIXME: some of this can probably be solved using linear algebra, what with multiple exact same points?
|
||||
"""
|
||||
if self._shape_type == RegionShape.Rectangular or (self._shape_type == RegionShape.Circular and analysis_type != AnalysisType.Full):
|
||||
x_indices = np.where((x >= self._min_extent[0]) & (x <= self._max_extent[0] ))[0]
|
||||
y_indices = np.where((y >= self._min_extent[1]) & (y <= self._max_extent[1] ))[0]
|
||||
if analysis_type == AnalysisType.Full:
|
||||
indices = np.array(list(set(x_indices).intersection(set(y_indices))), dtype=int)
|
||||
indices = np.where(((y >= self._min_extent[1]) & (y <= self._max_extent[1])) &
|
||||
((x >= self._min_extent[0]) & (x <= self._max_extent[0])))[0]
|
||||
indices = np.array(indices, dtype=int)
|
||||
elif analysis_type == AnalysisType.CollapseX:
|
||||
x_indices = np.where((x >= self._min_extent[0]) & (x <= self._max_extent[0] ))[0]
|
||||
indices = np.asarray(x_indices, dtype=int)
|
||||
else:
|
||||
y_indices = np.where((y >= self._min_extent[1]) & (y <= self._max_extent[1] ))[0]
|
||||
indices = np.asarray(y_indices, dtype=int)
|
||||
else:
|
||||
if self.is_child:
|
||||
@ -124,6 +126,49 @@ class Region(object):
|
||||
indices = np.array(indices)
|
||||
return indices
|
||||
|
||||
def time_in_region(self, x, y, time, analysis_type=AnalysisType.Full):
|
||||
"""_summary_
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : _type_
|
||||
_description_
|
||||
y : _type_
|
||||
_description_
|
||||
time : _type_
|
||||
_description_
|
||||
analysis_type : _type_, optional
|
||||
_description_, by default AnalysisType.Full
|
||||
|
||||
Returns
|
||||
-------
|
||||
_type_
|
||||
_description_
|
||||
"""
|
||||
indices = self.points_in_region(x, y, analysis_type)
|
||||
if len(indices) == 0:
|
||||
return np.array([]), np.array([])
|
||||
|
||||
diffs = np.diff(indices)
|
||||
if len(diffs) == sum(diffs):
|
||||
entering = [time[indices[0]]]
|
||||
leaving = [time[indices[-1]]]
|
||||
else:
|
||||
entering = []
|
||||
leaving = []
|
||||
jumps = np.where(diffs > 1)[0]
|
||||
start = time[indices[0]]
|
||||
for i in range(len(jumps)):
|
||||
end = time[indices[jumps[i]]]
|
||||
entering.append(start)
|
||||
leaving.append(end)
|
||||
start = time[indices[jumps[i] + 1]]
|
||||
|
||||
end = time[indices[-1]]
|
||||
entering.append(start)
|
||||
leaving.append(end)
|
||||
return np.array(entering), np.array(leaving)
|
||||
|
||||
def patch(self, **kwargs):
|
||||
if "fc" not in kwargs:
|
||||
kwargs["fc"] = None
|
||||
@ -179,7 +224,7 @@ class Arena(Region):
|
||||
|
||||
def region_vector(self, x, y):
|
||||
"""Returns a vector that contains the region names within which the agent was found.
|
||||
|
||||
FIXME: This does not work well with overlapping regions!@!
|
||||
Parameters
|
||||
----------
|
||||
x : np.array
|
||||
@ -198,6 +243,12 @@ class Arena(Region):
|
||||
rv[indices] = r
|
||||
return rv
|
||||
|
||||
def in_region(self, x, y):
|
||||
tmp = {}
|
||||
for r in self.regions:
|
||||
indices = self.regions[r].points_in_region(x, y)
|
||||
tmp[r] = indices
|
||||
return tmp
|
||||
|
||||
if __name__ == "__main__":
|
||||
a = Arena((0, 0), (1024, 768), name="arena", arena_shape=RegionShape.Rectangular)
|
||||
|
Loading…
Reference in New Issue
Block a user