restructuring project with toml add a test and docstrings
This commit is contained in:
parent
32c0a65c58
commit
1dd318f23e
1
.gitignore
vendored
1
.gitignore
vendored
@ -5,3 +5,4 @@ requires.txt
|
||||
SOURCES.txt
|
||||
dependency_links.txt
|
||||
top_level.txt
|
||||
.DS_Store
|
48
pyproject.toml
Normal file
48
pyproject.toml
Normal file
@ -0,0 +1,48 @@
|
||||
[build-system]
|
||||
requires = ["setuptools"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "etrack"
|
||||
dynamic = ["version"]
|
||||
dependencies = [
|
||||
"nixio>=1.5",
|
||||
"nixtrack",
|
||||
"numpy",
|
||||
"matplotlib",
|
||||
"opencv-python",
|
||||
"pandas",
|
||||
"scikit-image",
|
||||
]
|
||||
requires-python = ">=3.6"
|
||||
authors = [
|
||||
{name = "Jan Grewe", email = "jan.grewe@g-node.org"},
|
||||
]
|
||||
maintainers = [
|
||||
{name = "Jan Grewe", email = "jan.grewe@g-node.org"},
|
||||
]
|
||||
description = "Goodies for working with tracking data of efishes."
|
||||
readme = "README.md"
|
||||
license = {file = "LICENSE"}
|
||||
classifiers = [
|
||||
"Development Status :: 4 - Beta",
|
||||
"Environment :: Console",
|
||||
"Intended Audience :: Science/Research",
|
||||
"License :: OSI Approved :: BSD-2-Clause",
|
||||
"Natural Language :: English",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
"Operating System :: OS Independent",
|
||||
"Topic :: Scientific/Engineering",
|
||||
"Topic :: Software Development :: Libraries :: Python Modules",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
Repository = "https://github.com/bendalab/etrack"
|
||||
|
||||
[tool.setuptools.dynamic]
|
||||
version = {attr = "etrack.info.VERSION"}
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
pythonpath = "src"
|
43
setup.py
43
setup.py
@ -1,43 +0,0 @@
|
||||
import os
|
||||
from setuptools import setup
|
||||
import json
|
||||
from setuptools import setup
|
||||
|
||||
# load info from nixio/info.json
|
||||
with open(os.path.join("etrack", "info.json")) as infofile:
|
||||
infodict = json.load(infofile)
|
||||
|
||||
|
||||
NAME = "etrack"
|
||||
VERSION = infodict["VERSION"]
|
||||
AUTHOR = infodict["AUTHOR"]
|
||||
CONTACT = infodict["CONTACT"]
|
||||
BRIEF = infodict["BRIEF"]
|
||||
HOMEPAGE = infodict["HOMEPAGE"]
|
||||
CLASSIFIERS = "science"
|
||||
README = "README.md"
|
||||
|
||||
with open(README) as f:
|
||||
description_text = f.read()
|
||||
DESCRIPTION = description_text
|
||||
|
||||
packages = [
|
||||
"etrack", "etrack.io"
|
||||
]
|
||||
|
||||
install_req = ["h5py", "pandas", "matplotlib", "numpy", "opencv-python"]
|
||||
|
||||
setup(
|
||||
name=NAME,
|
||||
version=VERSION,
|
||||
description=DESCRIPTION,
|
||||
author=AUTHOR,
|
||||
author_email=CONTACT,
|
||||
packages=packages,
|
||||
install_requires=install_req,
|
||||
include_package_data=True,
|
||||
long_description=description_text,
|
||||
long_description_content_type="text/markdown",
|
||||
classifiers=CLASSIFIERS,
|
||||
license="BSD"
|
||||
)
|
@ -4,3 +4,4 @@ from .arena import Arena, Region
|
||||
from .tracking_data import TrackingData
|
||||
from .io.dlc_data import DLCReader
|
||||
from .io.nixtrack_data import NixtrackData
|
||||
from .util import RegionShape, AnalysisType
|
@ -10,26 +10,52 @@ from IPython import embed
|
||||
|
||||
|
||||
class Region(object):
|
||||
def __init__(
|
||||
self,
|
||||
origin,
|
||||
extent,
|
||||
inverted_y=True,
|
||||
name="",
|
||||
region_shape=RegionShape.Rectangular,
|
||||
parent=None,
|
||||
) -> None:
|
||||
"""
|
||||
Class representing a region (of interest). Regions can be either circular or rectangular.
|
||||
A Region can have a parent, i.e. it is contained inside a parent region. It can also have children.
|
||||
|
||||
Coordinates are given in absolute coordinates. The extent is treated depending on the shape. In case of a circular
|
||||
shape, it is the radius and the origin is the center of the circle. Otherwise the origin is the bottom, or top-left corner, depending on the y-axis orientation, if inverted, then it is top-left. FIXME: check this
|
||||
|
||||
"""
|
||||
def __init__(self, origin, extent, inverted_y=True, name="", region_shape=RegionShape.Rectangular, parent=None) -> None:
|
||||
"""Region constructor.
|
||||
Parameters
|
||||
----------
|
||||
origin : 2-tuple
|
||||
x, and y coordinates
|
||||
extent : scalar or 2-tuple, scalar only allowed to circular regions, 2-tuple for rectangular.
|
||||
inverted_y : bool, optional
|
||||
_description_, by default True
|
||||
name : str, optional
|
||||
_description_, by default ""
|
||||
region_shape : _type_, optional
|
||||
_description_, by default RegionShape.Rectangular
|
||||
parent : _type_, optional
|
||||
_description_, by default None
|
||||
|
||||
Returns
|
||||
-------
|
||||
_type_
|
||||
_description_
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
Raises Value error when origin or extent are invalid
|
||||
"""
|
||||
logging.debug(
|
||||
f"etrack.Region: Create {str(region_shape)} region {name} with props origin {origin}, extent {extent} and parent {parent}"
|
||||
)
|
||||
assert len(origin) == 2
|
||||
self._origin = origin
|
||||
self._extent = extent
|
||||
self._inverted_y = inverted_y
|
||||
if len(origin) != 2:
|
||||
raise ValueError("Region: origin must be 2-tuple!")
|
||||
self._parent = parent
|
||||
self._name = name
|
||||
self._shape_type = region_shape
|
||||
self._origin = origin
|
||||
self._check_extent(extent)
|
||||
self._parent = parent
|
||||
self._extent = extent
|
||||
self._inverted_y = inverted_y
|
||||
|
||||
@staticmethod
|
||||
def circular_mask(width, height, center, radius):
|
||||
@ -62,7 +88,7 @@ class Region(object):
|
||||
self._origin[0] + self._extent,
|
||||
self._origin[1] + self._extent,
|
||||
)
|
||||
return max_extent
|
||||
return np.asarray(max_extent)
|
||||
|
||||
@property
|
||||
def _min_extent(self):
|
||||
@ -73,7 +99,7 @@ class Region(object):
|
||||
self._origin[0] - self._extent,
|
||||
self._origin[1] - self._extent,
|
||||
)
|
||||
return min_extent
|
||||
return np.asarray(min_extent)
|
||||
|
||||
@property
|
||||
def xmax(self):
|
||||
@ -122,7 +148,13 @@ class Region(object):
|
||||
|
||||
def fits(self, other) -> bool:
|
||||
"""
|
||||
Returns true if the other region fits inside this region!
|
||||
Checks if the given region fits into the current region.
|
||||
|
||||
Args:
|
||||
other (Region): The region to check if it fits.
|
||||
|
||||
Returns:
|
||||
bool: True if the given region fits into the current region, False otherwise.
|
||||
"""
|
||||
assert isinstance(other, Region)
|
||||
does_fit = all(
|
||||
@ -146,10 +178,16 @@ class Region(object):
|
||||
|
||||
@property
|
||||
def is_child(self):
|
||||
"""
|
||||
Check if the current instance is a child.
|
||||
|
||||
Returns:
|
||||
bool: True if the instance has a parent, False otherwise.
|
||||
"""
|
||||
return self._parent is not None
|
||||
|
||||
def points_in_region(self, x, y, analysis_type=AnalysisType.Full):
|
||||
"""returns the indices of the points specified by 'x' and 'y' that fall into this region.
|
||||
"""Returns the indices of the points specified by 'x' and 'y' that fall into this region.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@ -210,23 +248,26 @@ class Region(object):
|
||||
and left a region. In case the animal was not observed after entering
|
||||
this region (for example when hidden in a tube) the leaving time is
|
||||
the maximum time entry.
|
||||
Whether the full position, or only the x- or y-position should be considered
|
||||
is controlled with the analysis_type parameter.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : _type_
|
||||
_description_
|
||||
y : _type_
|
||||
_description_
|
||||
time : _type_
|
||||
_description_
|
||||
analysis_type : _type_, optional
|
||||
_description_, by default AnalysisType.Full
|
||||
x : np.ndarray
|
||||
The animal's x-positions
|
||||
y : np.ndarray
|
||||
the animal's y-positions
|
||||
time : np.ndarray
|
||||
the time array
|
||||
analysis_type : AnalysisType, optional
|
||||
The type of analysis, by default AnalysisType.Full
|
||||
|
||||
Returns
|
||||
-------
|
||||
_type_
|
||||
_description_
|
||||
|
||||
np.ndarray
|
||||
The entering times
|
||||
np.ndarray
|
||||
The leaving times
|
||||
"""
|
||||
indices = self.points_in_region(x, y, analysis_type)
|
||||
if len(indices) == 0:
|
||||
@ -253,30 +294,65 @@ class Region(object):
|
||||
return np.array(entering), np.array(leaving)
|
||||
|
||||
def patch(self, **kwargs):
|
||||
if "fc" not in kwargs:
|
||||
kwargs["fc"] = None
|
||||
kwargs["fill"] = False
|
||||
if self._shape_type == RegionShape.Rectangular:
|
||||
w = self.position[2]
|
||||
h = self.position[3]
|
||||
return patches.Rectangle(self._origin, w, h, **kwargs)
|
||||
else:
|
||||
return patches.Circle(self._origin, self._extent, **kwargs)
|
||||
"""
|
||||
Create and return a matplotlib patch object based on the shape type of the arena.
|
||||
|
||||
Parameters:
|
||||
- kwargs: Additional keyword arguments to customize the patch object.
|
||||
|
||||
Returns:
|
||||
- A matplotlib patch object representing the arena shape.
|
||||
|
||||
If the 'fc' (facecolor) keyword argument is not provided, it will default to None.
|
||||
If the 'fill' keyword argument is not provided, it will default to False.
|
||||
|
||||
For rectangular arenas, the patch object will be a Rectangle with width and height
|
||||
based on the arena's position.
|
||||
For circular arenas, the patch object will be a Circle with radius based on the
|
||||
arena's extent.
|
||||
|
||||
Example usage:
|
||||
```
|
||||
arena = Arena()
|
||||
patch = arena.patch(fc='blue', fill=True)
|
||||
ax.add_patch(patch)
|
||||
```
|
||||
"""
|
||||
if "fc" not in kwargs:
|
||||
kwargs["fc"] = None
|
||||
kwargs["fill"] = False
|
||||
if self._shape_type == RegionShape.Rectangular:
|
||||
w = self.position[2]
|
||||
h = self.position[3]
|
||||
return patches.Rectangle(self._origin, w, h, **kwargs)
|
||||
else:
|
||||
return patches.Circle(self._origin, self._extent, **kwargs)
|
||||
|
||||
def __repr__(self):
|
||||
return f"Region: '{self._name}' of {self._shape_type} shape."
|
||||
|
||||
|
||||
class Arena(Region):
|
||||
def __init__(
|
||||
self,
|
||||
origin,
|
||||
extent,
|
||||
inverted_y=True,
|
||||
name="",
|
||||
arena_shape=RegionShape.Rectangular,
|
||||
illumination=Illumination.Backlight,
|
||||
) -> None:
|
||||
"""
|
||||
Class to represent the experimental arena. Arena is derived from Region and can be either rectangular or circular.
|
||||
An arena can not have a parent.
|
||||
See Region for more details.
|
||||
"""
|
||||
def __init__(self, origin, extent, inverted_y=True, name="", arena_shape=RegionShape.Rectangular,
|
||||
illumination=Illumination.Backlight) -> None:
|
||||
""" Construct a new Area with a given origin and extent.
|
||||
|
||||
|
||||
Returns
|
||||
-------
|
||||
_type_
|
||||
_description_
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
_description_
|
||||
"""
|
||||
super().__init__(origin, extent, inverted_y, name, arena_shape)
|
||||
self._illumination = illumination
|
||||
self.regions = {}
|
||||
@ -302,6 +378,16 @@ class Arena(Region):
|
||||
self.regions[name] = region
|
||||
|
||||
def remove_region(self, name):
|
||||
"""
|
||||
Remove a region from the arena.
|
||||
|
||||
Parameter:
|
||||
name : str
|
||||
The name of the region to remove.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
if name in self.regions:
|
||||
self.regions.pop(name)
|
||||
|
||||
@ -309,6 +395,17 @@ class Arena(Region):
|
||||
return f"Arena: '{self._name}' of {self._shape_type} shape."
|
||||
|
||||
def plot(self, axis=None):
|
||||
"""
|
||||
Plots the arena on the given axis.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
- axis (matplotlib.axes.Axes, optional): The axis on which to plot the arena. If not provided, a new figure and axis will be created.
|
||||
|
||||
Returns
|
||||
-------
|
||||
- matplotlib.axes.Axes: The axis on which the arena is plotted.
|
||||
"""
|
||||
if axis is None:
|
||||
fig = plt.figure()
|
||||
axis = fig.add_subplot(111)
|
||||
@ -336,8 +433,12 @@ class Arena(Region):
|
||||
Returns
|
||||
-------
|
||||
np.array
|
||||
vector of the same size as x and y. Each entry is the region to which the position is assinged to. If the point is not assigned to a region, the entry will be empty.
|
||||
vector of the same size as x and y. Each entry is the region to which the position is assigned to. If the point is not assigned to a region, the entry will be empty.
|
||||
"""
|
||||
if not isinstance(x, np.ndarray):
|
||||
x = np.asarray(x)
|
||||
if not isinstance(y, np.ndarray):
|
||||
y = np.asarray(y)
|
||||
rv = np.empty(x.shape, dtype=str)
|
||||
for r in self.regions:
|
||||
indices = self.regions[r].points_in_region(x, y)
|
||||
@ -345,8 +446,24 @@ class Arena(Region):
|
||||
return rv
|
||||
|
||||
def in_region(self, x, y):
|
||||
"""
|
||||
Determines if the given coordinates (x, y) are within any of the defined regions in the arena.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : float
|
||||
The x-coordinate of the point to check.
|
||||
y : float
|
||||
The y-coordinate of the point to check.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict:
|
||||
A dictionary containing the region names as keys and a list of indices of points within each region as values.
|
||||
"""
|
||||
tmp = {}
|
||||
for r in self.regions:
|
||||
print(r)
|
||||
indices = self.regions[r].points_in_region(x, y)
|
||||
tmp[r] = indices
|
||||
return tmp
|
@ -23,4 +23,3 @@ COPYRIGHT = infodict["COPYRIGHT"]
|
||||
CONTACT = infodict["CONTACT"]
|
||||
BRIEF = infodict["BRIEF"]
|
||||
HOMEPAGE = infodict["HOMEPAGE"]
|
||||
~
|
@ -6,6 +6,14 @@ class Illumination(Enum):
|
||||
|
||||
|
||||
class RegionShape(Enum):
|
||||
"""
|
||||
Enumeration representing the shape of a region.
|
||||
|
||||
Attributes:
|
||||
Circular: Represents a circular region.
|
||||
Rectangular: Represents a rectangular region.
|
||||
"""
|
||||
|
||||
Circular = 0
|
||||
Rectangular = 1
|
||||
|
82
test/test_arena.py
Normal file
82
test/test_arena.py
Normal file
@ -0,0 +1,82 @@
|
||||
import pytest
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.patches as mp
|
||||
|
||||
from etrack import Arena, Region, RegionShape
|
||||
|
||||
|
||||
def test_region():
|
||||
# Create a parent region
|
||||
parent_region = Region((0, 0), (100, 100), name="parent", region_shape=RegionShape.Rectangular)
|
||||
|
||||
# Create a child region
|
||||
child_region = Region((10, 10), (50, 50), name="child", region_shape=RegionShape.Rectangular, parent=parent_region)
|
||||
|
||||
# Test properties
|
||||
assert child_region.name == "child"
|
||||
assert child_region.inverted_y == True
|
||||
assert (child_region._max_extent == np.array((60, 60))).all()
|
||||
assert (child_region._min_extent == np.array((10, 10))).all()
|
||||
assert child_region.xmax == 60
|
||||
assert child_region.xmin == 10
|
||||
assert child_region.ymin == 10
|
||||
assert child_region.ymax == 60
|
||||
assert child_region.position == (10, 10, 50, 50)
|
||||
assert child_region.is_child == True
|
||||
|
||||
# Test fits method
|
||||
assert parent_region.fits(child_region) == True
|
||||
|
||||
# Test points_in_region method
|
||||
x = [15, 20, 25, 30, 35, 5]
|
||||
y = [15, 20, 25, 30, 35, 5]
|
||||
assert (child_region.points_in_region(x, y) == np.array([0, 1, 2, 3, 4])).all()
|
||||
|
||||
# Test time_in_region method
|
||||
x = [5, 15, 20, 25, 30, 35, 35]
|
||||
y = [5, 15, 20, 25, 30, 35, 65]
|
||||
time = np.arange(0, len(x), 1)
|
||||
enter, leave = child_region.time_in_region(x, y, time)
|
||||
assert enter[0] == 1
|
||||
assert leave[0] == 5
|
||||
|
||||
# Test patch method
|
||||
patch = child_region.patch(color='red')
|
||||
assert isinstance(patch, mp.Patch)
|
||||
|
||||
# Test __repr__ method
|
||||
assert repr(child_region) == "Region: 'child' of Rectangular shape."
|
||||
|
||||
|
||||
def test_arena():
|
||||
# Create an arena
|
||||
arena = Arena((0, 0), (100, 100), name="arena", arena_shape=RegionShape.Rectangular)
|
||||
# Test add_region method
|
||||
arena.add_region("small rect1", (0, 0), (50, 50))
|
||||
assert len(arena.regions) == 1
|
||||
assert arena.regions["small rect1"].name == "small rect1"
|
||||
# Test remove_region method
|
||||
arena.remove_region("small rect1")
|
||||
assert len(arena.regions) == 0
|
||||
# Test plot method
|
||||
axis = arena.plot()
|
||||
assert isinstance(axis, plt.Axes)
|
||||
# Test region_vector method
|
||||
x = [10, 20, 30]
|
||||
y = [10, 20, 30]
|
||||
assert (arena.region_vector(x, y) == "").all()
|
||||
|
||||
# Test in_region method
|
||||
# assert len(arena.in_region(10, 10)) > 0
|
||||
# print(arena.in_region(10, 10))
|
||||
|
||||
# print(arena.in_region(110, 110))
|
||||
# assert arena.in_region(110, 110) == False
|
||||
# Test __getitem__ method
|
||||
arena.add_region("small rect2", (0, 0), (50, 50))
|
||||
assert arena["small rect2"].name == "small rect2"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main()
|
Loading…
Reference in New Issue
Block a user