docs, fixes and test for nixtrack data wrapper
This commit is contained in:
parent
1dd318f23e
commit
cbd0541a54
@ -9,7 +9,8 @@ from IPython import embed
|
||||
|
||||
|
||||
class NixtrackData(object):
|
||||
|
||||
"""Wrapper around a nix data file that has been written accorind to the nixtrack model (https://github.com/bendalab/nixtrack)
|
||||
"""
|
||||
def __init__(self, filename, crop=(0, 0)) -> None:
|
||||
"""
|
||||
If the video data was cropped before tracking and the tracked positions are with respect to the cropped images, we may want to correct for this to convert the data back to absolute positions in the video frame.
|
||||
@ -37,28 +38,93 @@ class NixtrackData(object):
|
||||
|
||||
@property
|
||||
def filename(self):
|
||||
"""
|
||||
Returns the name of the file associated with the NixtrackData object.
|
||||
|
||||
Returns:
|
||||
str: The name of the file.
|
||||
"""
|
||||
return self._file_name
|
||||
|
||||
@property
|
||||
def bodyparts(self):
|
||||
"""
|
||||
Returns the bodyparts of the dataset.
|
||||
|
||||
Returns:
|
||||
list: A list of bodyparts.
|
||||
"""
|
||||
return self._dataset.nodes
|
||||
|
||||
def _correct_cropping(self, orgx, orgy):
|
||||
"""
|
||||
Corrects the coordinates based on the cropping values, If it cropping was done during tracking.
|
||||
|
||||
Args:
|
||||
orgx (int): The original x-coordinate.
|
||||
orgy (int): The original y-coordinate.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing the corrected x and y coordinates.
|
||||
"""
|
||||
x = orgx + self._crop[0]
|
||||
y = orgy + self._crop[1]
|
||||
return x, y
|
||||
|
||||
@property
|
||||
def fps(self):
|
||||
"""Property that holds frames per second of the original video.
|
||||
Returns
|
||||
-------
|
||||
int : the frames of second
|
||||
"""
|
||||
return self._dataset.fps
|
||||
|
||||
@property
|
||||
def tracks(self):
|
||||
return self._dataset.tracks
|
||||
"""
|
||||
Returns a list of tracks from the dataset.
|
||||
|
||||
Returns:
|
||||
list: A list of tracks.
|
||||
"""
|
||||
return [t[0] for t in self._dataset.tracks]
|
||||
|
||||
def track_data(self, bodypart=0, track=-1, fps=None):
|
||||
"""
|
||||
Retrieve tracking data for a specific body part and track.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
bodypart : int or str
|
||||
Index or name of the body part to retrieve tracking data for.
|
||||
track : int or str
|
||||
Index of the track to retrieve tracking data for.
|
||||
fps : float
|
||||
Frames per second of the tracking data. If not provided, it will be retrieved from the dataset.
|
||||
|
||||
Returns
|
||||
-------
|
||||
TrackingData: An object containing the x and y positions, time, score, body part name, and frames per second.
|
||||
|
||||
def track(self, bodypart=0, fps=None):
|
||||
Raises
|
||||
------
|
||||
ValueError: If the body part or track is not valid.
|
||||
"""
|
||||
if isinstance(bodypart, nb.Number):
|
||||
bp = self.bodyparts[bodypart]
|
||||
elif isinstance(bodypart, (str)) and bodypart in self.bodyparts:
|
||||
bp = bodypart
|
||||
else:
|
||||
raise ValueError(f"Body part {bodypart} is not a tracked node!")
|
||||
|
||||
if track not in self.tracks:
|
||||
raise ValueError(f"Track {track} is not a valid track name!")
|
||||
if not isinstance(track, (list, tuple)):
|
||||
track = [track]
|
||||
elif isinstance(track, int):
|
||||
track = [self.tracks[track]]
|
||||
|
||||
if fps is None:
|
||||
fps = self._dataset.fps
|
||||
|
||||
|
BIN
test/2022lepto01_converted_2024.03.27_0.mp4.nix
Normal file
BIN
test/2022lepto01_converted_2024.03.27_0.mp4.nix
Normal file
Binary file not shown.
32
test/test_nixtrackio.py
Normal file
32
test/test_nixtrackio.py
Normal file
@ -0,0 +1,32 @@
|
||||
import pytest
|
||||
import etrack as et
|
||||
|
||||
from IPython import embed
|
||||
|
||||
dataset = "test/2022lepto01_converted_2024.03.27_0.mp4.nix"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def nixtrack_data():
|
||||
# Create a NixTrackData object with some test data
|
||||
return et.NixtrackData(dataset)
|
||||
|
||||
|
||||
def test_basics(nixtrack_data):
|
||||
assert nixtrack_data.filename == dataset
|
||||
assert len(nixtrack_data.bodyparts) == 5
|
||||
assert len(nixtrack_data.tracks) == 2
|
||||
assert nixtrack_data.fps == 25
|
||||
|
||||
|
||||
def test_trackingdata(nixtrack_data):
|
||||
with pytest.raises(ValueError):
|
||||
nixtrack_data.track_data(bodypart="test")
|
||||
nixtrack_data.track_data(track="fish")
|
||||
|
||||
assert nixtrack_data.track_data("center") is not None
|
||||
assert nixtrack_data.track_data("center", "none") is not None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main()
|
Loading…
Reference in New Issue
Block a user