Compare commits

...

23 Commits

Author SHA1 Message Date
e3b5d2d6cc documentation 2024-06-01 20:21:14 +02:00
94fa5e3d14 docstrings 2024-06-01 12:19:43 +02:00
cbd0541a54 docs, fixes and test for nixtrack data wrapper 2024-06-01 12:19:34 +02:00
1dd318f23e restructuring project with toml add a test and docstrings 2024-05-30 23:59:16 +02:00
32c0a65c58 add info files 2024-05-24 17:31:52 +02:00
b3ba30ced6 Merge branch 'master' of https://whale.am28.uni-tuebingen.de/git/jgrewe/efish_tracking 2023-02-10 18:48:16 +01:00
469a35724d logging 2023-02-10 18:45:57 +01:00
bf8635d2fd [tracking data] add optional parameters to speed estimation to allow passing of custom values 2022-12-04 11:50:32 +01:00
2bba750e1f latest changes 2022-12-03 11:00:24 +01:00
0291ef088a [init] add nixtrack data io to package 2022-09-14 09:53:40 +02:00
6dd4a4f5de [io] adaptor for nixtrack data 2022-09-14 09:53:17 +02:00
701cda1069 [arena] add time in region function 2022-09-12 16:57:26 +02:00
3e1cbe4b9b [trackingdata] add fps, some docs, change interpolate 2022-09-12 16:56:43 +02:00
6f9633a74e [tracking_data] add interpolation function, some docstrings 2022-09-10 11:13:10 +02:00
f56e21d9b1 [init] add imports 2022-09-06 15:25:43 +02:00
30a035f82d [trackingData] new class that represents tracking data ...
With this we separate reading and handling of tracking data. io classes handle reading
2022-09-06 15:25:30 +02:00
9046e70592 [arena] cleanup, change regions to dict, and ...
add function to assign regions to positions
2022-09-06 15:23:55 +02:00
6487cb07ff [util] add PositionType enum 2022-09-06 15:22:30 +02:00
e5c9653bdd [setup] add etrack.io package 2022-09-06 15:01:06 +02:00
2593f21f3a [io] io class for reading dlc h5 data files 2022-09-06 15:00:43 +02:00
ae277ce8fb arena abstraction 2022-08-31 16:56:26 +02:00
e854ab591f [tracking] tear apart the positions function, offer method to access raw pixel positions 2022-08-11 11:06:45 +02:00
16873702d4 tiny adaptations of the package 2021-06-10 09:45:17 +02:00
23 changed files with 1469 additions and 71 deletions

1
.gitignore vendored
View File

@ -5,3 +5,4 @@ requires.txt
SOURCES.txt SOURCES.txt
dependency_links.txt dependency_links.txt
top_level.txt top_level.txt
.DS_Store

View File

@ -28,7 +28,7 @@ If you leave away the ```--user``` the package will be installed system-wide.
## TrackingResults ## TrackingResults
Is a class that wraps around the *.h5 files written by DeppLabCut Is a class that wraps around the *.h5 files written by DeepLabCut
## ImageMarker ## ImageMarker

74
build_docs.sh Executable file
View File

@ -0,0 +1,74 @@
#!/bin/bash
die() { echo "ERROR: $*"; exit 2; }
warn() { echo "WARNING: $*"; }
for cmd in mkdocs pdoc3 genbadge; do
command -v "$cmd" >/dev/null ||
warn "missing $cmd: run \`pip install $cmd\`"
done
PACKAGE="etrack"
PACKAGESRC="src/$PACKAGE"
PACKAGEROOT="$(dirname "$(realpath "$0")")"
BUILDROOT="$PACKAGEROOT/site"
# check for code coverage report:
# need to call nosetest with --with-coverage --cover-html --cover-xml
HAS_COVER=false
test -d cover && HAS_COVER=true
echo
echo "Clean up documentation of $PACKAGE"
echo
rm -rf "$BUILDROOT" 2> /dev/null || true
mkdir -p "$BUILDROOT"
if command -v mkdocs >/dev/null; then
echo
echo "Building general documentation for $PACKAGE"
echo
cd "$PACKAGEROOT"
cp .mkdocs.yml mkdocs-tmp.yml
if $HAS_COVER; then
echo " - Coverage: 'cover/index.html'" >> mkdocs-tmp.yml
fi
mkdir -p docs
sed -e 's|docs/||; /\[Documentation\]/d; /\[API Reference\]/d' README.md > docs/index.md
mkdocs build --config-file mkdocs.yml --site-dir "$BUILDROOT"
rm mkdocs-tmp.yml docs/index.md
cd - > /dev/null
fi
if $HAS_COVER; then
echo
echo "Copy code coverage report and generate badge for $PACKAGE"
echo
cd "$PACKAGEROOT"
cp -r cover "$BUILDROOT/"
genbadge coverage -i coverage.xml
# https://smarie.github.io/python-genbadge/
mv coverage-badge.svg site/coverage.svg
cd - > /dev/null
fi
if command -v pdoc3 >/dev/null; then
echo
echo "Building API reference docs for $PACKAGE"
echo
cd "$PACKAGEROOT"
pdoc3 --html --config latex_math=True --config sort_identifiers=False --output-dir "$BUILDROOT/api-tmp" $PACKAGESRC
mv "$BUILDROOT/api-tmp/$PACKAGE" "$BUILDROOT/api"
rmdir "$BUILDROOT/api-tmp"
cd - > /dev/null
fi
echo
echo "Done. Docs in:"
echo
echo " file://$BUILDROOT/index.html"
echo

35
docs/etrack.md Normal file
View File

@ -0,0 +1,35 @@
# E-Fish tracking
Tool for easier handling of tracking results.
## Installation
### 1. Clone git repository
```shell
git clone https://whale.am28.uni-tuebingen.de/git/jgrewe/efish_tracking.git
```
### 2. Change into directory
```shell
cd efish_tracking
````
### 3. Install with pip
```shell
pip3 install -e . --user
```
The ```-e``` installs the package in an *editable* model that you do not need to reinstall whenever you pull upstream changes.
If you leave away the ```--user``` the package will be installed system-wide.
## TrackingResults
Is a class that wraps around the *.h5 files written by DeepLabCut
## ImageMarker
Class that allows for creating MarkerTasks to get specific positions in a video.

4
docs/trackingdata.md Normal file
View File

@ -0,0 +1,4 @@
# TrackingData
Class that represents the position data associated with one noe/bodypart.

View File

@ -1,2 +0,0 @@
from .image_marker import ImageMarker, MarkerTask
from .tracking_result import TrackingResult

17
mkdocs.yml Normal file
View File

@ -0,0 +1,17 @@
site_name: etrack
repo_url: https://github.com/bendalab/etrack/
edit_uri: ""
site_author: Jan Grewe jan.grewe@g-node.org
theme: readthedocs
nav:
- Home: 'index.md'
- 'User guide':
- 'etrack': 'etrack.md'
- 'TrackingData' : 'trackingdata.md'
- 'Code':
- API reference: 'api/index.html'

48
pyproject.toml Normal file
View File

@ -0,0 +1,48 @@
[build-system]
requires = ["setuptools"]
build-backend = "setuptools.build_meta"
[project]
name = "etrack"
dynamic = ["version"]
dependencies = [
"hdf5",
"nixtrack",
"numpy",
"matplotlib",
"opencv-python",
"pandas",
"scikit-image",
]
requires-python = ">=3.6"
authors = [
{name = "Jan Grewe", email = "jan.grewe@g-node.org"},
]
maintainers = [
{name = "Jan Grewe", email = "jan.grewe@g-node.org"},
]
description = "Goodies for working with tracking data of efishes."
readme = "README.md"
license = {file = "LICENSE"}
classifiers = [
"Development Status :: 4 - Beta",
"Environment :: Console",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: BSD-2-Clause",
"Natural Language :: English",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.12",
"Operating System :: OS Independent",
"Topic :: Scientific/Engineering",
"Topic :: Software Development :: Libraries :: Python Modules",
]
[project.urls]
Repository = "https://github.com/bendalab/etrack"
[tool.setuptools.dynamic]
version = {attr = "etrack.info.VERSION"}
[tool.pytest.ini_options]
pythonpath = "src"

View File

@ -1,33 +0,0 @@
from setuptools import setup
NAME = "etrack"
VERSION = 0.5
AUTHOR = "Jan Grewe"
CONTACT = "jan.grewe@g-node.org"
CLASSIFIERS = "science"
DESCRIPTION = "helpers for handling depp lab cut tracking results"
README = "README.md"
with open(README) as f:
description_text = f.read()
packages = [
"etrack",
]
install_req = ["h5py", "pandas", "matplotlib", "numpy", "opencv-python"]
setup(
name=NAME,
version=VERSION,
description=DESCRIPTION,
author=AUTHOR,
author_email=CONTACT,
packages=packages,
install_requires=install_req,
include_package_data=True,
long_description=description_text,
long_description_content_type="text/markdown",
classifiers=CLASSIFIERS,
license="BSD"
)

14
src/etrack/__init__.py Normal file
View File

@ -0,0 +1,14 @@
# -*- coding: utf-8 -*-
""" etrack package for easier reading and handling of efish tracking data.
Copyright © 2024, Jan Grewe
Redistribution and use in source and binary forms, with or without modification, are permitted under the terms of the BSD License. See LICENSE file in the root of the Project.
"""
from .image_marker import ImageMarker, MarkerTask
from .tracking_result import TrackingResult, coordinate_transformation
from .arena import Arena, Region
from .tracking_data import TrackingData
from .io.dlc_data import DLCReader
from .io.nixtrack_data import NixtrackData
from .util import RegionShape, AnalysisType

527
src/etrack/arena.py Normal file
View File

@ -0,0 +1,527 @@
"""
Classes to construct the arena in which the animals were tracked.
"""
import logging
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from skimage.draw import disk
from .util import RegionShape, AnalysisType, Illumination
class Region(object):
"""
Class representing a region (of interest). Regions can be either circular or rectangular.
A Region can have a parent, i.e. it is contained inside a parent region. It can also have children.
Coordinates are given in absolute coordinates. The extent is treated depending on the shape. In case of a circular
shape, it is the radius and the origin is the center of the circle. Otherwise the origin is the bottom, or top-left corner, depending on the y-axis orientation, if inverted, then it is top-left. FIXME: check this
"""
def __init__(self, origin, extent, inverted_y=True, name="", region_shape=RegionShape.Rectangular, parent=None) -> None:
"""Region constructor.
Parameters
----------
origin : 2-tuple
x, and y coordinates
extent : scalar or 2-tuple, scalar only allowed to circular regions, 2-tuple for rectangular.
inverted_y : bool, optional
_description_, by default True
name : str, optional
_description_, by default ""
region_shape : _type_, optional
_description_, by default RegionShape.Rectangular
parent : _type_, optional
_description_, by default None
Returns
-------
_type_
_description_
Raises
------
ValueError
Raises Value error when origin or extent are invalid
"""
logging.debug(
f"etrack.Region: Create {str(region_shape)} region {name} with props origin {origin}, extent {extent} and parent {parent}"
)
if len(origin) != 2:
raise ValueError("Region: origin must be 2-tuple!")
self._parent = parent
self._name = name
self._shape_type = region_shape
self._origin = origin
self._check_extent(extent)
self._extent = extent
self._inverted_y = inverted_y
@staticmethod
def circular_mask(width, height, center, radius):
assert center[1] + radius < width and center[1] - radius > 0
assert center[0] + radius < height and center[0] - radius > 0
mask = np.zeros((height, width), dtype=np.uint8)
rr, cc = disk(reversed(center), radius)
mask[rr, cc] = 1
return mask
@property
def name(self):
return self._name
@property
def inverted_y(self):
return self._inverted_y
@property
def _max_extent(self):
if self._shape_type == RegionShape.Rectangular:
max_extent = (
self._origin[0] + self._extent[0],
self._origin[1] + self._extent[1],
)
else:
max_extent = (
self._origin[0] + self._extent,
self._origin[1] + self._extent,
)
return np.asarray(max_extent)
@property
def _min_extent(self):
if self._shape_type == RegionShape.Rectangular:
min_extent = self._origin
else:
min_extent = (
self._origin[0] - self._extent,
self._origin[1] - self._extent,
)
return np.asarray(min_extent)
@property
def xmax(self):
return self._max_extent[0]
@property
def xmin(self):
return self._min_extent[0]
@property
def ymin(self):
return self._min_extent[1]
@property
def ymax(self):
return self._max_extent[1]
@property
def position(self):
"""
Get the position of the arena.
Returns
-------
tuple
A tuple containing the x-coordinate, y-coordinate, width, and height of the arena.
"""
x = self._min_extent[0]
y = self._min_extent[1]
width = self._max_extent[0] - self._min_extent[0]
height = self._max_extent[1] - self._min_extent[1]
return x, y, width, height
def _check_extent(self, ext):
"""Checks whether the extent matches the shape. i.e. if the shape is Rectangular, extent must be a length 2 list, tuple, otherwise, if the region is circular, extent must be a single numerical value.
Parameters
----------
ext : tuple, or numeric scalar
"""
if self._shape_type == RegionShape.Rectangular:
if not isinstance(ext, (list, tuple, np.ndarray)) and len(ext) != 2:
raise ValueError(
"Extent must be a length 2 list or tuple for rectangular regions!"
)
elif self._shape_type == RegionShape.Circular:
if not isinstance(ext, (int, float)):
raise ValueError(
"Extent must be a numerical scalar for circular regions!"
)
else:
raise ValueError(f"Invalid ShapeType, {self._shape_type}!")
def fits(self, other) -> bool:
"""
Checks if the given region fits into the current region.
Args:
other (Region): The region to check if it fits.
Returns:
bool: True if the given region fits into the current region, False otherwise.
"""
assert isinstance(other, Region)
does_fit = all(
(
other._min_extent[0] >= self._min_extent[0],
other._min_extent[1] >= self._min_extent[1],
other._max_extent[0] <= self._max_extent[0],
other._max_extent[1] <= self._max_extent[1],
)
)
if not does_fit:
m = (
f"Region {other.name} does not fit into {self.name}. "
f"min x: {other._min_extent[0] >= self._min_extent[0]},",
f"min y: {other._min_extent[1] >= self._min_extent[1]},",
f"max x: {other._max_extent[0] <= self._max_extent[0]},",
f"max y: {other._max_extent[1] <= self._max_extent[1]}",
)
logging.debug(m)
return does_fit
@property
def is_child(self):
"""
Check if the current instance is a child.
Returns:
bool: True if the instance has a parent, False otherwise.
"""
return self._parent is not None
def points_in_region(self, x, y, analysis_type=AnalysisType.Full):
"""Returns the indices of the points specified by 'x' and 'y' that fall into this region.
Parameters
----------
x : np.ndarray
the x positions
y : np.ndarray
the y positions
analysis_type : AnalysisType, optional
defines how the positions are evaluated, by default AnalysisType.Full
FIXME: some of this can probably be solved using linear algebra, what with multiple exact same points?
"""
if self._shape_type == RegionShape.Rectangular or (
self._shape_type == RegionShape.Circular
and analysis_type != AnalysisType.Full
):
if analysis_type == AnalysisType.Full:
indices = np.where(
((y >= self._min_extent[1]) & (y <= self._max_extent[1]))
& ((x >= self._min_extent[0]) & (x <= self._max_extent[0]))
)[0]
indices = np.array(indices, dtype=int)
elif analysis_type == AnalysisType.CollapseX:
x_indices = np.where(
(x >= self._min_extent[0]) & (x <= self._max_extent[0])
)[0]
indices = np.asarray(x_indices, dtype=int)
else:
y_indices = np.where(
(y >= self._min_extent[1]) & (y <= self._max_extent[1])
)[0]
indices = np.asarray(y_indices, dtype=int)
else:
if self.is_child:
mask = self.circular_mask(
self._parent.position[2],
self._parent.position[3],
self._origin,
self._extent,
)
else:
mask = self.circular_mask(
self.position[2], self.position[3], self._origin, self._extent
)
img = np.zeros_like(mask)
img[np.asarray(y, dtype=int), np.asarray(x, dtype=int)] = 1
temp = np.where(img & mask)
indices = []
for i, j in zip(list(temp[1]), list(temp[0])):
matches = np.where((x == i) & (y == j))
if len(matches[0]) == 0:
continue
indices.append(matches[0][0])
indices = np.array(indices)
return indices
def time_in_region(self, x, y, time, analysis_type=AnalysisType.Full):
"""Returns the entering and leaving times at which the animal entered
and left a region. In case the animal was not observed after entering
this region (for example when hidden in a tube) the leaving time is
the maximum time entry.
Whether the full position, or only the x- or y-position should be considered
is controlled with the analysis_type parameter.
Parameters
----------
x : np.ndarray
The animal's x-positions
y : np.ndarray
the animal's y-positions
time : np.ndarray
the time array
analysis_type : AnalysisType, optional
The type of analysis, by default AnalysisType.Full
Returns
-------
np.ndarray
The entering times
np.ndarray
The leaving times
"""
indices = self.points_in_region(x, y, analysis_type)
if len(indices) == 0:
return np.array([]), np.array([])
diffs = np.diff(indices)
if len(diffs) == sum(diffs):
entering = [time[indices[0]]]
leaving = [time[indices[-1]]]
else:
entering = []
leaving = []
jumps = np.where(diffs > 1)[0]
start = time[indices[0]]
for i in range(len(jumps)):
end = time[indices[jumps[i]]]
entering.append(start)
leaving.append(end)
start = time[indices[jumps[i] + 1]]
end = time[indices[-1]]
entering.append(start)
leaving.append(end)
return np.array(entering), np.array(leaving)
def patch(self, **kwargs):
"""
Create and return a matplotlib patch object based on the shape type of the arena.
Parameters:
- kwargs: Additional keyword arguments to customize the patch object.
Returns:
- A matplotlib patch object representing the arena shape.
If the 'fc' (facecolor) keyword argument is not provided, it will default to None.
If the 'fill' keyword argument is not provided, it will default to False.
For rectangular arenas, the patch object will be a Rectangle with width and height
based on the arena's position.
For circular arenas, the patch object will be a Circle with radius based on the
arena's extent.
Example usage:
```
arena = Arena()
patch = arena.patch(fc='blue', fill=True)
ax.add_patch(patch)
```
"""
if "fc" not in kwargs:
kwargs["fc"] = None
kwargs["fill"] = False
if self._shape_type == RegionShape.Rectangular:
w = self.position[2]
h = self.position[3]
return patches.Rectangle(self._origin, w, h, **kwargs)
else:
return patches.Circle(self._origin, self._extent, **kwargs)
def __repr__(self):
return f"Region: '{self._name}' of {self._shape_type} shape."
class Arena(Region):
"""
Class to represent the experimental arena. Arena is derived from Region and can be either rectangular or circular.
An arena can not have a parent.
See Region for more details.
"""
def __init__(self, origin, extent, inverted_y=True, name="", arena_shape=RegionShape.Rectangular,
illumination=Illumination.Backlight) -> None:
""" Construct a new Area with a given origin and extent.
Returns
-------
_type_
_description_
Raises
------
ValueError
_description_
"""
super().__init__(origin, extent, inverted_y, name, arena_shape)
self._illumination = illumination
self.regions = {}
def add_region(
self, name, origin, extent, shape_type=RegionShape.Rectangular, region=None
):
if name is None or name in self.regions.keys():
raise ValueError(
"Region name '{name}' is invalid. The name must not be None and must be unique among the regions."
)
if region is None:
region = Region(
origin, extent, name=name, region_shape=shape_type, parent=self
)
else:
region._parent = self
doesfit = self.fits(region)
if not doesfit:
logging.warn(
f"Warning! Region {region.name} with size {region.position} does fit into {self.name} with size {self.position}!"
)
self.regions[name] = region
def remove_region(self, name):
"""
Remove a region from the arena.
Parameter:
name : str
The name of the region to remove.
Returns:
None
"""
if name in self.regions:
self.regions.pop(name)
def __repr__(self):
return f"Arena: '{self._name}' of {self._shape_type} shape."
def plot(self, axis=None):
"""
Plots the arena on the given axis.
Parameters
----------
- axis (matplotlib.axes.Axes, optional): The axis on which to plot the arena. If not provided, a new figure and axis will be created.
Returns
-------
- matplotlib.axes.Axes: The axis on which the arena is plotted.
"""
if axis is None:
fig = plt.figure()
axis = fig.add_subplot(111)
axis.add_patch(self.patch())
axis.set_xlim([self._origin[0], self._max_extent[0]])
if self.inverted_y:
axis.set_ylim([self._max_extent[1], self._origin[1]])
else:
axis.set_ylim([self._origin[1], self._max_extent[1]])
for r in self.regions:
axis.add_patch(self.regions[r].patch())
return axis
def region_vector(self, x, y):
"""Returns a vector that contains the region names within which the agent was found.
FIXME: This does not work well with overlapping regions!@!
Parameters
----------
x : np.array
the x-positions
y : np.ndarray
the y-positions
Returns
-------
np.array
vector of the same size as x and y. Each entry is the region to which the position is assigned to. If the point is not assigned to a region, the entry will be empty.
"""
if not isinstance(x, np.ndarray):
x = np.asarray(x)
if not isinstance(y, np.ndarray):
y = np.asarray(y)
rv = np.empty(x.shape, dtype=str)
for r in self.regions:
indices = self.regions[r].points_in_region(x, y)
rv[indices] = r
return rv
def in_region(self, x, y):
"""
Determines if the given coordinates (x, y) are within any of the defined regions in the arena.
Parameters
----------
x : float
The x-coordinate of the point to check.
y : float
The y-coordinate of the point to check.
Returns
-------
dict:
A dictionary containing the region names as keys and a list of indices of points within each region as values.
"""
tmp = {}
for r in self.regions:
print(r)
indices = self.regions[r].points_in_region(x, y)
tmp[r] = indices
return tmp
def __getitem__(self, key):
if isinstance(key, (str)):
return self.regions[key]
else:
return self.regions[self.regions.keys()[key]]
if __name__ == "__main__":
a = Arena((0, 0), (1024, 768), name="arena", arena_shape=RegionShape.Rectangular)
a.add_region("small rect1", (0, 0), (100, 300))
a.add_region("small rect2", (150, 0), (100, 300))
a.add_region("small rect3", (300, 0), (100, 300))
a.add_region("circ", (600, 400), 150, shape_type=RegionShape.Circular)
axis = a.plot()
x = np.linspace(a.position[0], a.position[0] + a.position[2] - 1, 100, dtype=int)
y = np.asarray(
(np.sin(x * 0.01) + 1) * a.position[3] / 2 + a.position[1] - 1, dtype=int
)
# y = np.linspace(a.position[1], a.position[1] + a.position[3] - 1, 100, dtype=int)
axis.scatter(x, y, c="k", s=2)
ind = a.regions[3].points_in_region(x, y)
if len(ind) > 0:
axis.scatter(x[ind], y[ind], label="circ full")
ind = a.regions[3].points_in_region(x, y, AnalysisType.CollapseX)
if len(ind) > 0:
axis.scatter(x[ind], y[ind] - 10, label="circ collapseX")
ind = a.regions[3].points_in_region(x, y, AnalysisType.CollapseY)
if len(ind) > 0:
axis.scatter(x[ind], y[ind] + 10, label="circ collapseY")
ind = a.regions[0].points_in_region(x, y, AnalysisType.CollapseX)
if len(ind) > 0:
axis.scatter(x[ind], y[ind] - 10, label="rect collapseX")
ind = a.regions[1].points_in_region(x, y, AnalysisType.CollapseY)
if len(ind) > 0:
axis.scatter(x[ind], y[ind] + 10, label="rect collapseY")
ind = a.regions[2].points_in_region(x, y, AnalysisType.Full)
if len(ind) > 0:
axis.scatter(x[ind], y[ind] + 20, label="rect full")
axis.legend()
plt.show()
a.plot()
plt.show()

View File

@ -1,8 +1,11 @@
import matplotlib.pyplot as plt """
import cv2 Module that defines the ImageMarker and MarkerTask classes to manually mark things in individual images.
"""
import os import os
import cv2
import sys import sys
from IPython import embed import matplotlib.pyplot as plt
class ImageMarker: class ImageMarker:
@ -30,6 +33,7 @@ class ImageMarker:
print("Reading frame: %i" % frame_counter, end="\r") print("Reading frame: %i" % frame_counter, end="\r")
success, frame = video.read() success, frame = video.read()
frame_counter += 1 frame_counter += 1
if success: if success:
self._fig.gca().imshow(frame) self._fig.gca().imshow(frame)
else: else:
@ -139,13 +143,14 @@ class MarkerTask():
if __name__ == "__main__": if __name__ == "__main__":
print("Hello Jan!")
tank_task = MarkerTask("tank limits", ["bottom left corner", "top left corner", "top right corner", "bottom right corner"], "Mark tank corners") tank_task = MarkerTask("tank limits", ["bottom left corner", "top left corner", "top right corner", "bottom right corner"], "Mark tank corners")
feeder_task = MarkerTask("Feeder positions", list(map(str, range(1, 2))), "Mark feeder positions") #feeder_task = MarkerTask("Feeder positions", list(map(str, range(1, 2))), "Mark feeder positions")
tasks = [tank_task, feeder_task] #tasks = [tank_task, feeder_task]
im = ImageMarker(tasks) im = ImageMarker([tank_task])
# vid1 = "2020.12.11_lepto48DLC_resnet50_boldnessDec11shuffle1_200000_labeled.mp4" vid1 = "/data/personality/secondhome/fischies/lepto_03/position/lepto03_position_2021.06.07_60.mp4"
print(sys.argv[0]) # print(sys.argv[0])
print (sys.argv[1]) # print (sys.argv[1])
vid1 = sys.argv[1] # vid1 = sys.argv[1]
marker_positions = im.mark_movie(vid1, 10) marker_positions = im.mark_movie(vid1, 00)
print(marker_positions) print(marker_positions)

10
src/etrack/info.json Normal file
View File

@ -0,0 +1,10 @@
{
"VERSION": "0.5.0",
"STATUS": "Release",
"RELEASE": "0.5.0 Release",
"AUTHOR": "Jan Grewe",
"COPYRIGHT": "2024, University of Tuebingen, Neuroethology, Jan Grewe",
"CONTACT": "jan.grewe@g-node.org",
"BRIEF": "Efish tracking helpers for handling tracking data.",
"HOMEPAGE": "https://github.com/G-Node/nixpy"
}

28
src/etrack/info.py Normal file
View File

@ -0,0 +1,28 @@
# -*- coding: utf-8 -*-
# Copyright © 2024, Jan Grewe
#
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted under the terms of the BSD License. See
# LICENSE file in the root of the Project.
"""
Package info.
"""
import os
import json
here = os.path.dirname(__file__)
with open(os.path.join(here, "info.json")) as infofile:
infodict = json.load(infofile)
VERSION = infodict["VERSION"]
STATUS = infodict["STATUS"]
RELEASE = infodict["RELEASE"]
AUTHOR = infodict["AUTHOR"]
COPYRIGHT = infodict["COPYRIGHT"]
CONTACT = infodict["CONTACT"]
BRIEF = infodict["BRIEF"]
HOMEPAGE = infodict["HOMEPAGE"]

View File

@ -0,0 +1,3 @@
"""
Reader classes for DeepLabCut, or SLEAP written data files.
"""

78
src/etrack/io/dlc_data.py Normal file
View File

@ -0,0 +1,78 @@
import os
import numpy as np
import pandas as pd
import numbers as nb
from ..tracking_data import TrackingData
class DLCReader(object):
"""Class that represents the tracking data stored in a DeepLabCut hdf5 file."""
def __init__(self, results_file, crop=(0, 0)) -> None:
"""
If the video data was cropped before tracking and the tracked positions are with respect to the cropped images, we may want to correct for this to convert the data back to absolute positions in the video frame.
Parameters
----------
crop : 2-tuple
tuple of (xoffset, yoffset)
Raises
------
ValueError if crop value is not a 2-tuple
"""
if not os.path.exists(results_file):
raise ValueError("File %s does not exist!" % results_file)
if not isinstance(crop, tuple) or len(crop) < 2:
raise ValueError("Cropping info must be a 2-tuple of (x, y)")
self._file_name = results_file
self._crop = crop
self._data_frame = pd.read_hdf(results_file)
self._level_shape = self._data_frame.columns.levshape
self._scorer = self._data_frame.columns.levels[0].values
self._bodyparts = self._data_frame.columns.levels[1].values if self._level_shape[1] > 0 else []
self._positions = self._data_frame.columns.levels[2].values if self._level_shape[2] > 0 else []
@property
def filename(self):
return self._file_name
@property
def dataframe(self):
return self._data_frame
@property
def scorer(self):
return self._scorer
@property
def bodyparts(self):
return self._bodyparts
def _correct_cropping(self, orgx, orgy):
x = orgx + self._crop[0]
y = orgy + self._crop[1]
return x, y
def track(self, scorer=0, bodypart=0, framerate=30):
if isinstance(scorer, nb.Number):
sc = self._scorer[scorer]
elif isinstance(scorer, str) and scorer in self._scorer:
sc = scorer
else:
raise ValueError(f"Scorer {scorer} is not in dataframe!")
if isinstance(bodypart, nb.Number):
bp = self._bodyparts[bodypart]
elif isinstance(bodypart, str) and bodypart in self._bodyparts:
bp = bodypart
else:
raise ValueError(f"Body part {bodypart} is not in dataframe!")
x = np.asarray(self._data_frame[sc][bp]["x"] if "x" in self._positions else [])
y = np.asarray(self._data_frame[sc][bp]["y"] if "y" in self._positions else [])
x, y = self._correct_cropping(x, y)
l = np.asarray(self._data_frame[sc][bp]["likelihood"] if "likelihood" in self._positions else [])
time = np.arange(len(x))/framerate
return TrackingData(x, y, time, l, bp, fps=framerate)

View File

@ -0,0 +1,137 @@
import os
import numpy as np
import pandas as pd
import numbers as nb
import nixtrack as nt
from ..tracking_data import TrackingData
from IPython import embed
class NixtrackData(object):
"""Wrapper around a nix data file that has been written accorind to the nixtrack model (https://github.com/bendalab/nixtrack)
"""
def __init__(self, filename, crop=(0, 0)) -> None:
"""
If the video data was cropped before tracking and the tracked positions are with respect to the cropped images, we may want to correct for this to convert the data back to absolute positions in the video frame.
Parameters
----------
filename : str
full filename
crop : 2-tuple
tuple of (xoffset, yoffset)
Raises
------
ValueError if crop value is not a 2-tuple
"""
if not os.path.exists(filename):
raise ValueError("File %s does not exist!" % filename)
if not isinstance(crop, tuple) or len(crop) < 2:
raise ValueError("Cropping info must be a 2-tuple of (x, y)")
self._file_name = filename
self._crop = crop
self._dataset = nt.Dataset(self._file_name)
if not self._dataset.is_open:
raise ValueError(f"An error occurred opening file {self._file_name}! File is not open!")
@property
def filename(self):
"""
Returns the name of the file associated with the NixtrackData object.
Returns:
str: The name of the file.
"""
return self._file_name
@property
def bodyparts(self):
"""
Returns the bodyparts of the dataset.
Returns:
list: A list of bodyparts.
"""
return self._dataset.nodes
def _correct_cropping(self, orgx, orgy):
"""
Corrects the coordinates based on the cropping values, If it cropping was done during tracking.
Args:
orgx (int): The original x-coordinate.
orgy (int): The original y-coordinate.
Returns:
tuple: A tuple containing the corrected x and y coordinates.
"""
x = orgx + self._crop[0]
y = orgy + self._crop[1]
return x, y
@property
def fps(self):
"""Property that holds frames per second of the original video.
Returns
-------
int : the frames of second
"""
return self._dataset.fps
@property
def tracks(self):
"""
Returns a list of tracks from the dataset.
Returns:
list: A list of tracks.
"""
return [t[0] for t in self._dataset.tracks]
def track_data(self, bodypart=0, track=-1, fps=None):
"""
Retrieve tracking data for a specific body part and track.
Parameters
----------
bodypart : int or str
Index or name of the body part to retrieve tracking data for.
track : int or str
Index of the track to retrieve tracking data for.
fps : float
Frames per second of the tracking data. If not provided, it will be retrieved from the dataset.
Returns
-------
TrackingData: An object containing the x and y positions, time, score, body part name, and frames per second.
Raises
------
ValueError: If the body part or track is not valid.
"""
if isinstance(bodypart, nb.Number):
bp = self.bodyparts[bodypart]
elif isinstance(bodypart, (str)) and bodypart in self.bodyparts:
bp = bodypart
else:
raise ValueError(f"Body part {bodypart} is not a tracked node!")
if track not in self.tracks:
raise ValueError(f"Track {track} is not a valid track name!")
if not isinstance(track, (list, tuple)):
track = [track]
elif isinstance(track, int):
track = [self.tracks[track]]
if fps is None:
fps = self._dataset.fps
positions, time, _, nscore = self._dataset.positions(node=bp, axis_type=nt.AxisType.Time)
valid = ~np.isnan(positions[:, 0])
positions = positions[valid,:]
time = time[valid]
score = nscore[valid]
return TrackingData(positions[:, 0], positions[:, 1], time, score, bp, fps=fps)

267
src/etrack/tracking_data.py Normal file
View File

@ -0,0 +1,267 @@
"""
Module that defines the TrackingData class that wraps the position data for a given node/bodypart that has been tracked.
"""
import numpy as np
class TrackingData(object):
"""Class that represents tracking data, i.e. positions of an agent tracked in an environment.
These data are the x, and y-positions, the time at which the agent was detected, and the quality associated with the position estimation.
TrackingData contains these data and offers a few functions to work with it.
Using the 'quality_threshold', 'temporal_limits', or the 'position_limits' data can be filtered (see filter_tracks function).
The 'interpolate' function allows to fill up the gaps that may result from filtering with linearly interpolated data points.
More may follow...
"""
def __init__(self, x, y, time, quality, node="", fps=None,
quality_threshold=None, temporal_limits=None, position_limits=None) -> None:
"""
Initialize a TrackingData object.
Parameters
----------
x : float
The x-coordinates of the tracking data.
y : float
The y-coordinates of the tracking data.
time : float
The time vector associated with the x-, and y-coordinates.
quality : float
The quality score associated with the position estimates.
node : str, optional
The node name associated with the data. Default is an empty string.
fps : float, optional
The frames per second of the tracking data. Default is None.
quality_threshold : float, optional
The quality threshold for the tracking data. Default is None.
temporal_limits : tuple, optional
The temporal limits for the tracking data. Default is None.
position_limits : tuple, optional
The position limits for the tracking data. Default is None.
"""
self._orgx = x
self._orgy = y
self._orgtime = time
self._orgquality = quality
self._x = x
self._y = y
self._time = time
self._quality = quality
self._node = node
self._threshold = quality_threshold
self._position_limits = position_limits
self._time_limits = temporal_limits
self._fps = fps
@property
def original_positions(self):
return self._orgx, self._orgy
@property
def original_quality(self):
return self._orgquality
def interpolate(self, start_time=None, end_time=None, min_count=5):
if len(self._x) < min_count:
print(
f"{self._node} data has less than {min_count} data points with sufficient quality ({len(self._x)})!"
)
return None, None, None
start = self._time[0] if start_time is None else start_time
end = self._time[-1] if end_time is None else end_time
time = np.arange(start, end, 1.0 / self._fps)
x = np.interp(time, self._time, self._x)
y = np.interp(time, self._time, self._y)
return x, y, time
@property
def quality_threshold(self):
"""Property that holds the quality filter setting.
Returns
-------
float : the quality threshold
"""
return self._threshold
@quality_threshold.setter
def quality_threshold(self, new_threshold):
"""Setter of the quality threshold that should be applied when filtering the data. Setting this to None removes the quality filter.
Data points that have a quality score below the given threshold are discarded.
Parameters
----------
new_threshold : float
"""
self._threshold = new_threshold
@property
def position_limits(self):
"""
Get the position limits of the tracking data.
Returns:
tuple: A 4-tuple containing the start x, and y positions, width and height limits.
"""
return self._position_limits
@position_limits.setter
def position_limits(self, new_limits):
"""Sets the limits for the position filter. 'new_limits' must be a 4-tuple of the form (x0, y0, width, height). If None, the limits will be removed.
Data points outside the position limits are discarded.
Parameters
----------
new_limits: 4-tuple
tuple of x-position, y-position, the width and the height. Passing None removes the filter
Raises
------
ValueError, if new_value is not a 4-tuple
"""
if new_limits is not None and not (
isinstance(new_limits, (tuple, list)) and len(new_limits) == 4
):
raise ValueError(
f"The new_limits vector must be a 4-tuple of the form (x, y, width, height)"
)
self._position_limits = new_limits
@property
def temporal_limits(self):
"""
Get the temporal limits of the tracking data.
Returns:
tuple: A tuple containing the start and end time of the tracking data.
"""
return self._time_limits
@temporal_limits.setter
def temporal_limits(self, new_limits):
"""Limits for temporal filter. The limits must be a 2-tuple of start and end time. Setting this to None removes the filter.
Data points the are associated with times outside the limits are discarded.
Parameters
----------
new_limits : 2-tuple
The new limits in the form (start, end) given in seconds.
Returns
-------
None
Raises
------
ValueError if the limits are not valid.
"""
if new_limits is not None and not (
isinstance(new_limits, (tuple, list)) and len(new_limits) == 2
):
raise ValueError(
f"The new_limits vector must be a 2-tuple of the form (start, end). "
)
self._time_limits = new_limits
def filter_tracks(self, align_time=True):
"""Applies the filters to the tracking data. All filters will be applied sequentially, i.e. an AND connection.
To change the filter settings use the setters for 'quality_threshold', 'temporal_limits', 'position_limits'. Setting them to None disables the respective filter discarding the setting.
Parameters
----------
align_time: bool
Controls whether the time vector is aligned to the first time point at which the agent is within the positional_limits. Default = True
"""
self._x = self._orgx.copy()
self._y = self._orgy.copy()
self._time = self._orgtime.copy()
self._quality = self.original_quality.copy()
if self.position_limits is not None:
x_max = self.position_limits[0] + self.position_limits[2]
y_max = self.position_limits[1] + self.position_limits[3]
indices = np.where(
(self._x >= self.position_limits[0])
& (self._x < x_max)
& (self._y >= self.position_limits[1])
& (self._y < y_max)
)
self._x = self._x[indices]
self._y = self._y[indices]
self._time = self._time[indices] - self._time[0] if align_time else 0.0
self._quality = self._quality[indices]
if self.temporal_limits is not None:
indices = np.where(
(self._time >= self.temporal_limits[0])
& (self._time < self.temporal_limits[1])
)
self._x = self._x[indices]
self._y = self._y[indices]
self._time = self._time[indices]
self._quality = self._quality[indices]
if self.quality_threshold is not None:
indices = np.where((self._quality >= self.quality_threshold))
self._x = self._x[indices]
self._y = self._y[indices]
self._time = self._time[indices]
self._quality = self._quality[indices]
def positions(self):
"""Returns the filtered data (if filters have been applied, otherwise the original data).
Returns
-------
np.ndarray
The x-positions
np.ndarray
The y-positions
np.ndarray
The time vector
np.ndarray
The detection quality
"""
return self._x, self._y, self._time, self._quality
def speed(self, x=None, y=None, t=None):
""" Returns the agent's speed as a function of time and position. The speed estimation is associated to the time/position between two sample points. If any of the arguments is not provided, the function will use the x,y coordinates that are stored within the object, otherwise, if all are provided, the user-provided values will be used.
Since the velocities are estimated from the difference between two sample points the returned velocities and positions are assigned to positions and times between the respective sampled positions/times.
Parameters
----------
x: np.ndarray
The x-coordinates, defaults to None
y: np.ndarray
The y-coordinates, defaults to None
t: np.ndarray
The time vector, defaults to None
Returns
-------
np.ndarray:
The time vector.
np.ndarray:
The speed.
tuple of np.ndarray
The position
"""
if x is None or y is None or t is None:
x = self._x.copy()
y = self._y.copy()
t = self._time.copy()
dt = np.diff(t)
speed = np.sqrt(np.diff(x)**2 + np.diff(y)**2) / dt
t = t[:-1] + dt / 2
x = x[:-1] + np.diff(x) / 2
y = y[:-1] + np.diff(y) / 2
return t, speed, (x, y)
def __repr__(self) -> str:
s = f"Tracking data of node '{self._node}'!"
return s

View File

@ -14,6 +14,10 @@ y_factor = 0.81/height # Einheit m/px
center = (np.round(x_0 + width/2), np.round(y_0 + height/2)) center = (np.round(x_0 + width/2), np.round(y_0 + height/2))
center_meter = ((center[0] - x_0) * x_factor, (center[1] - y_0) * y_factor) center_meter = ((center[0] - x_0) * x_factor, (center[1] - y_0) * y_factor)
""" """
def coordinate_transformation(position,x_0, y_0, x_factor, y_factor):
x = (position[0] - x_0) * x_factor
y = (position[1] - y_0) * y_factor
return (x, y) #in m
class TrackingResult(object): class TrackingResult(object):
@ -105,7 +109,11 @@ class TrackingResult(object):
bp string: the body part bp string: the body part
[type]: [description] [type]: [description]
""" """
time, x, y, l, bp = self.pixel_positions(scorer, bodypart, framerate, interpolate, min_likelihood)
x, y = self._to_meter(x, y)
return time, x, y, l, bp
def pixel_positions(self, scorer=0, bodypart=0, framerate=30, interpolate=True, min_likelihood=0.95):
if isinstance(scorer, nb.Number): if isinstance(scorer, nb.Number):
sc = self._scorer[scorer] sc = self._scorer[scorer]
elif isinstance(scorer, str) and scorer in self._scorer: elif isinstance(scorer, str) and scorer in self._scorer:
@ -119,25 +127,37 @@ class TrackingResult(object):
else: else:
raise ValueError("Bodypart %s is not in dataframe!" % bodypart) raise ValueError("Bodypart %s is not in dataframe!" % bodypart)
x = self._data_frame[sc][bp]["x"] if "x" in self._positions else [] x = np.asarray(self._data_frame[sc][bp]["x"] if "x" in self._positions else [])
x = (np.asarray(x) - self.x_0) * self.x_factor y = np.asarray(self._data_frame[sc][bp]["y"] if "y" in self._positions else [])
y = self._data_frame[sc][bp]["y"] if "y" in self._positions else [] l = np.asarray(self._data_frame[sc][bp]["likelihood"] if "likelihood" in self._positions else [])
y = (np.asarray(y) - self.y_0) * self.y_factor
l = self._data_frame[sc][bp]["likelihood"] if "likelihood" in self._positions else [] time = np.arange(len(x))/framerate
if interpolate:
time = np.arange(len(self._data_frame))/framerate x, y = self.interpolate(time, x, y, l, min_likelihood)
time2 = time[l > min_likelihood] return time, x, y, l, bp
if len(l[l > min_likelihood]) < 100:
print("%s has not datapoints with likelihood larger than %.2f" % (self._file_name, min_likelihood) ) def _to_meter(self, x, y):
return None, None, None, None, None new_x = (np.asarray(x) - self.x_0) * self.x_factor
new_y = (np.asarray(y) - self.y_0) * self.y_factor
return new_x, new_y
def _speed(self, t, x, y):
speed = np.sqrt(np.diff(x)**2 + np.diff(y)**2) / np.diff(t)
return speed
def interpolate(self, t, x, y, l, min_likelihood=0.9):
time2 = t[l > min_likelihood]
if len(l[l > min_likelihood]) < 10:
print("%s has less than 10 datapoints with likelihood larger than %.2f" % (self._file_name, min_likelihood) )
return None, None
x2 = x[l > min_likelihood] x2 = x[l > min_likelihood]
y2 = y[l > min_likelihood] y2 = y[l > min_likelihood]
x3 = np.interp(time, time2, x2) x3 = np.interp(t, time2, x2)
y3 = np.interp(time, time2, y2) y3 = np.interp(t, time2, y2)
return time, x3, y3, l, bp return x3, y3
def plot(self, scorer=0, bodypart=0, threshold=0.9, framerate=30): def plot(self, scorer=0, bodypart=0, threshold=0.9, framerate=30):
t, x, y, l, name = self.position_values(scorer=scorer, bodypart=bodypart, framerate=framerate) t, x, y, l, name = self.position_values(scorer=scorer, bodypart=bodypart, framerate=framerate, min_likelihood=threshold)
plt.scatter(x[l > threshold], y[l > threshold], c=t[l > threshold], label=name) plt.scatter(x[l > threshold], y[l > threshold], c=t[l > threshold], label=name)
plt.scatter(self.center_meter[0], self.center_meter[1], marker="*") plt.scatter(self.center_meter[0], self.center_meter[1], marker="*")
plt.plot(x[l > threshold], y[l > threshold]) plt.plot(x[l > threshold], y[l > threshold])
@ -148,7 +168,6 @@ class TrackingResult(object):
bar.set_label("time [s]") bar.set_label("time [s]")
plt.legend() plt.legend()
plt.show() plt.show()
from IPython import embed
if __name__ == '__main__': if __name__ == '__main__':
@ -166,12 +185,12 @@ if __name__ == '__main__':
x3 = np.interp(time, time2, x2) x3 = np.interp(time, time2, x2)
y3 = np.interp(time, time2, y2) y3 = np.interp(time, time2, y2)
fig, axes = plt.subplots(3,1, sharex=True) fig, axes = plt.subplots(3,1, sharex=True)
axes[0].plot(time, x) axes[0].plot(time, x)
axes[0].plot(time, x3) axes[0].plot(time, x3)
axes[1].plot(time, y) axes[1].plot(time, y)
axes[1].plot(time, y3) axes[1].plot(time, y3)
axes[2].plot(time, l) axes[2].plot(time, l)
plt.show() plt.show()

52
src/etrack/util.py Normal file
View File

@ -0,0 +1,52 @@
"""
Module containing utility functions and enum classes.
"""
from enum import Enum
class Illumination(Enum):
Backlight = 0
Incident = 1
class RegionShape(Enum):
"""
Enumeration representing the shape of a region.
Attributes:
Circular: Represents a circular region.
Rectangular: Represents a rectangular region.
"""
Circular = 0
Rectangular = 1
def __str__(self) -> str:
return self.name
class AnalysisType(Enum):
"""
Enumeration representing different types of analysis used when analyzing whether
positions fall into a given region.
Possible types:
AnalysisType.Full: considers both, the x- and the y-coordinates
AnalysisType.CollapseX: consider only the x-coordinates
AnalysisType.CollapseY: consider only the y-coordinates
"""
Full = 0
CollapseX = 1
CollapseY = 2
def __str__(self) -> str:
"""
Returns the string representation of the analysis type.
Returns:
str: The name of the analysis type.
"""
return self.name
class PositionType(Enum):
Absolute = 0
Cropped = 1

Binary file not shown.

82
test/test_arena.py Normal file
View File

@ -0,0 +1,82 @@
import pytest
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mp
from etrack import Arena, Region, RegionShape
def test_region():
# Create a parent region
parent_region = Region((0, 0), (100, 100), name="parent", region_shape=RegionShape.Rectangular)
# Create a child region
child_region = Region((10, 10), (50, 50), name="child", region_shape=RegionShape.Rectangular, parent=parent_region)
# Test properties
assert child_region.name == "child"
assert child_region.inverted_y == True
assert (child_region._max_extent == np.array((60, 60))).all()
assert (child_region._min_extent == np.array((10, 10))).all()
assert child_region.xmax == 60
assert child_region.xmin == 10
assert child_region.ymin == 10
assert child_region.ymax == 60
assert child_region.position == (10, 10, 50, 50)
assert child_region.is_child == True
# Test fits method
assert parent_region.fits(child_region) == True
# Test points_in_region method
x = [15, 20, 25, 30, 35, 5]
y = [15, 20, 25, 30, 35, 5]
assert (child_region.points_in_region(x, y) == np.array([0, 1, 2, 3, 4])).all()
# Test time_in_region method
x = [5, 15, 20, 25, 30, 35, 35]
y = [5, 15, 20, 25, 30, 35, 65]
time = np.arange(0, len(x), 1)
enter, leave = child_region.time_in_region(x, y, time)
assert enter[0] == 1
assert leave[0] == 5
# Test patch method
patch = child_region.patch(color='red')
assert isinstance(patch, mp.Patch)
# Test __repr__ method
assert repr(child_region) == "Region: 'child' of Rectangular shape."
def test_arena():
# Create an arena
arena = Arena((0, 0), (100, 100), name="arena", arena_shape=RegionShape.Rectangular)
# Test add_region method
arena.add_region("small rect1", (0, 0), (50, 50))
assert len(arena.regions) == 1
assert arena.regions["small rect1"].name == "small rect1"
# Test remove_region method
arena.remove_region("small rect1")
assert len(arena.regions) == 0
# Test plot method
axis = arena.plot()
assert isinstance(axis, plt.Axes)
# Test region_vector method
x = [10, 20, 30]
y = [10, 20, 30]
assert (arena.region_vector(x, y) == "").all()
# Test in_region method
# assert len(arena.in_region(10, 10)) > 0
# print(arena.in_region(10, 10))
# print(arena.in_region(110, 110))
# assert arena.in_region(110, 110) == False
# Test __getitem__ method
arena.add_region("small rect2", (0, 0), (50, 50))
assert arena["small rect2"].name == "small rect2"
if __name__ == "__main__":
pytest.main()

32
test/test_nixtrackio.py Normal file
View File

@ -0,0 +1,32 @@
import pytest
import etrack as et
from IPython import embed
dataset = "test/2022lepto01_converted_2024.03.27_0.mp4.nix"
@pytest.fixture
def nixtrack_data():
# Create a NixTrackData object with some test data
return et.NixtrackData(dataset)
def test_basics(nixtrack_data):
assert nixtrack_data.filename == dataset
assert len(nixtrack_data.bodyparts) == 5
assert len(nixtrack_data.tracks) == 2
assert nixtrack_data.fps == 25
def test_trackingdata(nixtrack_data):
with pytest.raises(ValueError):
nixtrack_data.track_data(bodypart="test")
nixtrack_data.track_data(track="fish")
assert nixtrack_data.track_data("center") is not None
assert nixtrack_data.track_data("center", "none") is not None
if __name__ == "__main__":
pytest.main()