latest changes

This commit is contained in:
Jan Grewe 2022-12-03 11:00:24 +01:00
parent 0291ef088a
commit 2bba750e1f
2 changed files with 45 additions and 19 deletions

View File

@ -5,7 +5,7 @@ import matplotlib.patches as patches
from skimage.draw import disk
from .util import RegionShape, AnalysisType, Illumination
from IPython import embed
class Region(object):
@ -30,6 +30,14 @@ class Region(object):
return mask
@property
def name(self):
return self._name
@property
def inverted_y(self):
return self._inverted_y
@property
def _max_extent(self):
if self._shape_type == RegionShape.Rectangular:
@ -127,7 +135,7 @@ class Region(object):
return indices
def time_in_region(self, x, y, time, analysis_type=AnalysisType.Full):
"""_summary_
"""Returns the entering and leaving times at which the animal entered and left a region. In case the animal was not observed after entering this region (for example when hidden in a tube) the leaving time is the maximum time entry.
Parameters
----------
@ -148,7 +156,7 @@ class Region(object):
indices = self.points_in_region(x, y, analysis_type)
if len(indices) == 0:
return np.array([]), np.array([])
diffs = np.diff(indices)
if len(diffs) == sum(diffs):
entering = [time[indices[0]]]
@ -199,10 +207,9 @@ class Arena(Region):
region = Region(origin, extent, name=name, region_shape=shape_type, parent=self)
else:
region._parent = self
if self.fits(region):
self.regions[name] = region
else:
Warning(f"Region {region} fits not! Not added to the list of regions!")
if ~self.fits(region):
print(f"Warning! Region {region.name} with size {region.position} does fit into {self.name} with size {self.position}!")
self.regions[name] = region
def remove_region(self, name):
if name in self.regions:
@ -217,9 +224,13 @@ class Arena(Region):
axis = fig.add_subplot(111)
axis.add_patch(self.patch())
axis.set_xlim([self._origin[0], self._max_extent[0]])
axis.set_ylim([self._origin[1], self._max_extent[1]])
if self.inverted_y:
axis.set_ylim([self._max_extent[1], self._origin[1]])
else:
axis.set_ylim([self._origin[1], self._max_extent[1]])
for r in self.regions:
axis.add_patch(r.patch())
axis.add_patch(self.regions[r].patch())
return axis
def region_vector(self, x, y):
@ -250,6 +261,13 @@ class Arena(Region):
tmp[r] = indices
return tmp
def __getitem__(self, key):
if isinstance(key, (str)):
return self.regions[key]
else:
return self.regions[self.regions.keys()[key]]
if __name__ == "__main__":
a = Arena((0, 0), (1024, 768), name="arena", arena_shape=RegionShape.Rectangular)
a.add_region("small rect1", (0, 0), (100, 300))

View File

@ -10,12 +10,14 @@ from IPython import embed
class NixtrackData(object):
def __init__(self, results_file, crop=(0, 0)) -> None:
def __init__(self, filename, crop=(0, 0)) -> None:
"""
If the video data was cropped before tracking and the tracked positions are with respect to the cropped images, we may want to correct for this to convert the data back to absolute positions in the video frame.
Parameters
----------
filename : str
full filename
crop : 2-tuple
tuple of (xoffset, yoffset)
@ -23,13 +25,13 @@ class NixtrackData(object):
------
ValueError if crop value is not a 2-tuple
"""
if not os.path.exists(results_file):
raise ValueError("File %s does not exist!" % results_file)
if not os.path.exists(filename):
raise ValueError("File %s does not exist!" % filename)
if not isinstance(crop, tuple) or len(crop) < 2:
raise ValueError("Cropping info must be a 2-tuple of (x, y)")
self._file_name = results_file
self._file_name = filename
self._crop = crop
self._dataset = nt.Dataset(self._file_name, nt.util.FileMode.ReadOnly)
self._dataset = nt.Dataset(self._file_name)
if not self._dataset.is_open:
raise ValueError(f"An error occurred opening file {self._file_name}! File is not open!")
@ -50,14 +52,20 @@ class NixtrackData(object):
def tracks(self):
return self._dataset.tracks
def track(self, track=None, bodypart=0):
def track(self, bodypart=0, fps=None):
if isinstance(bodypart, nb.Number):
bp = self.bodyparts[bodypart]
elif isinstance(bodypart, str) and bodypart in self.bodyparts:
elif isinstance(bodypart, (str)) and bodypart in self.bodyparts:
bp = bodypart
else:
raise ValueError(f"Body part {bodypart} is not a tracked node!")
positions, time, iscore, nscore = self._dataset.positions(node=bp, axis_type=nt.AxisType.Time)
embed()
if fps is None:
fps = self._dataset.fps
return TrackingData(positions[:, 0], positions[:, 1], time, nscore, bp, fps=self._dataset.fps)
positions, time, _, nscore = self._dataset.positions(node=bp, axis_type=nt.AxisType.Time)
valid = ~np.isnan(positions[:, 0])
positions = positions[valid,:]
time = time[valid]
score = nscore[valid]
return TrackingData(positions[:, 0], positions[:, 1], time, score, bp, fps=fps)