From 2bba750e1fcb08140e8f5d45cf3a0922dbefa6ef Mon Sep 17 00:00:00 2001 From: Jan Grewe Date: Sat, 3 Dec 2022 11:00:24 +0100 Subject: [PATCH] latest changes --- etrack/arena.py | 36 +++++++++++++++++++++++++++--------- etrack/io/nixtrack_data.py | 28 ++++++++++++++++++---------- 2 files changed, 45 insertions(+), 19 deletions(-) diff --git a/etrack/arena.py b/etrack/arena.py index 44997fa..52061af 100644 --- a/etrack/arena.py +++ b/etrack/arena.py @@ -5,7 +5,7 @@ import matplotlib.patches as patches from skimage.draw import disk from .util import RegionShape, AnalysisType, Illumination - +from IPython import embed class Region(object): @@ -30,6 +30,14 @@ class Region(object): return mask + @property + def name(self): + return self._name + + @property + def inverted_y(self): + return self._inverted_y + @property def _max_extent(self): if self._shape_type == RegionShape.Rectangular: @@ -127,7 +135,7 @@ class Region(object): return indices def time_in_region(self, x, y, time, analysis_type=AnalysisType.Full): - """_summary_ + """Returns the entering and leaving times at which the animal entered and left a region. In case the animal was not observed after entering this region (for example when hidden in a tube) the leaving time is the maximum time entry. Parameters ---------- @@ -148,7 +156,7 @@ class Region(object): indices = self.points_in_region(x, y, analysis_type) if len(indices) == 0: return np.array([]), np.array([]) - + diffs = np.diff(indices) if len(diffs) == sum(diffs): entering = [time[indices[0]]] @@ -199,10 +207,9 @@ class Arena(Region): region = Region(origin, extent, name=name, region_shape=shape_type, parent=self) else: region._parent = self - if self.fits(region): - self.regions[name] = region - else: - Warning(f"Region {region} fits not! Not added to the list of regions!") + if ~self.fits(region): + print(f"Warning! Region {region.name} with size {region.position} does fit into {self.name} with size {self.position}!") + self.regions[name] = region def remove_region(self, name): if name in self.regions: @@ -217,9 +224,13 @@ class Arena(Region): axis = fig.add_subplot(111) axis.add_patch(self.patch()) axis.set_xlim([self._origin[0], self._max_extent[0]]) - axis.set_ylim([self._origin[1], self._max_extent[1]]) + + if self.inverted_y: + axis.set_ylim([self._max_extent[1], self._origin[1]]) + else: + axis.set_ylim([self._origin[1], self._max_extent[1]]) for r in self.regions: - axis.add_patch(r.patch()) + axis.add_patch(self.regions[r].patch()) return axis def region_vector(self, x, y): @@ -250,6 +261,13 @@ class Arena(Region): tmp[r] = indices return tmp + def __getitem__(self, key): + if isinstance(key, (str)): + return self.regions[key] + else: + return self.regions[self.regions.keys()[key]] + + if __name__ == "__main__": a = Arena((0, 0), (1024, 768), name="arena", arena_shape=RegionShape.Rectangular) a.add_region("small rect1", (0, 0), (100, 300)) diff --git a/etrack/io/nixtrack_data.py b/etrack/io/nixtrack_data.py index 9eeab79..6d0ec8d 100644 --- a/etrack/io/nixtrack_data.py +++ b/etrack/io/nixtrack_data.py @@ -10,12 +10,14 @@ from IPython import embed class NixtrackData(object): - def __init__(self, results_file, crop=(0, 0)) -> None: + def __init__(self, filename, crop=(0, 0)) -> None: """ If the video data was cropped before tracking and the tracked positions are with respect to the cropped images, we may want to correct for this to convert the data back to absolute positions in the video frame. Parameters ---------- + filename : str + full filename crop : 2-tuple tuple of (xoffset, yoffset) @@ -23,13 +25,13 @@ class NixtrackData(object): ------ ValueError if crop value is not a 2-tuple """ - if not os.path.exists(results_file): - raise ValueError("File %s does not exist!" % results_file) + if not os.path.exists(filename): + raise ValueError("File %s does not exist!" % filename) if not isinstance(crop, tuple) or len(crop) < 2: raise ValueError("Cropping info must be a 2-tuple of (x, y)") - self._file_name = results_file + self._file_name = filename self._crop = crop - self._dataset = nt.Dataset(self._file_name, nt.util.FileMode.ReadOnly) + self._dataset = nt.Dataset(self._file_name) if not self._dataset.is_open: raise ValueError(f"An error occurred opening file {self._file_name}! File is not open!") @@ -50,14 +52,20 @@ class NixtrackData(object): def tracks(self): return self._dataset.tracks - def track(self, track=None, bodypart=0): + def track(self, bodypart=0, fps=None): if isinstance(bodypart, nb.Number): bp = self.bodyparts[bodypart] - elif isinstance(bodypart, str) and bodypart in self.bodyparts: + elif isinstance(bodypart, (str)) and bodypart in self.bodyparts: bp = bodypart else: raise ValueError(f"Body part {bodypart} is not a tracked node!") - positions, time, iscore, nscore = self._dataset.positions(node=bp, axis_type=nt.AxisType.Time) - embed() + if fps is None: + fps = self._dataset.fps - return TrackingData(positions[:, 0], positions[:, 1], time, nscore, bp, fps=self._dataset.fps) \ No newline at end of file + positions, time, _, nscore = self._dataset.positions(node=bp, axis_type=nt.AxisType.Time) + valid = ~np.isnan(positions[:, 0]) + positions = positions[valid,:] + time = time[valid] + score = nscore[valid] + + return TrackingData(positions[:, 0], positions[:, 1], time, score, bp, fps=fps)