restructuring project with toml add a test and docstrings
This commit is contained in:
7
src/etrack/__init__.py
Normal file
7
src/etrack/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from .image_marker import ImageMarker, MarkerTask
|
||||
from .tracking_result import TrackingResult, coordinate_transformation
|
||||
from .arena import Arena, Region
|
||||
from .tracking_data import TrackingData
|
||||
from .io.dlc_data import DLCReader
|
||||
from .io.nixtrack_data import NixtrackData
|
||||
from .util import RegionShape, AnalysisType
|
||||
519
src/etrack/arena.py
Normal file
519
src/etrack/arena.py
Normal file
@@ -0,0 +1,519 @@
|
||||
import logging
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.patches as patches
|
||||
|
||||
from skimage.draw import disk
|
||||
|
||||
from .util import RegionShape, AnalysisType, Illumination
|
||||
from IPython import embed
|
||||
|
||||
|
||||
class Region(object):
|
||||
"""
|
||||
Class representing a region (of interest). Regions can be either circular or rectangular.
|
||||
A Region can have a parent, i.e. it is contained inside a parent region. It can also have children.
|
||||
|
||||
Coordinates are given in absolute coordinates. The extent is treated depending on the shape. In case of a circular
|
||||
shape, it is the radius and the origin is the center of the circle. Otherwise the origin is the bottom, or top-left corner, depending on the y-axis orientation, if inverted, then it is top-left. FIXME: check this
|
||||
|
||||
"""
|
||||
def __init__(self, origin, extent, inverted_y=True, name="", region_shape=RegionShape.Rectangular, parent=None) -> None:
|
||||
"""Region constructor.
|
||||
Parameters
|
||||
----------
|
||||
origin : 2-tuple
|
||||
x, and y coordinates
|
||||
extent : scalar or 2-tuple, scalar only allowed to circular regions, 2-tuple for rectangular.
|
||||
inverted_y : bool, optional
|
||||
_description_, by default True
|
||||
name : str, optional
|
||||
_description_, by default ""
|
||||
region_shape : _type_, optional
|
||||
_description_, by default RegionShape.Rectangular
|
||||
parent : _type_, optional
|
||||
_description_, by default None
|
||||
|
||||
Returns
|
||||
-------
|
||||
_type_
|
||||
_description_
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
Raises Value error when origin or extent are invalid
|
||||
"""
|
||||
logging.debug(
|
||||
f"etrack.Region: Create {str(region_shape)} region {name} with props origin {origin}, extent {extent} and parent {parent}"
|
||||
)
|
||||
if len(origin) != 2:
|
||||
raise ValueError("Region: origin must be 2-tuple!")
|
||||
self._parent = parent
|
||||
self._name = name
|
||||
self._shape_type = region_shape
|
||||
self._origin = origin
|
||||
self._check_extent(extent)
|
||||
self._extent = extent
|
||||
self._inverted_y = inverted_y
|
||||
|
||||
@staticmethod
|
||||
def circular_mask(width, height, center, radius):
|
||||
assert center[1] + radius < width and center[1] - radius > 0
|
||||
assert center[0] + radius < height and center[0] - radius > 0
|
||||
|
||||
mask = np.zeros((height, width), dtype=np.uint8)
|
||||
rr, cc = disk(reversed(center), radius)
|
||||
mask[rr, cc] = 1
|
||||
|
||||
return mask
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def inverted_y(self):
|
||||
return self._inverted_y
|
||||
|
||||
@property
|
||||
def _max_extent(self):
|
||||
if self._shape_type == RegionShape.Rectangular:
|
||||
max_extent = (
|
||||
self._origin[0] + self._extent[0],
|
||||
self._origin[1] + self._extent[1],
|
||||
)
|
||||
else:
|
||||
max_extent = (
|
||||
self._origin[0] + self._extent,
|
||||
self._origin[1] + self._extent,
|
||||
)
|
||||
return np.asarray(max_extent)
|
||||
|
||||
@property
|
||||
def _min_extent(self):
|
||||
if self._shape_type == RegionShape.Rectangular:
|
||||
min_extent = self._origin
|
||||
else:
|
||||
min_extent = (
|
||||
self._origin[0] - self._extent,
|
||||
self._origin[1] - self._extent,
|
||||
)
|
||||
return np.asarray(min_extent)
|
||||
|
||||
@property
|
||||
def xmax(self):
|
||||
return self._max_extent[0]
|
||||
|
||||
@property
|
||||
def xmin(self):
|
||||
return self._min_extent[0]
|
||||
|
||||
@property
|
||||
def ymin(self):
|
||||
return self._min_extent[1]
|
||||
|
||||
@property
|
||||
def ymax(self):
|
||||
return self._max_extent[1]
|
||||
|
||||
@property
|
||||
def position(self):
|
||||
"""Returns the position and extent of the region as 4-tuple, (x, y, width, height)"""
|
||||
x = self._min_extent[0]
|
||||
y = self._min_extent[1]
|
||||
width = self._max_extent[0] - self._min_extent[0]
|
||||
height = self._max_extent[1] - self._min_extent[1]
|
||||
return x, y, width, height
|
||||
|
||||
def _check_extent(self, ext):
|
||||
"""Checks whether the extent matches the shape. i.e. if the shape is Rectangular, extent must be a length 2 list, tuple, otherwise, if the region is circular, extent must be a single numerical value.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ext : tuple, or numeric scalar
|
||||
"""
|
||||
if self._shape_type == RegionShape.Rectangular:
|
||||
if not isinstance(ext, (list, tuple, np.ndarray)) and len(ext) != 2:
|
||||
raise ValueError(
|
||||
"Extent must be a length 2 list or tuple for rectangular regions!"
|
||||
)
|
||||
elif self._shape_type == RegionShape.Circular:
|
||||
if not isinstance(ext, (int, float)):
|
||||
raise ValueError(
|
||||
"Extent must be a numerical scalar for circular regions!"
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid ShapeType, {self._shape_type}!")
|
||||
|
||||
def fits(self, other) -> bool:
|
||||
"""
|
||||
Checks if the given region fits into the current region.
|
||||
|
||||
Args:
|
||||
other (Region): The region to check if it fits.
|
||||
|
||||
Returns:
|
||||
bool: True if the given region fits into the current region, False otherwise.
|
||||
"""
|
||||
assert isinstance(other, Region)
|
||||
does_fit = all(
|
||||
(
|
||||
other._min_extent[0] >= self._min_extent[0],
|
||||
other._min_extent[1] >= self._min_extent[1],
|
||||
other._max_extent[0] <= self._max_extent[0],
|
||||
other._max_extent[1] <= self._max_extent[1],
|
||||
)
|
||||
)
|
||||
if not does_fit:
|
||||
m = (
|
||||
f"Region {other.name} does not fit into {self.name}. "
|
||||
f"min x: {other._min_extent[0] >= self._min_extent[0]},",
|
||||
f"min y: {other._min_extent[1] >= self._min_extent[1]},",
|
||||
f"max x: {other._max_extent[0] <= self._max_extent[0]},",
|
||||
f"max y: {other._max_extent[1] <= self._max_extent[1]}",
|
||||
)
|
||||
logging.debug(m)
|
||||
return does_fit
|
||||
|
||||
@property
|
||||
def is_child(self):
|
||||
"""
|
||||
Check if the current instance is a child.
|
||||
|
||||
Returns:
|
||||
bool: True if the instance has a parent, False otherwise.
|
||||
"""
|
||||
return self._parent is not None
|
||||
|
||||
def points_in_region(self, x, y, analysis_type=AnalysisType.Full):
|
||||
"""Returns the indices of the points specified by 'x' and 'y' that fall into this region.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : np.ndarray
|
||||
the x positions
|
||||
y : np.ndarray
|
||||
the y positions
|
||||
analysis_type : AnalysisType, optional
|
||||
defines how the positions are evaluated, by default AnalysisType.Full
|
||||
FIXME: some of this can probably be solved using linear algebra, what with multiple exact same points?
|
||||
"""
|
||||
if self._shape_type == RegionShape.Rectangular or (
|
||||
self._shape_type == RegionShape.Circular
|
||||
and analysis_type != AnalysisType.Full
|
||||
):
|
||||
if analysis_type == AnalysisType.Full:
|
||||
indices = np.where(
|
||||
((y >= self._min_extent[1]) & (y <= self._max_extent[1]))
|
||||
& ((x >= self._min_extent[0]) & (x <= self._max_extent[0]))
|
||||
)[0]
|
||||
indices = np.array(indices, dtype=int)
|
||||
elif analysis_type == AnalysisType.CollapseX:
|
||||
x_indices = np.where(
|
||||
(x >= self._min_extent[0]) & (x <= self._max_extent[0])
|
||||
)[0]
|
||||
indices = np.asarray(x_indices, dtype=int)
|
||||
else:
|
||||
y_indices = np.where(
|
||||
(y >= self._min_extent[1]) & (y <= self._max_extent[1])
|
||||
)[0]
|
||||
indices = np.asarray(y_indices, dtype=int)
|
||||
else:
|
||||
if self.is_child:
|
||||
mask = self.circular_mask(
|
||||
self._parent.position[2],
|
||||
self._parent.position[3],
|
||||
self._origin,
|
||||
self._extent,
|
||||
)
|
||||
else:
|
||||
mask = self.circular_mask(
|
||||
self.position[2], self.position[3], self._origin, self._extent
|
||||
)
|
||||
img = np.zeros_like(mask)
|
||||
img[np.asarray(y, dtype=int), np.asarray(x, dtype=int)] = 1
|
||||
temp = np.where(img & mask)
|
||||
indices = []
|
||||
for i, j in zip(list(temp[1]), list(temp[0])):
|
||||
matches = np.where((x == i) & (y == j))
|
||||
if len(matches[0]) == 0:
|
||||
continue
|
||||
indices.append(matches[0][0])
|
||||
indices = np.array(indices)
|
||||
return indices
|
||||
|
||||
def time_in_region(self, x, y, time, analysis_type=AnalysisType.Full):
|
||||
"""Returns the entering and leaving times at which the animal entered
|
||||
and left a region. In case the animal was not observed after entering
|
||||
this region (for example when hidden in a tube) the leaving time is
|
||||
the maximum time entry.
|
||||
Whether the full position, or only the x- or y-position should be considered
|
||||
is controlled with the analysis_type parameter.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : np.ndarray
|
||||
The animal's x-positions
|
||||
y : np.ndarray
|
||||
the animal's y-positions
|
||||
time : np.ndarray
|
||||
the time array
|
||||
analysis_type : AnalysisType, optional
|
||||
The type of analysis, by default AnalysisType.Full
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
The entering times
|
||||
np.ndarray
|
||||
The leaving times
|
||||
"""
|
||||
indices = self.points_in_region(x, y, analysis_type)
|
||||
if len(indices) == 0:
|
||||
return np.array([]), np.array([])
|
||||
|
||||
diffs = np.diff(indices)
|
||||
if len(diffs) == sum(diffs):
|
||||
entering = [time[indices[0]]]
|
||||
leaving = [time[indices[-1]]]
|
||||
else:
|
||||
entering = []
|
||||
leaving = []
|
||||
jumps = np.where(diffs > 1)[0]
|
||||
start = time[indices[0]]
|
||||
for i in range(len(jumps)):
|
||||
end = time[indices[jumps[i]]]
|
||||
entering.append(start)
|
||||
leaving.append(end)
|
||||
start = time[indices[jumps[i] + 1]]
|
||||
|
||||
end = time[indices[-1]]
|
||||
entering.append(start)
|
||||
leaving.append(end)
|
||||
return np.array(entering), np.array(leaving)
|
||||
|
||||
def patch(self, **kwargs):
|
||||
"""
|
||||
Create and return a matplotlib patch object based on the shape type of the arena.
|
||||
|
||||
Parameters:
|
||||
- kwargs: Additional keyword arguments to customize the patch object.
|
||||
|
||||
Returns:
|
||||
- A matplotlib patch object representing the arena shape.
|
||||
|
||||
If the 'fc' (facecolor) keyword argument is not provided, it will default to None.
|
||||
If the 'fill' keyword argument is not provided, it will default to False.
|
||||
|
||||
For rectangular arenas, the patch object will be a Rectangle with width and height
|
||||
based on the arena's position.
|
||||
For circular arenas, the patch object will be a Circle with radius based on the
|
||||
arena's extent.
|
||||
|
||||
Example usage:
|
||||
```
|
||||
arena = Arena()
|
||||
patch = arena.patch(fc='blue', fill=True)
|
||||
ax.add_patch(patch)
|
||||
```
|
||||
"""
|
||||
if "fc" not in kwargs:
|
||||
kwargs["fc"] = None
|
||||
kwargs["fill"] = False
|
||||
if self._shape_type == RegionShape.Rectangular:
|
||||
w = self.position[2]
|
||||
h = self.position[3]
|
||||
return patches.Rectangle(self._origin, w, h, **kwargs)
|
||||
else:
|
||||
return patches.Circle(self._origin, self._extent, **kwargs)
|
||||
|
||||
def __repr__(self):
|
||||
return f"Region: '{self._name}' of {self._shape_type} shape."
|
||||
|
||||
|
||||
class Arena(Region):
|
||||
"""
|
||||
Class to represent the experimental arena. Arena is derived from Region and can be either rectangular or circular.
|
||||
An arena can not have a parent.
|
||||
See Region for more details.
|
||||
"""
|
||||
def __init__(self, origin, extent, inverted_y=True, name="", arena_shape=RegionShape.Rectangular,
|
||||
illumination=Illumination.Backlight) -> None:
|
||||
""" Construct a new Area with a given origin and extent.
|
||||
|
||||
|
||||
Returns
|
||||
-------
|
||||
_type_
|
||||
_description_
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
_description_
|
||||
"""
|
||||
super().__init__(origin, extent, inverted_y, name, arena_shape)
|
||||
self._illumination = illumination
|
||||
self.regions = {}
|
||||
|
||||
def add_region(
|
||||
self, name, origin, extent, shape_type=RegionShape.Rectangular, region=None
|
||||
):
|
||||
if name is None or name in self.regions.keys():
|
||||
raise ValueError(
|
||||
"Region name '{name}' is invalid. The name must not be None and must be unique among the regions."
|
||||
)
|
||||
if region is None:
|
||||
region = Region(
|
||||
origin, extent, name=name, region_shape=shape_type, parent=self
|
||||
)
|
||||
else:
|
||||
region._parent = self
|
||||
doesfit = self.fits(region)
|
||||
if not doesfit:
|
||||
logging.warn(
|
||||
f"Warning! Region {region.name} with size {region.position} does fit into {self.name} with size {self.position}!"
|
||||
)
|
||||
self.regions[name] = region
|
||||
|
||||
def remove_region(self, name):
|
||||
"""
|
||||
Remove a region from the arena.
|
||||
|
||||
Parameter:
|
||||
name : str
|
||||
The name of the region to remove.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
if name in self.regions:
|
||||
self.regions.pop(name)
|
||||
|
||||
def __repr__(self):
|
||||
return f"Arena: '{self._name}' of {self._shape_type} shape."
|
||||
|
||||
def plot(self, axis=None):
|
||||
"""
|
||||
Plots the arena on the given axis.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
- axis (matplotlib.axes.Axes, optional): The axis on which to plot the arena. If not provided, a new figure and axis will be created.
|
||||
|
||||
Returns
|
||||
-------
|
||||
- matplotlib.axes.Axes: The axis on which the arena is plotted.
|
||||
"""
|
||||
if axis is None:
|
||||
fig = plt.figure()
|
||||
axis = fig.add_subplot(111)
|
||||
axis.add_patch(self.patch())
|
||||
axis.set_xlim([self._origin[0], self._max_extent[0]])
|
||||
|
||||
if self.inverted_y:
|
||||
axis.set_ylim([self._max_extent[1], self._origin[1]])
|
||||
else:
|
||||
axis.set_ylim([self._origin[1], self._max_extent[1]])
|
||||
for r in self.regions:
|
||||
axis.add_patch(self.regions[r].patch())
|
||||
return axis
|
||||
|
||||
def region_vector(self, x, y):
|
||||
"""Returns a vector that contains the region names within which the agent was found.
|
||||
FIXME: This does not work well with overlapping regions!@!
|
||||
Parameters
|
||||
----------
|
||||
x : np.array
|
||||
the x-positions
|
||||
y : np.ndarray
|
||||
the y-positions
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.array
|
||||
vector of the same size as x and y. Each entry is the region to which the position is assigned to. If the point is not assigned to a region, the entry will be empty.
|
||||
"""
|
||||
if not isinstance(x, np.ndarray):
|
||||
x = np.asarray(x)
|
||||
if not isinstance(y, np.ndarray):
|
||||
y = np.asarray(y)
|
||||
rv = np.empty(x.shape, dtype=str)
|
||||
for r in self.regions:
|
||||
indices = self.regions[r].points_in_region(x, y)
|
||||
rv[indices] = r
|
||||
return rv
|
||||
|
||||
def in_region(self, x, y):
|
||||
"""
|
||||
Determines if the given coordinates (x, y) are within any of the defined regions in the arena.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : float
|
||||
The x-coordinate of the point to check.
|
||||
y : float
|
||||
The y-coordinate of the point to check.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict:
|
||||
A dictionary containing the region names as keys and a list of indices of points within each region as values.
|
||||
"""
|
||||
tmp = {}
|
||||
for r in self.regions:
|
||||
print(r)
|
||||
indices = self.regions[r].points_in_region(x, y)
|
||||
tmp[r] = indices
|
||||
return tmp
|
||||
|
||||
def __getitem__(self, key):
|
||||
if isinstance(key, (str)):
|
||||
return self.regions[key]
|
||||
else:
|
||||
return self.regions[self.regions.keys()[key]]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
a = Arena((0, 0), (1024, 768), name="arena", arena_shape=RegionShape.Rectangular)
|
||||
a.add_region("small rect1", (0, 0), (100, 300))
|
||||
a.add_region("small rect2", (150, 0), (100, 300))
|
||||
a.add_region("small rect3", (300, 0), (100, 300))
|
||||
a.add_region("circ", (600, 400), 150, shape_type=RegionShape.Circular)
|
||||
axis = a.plot()
|
||||
x = np.linspace(a.position[0], a.position[0] + a.position[2] - 1, 100, dtype=int)
|
||||
y = np.asarray(
|
||||
(np.sin(x * 0.01) + 1) * a.position[3] / 2 + a.position[1] - 1, dtype=int
|
||||
)
|
||||
# y = np.linspace(a.position[1], a.position[1] + a.position[3] - 1, 100, dtype=int)
|
||||
axis.scatter(x, y, c="k", s=2)
|
||||
|
||||
ind = a.regions[3].points_in_region(x, y)
|
||||
if len(ind) > 0:
|
||||
axis.scatter(x[ind], y[ind], label="circ full")
|
||||
|
||||
ind = a.regions[3].points_in_region(x, y, AnalysisType.CollapseX)
|
||||
if len(ind) > 0:
|
||||
axis.scatter(x[ind], y[ind] - 10, label="circ collapseX")
|
||||
|
||||
ind = a.regions[3].points_in_region(x, y, AnalysisType.CollapseY)
|
||||
if len(ind) > 0:
|
||||
axis.scatter(x[ind], y[ind] + 10, label="circ collapseY")
|
||||
|
||||
ind = a.regions[0].points_in_region(x, y, AnalysisType.CollapseX)
|
||||
if len(ind) > 0:
|
||||
axis.scatter(x[ind], y[ind] - 10, label="rect collapseX")
|
||||
|
||||
ind = a.regions[1].points_in_region(x, y, AnalysisType.CollapseY)
|
||||
if len(ind) > 0:
|
||||
axis.scatter(x[ind], y[ind] + 10, label="rect collapseY")
|
||||
|
||||
ind = a.regions[2].points_in_region(x, y, AnalysisType.Full)
|
||||
if len(ind) > 0:
|
||||
axis.scatter(x[ind], y[ind] + 20, label="rect full")
|
||||
axis.legend()
|
||||
plt.show()
|
||||
|
||||
a.plot()
|
||||
plt.show()
|
||||
153
src/etrack/image_marker.py
Normal file
153
src/etrack/image_marker.py
Normal file
@@ -0,0 +1,153 @@
|
||||
import matplotlib.pyplot as plt
|
||||
import cv2
|
||||
import os
|
||||
import sys
|
||||
from IPython import embed
|
||||
|
||||
class ImageMarker:
|
||||
|
||||
def __init__(self, tasks=[]) -> None:
|
||||
super().__init__()
|
||||
self._fig = plt.figure()
|
||||
self._tasks = tasks
|
||||
self._task_index = -1
|
||||
self._current_task = None
|
||||
self._marker_set = False
|
||||
self._interrupt = False
|
||||
self._fig.canvas.mpl_connect('button_press_event', self._on_click_event)
|
||||
self._fig.canvas.mpl_connect('close_event', self._fig_close_event)
|
||||
self._fig.canvas.mpl_connect('key_press_event', self._key_press_event)
|
||||
|
||||
def mark_movie(self, filename, frame_number=0):
|
||||
if not os.path.exists(filename):
|
||||
raise IOError("file %s does not exist!" % filename)
|
||||
video = cv2.VideoCapture()
|
||||
video.open(filename)
|
||||
frame_counter = 0
|
||||
success = True
|
||||
frame = None
|
||||
while success and frame_counter <= frame_number:
|
||||
print("Reading frame: %i" % frame_counter, end="\r")
|
||||
success, frame = video.read()
|
||||
frame_counter += 1
|
||||
|
||||
if success:
|
||||
self._fig.gca().imshow(frame)
|
||||
else:
|
||||
print("Could not read frame number %i either failed to open movie or beyond maximum frame number!" % frame_number)
|
||||
return []
|
||||
plt.ion()
|
||||
plt.show(block=False)
|
||||
|
||||
self._task_index = -1
|
||||
if len(self._tasks) > 0:
|
||||
self._next_task()
|
||||
|
||||
while not self._tasks_done:
|
||||
plt.pause(0.250)
|
||||
if self._interrupt:
|
||||
return []
|
||||
|
||||
self._fig.gca().set_title("All set and done!\n Window will close in 2s")
|
||||
self._fig.canvas.draw()
|
||||
plt.pause(2.0)
|
||||
return [t.marker_positions for t in self._tasks]
|
||||
|
||||
def _key_press_event(self, event):
|
||||
print("Key pressed: %s!" % event.key)
|
||||
|
||||
@property
|
||||
def _tasks_done(self):
|
||||
done = self._task_index == len(self._tasks) and self._current_task is not None and self._current_task.task_done
|
||||
return done
|
||||
|
||||
def _next_task(self):
|
||||
if self._current_task is None:
|
||||
self._task_index += 1
|
||||
self._current_task = self._tasks[self._task_index]
|
||||
|
||||
if self._current_task is not None and not self._current_task.task_done:
|
||||
self._fig.gca().set_title("%s: \n%s: %s" % (self._current_task.name, self._current_task.message, self._current_task.current_marker))
|
||||
self._fig.canvas.draw()
|
||||
elif self._current_task is not None and self._current_task.task_done:
|
||||
self._task_index += 1
|
||||
if self._task_index < len(self._tasks):
|
||||
self._current_task = self._tasks[self._task_index]
|
||||
self._fig.gca().set_title("%s: \n%s: %s" % (self._current_task.name, self._current_task.message, self._current_task.current_marker))
|
||||
self._fig.canvas.draw()
|
||||
|
||||
def _on_click_event(self, event):
|
||||
self._fig.gca().scatter(event.xdata, event.ydata, marker=self._current_task.marker_symbol, color=self._current_task.marker_color, s=20)
|
||||
event.canvas.draw()
|
||||
self._current_task.set_position(self._current_task.current_marker, event.xdata, event.ydata)
|
||||
self._next_task()
|
||||
|
||||
def _fig_close_event(self, even):
|
||||
self._interrupt = True
|
||||
|
||||
class MarkerTask():
|
||||
def __init__(self, name:str, marker_names=[], message="", marker="o", color="tab:blue") -> None:
|
||||
super().__init__()
|
||||
self._positions = {}
|
||||
self._marker_names = marker_names
|
||||
self._name = name
|
||||
self._message = message
|
||||
self._current_marker = marker_names[0] if len(marker_names) > 0 else None
|
||||
self._current_index = 0
|
||||
self._marker = marker
|
||||
self._marker_color = color
|
||||
|
||||
@property
|
||||
def positions(self):
|
||||
return self._positions
|
||||
|
||||
@property
|
||||
def name(self)->str:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def message(self)->str:
|
||||
return self._message
|
||||
|
||||
def set_position(self, marker_name, x, y):
|
||||
self._positions[marker_name] = (x, y)
|
||||
if not self.task_done:
|
||||
self._current_index += 1
|
||||
self._current_marker = self._marker_names[self._current_index]
|
||||
|
||||
@property
|
||||
def marker_positions(self):
|
||||
return self._positions
|
||||
|
||||
@property
|
||||
def task_done(self):
|
||||
return len(self._positions) == len(self._marker_names)
|
||||
|
||||
@property
|
||||
def current_marker(self):
|
||||
return self._current_marker
|
||||
|
||||
@property
|
||||
def marker_symbol(self):
|
||||
return self._marker
|
||||
|
||||
@property
|
||||
def marker_color(self):
|
||||
return self._marker_color
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "MarkerTask %s with markers: %s" % (self.name, [mn for mn in self._marker_names])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Hello Jan!")
|
||||
tank_task = MarkerTask("tank limits", ["bottom left corner", "top left corner", "top right corner", "bottom right corner"], "Mark tank corners")
|
||||
#feeder_task = MarkerTask("Feeder positions", list(map(str, range(1, 2))), "Mark feeder positions")
|
||||
#tasks = [tank_task, feeder_task]
|
||||
im = ImageMarker([tank_task])
|
||||
vid1 = "/data/personality/secondhome/fischies/lepto_03/position/lepto03_position_2021.06.07_60.mp4"
|
||||
# print(sys.argv[0])
|
||||
# print (sys.argv[1])
|
||||
# vid1 = sys.argv[1]
|
||||
marker_positions = im.mark_movie(vid1, 00)
|
||||
print(marker_positions)
|
||||
10
src/etrack/info.json
Normal file
10
src/etrack/info.json
Normal file
@@ -0,0 +1,10 @@
|
||||
{
|
||||
"VERSION": "0.5.0",
|
||||
"STATUS": "Release",
|
||||
"RELEASE": "0.5.0 Release",
|
||||
"AUTHOR": "Jan Grewe",
|
||||
"COPYRIGHT": "2024, University of Tuebingen, Neuroethology, Jan Grewe",
|
||||
"CONTACT": "jan.grewe@g-node.org",
|
||||
"BRIEF": "Efish tracking helpers for handling tracking data.",
|
||||
"HOMEPAGE": "https://github.com/G-Node/nixpy"
|
||||
}
|
||||
25
src/etrack/info.py
Normal file
25
src/etrack/info.py
Normal file
@@ -0,0 +1,25 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright © 2024, Jan Grewe
|
||||
#
|
||||
# All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted under the terms of the BSD License. See
|
||||
# LICENSE file in the root of the Project.
|
||||
import os
|
||||
import json
|
||||
|
||||
here = os.path.dirname(__file__)
|
||||
|
||||
with open(os.path.join(here, "info.json")) as infofile:
|
||||
infodict = json.load(infofile)
|
||||
|
||||
|
||||
VERSION = infodict["VERSION"]
|
||||
STATUS = infodict["STATUS"]
|
||||
RELEASE = infodict["RELEASE"]
|
||||
AUTHOR = infodict["AUTHOR"]
|
||||
COPYRIGHT = infodict["COPYRIGHT"]
|
||||
CONTACT = infodict["CONTACT"]
|
||||
BRIEF = infodict["BRIEF"]
|
||||
HOMEPAGE = infodict["HOMEPAGE"]
|
||||
0
src/etrack/io/__init__.py
Normal file
0
src/etrack/io/__init__.py
Normal file
78
src/etrack/io/dlc_data.py
Normal file
78
src/etrack/io/dlc_data.py
Normal file
@@ -0,0 +1,78 @@
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import numbers as nb
|
||||
|
||||
from ..tracking_data import TrackingData
|
||||
|
||||
|
||||
class DLCReader(object):
|
||||
|
||||
def __init__(self, results_file, crop=(0, 0)) -> None:
|
||||
"""
|
||||
If the video data was cropped before tracking and the tracked positions are with respect to the cropped images, we may want to correct for this to convert the data back to absolute positions in the video frame.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
crop : 2-tuple
|
||||
tuple of (xoffset, yoffset)
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError if crop value is not a 2-tuple
|
||||
"""
|
||||
if not os.path.exists(results_file):
|
||||
raise ValueError("File %s does not exist!" % results_file)
|
||||
if not isinstance(crop, tuple) or len(crop) < 2:
|
||||
raise ValueError("Cropping info must be a 2-tuple of (x, y)")
|
||||
self._file_name = results_file
|
||||
self._crop = crop
|
||||
self._data_frame = pd.read_hdf(results_file)
|
||||
self._level_shape = self._data_frame.columns.levshape
|
||||
self._scorer = self._data_frame.columns.levels[0].values
|
||||
self._bodyparts = self._data_frame.columns.levels[1].values if self._level_shape[1] > 0 else []
|
||||
self._positions = self._data_frame.columns.levels[2].values if self._level_shape[2] > 0 else []
|
||||
|
||||
@property
|
||||
def filename(self):
|
||||
return self._file_name
|
||||
|
||||
@property
|
||||
def dataframe(self):
|
||||
return self._data_frame
|
||||
|
||||
@property
|
||||
def scorer(self):
|
||||
return self._scorer
|
||||
|
||||
@property
|
||||
def bodyparts(self):
|
||||
return self._bodyparts
|
||||
|
||||
def _correct_cropping(self, orgx, orgy):
|
||||
x = orgx + self._crop[0]
|
||||
y = orgy + self._crop[1]
|
||||
return x, y
|
||||
|
||||
def track(self, scorer=0, bodypart=0, framerate=30):
|
||||
if isinstance(scorer, nb.Number):
|
||||
sc = self._scorer[scorer]
|
||||
elif isinstance(scorer, str) and scorer in self._scorer:
|
||||
sc = scorer
|
||||
else:
|
||||
raise ValueError(f"Scorer {scorer} is not in dataframe!")
|
||||
if isinstance(bodypart, nb.Number):
|
||||
bp = self._bodyparts[bodypart]
|
||||
elif isinstance(bodypart, str) and bodypart in self._bodyparts:
|
||||
bp = bodypart
|
||||
else:
|
||||
raise ValueError(f"Body part {bodypart} is not in dataframe!")
|
||||
|
||||
x = np.asarray(self._data_frame[sc][bp]["x"] if "x" in self._positions else [])
|
||||
y = np.asarray(self._data_frame[sc][bp]["y"] if "y" in self._positions else [])
|
||||
x, y = self._correct_cropping(x, y)
|
||||
l = np.asarray(self._data_frame[sc][bp]["likelihood"] if "likelihood" in self._positions else [])
|
||||
|
||||
time = np.arange(len(x))/framerate
|
||||
|
||||
return TrackingData(x, y, time, l, bp, fps=framerate)
|
||||
71
src/etrack/io/nixtrack_data.py
Normal file
71
src/etrack/io/nixtrack_data.py
Normal file
@@ -0,0 +1,71 @@
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import numbers as nb
|
||||
import nixtrack as nt
|
||||
|
||||
from ..tracking_data import TrackingData
|
||||
from IPython import embed
|
||||
|
||||
|
||||
class NixtrackData(object):
|
||||
|
||||
def __init__(self, filename, crop=(0, 0)) -> None:
|
||||
"""
|
||||
If the video data was cropped before tracking and the tracked positions are with respect to the cropped images, we may want to correct for this to convert the data back to absolute positions in the video frame.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
filename : str
|
||||
full filename
|
||||
crop : 2-tuple
|
||||
tuple of (xoffset, yoffset)
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError if crop value is not a 2-tuple
|
||||
"""
|
||||
if not os.path.exists(filename):
|
||||
raise ValueError("File %s does not exist!" % filename)
|
||||
if not isinstance(crop, tuple) or len(crop) < 2:
|
||||
raise ValueError("Cropping info must be a 2-tuple of (x, y)")
|
||||
self._file_name = filename
|
||||
self._crop = crop
|
||||
self._dataset = nt.Dataset(self._file_name)
|
||||
if not self._dataset.is_open:
|
||||
raise ValueError(f"An error occurred opening file {self._file_name}! File is not open!")
|
||||
|
||||
@property
|
||||
def filename(self):
|
||||
return self._file_name
|
||||
|
||||
@property
|
||||
def bodyparts(self):
|
||||
return self._dataset.nodes
|
||||
|
||||
def _correct_cropping(self, orgx, orgy):
|
||||
x = orgx + self._crop[0]
|
||||
y = orgy + self._crop[1]
|
||||
return x, y
|
||||
|
||||
@property
|
||||
def tracks(self):
|
||||
return self._dataset.tracks
|
||||
|
||||
def track(self, bodypart=0, fps=None):
|
||||
if isinstance(bodypart, nb.Number):
|
||||
bp = self.bodyparts[bodypart]
|
||||
elif isinstance(bodypart, (str)) and bodypart in self.bodyparts:
|
||||
bp = bodypart
|
||||
else:
|
||||
raise ValueError(f"Body part {bodypart} is not a tracked node!")
|
||||
if fps is None:
|
||||
fps = self._dataset.fps
|
||||
|
||||
positions, time, _, nscore = self._dataset.positions(node=bp, axis_type=nt.AxisType.Time)
|
||||
valid = ~np.isnan(positions[:, 0])
|
||||
positions = positions[valid,:]
|
||||
time = time[valid]
|
||||
score = nscore[valid]
|
||||
|
||||
return TrackingData(positions[:, 0], positions[:, 1], time, score, bp, fps=fps)
|
||||
219
src/etrack/tracking_data.py
Normal file
219
src/etrack/tracking_data.py
Normal file
@@ -0,0 +1,219 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
class TrackingData(object):
|
||||
"""Class that represents tracking data, i.e. positions of an agent tracked in an environment.
|
||||
These data are the x, and y-positions, the time at which the agent was detected, and the quality associated with the position estimation.
|
||||
TrackingData contains these data and offers a few functions to work with it.
|
||||
Using the 'quality_threshold', 'temporal_limits', or the 'position_limits' data can be filtered (see filter_tracks function).
|
||||
The 'interpolate' function allows to fill up the gaps that be result from filtering with linearly interpolated data points.
|
||||
|
||||
More may follow...
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
x,
|
||||
y,
|
||||
time,
|
||||
quality,
|
||||
node="",
|
||||
fps=None,
|
||||
quality_threshold=None,
|
||||
temporal_limits=None,
|
||||
position_limits=None,
|
||||
) -> None:
|
||||
self._orgx = x
|
||||
self._orgy = y
|
||||
self._orgtime = time
|
||||
self._orgquality = quality
|
||||
self._x = x
|
||||
self._y = y
|
||||
self._time = time
|
||||
self._quality = quality
|
||||
self._node = node
|
||||
self._threshold = quality_threshold
|
||||
self._position_limits = position_limits
|
||||
self._time_limits = temporal_limits
|
||||
self._fps = fps
|
||||
|
||||
@property
|
||||
def original_positions(self):
|
||||
return self._orgx, self._orgy
|
||||
|
||||
@property
|
||||
def original_quality(self):
|
||||
return self._orgquality
|
||||
|
||||
def interpolate(self, start_time=None, end_time=None, min_count=5):
|
||||
if len(self._x) < min_count:
|
||||
print(
|
||||
f"{self._node} data has less than {min_count} data points with sufficient quality ({len(self._x)})!"
|
||||
)
|
||||
return None, None, None
|
||||
start = self._time[0] if start_time is None else start_time
|
||||
end = self._time[-1] if end_time is None else end_time
|
||||
time = np.arange(start, end, 1.0 / self._fps)
|
||||
x = np.interp(time, self._time, self._x)
|
||||
y = np.interp(time, self._time, self._y)
|
||||
|
||||
return x, y, time
|
||||
|
||||
@property
|
||||
def quality_threshold(self):
|
||||
return self._threshold
|
||||
|
||||
@quality_threshold.setter
|
||||
def quality_threshold(self, new_threshold):
|
||||
"""Setter of the quality threshold that should be applied when filterin the data. Setting this to None removes the quality filter.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
new_threshold : float
|
||||
|
||||
"""
|
||||
self._threshold = new_threshold
|
||||
|
||||
@property
|
||||
def position_limits(self):
|
||||
return self._position_limits
|
||||
|
||||
@position_limits.setter
|
||||
def position_limits(self, new_limits):
|
||||
"""Sets the limits for the position filter. 'new_limits' must be a 4-tuple of the form (x0, y0, width, height). If None, the limits will be removed.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
new_limits: 4-tuple
|
||||
tuple of x-position, y-position, the width and the height. Passing None removes the filter
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError, if new_value is not a 4-tuple
|
||||
"""
|
||||
if new_limits is not None and not (
|
||||
isinstance(new_limits, (tuple, list)) and len(new_limits) == 4
|
||||
):
|
||||
raise ValueError(
|
||||
f"The new_limits vector must be a 4-tuple of the form (x, y, width, height)"
|
||||
)
|
||||
self._position_limits = new_limits
|
||||
|
||||
@property
|
||||
def temporal_limits(self):
|
||||
return self._time_limits
|
||||
|
||||
@temporal_limits.setter
|
||||
def temporal_limits(self, new_limits):
|
||||
"""Limits for temporal filter. The limits must be a 2-tuple of start and end time. Setting this to None removes the filter.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
new_limits : 2-tuple
|
||||
The new limits in the form (start, end) given in seconds.
|
||||
"""
|
||||
if new_limits is not None and not (
|
||||
isinstance(new_limits, (tuple, list)) and len(new_limits) == 2
|
||||
):
|
||||
raise ValueError(
|
||||
f"The new_limits vector must be a 2-tuple of the form (start, end). "
|
||||
)
|
||||
self._time_limits = new_limits
|
||||
|
||||
def filter_tracks(self, align_time=True):
|
||||
"""Applies the filters to the tracking data. All filters will be applied sequentially, i.e. an AND connection.
|
||||
To change the filter settings use the setters for 'quality_threshold', 'temporal_limits', 'position_limits'. Setting them to None disables the respective filter discarding the setting.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
align_time: bool
|
||||
Controls whether the time vector is aligned to the first time point at which the agent is within the positional_limits. Default = True
|
||||
"""
|
||||
self._x = self._orgx.copy()
|
||||
self._y = self._orgy.copy()
|
||||
self._time = self._orgtime.copy()
|
||||
self._quality = self.original_quality.copy()
|
||||
|
||||
if self.position_limits is not None:
|
||||
x_max = self.position_limits[0] + self.position_limits[2]
|
||||
y_max = self.position_limits[1] + self.position_limits[3]
|
||||
indices = np.where(
|
||||
(self._x >= self.position_limits[0])
|
||||
& (self._x < x_max)
|
||||
& (self._y >= self.position_limits[1])
|
||||
& (self._y < y_max)
|
||||
)
|
||||
self._x = self._x[indices]
|
||||
self._y = self._y[indices]
|
||||
self._time = self._time[indices] - self._time[0] if align_time else 0.0
|
||||
self._quality = self._quality[indices]
|
||||
|
||||
if self.temporal_limits is not None:
|
||||
indices = np.where(
|
||||
(self._time >= self.temporal_limits[0])
|
||||
& (self._time < self.temporal_limits[1])
|
||||
)
|
||||
self._x = self._x[indices]
|
||||
self._y = self._y[indices]
|
||||
self._time = self._time[indices]
|
||||
self._quality = self._quality[indices]
|
||||
|
||||
if self.quality_threshold is not None:
|
||||
indices = np.where((self._quality >= self.quality_threshold))
|
||||
self._x = self._x[indices]
|
||||
self._y = self._y[indices]
|
||||
self._time = self._time[indices]
|
||||
self._quality = self._quality[indices]
|
||||
|
||||
def positions(self):
|
||||
"""Returns the filtered data (if filters have been applied).
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
The x-positions
|
||||
np.ndarray
|
||||
The y-positions
|
||||
np.ndarray
|
||||
The time vector
|
||||
np.ndarray
|
||||
The detection quality
|
||||
"""
|
||||
return self._x, self._y, self._time, self._quality
|
||||
|
||||
def speed(self, x=None, y=None, t=None):
|
||||
""" Returns the agent's speed as a function of time and position. The speed estimation is associated to the time/position between two sample points. If any of the arguments is not provided, the function will use the x,y coordinates that are stored within the object, otherwise, if all are provided, the user-provided values will be used.
|
||||
|
||||
Since the velocities are estimated from the difference between two sample points the returned velocities are assigned to positions and times between the respective sampled positions/times.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x: np.ndarray
|
||||
The x-coordinates, defaults to None
|
||||
y: np.ndarray
|
||||
The y-coordinates, defaults to None
|
||||
t: np.ndarray
|
||||
The time vector, defaults to None
|
||||
Returns
|
||||
-------
|
||||
np.ndarray:
|
||||
The time vector.
|
||||
np.ndarray:
|
||||
The speed.
|
||||
tuple of np.ndarray
|
||||
The position
|
||||
"""
|
||||
if x is None or y is None or t is None:
|
||||
x = self._x
|
||||
y = self._y
|
||||
t = self._time
|
||||
speed = np.sqrt(np.diff(x)**2 + np.diff(y)**2) / np.diff(t)
|
||||
t = t[:-1] + np.diff(t) / 2
|
||||
x = x[:-1] + np.diff(x) / 2
|
||||
y = y[:-1] + np.diff(y) / 2
|
||||
|
||||
return t, speed, (x, y)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
s = f"Tracking data of node '{self._node}'!"
|
||||
return s
|
||||
197
src/etrack/tracking_result.py
Normal file
197
src/etrack/tracking_result.py
Normal file
@@ -0,0 +1,197 @@
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import numbers as nb
|
||||
import os
|
||||
|
||||
"""
|
||||
x_0 = 0
|
||||
width = 1230
|
||||
y_0 = 0
|
||||
height = 1100
|
||||
x_factor = 0.81/width # Einheit m/px
|
||||
y_factor = 0.81/height # Einheit m/px
|
||||
center = (np.round(x_0 + width/2), np.round(y_0 + height/2))
|
||||
center_meter = ((center[0] - x_0) * x_factor, (center[1] - y_0) * y_factor)
|
||||
"""
|
||||
def coordinate_transformation(position,x_0, y_0, x_factor, y_factor):
|
||||
x = (position[0] - x_0) * x_factor
|
||||
y = (position[1] - y_0) * y_factor
|
||||
return (x, y) #in m
|
||||
|
||||
class TrackingResult(object):
|
||||
|
||||
def __init__(self, results_file, x_0=0, y_0= 0, width_pixel=1230, height_pixel=1100, width_meter=0.81, height_meter=0.81) -> None:
|
||||
super().__init__()
|
||||
if not os.path.exists(results_file):
|
||||
raise ValueError("File %s does not exist!" % results_file)
|
||||
self._file_name = results_file
|
||||
self.x_0 = x_0
|
||||
self.y_0 = y_0
|
||||
self.width_pix = width_pixel
|
||||
self.width_m = width_meter
|
||||
self.height_pix = height_pixel
|
||||
self.height_m = height_meter
|
||||
self.x_factor = self.width_m / self.width_pix # m/pix
|
||||
self.y_factor = self.height_m / self.height_pix # m/pix
|
||||
|
||||
self.center = (np.round(self.x_0 + self.width_pix/2), np.round(self.y_0 + self.height_pix/2))
|
||||
self.center_meter = ((self.center[0] - self.x_0) * self.x_factor, (self.center[1] - self.y_0) * self.y_factor)
|
||||
|
||||
self._data_frame = pd.read_hdf(results_file)
|
||||
self._level_shape = self._data_frame.columns.levshape
|
||||
self._scorer = self._data_frame.columns.levels[0].values
|
||||
self._bodyparts = self._data_frame.columns.levels[1].values if self._level_shape[1] > 0 else []
|
||||
self._positions = self._data_frame.columns.levels[2].values if self._level_shape[2] > 0 else []
|
||||
|
||||
def angle_to_center(self, bodypart=0, twopi=True, origin="topleft", min_likelihood=0.95):
|
||||
if isinstance(bodypart, nb.Number):
|
||||
bp = self._bodyparts[bodypart]
|
||||
elif isinstance(bodypart, str) and bodypart in self._bodyparts:
|
||||
bp = bodypart
|
||||
else:
|
||||
raise ValueError("Bodypart %s is not in dataframe!" % bodypart)
|
||||
_, x, y, _, _ = self.position_values(bodypart=bp, min_likelihood=min_likelihood)
|
||||
if x is None:
|
||||
print("Error: no valid angles for %s" % self._file_name)
|
||||
return []
|
||||
x_meter = x - self.center_meter[0]
|
||||
y_meter = y - self.center_meter[1]
|
||||
if origin.lower() == "topleft":
|
||||
y_meter *= -1
|
||||
phi = np.arctan2(y_meter, x_meter) * 180 / np.pi
|
||||
if twopi:
|
||||
phi[phi < 0] = 360 + phi[phi < 0]
|
||||
return phi
|
||||
|
||||
def coordinate_transformation(self, position):
|
||||
x = (position[0] - self.x_0) * self.x_factor
|
||||
y = (position[1] - self.y_0) * self.y_factor
|
||||
return (x, y) #in m
|
||||
|
||||
@property
|
||||
def filename(self):
|
||||
return self._file_name
|
||||
|
||||
@property
|
||||
def dataframe(self):
|
||||
return self._data_frame
|
||||
|
||||
@property
|
||||
def scorer(self):
|
||||
return self._scorer
|
||||
|
||||
@property
|
||||
def bodyparts(self):
|
||||
return self._bodyparts
|
||||
|
||||
@property
|
||||
def positions(self):
|
||||
return self._positions
|
||||
|
||||
def position_values(self, scorer=0, bodypart=0, framerate=30, interpolate=True, min_likelihood=0.95):
|
||||
"""returns the x and y positions in m and the likelihood of the positions.
|
||||
|
||||
Args:
|
||||
scorer (int, optional): [description]. Defaults to 0.
|
||||
bodypart (int, optional): [description]. Defaults to 0.
|
||||
framerate (int, optional): [description]. Defaults to 30.
|
||||
|
||||
Raises:
|
||||
ValueError: [description]
|
||||
ValueError: [description]
|
||||
|
||||
Returns:
|
||||
time [np.array]: the time axis
|
||||
x [np.array]: the x-position in m
|
||||
y [np.array]: the y-position in m
|
||||
l [np.array]: the likelihood of the position estimation
|
||||
bp string: the body part
|
||||
[type]: [description]
|
||||
"""
|
||||
time, x, y, l, bp = self.pixel_positions(scorer, bodypart, framerate, interpolate, min_likelihood)
|
||||
x, y = self._to_meter(x, y)
|
||||
return time, x, y, l, bp
|
||||
|
||||
def pixel_positions(self, scorer=0, bodypart=0, framerate=30, interpolate=True, min_likelihood=0.95):
|
||||
if isinstance(scorer, nb.Number):
|
||||
sc = self._scorer[scorer]
|
||||
elif isinstance(scorer, str) and scorer in self._scorer:
|
||||
sc = scorer
|
||||
else:
|
||||
raise ValueError("Scorer %s is not in dataframe!" % scorer)
|
||||
if isinstance(bodypart, nb.Number):
|
||||
bp = self._bodyparts[bodypart]
|
||||
elif isinstance(bodypart, str) and bodypart in self._bodyparts:
|
||||
bp = bodypart
|
||||
else:
|
||||
raise ValueError("Bodypart %s is not in dataframe!" % bodypart)
|
||||
|
||||
x = np.asarray(self._data_frame[sc][bp]["x"] if "x" in self._positions else [])
|
||||
y = np.asarray(self._data_frame[sc][bp]["y"] if "y" in self._positions else [])
|
||||
l = np.asarray(self._data_frame[sc][bp]["likelihood"] if "likelihood" in self._positions else [])
|
||||
|
||||
time = np.arange(len(x))/framerate
|
||||
if interpolate:
|
||||
x, y = self.interpolate(time, x, y, l, min_likelihood)
|
||||
return time, x, y, l, bp
|
||||
|
||||
def _to_meter(self, x, y):
|
||||
new_x = (np.asarray(x) - self.x_0) * self.x_factor
|
||||
new_y = (np.asarray(y) - self.y_0) * self.y_factor
|
||||
return new_x, new_y
|
||||
|
||||
def _speed(self, t, x, y):
|
||||
speed = np.sqrt(np.diff(x)**2 + np.diff(y)**2) / np.diff(t)
|
||||
return speed
|
||||
|
||||
def interpolate(self, t, x, y, l, min_likelihood=0.9):
|
||||
time2 = t[l > min_likelihood]
|
||||
if len(l[l > min_likelihood]) < 10:
|
||||
print("%s has less than 10 datapoints with likelihood larger than %.2f" % (self._file_name, min_likelihood) )
|
||||
return None, None
|
||||
x2 = x[l > min_likelihood]
|
||||
y2 = y[l > min_likelihood]
|
||||
x3 = np.interp(t, time2, x2)
|
||||
y3 = np.interp(t, time2, y2)
|
||||
return x3, y3
|
||||
|
||||
def plot(self, scorer=0, bodypart=0, threshold=0.9, framerate=30):
|
||||
t, x, y, l, name = self.position_values(scorer=scorer, bodypart=bodypart, framerate=framerate, min_likelihood=threshold)
|
||||
plt.scatter(x[l > threshold], y[l > threshold], c=t[l > threshold], label=name)
|
||||
plt.scatter(self.center_meter[0], self.center_meter[1], marker="*")
|
||||
plt.plot(x[l > threshold], y[l > threshold])
|
||||
plt.xlabel("x position")
|
||||
plt.ylabel("y position")
|
||||
plt.gca().invert_yaxis()
|
||||
bar = plt.colorbar()
|
||||
bar.set_label("time [s]")
|
||||
plt.legend()
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
from IPython import embed
|
||||
filename = "2020.12.04_lepto48DLC_resnet50_boldnessDec11shuffle1_200000.h5"
|
||||
path = "/mnt/movies/merle_verena/boldness/labeled_videos/day_4/"
|
||||
tr = TrackingResult(path+filename)
|
||||
time, x, y, l, bp = tr.position_values(bodypart=2)
|
||||
|
||||
|
||||
thresh = 0.95
|
||||
time2 = time[l>thresh]
|
||||
x2 = x[l>thresh]
|
||||
y2 = y[l>thresh]
|
||||
x3 = np.interp(time, time2, x2)
|
||||
y3 = np.interp(time, time2, y2)
|
||||
|
||||
fig, axes = plt.subplots(3,1, sharex=True)
|
||||
axes[0].plot(time, x)
|
||||
axes[0].plot(time, x3)
|
||||
axes[1].plot(time, y)
|
||||
axes[1].plot(time, y3)
|
||||
|
||||
axes[2].plot(time, l)
|
||||
plt.show()
|
||||
|
||||
embed()
|
||||
34
src/etrack/util.py
Normal file
34
src/etrack/util.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from enum import Enum
|
||||
|
||||
class Illumination(Enum):
|
||||
Backlight = 0
|
||||
Incident = 1
|
||||
|
||||
|
||||
class RegionShape(Enum):
|
||||
"""
|
||||
Enumeration representing the shape of a region.
|
||||
|
||||
Attributes:
|
||||
Circular: Represents a circular region.
|
||||
Rectangular: Represents a rectangular region.
|
||||
"""
|
||||
|
||||
Circular = 0
|
||||
Rectangular = 1
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.name
|
||||
|
||||
|
||||
class AnalysisType(Enum):
|
||||
Full = 0
|
||||
CollapseX = 1
|
||||
CollapseY = 2
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.name
|
||||
|
||||
class PositionType(Enum):
|
||||
Absolute = 0
|
||||
Cropped = 1
|
||||
Reference in New Issue
Block a user