diff --git a/src/etrack/io/nixtrack_data.py b/src/etrack/io/nixtrack_data.py index 6d0ec8d..0ef6685 100644 --- a/src/etrack/io/nixtrack_data.py +++ b/src/etrack/io/nixtrack_data.py @@ -9,7 +9,8 @@ from IPython import embed class NixtrackData(object): - + """Wrapper around a nix data file that has been written accorind to the nixtrack model (https://github.com/bendalab/nixtrack) + """ def __init__(self, filename, crop=(0, 0)) -> None: """ If the video data was cropped before tracking and the tracked positions are with respect to the cropped images, we may want to correct for this to convert the data back to absolute positions in the video frame. @@ -37,28 +38,93 @@ class NixtrackData(object): @property def filename(self): + """ + Returns the name of the file associated with the NixtrackData object. + + Returns: + str: The name of the file. + """ return self._file_name @property def bodyparts(self): + """ + Returns the bodyparts of the dataset. + + Returns: + list: A list of bodyparts. + """ return self._dataset.nodes def _correct_cropping(self, orgx, orgy): + """ + Corrects the coordinates based on the cropping values, If it cropping was done during tracking. + + Args: + orgx (int): The original x-coordinate. + orgy (int): The original y-coordinate. + + Returns: + tuple: A tuple containing the corrected x and y coordinates. + """ x = orgx + self._crop[0] y = orgy + self._crop[1] return x, y - + + @property + def fps(self): + """Property that holds frames per second of the original video. + Returns + ------- + int : the frames of second + """ + return self._dataset.fps + @property def tracks(self): - return self._dataset.tracks + """ + Returns a list of tracks from the dataset. + + Returns: + list: A list of tracks. + """ + return [t[0] for t in self._dataset.tracks] + + def track_data(self, bodypart=0, track=-1, fps=None): + """ + Retrieve tracking data for a specific body part and track. + + Parameters + ---------- + bodypart : int or str + Index or name of the body part to retrieve tracking data for. + track : int or str + Index of the track to retrieve tracking data for. + fps : float + Frames per second of the tracking data. If not provided, it will be retrieved from the dataset. + + Returns + ------- + TrackingData: An object containing the x and y positions, time, score, body part name, and frames per second. - def track(self, bodypart=0, fps=None): + Raises + ------ + ValueError: If the body part or track is not valid. + """ if isinstance(bodypart, nb.Number): bp = self.bodyparts[bodypart] elif isinstance(bodypart, (str)) and bodypart in self.bodyparts: bp = bodypart else: raise ValueError(f"Body part {bodypart} is not a tracked node!") + + if track not in self.tracks: + raise ValueError(f"Track {track} is not a valid track name!") + if not isinstance(track, (list, tuple)): + track = [track] + elif isinstance(track, int): + track = [self.tracks[track]] + if fps is None: fps = self._dataset.fps diff --git a/test/2022lepto01_converted_2024.03.27_0.mp4.nix b/test/2022lepto01_converted_2024.03.27_0.mp4.nix new file mode 100644 index 0000000..378af3a Binary files /dev/null and b/test/2022lepto01_converted_2024.03.27_0.mp4.nix differ diff --git a/test/test_nixtrackio.py b/test/test_nixtrackio.py new file mode 100644 index 0000000..bce61e1 --- /dev/null +++ b/test/test_nixtrackio.py @@ -0,0 +1,32 @@ +import pytest +import etrack as et + +from IPython import embed + +dataset = "test/2022lepto01_converted_2024.03.27_0.mp4.nix" + + +@pytest.fixture +def nixtrack_data(): + # Create a NixTrackData object with some test data + return et.NixtrackData(dataset) + + +def test_basics(nixtrack_data): + assert nixtrack_data.filename == dataset + assert len(nixtrack_data.bodyparts) == 5 + assert len(nixtrack_data.tracks) == 2 + assert nixtrack_data.fps == 25 + + +def test_trackingdata(nixtrack_data): + with pytest.raises(ValueError): + nixtrack_data.track_data(bodypart="test") + nixtrack_data.track_data(track="fish") + + assert nixtrack_data.track_data("center") is not None + assert nixtrack_data.track_data("center", "none") is not None + + +if __name__ == "__main__": + pytest.main() \ No newline at end of file