From 1dd318f23e9da9cbea28bf6441978b4167fcb01e Mon Sep 17 00:00:00 2001 From: Jan Grewe Date: Thu, 30 May 2024 23:59:16 +0200 Subject: [PATCH] restructuring project with toml add a test and docstrings --- .gitignore | 1 + pyproject.toml | 48 +++++ setup.py | 43 ----- {etrack => src/etrack}/__init__.py | 3 +- {etrack => src/etrack}/arena.py | 213 ++++++++++++++++----- {etrack => src/etrack}/image_marker.py | 0 {etrack => src/etrack}/info.json | 0 {etrack => src/etrack}/info.py | 3 +- {etrack => src/etrack}/io/__init__.py | 0 {etrack => src/etrack}/io/dlc_data.py | 0 {etrack => src/etrack}/io/nixtrack_data.py | 0 {etrack => src/etrack}/tracking_data.py | 0 {etrack => src/etrack}/tracking_result.py | 0 {etrack => src/etrack}/util.py | 8 + test/test_arena.py | 82 ++++++++ 15 files changed, 307 insertions(+), 94 deletions(-) create mode 100644 pyproject.toml delete mode 100644 setup.py rename {etrack => src/etrack}/__init__.py (72%) rename {etrack => src/etrack}/arena.py (68%) rename {etrack => src/etrack}/image_marker.py (100%) rename {etrack => src/etrack}/info.json (100%) rename {etrack => src/etrack}/info.py (94%) rename {etrack => src/etrack}/io/__init__.py (100%) rename {etrack => src/etrack}/io/dlc_data.py (100%) rename {etrack => src/etrack}/io/nixtrack_data.py (100%) rename {etrack => src/etrack}/tracking_data.py (100%) rename {etrack => src/etrack}/tracking_result.py (100%) rename {etrack => src/etrack}/util.py (65%) create mode 100644 test/test_arena.py diff --git a/.gitignore b/.gitignore index 6eb27cf..d94fcb2 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,4 @@ requires.txt SOURCES.txt dependency_links.txt top_level.txt +.DS_Store \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..a54656d --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,48 @@ +[build-system] +requires = ["setuptools"] +build-backend = "setuptools.build_meta" + +[project] +name = "etrack" +dynamic = ["version"] +dependencies = [ + "nixio>=1.5", + "nixtrack", + "numpy", + "matplotlib", + "opencv-python", + "pandas", + "scikit-image", +] +requires-python = ">=3.6" +authors = [ + {name = "Jan Grewe", email = "jan.grewe@g-node.org"}, +] +maintainers = [ + {name = "Jan Grewe", email = "jan.grewe@g-node.org"}, +] +description = "Goodies for working with tracking data of efishes." +readme = "README.md" +license = {file = "LICENSE"} +classifiers = [ + "Development Status :: 4 - Beta", + "Environment :: Console", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: BSD-2-Clause", + "Natural Language :: English", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.12", + "Operating System :: OS Independent", + "Topic :: Scientific/Engineering", + "Topic :: Software Development :: Libraries :: Python Modules", +] + +[project.urls] +Repository = "https://github.com/bendalab/etrack" + +[tool.setuptools.dynamic] +version = {attr = "etrack.info.VERSION"} + +[tool.pytest.ini_options] +pythonpath = "src" \ No newline at end of file diff --git a/setup.py b/setup.py deleted file mode 100644 index 5837e78..0000000 --- a/setup.py +++ /dev/null @@ -1,43 +0,0 @@ -import os -from setuptools import setup -import json -from setuptools import setup - -# load info from nixio/info.json -with open(os.path.join("etrack", "info.json")) as infofile: - infodict = json.load(infofile) - - -NAME = "etrack" -VERSION = infodict["VERSION"] -AUTHOR = infodict["AUTHOR"] -CONTACT = infodict["CONTACT"] -BRIEF = infodict["BRIEF"] -HOMEPAGE = infodict["HOMEPAGE"] -CLASSIFIERS = "science" -README = "README.md" - -with open(README) as f: - description_text = f.read() -DESCRIPTION = description_text - -packages = [ - "etrack", "etrack.io" -] - -install_req = ["h5py", "pandas", "matplotlib", "numpy", "opencv-python"] - -setup( - name=NAME, - version=VERSION, - description=DESCRIPTION, - author=AUTHOR, - author_email=CONTACT, - packages=packages, - install_requires=install_req, - include_package_data=True, - long_description=description_text, - long_description_content_type="text/markdown", - classifiers=CLASSIFIERS, - license="BSD" -) \ No newline at end of file diff --git a/etrack/__init__.py b/src/etrack/__init__.py similarity index 72% rename from etrack/__init__.py rename to src/etrack/__init__.py index accbac9..dcaf5a6 100644 --- a/etrack/__init__.py +++ b/src/etrack/__init__.py @@ -3,4 +3,5 @@ from .tracking_result import TrackingResult, coordinate_transformation from .arena import Arena, Region from .tracking_data import TrackingData from .io.dlc_data import DLCReader -from .io.nixtrack_data import NixtrackData \ No newline at end of file +from .io.nixtrack_data import NixtrackData +from .util import RegionShape, AnalysisType \ No newline at end of file diff --git a/etrack/arena.py b/src/etrack/arena.py similarity index 68% rename from etrack/arena.py rename to src/etrack/arena.py index c01b3e0..5b74ccb 100644 --- a/etrack/arena.py +++ b/src/etrack/arena.py @@ -10,26 +10,52 @@ from IPython import embed class Region(object): - def __init__( - self, - origin, - extent, - inverted_y=True, - name="", - region_shape=RegionShape.Rectangular, - parent=None, - ) -> None: + """ + Class representing a region (of interest). Regions can be either circular or rectangular. + A Region can have a parent, i.e. it is contained inside a parent region. It can also have children. + + Coordinates are given in absolute coordinates. The extent is treated depending on the shape. In case of a circular + shape, it is the radius and the origin is the center of the circle. Otherwise the origin is the bottom, or top-left corner, depending on the y-axis orientation, if inverted, then it is top-left. FIXME: check this + + """ + def __init__(self, origin, extent, inverted_y=True, name="", region_shape=RegionShape.Rectangular, parent=None) -> None: + """Region constructor. + Parameters + ---------- + origin : 2-tuple + x, and y coordinates + extent : scalar or 2-tuple, scalar only allowed to circular regions, 2-tuple for rectangular. + inverted_y : bool, optional + _description_, by default True + name : str, optional + _description_, by default "" + region_shape : _type_, optional + _description_, by default RegionShape.Rectangular + parent : _type_, optional + _description_, by default None + + Returns + ------- + _type_ + _description_ + + Raises + ------ + ValueError + Raises Value error when origin or extent are invalid + """ logging.debug( f"etrack.Region: Create {str(region_shape)} region {name} with props origin {origin}, extent {extent} and parent {parent}" ) - assert len(origin) == 2 - self._origin = origin - self._extent = extent - self._inverted_y = inverted_y + if len(origin) != 2: + raise ValueError("Region: origin must be 2-tuple!") + self._parent = parent self._name = name self._shape_type = region_shape + self._origin = origin self._check_extent(extent) - self._parent = parent + self._extent = extent + self._inverted_y = inverted_y @staticmethod def circular_mask(width, height, center, radius): @@ -62,7 +88,7 @@ class Region(object): self._origin[0] + self._extent, self._origin[1] + self._extent, ) - return max_extent + return np.asarray(max_extent) @property def _min_extent(self): @@ -73,7 +99,7 @@ class Region(object): self._origin[0] - self._extent, self._origin[1] - self._extent, ) - return min_extent + return np.asarray(min_extent) @property def xmax(self): @@ -122,7 +148,13 @@ class Region(object): def fits(self, other) -> bool: """ - Returns true if the other region fits inside this region! + Checks if the given region fits into the current region. + + Args: + other (Region): The region to check if it fits. + + Returns: + bool: True if the given region fits into the current region, False otherwise. """ assert isinstance(other, Region) does_fit = all( @@ -146,10 +178,16 @@ class Region(object): @property def is_child(self): + """ + Check if the current instance is a child. + + Returns: + bool: True if the instance has a parent, False otherwise. + """ return self._parent is not None def points_in_region(self, x, y, analysis_type=AnalysisType.Full): - """returns the indices of the points specified by 'x' and 'y' that fall into this region. + """Returns the indices of the points specified by 'x' and 'y' that fall into this region. Parameters ---------- @@ -210,23 +248,26 @@ class Region(object): and left a region. In case the animal was not observed after entering this region (for example when hidden in a tube) the leaving time is the maximum time entry. + Whether the full position, or only the x- or y-position should be considered + is controlled with the analysis_type parameter. Parameters ---------- - x : _type_ - _description_ - y : _type_ - _description_ - time : _type_ - _description_ - analysis_type : _type_, optional - _description_, by default AnalysisType.Full + x : np.ndarray + The animal's x-positions + y : np.ndarray + the animal's y-positions + time : np.ndarray + the time array + analysis_type : AnalysisType, optional + The type of analysis, by default AnalysisType.Full Returns ------- - _type_ - _description_ - + np.ndarray + The entering times + np.ndarray + The leaving times """ indices = self.points_in_region(x, y, analysis_type) if len(indices) == 0: @@ -253,30 +294,65 @@ class Region(object): return np.array(entering), np.array(leaving) def patch(self, **kwargs): - if "fc" not in kwargs: - kwargs["fc"] = None - kwargs["fill"] = False - if self._shape_type == RegionShape.Rectangular: - w = self.position[2] - h = self.position[3] - return patches.Rectangle(self._origin, w, h, **kwargs) - else: - return patches.Circle(self._origin, self._extent, **kwargs) + """ + Create and return a matplotlib patch object based on the shape type of the arena. + + Parameters: + - kwargs: Additional keyword arguments to customize the patch object. + + Returns: + - A matplotlib patch object representing the arena shape. + + If the 'fc' (facecolor) keyword argument is not provided, it will default to None. + If the 'fill' keyword argument is not provided, it will default to False. + + For rectangular arenas, the patch object will be a Rectangle with width and height + based on the arena's position. + For circular arenas, the patch object will be a Circle with radius based on the + arena's extent. + + Example usage: + ``` + arena = Arena() + patch = arena.patch(fc='blue', fill=True) + ax.add_patch(patch) + ``` + """ + if "fc" not in kwargs: + kwargs["fc"] = None + kwargs["fill"] = False + if self._shape_type == RegionShape.Rectangular: + w = self.position[2] + h = self.position[3] + return patches.Rectangle(self._origin, w, h, **kwargs) + else: + return patches.Circle(self._origin, self._extent, **kwargs) def __repr__(self): return f"Region: '{self._name}' of {self._shape_type} shape." class Arena(Region): - def __init__( - self, - origin, - extent, - inverted_y=True, - name="", - arena_shape=RegionShape.Rectangular, - illumination=Illumination.Backlight, - ) -> None: + """ + Class to represent the experimental arena. Arena is derived from Region and can be either rectangular or circular. + An arena can not have a parent. + See Region for more details. + """ + def __init__(self, origin, extent, inverted_y=True, name="", arena_shape=RegionShape.Rectangular, + illumination=Illumination.Backlight) -> None: + """ Construct a new Area with a given origin and extent. + + + Returns + ------- + _type_ + _description_ + + Raises + ------ + ValueError + _description_ + """ super().__init__(origin, extent, inverted_y, name, arena_shape) self._illumination = illumination self.regions = {} @@ -302,6 +378,16 @@ class Arena(Region): self.regions[name] = region def remove_region(self, name): + """ + Remove a region from the arena. + + Parameter: + name : str + The name of the region to remove. + + Returns: + None + """ if name in self.regions: self.regions.pop(name) @@ -309,6 +395,17 @@ class Arena(Region): return f"Arena: '{self._name}' of {self._shape_type} shape." def plot(self, axis=None): + """ + Plots the arena on the given axis. + + Parameters + ---------- + - axis (matplotlib.axes.Axes, optional): The axis on which to plot the arena. If not provided, a new figure and axis will be created. + + Returns + ------- + - matplotlib.axes.Axes: The axis on which the arena is plotted. + """ if axis is None: fig = plt.figure() axis = fig.add_subplot(111) @@ -336,8 +433,12 @@ class Arena(Region): Returns ------- np.array - vector of the same size as x and y. Each entry is the region to which the position is assinged to. If the point is not assigned to a region, the entry will be empty. + vector of the same size as x and y. Each entry is the region to which the position is assigned to. If the point is not assigned to a region, the entry will be empty. """ + if not isinstance(x, np.ndarray): + x = np.asarray(x) + if not isinstance(y, np.ndarray): + y = np.asarray(y) rv = np.empty(x.shape, dtype=str) for r in self.regions: indices = self.regions[r].points_in_region(x, y) @@ -345,8 +446,24 @@ class Arena(Region): return rv def in_region(self, x, y): + """ + Determines if the given coordinates (x, y) are within any of the defined regions in the arena. + + Parameters + ---------- + x : float + The x-coordinate of the point to check. + y : float + The y-coordinate of the point to check. + + Returns + ------- + dict: + A dictionary containing the region names as keys and a list of indices of points within each region as values. + """ tmp = {} for r in self.regions: + print(r) indices = self.regions[r].points_in_region(x, y) tmp[r] = indices return tmp diff --git a/etrack/image_marker.py b/src/etrack/image_marker.py similarity index 100% rename from etrack/image_marker.py rename to src/etrack/image_marker.py diff --git a/etrack/info.json b/src/etrack/info.json similarity index 100% rename from etrack/info.json rename to src/etrack/info.json diff --git a/etrack/info.py b/src/etrack/info.py similarity index 94% rename from etrack/info.py rename to src/etrack/info.py index 0d4a401..14d4179 100644 --- a/etrack/info.py +++ b/src/etrack/info.py @@ -22,5 +22,4 @@ AUTHOR = infodict["AUTHOR"] COPYRIGHT = infodict["COPYRIGHT"] CONTACT = infodict["CONTACT"] BRIEF = infodict["BRIEF"] -HOMEPAGE = infodict["HOMEPAGE"] -~ +HOMEPAGE = infodict["HOMEPAGE"] \ No newline at end of file diff --git a/etrack/io/__init__.py b/src/etrack/io/__init__.py similarity index 100% rename from etrack/io/__init__.py rename to src/etrack/io/__init__.py diff --git a/etrack/io/dlc_data.py b/src/etrack/io/dlc_data.py similarity index 100% rename from etrack/io/dlc_data.py rename to src/etrack/io/dlc_data.py diff --git a/etrack/io/nixtrack_data.py b/src/etrack/io/nixtrack_data.py similarity index 100% rename from etrack/io/nixtrack_data.py rename to src/etrack/io/nixtrack_data.py diff --git a/etrack/tracking_data.py b/src/etrack/tracking_data.py similarity index 100% rename from etrack/tracking_data.py rename to src/etrack/tracking_data.py diff --git a/etrack/tracking_result.py b/src/etrack/tracking_result.py similarity index 100% rename from etrack/tracking_result.py rename to src/etrack/tracking_result.py diff --git a/etrack/util.py b/src/etrack/util.py similarity index 65% rename from etrack/util.py rename to src/etrack/util.py index f61946f..8b4219c 100644 --- a/etrack/util.py +++ b/src/etrack/util.py @@ -6,6 +6,14 @@ class Illumination(Enum): class RegionShape(Enum): + """ + Enumeration representing the shape of a region. + + Attributes: + Circular: Represents a circular region. + Rectangular: Represents a rectangular region. + """ + Circular = 0 Rectangular = 1 diff --git a/test/test_arena.py b/test/test_arena.py new file mode 100644 index 0000000..e830057 --- /dev/null +++ b/test/test_arena.py @@ -0,0 +1,82 @@ +import pytest +import numpy as np +import matplotlib.pyplot as plt +import matplotlib.patches as mp + +from etrack import Arena, Region, RegionShape + + +def test_region(): + # Create a parent region + parent_region = Region((0, 0), (100, 100), name="parent", region_shape=RegionShape.Rectangular) + + # Create a child region + child_region = Region((10, 10), (50, 50), name="child", region_shape=RegionShape.Rectangular, parent=parent_region) + + # Test properties + assert child_region.name == "child" + assert child_region.inverted_y == True + assert (child_region._max_extent == np.array((60, 60))).all() + assert (child_region._min_extent == np.array((10, 10))).all() + assert child_region.xmax == 60 + assert child_region.xmin == 10 + assert child_region.ymin == 10 + assert child_region.ymax == 60 + assert child_region.position == (10, 10, 50, 50) + assert child_region.is_child == True + + # Test fits method + assert parent_region.fits(child_region) == True + + # Test points_in_region method + x = [15, 20, 25, 30, 35, 5] + y = [15, 20, 25, 30, 35, 5] + assert (child_region.points_in_region(x, y) == np.array([0, 1, 2, 3, 4])).all() + + # Test time_in_region method + x = [5, 15, 20, 25, 30, 35, 35] + y = [5, 15, 20, 25, 30, 35, 65] + time = np.arange(0, len(x), 1) + enter, leave = child_region.time_in_region(x, y, time) + assert enter[0] == 1 + assert leave[0] == 5 + + # Test patch method + patch = child_region.patch(color='red') + assert isinstance(patch, mp.Patch) + + # Test __repr__ method + assert repr(child_region) == "Region: 'child' of Rectangular shape." + + +def test_arena(): + # Create an arena + arena = Arena((0, 0), (100, 100), name="arena", arena_shape=RegionShape.Rectangular) + # Test add_region method + arena.add_region("small rect1", (0, 0), (50, 50)) + assert len(arena.regions) == 1 + assert arena.regions["small rect1"].name == "small rect1" + # Test remove_region method + arena.remove_region("small rect1") + assert len(arena.regions) == 0 + # Test plot method + axis = arena.plot() + assert isinstance(axis, plt.Axes) + # Test region_vector method + x = [10, 20, 30] + y = [10, 20, 30] + assert (arena.region_vector(x, y) == "").all() + + # Test in_region method + # assert len(arena.in_region(10, 10)) > 0 + # print(arena.in_region(10, 10)) + + # print(arena.in_region(110, 110)) + # assert arena.in_region(110, 110) == False + # Test __getitem__ method + arena.add_region("small rect2", (0, 0), (50, 50)) + assert arena["small rect2"].name == "small rect2" + + +if __name__ == "__main__": + pytest.main() \ No newline at end of file