restructuring project with toml add a test and docstrings

This commit is contained in:
Jan Grewe 2024-05-30 23:59:16 +02:00
parent 32c0a65c58
commit 1dd318f23e
15 changed files with 307 additions and 94 deletions

1
.gitignore vendored
View File

@ -5,3 +5,4 @@ requires.txt
SOURCES.txt
dependency_links.txt
top_level.txt
.DS_Store

48
pyproject.toml Normal file
View File

@ -0,0 +1,48 @@
[build-system]
requires = ["setuptools"]
build-backend = "setuptools.build_meta"
[project]
name = "etrack"
dynamic = ["version"]
dependencies = [
"nixio>=1.5",
"nixtrack",
"numpy",
"matplotlib",
"opencv-python",
"pandas",
"scikit-image",
]
requires-python = ">=3.6"
authors = [
{name = "Jan Grewe", email = "jan.grewe@g-node.org"},
]
maintainers = [
{name = "Jan Grewe", email = "jan.grewe@g-node.org"},
]
description = "Goodies for working with tracking data of efishes."
readme = "README.md"
license = {file = "LICENSE"}
classifiers = [
"Development Status :: 4 - Beta",
"Environment :: Console",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: BSD-2-Clause",
"Natural Language :: English",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.12",
"Operating System :: OS Independent",
"Topic :: Scientific/Engineering",
"Topic :: Software Development :: Libraries :: Python Modules",
]
[project.urls]
Repository = "https://github.com/bendalab/etrack"
[tool.setuptools.dynamic]
version = {attr = "etrack.info.VERSION"}
[tool.pytest.ini_options]
pythonpath = "src"

View File

@ -1,43 +0,0 @@
import os
from setuptools import setup
import json
from setuptools import setup
# load info from nixio/info.json
with open(os.path.join("etrack", "info.json")) as infofile:
infodict = json.load(infofile)
NAME = "etrack"
VERSION = infodict["VERSION"]
AUTHOR = infodict["AUTHOR"]
CONTACT = infodict["CONTACT"]
BRIEF = infodict["BRIEF"]
HOMEPAGE = infodict["HOMEPAGE"]
CLASSIFIERS = "science"
README = "README.md"
with open(README) as f:
description_text = f.read()
DESCRIPTION = description_text
packages = [
"etrack", "etrack.io"
]
install_req = ["h5py", "pandas", "matplotlib", "numpy", "opencv-python"]
setup(
name=NAME,
version=VERSION,
description=DESCRIPTION,
author=AUTHOR,
author_email=CONTACT,
packages=packages,
install_requires=install_req,
include_package_data=True,
long_description=description_text,
long_description_content_type="text/markdown",
classifiers=CLASSIFIERS,
license="BSD"
)

View File

@ -4,3 +4,4 @@ from .arena import Arena, Region
from .tracking_data import TrackingData
from .io.dlc_data import DLCReader
from .io.nixtrack_data import NixtrackData
from .util import RegionShape, AnalysisType

View File

@ -10,26 +10,52 @@ from IPython import embed
class Region(object):
def __init__(
self,
origin,
extent,
inverted_y=True,
name="",
region_shape=RegionShape.Rectangular,
parent=None,
) -> None:
"""
Class representing a region (of interest). Regions can be either circular or rectangular.
A Region can have a parent, i.e. it is contained inside a parent region. It can also have children.
Coordinates are given in absolute coordinates. The extent is treated depending on the shape. In case of a circular
shape, it is the radius and the origin is the center of the circle. Otherwise the origin is the bottom, or top-left corner, depending on the y-axis orientation, if inverted, then it is top-left. FIXME: check this
"""
def __init__(self, origin, extent, inverted_y=True, name="", region_shape=RegionShape.Rectangular, parent=None) -> None:
"""Region constructor.
Parameters
----------
origin : 2-tuple
x, and y coordinates
extent : scalar or 2-tuple, scalar only allowed to circular regions, 2-tuple for rectangular.
inverted_y : bool, optional
_description_, by default True
name : str, optional
_description_, by default ""
region_shape : _type_, optional
_description_, by default RegionShape.Rectangular
parent : _type_, optional
_description_, by default None
Returns
-------
_type_
_description_
Raises
------
ValueError
Raises Value error when origin or extent are invalid
"""
logging.debug(
f"etrack.Region: Create {str(region_shape)} region {name} with props origin {origin}, extent {extent} and parent {parent}"
)
assert len(origin) == 2
self._origin = origin
self._extent = extent
self._inverted_y = inverted_y
if len(origin) != 2:
raise ValueError("Region: origin must be 2-tuple!")
self._parent = parent
self._name = name
self._shape_type = region_shape
self._origin = origin
self._check_extent(extent)
self._parent = parent
self._extent = extent
self._inverted_y = inverted_y
@staticmethod
def circular_mask(width, height, center, radius):
@ -62,7 +88,7 @@ class Region(object):
self._origin[0] + self._extent,
self._origin[1] + self._extent,
)
return max_extent
return np.asarray(max_extent)
@property
def _min_extent(self):
@ -73,7 +99,7 @@ class Region(object):
self._origin[0] - self._extent,
self._origin[1] - self._extent,
)
return min_extent
return np.asarray(min_extent)
@property
def xmax(self):
@ -122,7 +148,13 @@ class Region(object):
def fits(self, other) -> bool:
"""
Returns true if the other region fits inside this region!
Checks if the given region fits into the current region.
Args:
other (Region): The region to check if it fits.
Returns:
bool: True if the given region fits into the current region, False otherwise.
"""
assert isinstance(other, Region)
does_fit = all(
@ -146,10 +178,16 @@ class Region(object):
@property
def is_child(self):
"""
Check if the current instance is a child.
Returns:
bool: True if the instance has a parent, False otherwise.
"""
return self._parent is not None
def points_in_region(self, x, y, analysis_type=AnalysisType.Full):
"""returns the indices of the points specified by 'x' and 'y' that fall into this region.
"""Returns the indices of the points specified by 'x' and 'y' that fall into this region.
Parameters
----------
@ -210,23 +248,26 @@ class Region(object):
and left a region. In case the animal was not observed after entering
this region (for example when hidden in a tube) the leaving time is
the maximum time entry.
Whether the full position, or only the x- or y-position should be considered
is controlled with the analysis_type parameter.
Parameters
----------
x : _type_
_description_
y : _type_
_description_
time : _type_
_description_
analysis_type : _type_, optional
_description_, by default AnalysisType.Full
x : np.ndarray
The animal's x-positions
y : np.ndarray
the animal's y-positions
time : np.ndarray
the time array
analysis_type : AnalysisType, optional
The type of analysis, by default AnalysisType.Full
Returns
-------
_type_
_description_
np.ndarray
The entering times
np.ndarray
The leaving times
"""
indices = self.points_in_region(x, y, analysis_type)
if len(indices) == 0:
@ -253,30 +294,65 @@ class Region(object):
return np.array(entering), np.array(leaving)
def patch(self, **kwargs):
if "fc" not in kwargs:
kwargs["fc"] = None
kwargs["fill"] = False
if self._shape_type == RegionShape.Rectangular:
w = self.position[2]
h = self.position[3]
return patches.Rectangle(self._origin, w, h, **kwargs)
else:
return patches.Circle(self._origin, self._extent, **kwargs)
"""
Create and return a matplotlib patch object based on the shape type of the arena.
Parameters:
- kwargs: Additional keyword arguments to customize the patch object.
Returns:
- A matplotlib patch object representing the arena shape.
If the 'fc' (facecolor) keyword argument is not provided, it will default to None.
If the 'fill' keyword argument is not provided, it will default to False.
For rectangular arenas, the patch object will be a Rectangle with width and height
based on the arena's position.
For circular arenas, the patch object will be a Circle with radius based on the
arena's extent.
Example usage:
```
arena = Arena()
patch = arena.patch(fc='blue', fill=True)
ax.add_patch(patch)
```
"""
if "fc" not in kwargs:
kwargs["fc"] = None
kwargs["fill"] = False
if self._shape_type == RegionShape.Rectangular:
w = self.position[2]
h = self.position[3]
return patches.Rectangle(self._origin, w, h, **kwargs)
else:
return patches.Circle(self._origin, self._extent, **kwargs)
def __repr__(self):
return f"Region: '{self._name}' of {self._shape_type} shape."
class Arena(Region):
def __init__(
self,
origin,
extent,
inverted_y=True,
name="",
arena_shape=RegionShape.Rectangular,
illumination=Illumination.Backlight,
) -> None:
"""
Class to represent the experimental arena. Arena is derived from Region and can be either rectangular or circular.
An arena can not have a parent.
See Region for more details.
"""
def __init__(self, origin, extent, inverted_y=True, name="", arena_shape=RegionShape.Rectangular,
illumination=Illumination.Backlight) -> None:
""" Construct a new Area with a given origin and extent.
Returns
-------
_type_
_description_
Raises
------
ValueError
_description_
"""
super().__init__(origin, extent, inverted_y, name, arena_shape)
self._illumination = illumination
self.regions = {}
@ -302,6 +378,16 @@ class Arena(Region):
self.regions[name] = region
def remove_region(self, name):
"""
Remove a region from the arena.
Parameter:
name : str
The name of the region to remove.
Returns:
None
"""
if name in self.regions:
self.regions.pop(name)
@ -309,6 +395,17 @@ class Arena(Region):
return f"Arena: '{self._name}' of {self._shape_type} shape."
def plot(self, axis=None):
"""
Plots the arena on the given axis.
Parameters
----------
- axis (matplotlib.axes.Axes, optional): The axis on which to plot the arena. If not provided, a new figure and axis will be created.
Returns
-------
- matplotlib.axes.Axes: The axis on which the arena is plotted.
"""
if axis is None:
fig = plt.figure()
axis = fig.add_subplot(111)
@ -336,8 +433,12 @@ class Arena(Region):
Returns
-------
np.array
vector of the same size as x and y. Each entry is the region to which the position is assinged to. If the point is not assigned to a region, the entry will be empty.
vector of the same size as x and y. Each entry is the region to which the position is assigned to. If the point is not assigned to a region, the entry will be empty.
"""
if not isinstance(x, np.ndarray):
x = np.asarray(x)
if not isinstance(y, np.ndarray):
y = np.asarray(y)
rv = np.empty(x.shape, dtype=str)
for r in self.regions:
indices = self.regions[r].points_in_region(x, y)
@ -345,8 +446,24 @@ class Arena(Region):
return rv
def in_region(self, x, y):
"""
Determines if the given coordinates (x, y) are within any of the defined regions in the arena.
Parameters
----------
x : float
The x-coordinate of the point to check.
y : float
The y-coordinate of the point to check.
Returns
-------
dict:
A dictionary containing the region names as keys and a list of indices of points within each region as values.
"""
tmp = {}
for r in self.regions:
print(r)
indices = self.regions[r].points_in_region(x, y)
tmp[r] = indices
return tmp

View File

@ -23,4 +23,3 @@ COPYRIGHT = infodict["COPYRIGHT"]
CONTACT = infodict["CONTACT"]
BRIEF = infodict["BRIEF"]
HOMEPAGE = infodict["HOMEPAGE"]
~

View File

@ -6,6 +6,14 @@ class Illumination(Enum):
class RegionShape(Enum):
"""
Enumeration representing the shape of a region.
Attributes:
Circular: Represents a circular region.
Rectangular: Represents a rectangular region.
"""
Circular = 0
Rectangular = 1

82
test/test_arena.py Normal file
View File

@ -0,0 +1,82 @@
import pytest
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mp
from etrack import Arena, Region, RegionShape
def test_region():
# Create a parent region
parent_region = Region((0, 0), (100, 100), name="parent", region_shape=RegionShape.Rectangular)
# Create a child region
child_region = Region((10, 10), (50, 50), name="child", region_shape=RegionShape.Rectangular, parent=parent_region)
# Test properties
assert child_region.name == "child"
assert child_region.inverted_y == True
assert (child_region._max_extent == np.array((60, 60))).all()
assert (child_region._min_extent == np.array((10, 10))).all()
assert child_region.xmax == 60
assert child_region.xmin == 10
assert child_region.ymin == 10
assert child_region.ymax == 60
assert child_region.position == (10, 10, 50, 50)
assert child_region.is_child == True
# Test fits method
assert parent_region.fits(child_region) == True
# Test points_in_region method
x = [15, 20, 25, 30, 35, 5]
y = [15, 20, 25, 30, 35, 5]
assert (child_region.points_in_region(x, y) == np.array([0, 1, 2, 3, 4])).all()
# Test time_in_region method
x = [5, 15, 20, 25, 30, 35, 35]
y = [5, 15, 20, 25, 30, 35, 65]
time = np.arange(0, len(x), 1)
enter, leave = child_region.time_in_region(x, y, time)
assert enter[0] == 1
assert leave[0] == 5
# Test patch method
patch = child_region.patch(color='red')
assert isinstance(patch, mp.Patch)
# Test __repr__ method
assert repr(child_region) == "Region: 'child' of Rectangular shape."
def test_arena():
# Create an arena
arena = Arena((0, 0), (100, 100), name="arena", arena_shape=RegionShape.Rectangular)
# Test add_region method
arena.add_region("small rect1", (0, 0), (50, 50))
assert len(arena.regions) == 1
assert arena.regions["small rect1"].name == "small rect1"
# Test remove_region method
arena.remove_region("small rect1")
assert len(arena.regions) == 0
# Test plot method
axis = arena.plot()
assert isinstance(axis, plt.Axes)
# Test region_vector method
x = [10, 20, 30]
y = [10, 20, 30]
assert (arena.region_vector(x, y) == "").all()
# Test in_region method
# assert len(arena.in_region(10, 10)) > 0
# print(arena.in_region(10, 10))
# print(arena.in_region(110, 110))
# assert arena.in_region(110, 110) == False
# Test __getitem__ method
arena.add_region("small rect2", (0, 0), (50, 50))
assert arena["small rect2"].name == "small rect2"
if __name__ == "__main__":
pytest.main()