docs, fixes and test for nixtrack data wrapper

This commit is contained in:
Jan Grewe 2024-06-01 12:19:34 +02:00
parent 1dd318f23e
commit cbd0541a54
3 changed files with 102 additions and 4 deletions

View File

@ -9,7 +9,8 @@ from IPython import embed
class NixtrackData(object):
"""Wrapper around a nix data file that has been written accorind to the nixtrack model (https://github.com/bendalab/nixtrack)
"""
def __init__(self, filename, crop=(0, 0)) -> None:
"""
If the video data was cropped before tracking and the tracked positions are with respect to the cropped images, we may want to correct for this to convert the data back to absolute positions in the video frame.
@ -37,28 +38,93 @@ class NixtrackData(object):
@property
def filename(self):
"""
Returns the name of the file associated with the NixtrackData object.
Returns:
str: The name of the file.
"""
return self._file_name
@property
def bodyparts(self):
"""
Returns the bodyparts of the dataset.
Returns:
list: A list of bodyparts.
"""
return self._dataset.nodes
def _correct_cropping(self, orgx, orgy):
"""
Corrects the coordinates based on the cropping values, If it cropping was done during tracking.
Args:
orgx (int): The original x-coordinate.
orgy (int): The original y-coordinate.
Returns:
tuple: A tuple containing the corrected x and y coordinates.
"""
x = orgx + self._crop[0]
y = orgy + self._crop[1]
return x, y
@property
def fps(self):
"""Property that holds frames per second of the original video.
Returns
-------
int : the frames of second
"""
return self._dataset.fps
@property
def tracks(self):
return self._dataset.tracks
"""
Returns a list of tracks from the dataset.
Returns:
list: A list of tracks.
"""
return [t[0] for t in self._dataset.tracks]
def track_data(self, bodypart=0, track=-1, fps=None):
"""
Retrieve tracking data for a specific body part and track.
Parameters
----------
bodypart : int or str
Index or name of the body part to retrieve tracking data for.
track : int or str
Index of the track to retrieve tracking data for.
fps : float
Frames per second of the tracking data. If not provided, it will be retrieved from the dataset.
Returns
-------
TrackingData: An object containing the x and y positions, time, score, body part name, and frames per second.
def track(self, bodypart=0, fps=None):
Raises
------
ValueError: If the body part or track is not valid.
"""
if isinstance(bodypart, nb.Number):
bp = self.bodyparts[bodypart]
elif isinstance(bodypart, (str)) and bodypart in self.bodyparts:
bp = bodypart
else:
raise ValueError(f"Body part {bodypart} is not a tracked node!")
if track not in self.tracks:
raise ValueError(f"Track {track} is not a valid track name!")
if not isinstance(track, (list, tuple)):
track = [track]
elif isinstance(track, int):
track = [self.tracks[track]]
if fps is None:
fps = self._dataset.fps

Binary file not shown.

32
test/test_nixtrackio.py Normal file
View File

@ -0,0 +1,32 @@
import pytest
import etrack as et
from IPython import embed
dataset = "test/2022lepto01_converted_2024.03.27_0.mp4.nix"
@pytest.fixture
def nixtrack_data():
# Create a NixTrackData object with some test data
return et.NixtrackData(dataset)
def test_basics(nixtrack_data):
assert nixtrack_data.filename == dataset
assert len(nixtrack_data.bodyparts) == 5
assert len(nixtrack_data.tracks) == 2
assert nixtrack_data.fps == 25
def test_trackingdata(nixtrack_data):
with pytest.raises(ValueError):
nixtrack_data.track_data(bodypart="test")
nixtrack_data.track_data(track="fish")
assert nixtrack_data.track_data("center") is not None
assert nixtrack_data.track_data("center", "none") is not None
if __name__ == "__main__":
pytest.main()