forked from jgrewe/efish_tracking
Compare commits
23 Commits
Author | SHA1 | Date | |
---|---|---|---|
e3b5d2d6cc | |||
94fa5e3d14 | |||
cbd0541a54 | |||
1dd318f23e | |||
32c0a65c58 | |||
b3ba30ced6 | |||
469a35724d | |||
bf8635d2fd | |||
2bba750e1f | |||
0291ef088a | |||
6dd4a4f5de | |||
701cda1069 | |||
3e1cbe4b9b | |||
6f9633a74e | |||
f56e21d9b1 | |||
30a035f82d | |||
9046e70592 | |||
6487cb07ff | |||
e5c9653bdd | |||
2593f21f3a | |||
ae277ce8fb | |||
e854ab591f | |||
16873702d4 |
1
.gitignore
vendored
1
.gitignore
vendored
@ -5,3 +5,4 @@ requires.txt
|
||||
SOURCES.txt
|
||||
dependency_links.txt
|
||||
top_level.txt
|
||||
.DS_Store
|
74
build_docs.sh
Executable file
74
build_docs.sh
Executable file
@ -0,0 +1,74 @@
|
||||
#!/bin/bash
|
||||
|
||||
die() { echo "ERROR: $*"; exit 2; }
|
||||
warn() { echo "WARNING: $*"; }
|
||||
|
||||
for cmd in mkdocs pdoc3 genbadge; do
|
||||
command -v "$cmd" >/dev/null ||
|
||||
warn "missing $cmd: run \`pip install $cmd\`"
|
||||
done
|
||||
|
||||
PACKAGE="etrack"
|
||||
PACKAGESRC="src/$PACKAGE"
|
||||
PACKAGEROOT="$(dirname "$(realpath "$0")")"
|
||||
BUILDROOT="$PACKAGEROOT/site"
|
||||
|
||||
# check for code coverage report:
|
||||
# need to call nosetest with --with-coverage --cover-html --cover-xml
|
||||
HAS_COVER=false
|
||||
test -d cover && HAS_COVER=true
|
||||
|
||||
echo
|
||||
echo "Clean up documentation of $PACKAGE"
|
||||
echo
|
||||
|
||||
rm -rf "$BUILDROOT" 2> /dev/null || true
|
||||
mkdir -p "$BUILDROOT"
|
||||
|
||||
if command -v mkdocs >/dev/null; then
|
||||
echo
|
||||
echo "Building general documentation for $PACKAGE"
|
||||
echo
|
||||
|
||||
cd "$PACKAGEROOT"
|
||||
cp .mkdocs.yml mkdocs-tmp.yml
|
||||
if $HAS_COVER; then
|
||||
echo " - Coverage: 'cover/index.html'" >> mkdocs-tmp.yml
|
||||
fi
|
||||
mkdir -p docs
|
||||
sed -e 's|docs/||; /\[Documentation\]/d; /\[API Reference\]/d' README.md > docs/index.md
|
||||
mkdocs build --config-file mkdocs.yml --site-dir "$BUILDROOT"
|
||||
rm mkdocs-tmp.yml docs/index.md
|
||||
cd - > /dev/null
|
||||
fi
|
||||
|
||||
if $HAS_COVER; then
|
||||
echo
|
||||
echo "Copy code coverage report and generate badge for $PACKAGE"
|
||||
echo
|
||||
|
||||
cd "$PACKAGEROOT"
|
||||
cp -r cover "$BUILDROOT/"
|
||||
genbadge coverage -i coverage.xml
|
||||
# https://smarie.github.io/python-genbadge/
|
||||
mv coverage-badge.svg site/coverage.svg
|
||||
cd - > /dev/null
|
||||
fi
|
||||
|
||||
if command -v pdoc3 >/dev/null; then
|
||||
echo
|
||||
echo "Building API reference docs for $PACKAGE"
|
||||
echo
|
||||
|
||||
cd "$PACKAGEROOT"
|
||||
pdoc3 --html --config latex_math=True --config sort_identifiers=False --output-dir "$BUILDROOT/api-tmp" $PACKAGESRC
|
||||
mv "$BUILDROOT/api-tmp/$PACKAGE" "$BUILDROOT/api"
|
||||
rmdir "$BUILDROOT/api-tmp"
|
||||
cd - > /dev/null
|
||||
fi
|
||||
|
||||
echo
|
||||
echo "Done. Docs in:"
|
||||
echo
|
||||
echo " file://$BUILDROOT/index.html"
|
||||
echo
|
35
docs/etrack.md
Normal file
35
docs/etrack.md
Normal file
@ -0,0 +1,35 @@
|
||||
# E-Fish tracking
|
||||
|
||||
Tool for easier handling of tracking results.
|
||||
|
||||
## Installation
|
||||
|
||||
### 1. Clone git repository
|
||||
|
||||
```shell
|
||||
git clone https://whale.am28.uni-tuebingen.de/git/jgrewe/efish_tracking.git
|
||||
```
|
||||
|
||||
### 2. Change into directory
|
||||
|
||||
```shell
|
||||
cd efish_tracking
|
||||
````
|
||||
|
||||
### 3. Install with pip
|
||||
|
||||
```shell
|
||||
pip3 install -e . --user
|
||||
```
|
||||
|
||||
The ```-e``` installs the package in an *editable* model that you do not need to reinstall whenever you pull upstream changes.
|
||||
|
||||
If you leave away the ```--user``` the package will be installed system-wide.
|
||||
|
||||
## TrackingResults
|
||||
|
||||
Is a class that wraps around the *.h5 files written by DeepLabCut
|
||||
|
||||
## ImageMarker
|
||||
|
||||
Class that allows for creating MarkerTasks to get specific positions in a video.
|
4
docs/trackingdata.md
Normal file
4
docs/trackingdata.md
Normal file
@ -0,0 +1,4 @@
|
||||
# TrackingData
|
||||
|
||||
Class that represents the position data associated with one noe/bodypart.
|
||||
|
@ -1,3 +0,0 @@
|
||||
from .image_marker import ImageMarker, MarkerTask
|
||||
from .tracking_result import TrackingResult
|
||||
from .distance_calibration import DistanceCalibration
|
@ -1,193 +0,0 @@
|
||||
from turtle import left
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from IPython import embed
|
||||
from etrack import MarkerTask, ImageMarker
|
||||
|
||||
|
||||
def mark_crop_positions(self):
|
||||
task = MarkerTask("crop area", ["bottom left corner", "top left corner", "top right corner", "bottom right corner"], "Mark crop area")
|
||||
im = ImageMarker([task])
|
||||
|
||||
marker_positions = im.mark_movie(file_name, frame_number)
|
||||
print(marker_positions)
|
||||
|
||||
np.save('marker_positions', marker_positions)
|
||||
|
||||
return marker_positions
|
||||
|
||||
|
||||
def assign_marker_positions(marker_positions):
|
||||
bottom_left_x = marker_positions[0]['bottom left corner'][0]
|
||||
bottom_left_y = marker_positions[0]['bottom left corner'][1]
|
||||
bottom_right_x = marker_positions[0]['bottom right corner'][0]
|
||||
bottom_right_y = marker_positions[0]['bottom right corner'][1]
|
||||
top_left_x = marker_positions[0]['top left corner'][0]
|
||||
top_left_y = marker_positions[0]['top left corner'][1]
|
||||
top_right_x = marker_positions[0]['top right corner'][0]
|
||||
top_right_y = marker_positions[0]['top right corner'][1]
|
||||
return bottom_left_x, bottom_left_y, bottom_right_x, bottom_right_y, top_left_x, top_left_y, top_right_x, top_right_y
|
||||
|
||||
|
||||
def assign_checkerboard_positions(checkerboard_marker_positions):
|
||||
checkerboard_top_right = checkerboard_marker_positions[0]['top right corner']
|
||||
checkerboard_top_left = checkerboard_marker_positions[0]['top left corner']
|
||||
checkerboard_bottom_right = checkerboard_marker_positions[0]['bottom right corner']
|
||||
checkerboard_bottom_left = checkerboard_marker_positions[0]['bottom left corner']
|
||||
return checkerboard_top_right, checkerboard_top_left, checkerboard_bottom_right, checkerboard_bottom_left
|
||||
|
||||
|
||||
def crop_frame(frame, marker_positions):
|
||||
|
||||
# load the four marker positions
|
||||
bottom_left_x, bottom_left_y, bottom_right_x, bottom_right_y, top_left_x, top_left_y, top_right_x, top_right_y = assign_marker_positions(marker_positions)
|
||||
|
||||
# define boundaries of frame, taken by average of points on same line but slightly different pixel values
|
||||
left_bound = int(np.mean([bottom_left_x, top_left_x]))
|
||||
right_bound = int(np.mean([bottom_right_x, top_right_x]))
|
||||
top_bound = int(np.mean([top_left_y, top_right_y]))
|
||||
bottom_bound = int(np.mean([bottom_left_y, bottom_right_y]))
|
||||
|
||||
# crop the frame by boundary values
|
||||
cropped_frame = frame[top_bound:bottom_bound, left_bound:right_bound]
|
||||
cropped_frame = np.mean(cropped_frame, axis=2) # mean over 3rd dimension (RGB/color values)
|
||||
|
||||
# mean over short or long side of the frame corresponding to x or y axis of picture
|
||||
frame_width = np.mean(cropped_frame,axis=0)
|
||||
frame_height = np.mean(cropped_frame,axis=1)
|
||||
|
||||
# differences of color values lying next to each other --> derivation
|
||||
diff_width = np.diff(frame_width)
|
||||
diff_height = np.diff(frame_height)
|
||||
|
||||
# two x vectors for better plotting
|
||||
x_width = np.arange(0, len(diff_width), 1)
|
||||
x_height = np.arange(0, len(diff_height), 1)
|
||||
|
||||
return cropped_frame, frame_width, frame_height, diff_width, diff_height, x_width, x_height
|
||||
|
||||
def rotation_angle():
|
||||
pass
|
||||
|
||||
|
||||
def threshold_crossings(data, threshold_factor):
|
||||
# upper and lower threshold
|
||||
median_data = np.median(data)
|
||||
median_lower = median_data + np.min(data)
|
||||
median_upper = np.max(data) - median_data
|
||||
lower_threshold = median_lower / threshold_factor
|
||||
upper_threshold = median_upper / threshold_factor
|
||||
|
||||
# array with values if data >/< than threshold = True or not
|
||||
lower_crossings = np.diff(data < lower_threshold, prepend=False) # prepend: point after crossing
|
||||
upper_crossings = np.diff(data > upper_threshold, append=False) # append: point before crossing
|
||||
|
||||
# indices where crossings are
|
||||
lower_crossings_indices = np.argwhere(lower_crossings)
|
||||
upper_crossings_indices = np.argwhere(upper_crossings)
|
||||
|
||||
# sort out several crossings of same edge of checkerboard (due to noise)
|
||||
half_window_size = 10
|
||||
lower_peaks = []
|
||||
upper_peaks = []
|
||||
for lower_idx in lower_crossings_indices: # for every lower crossing..
|
||||
if lower_idx < half_window_size: # ..if indice smaller than window size near indice 0
|
||||
half_window_size = lower_idx
|
||||
lower_window = data[lower_idx[0] - int(half_window_size):lower_idx[0] + int(half_window_size)] # create data window from -window_size to +window_size
|
||||
min_window = np.min(lower_window) # take minimum of window
|
||||
min_idx = np.where(data == min_window) # find indice where minimum is
|
||||
|
||||
lower_peaks.append(min_idx) # append to list
|
||||
for upper_idx in upper_crossings_indices: # same for upper crossings with max of window
|
||||
if upper_idx < half_window_size:
|
||||
half_window_size = upper_idx
|
||||
upper_window = data[upper_idx[0] - int(half_window_size) : upper_idx[0] + int(half_window_size)]
|
||||
|
||||
max_window = np.max(upper_window)
|
||||
max_idx = np.where(data == max_window)
|
||||
upper_peaks.append(max_idx)
|
||||
|
||||
# if several crossings create same peaks due to overlapping windows, only one (unique) will be taken
|
||||
lower_peaks = np.unique(lower_peaks)
|
||||
upper_peaks = np.unique(upper_peaks)
|
||||
|
||||
return lower_peaks, upper_peaks
|
||||
|
||||
|
||||
def checkerboard_position(lower_crossings_indices, upper_crossings_indices):
|
||||
"""Take crossing positions to generate a characteristic sequence for a corresponding position of the checkerboard inside the frame.
|
||||
Positional description has to be interpreted depending on the input data.
|
||||
|
||||
Args:
|
||||
lower_crossings_indices: Indices where lower threshold was crossed by derivation data.
|
||||
upper_crossings_indices: Indices where upper threshold was crossed by derivation data
|
||||
|
||||
Returns:
|
||||
checkerboard_position: General position where the checkerboard lays inside the frame along the axis of the input data.
|
||||
"""
|
||||
|
||||
# create zipped list with both indices
|
||||
zip_list = []
|
||||
for zl in lower_crossings_indices:
|
||||
zip_list.append(zl)
|
||||
for zu in upper_crossings_indices:
|
||||
zip_list.append(zu)
|
||||
|
||||
zip_list = np.sort(zip_list) # order by indice
|
||||
|
||||
# compare and assign zipped list to original indices lists and corresponding direction (to upper or lower threshold)
|
||||
sequence = []
|
||||
for z in zip_list:
|
||||
if z in lower_crossings_indices:
|
||||
sequence.append('down')
|
||||
else:
|
||||
sequence.append('up')
|
||||
print('sequence:', sequence)
|
||||
|
||||
# depending on order of crossings through upper or lower treshold, we get a characteristic sequence for a position of the checkerboard in the frame
|
||||
if sequence == ['up', 'down', 'up', 'down']: # first down, second up are edges of checkerboard
|
||||
print('in middle')
|
||||
checkerboard_position = 'middle'
|
||||
left_checkerboard_edge = zip_list[1]
|
||||
right_checkerboard_edge = zip_list[2]
|
||||
elif sequence == ['up', 'up', 'down']: # first and second up are edges of checkerboard
|
||||
print('at left')
|
||||
checkerboard_position = 'left'
|
||||
left_checkerboard_edge = zip_list[0]
|
||||
right_checkerboard_edge = zip_list[1]
|
||||
else: # first and second down are edges of checkerboard
|
||||
print('at right')
|
||||
checkerboard_position = 'right'
|
||||
left_checkerboard_edge = zip_list[1]
|
||||
right_checkerboard_edge = zip_list[2]
|
||||
|
||||
return checkerboard_position, left_checkerboard_edge, right_checkerboard_edge # position of checkerboard then will be returned
|
||||
|
||||
|
||||
def filter_data(data, n):
|
||||
"""Filter/smooth data with kernel of length n.
|
||||
|
||||
Args:
|
||||
data: Raw data.
|
||||
n: Number of datapoints the mean gets computed over.
|
||||
|
||||
Returns:
|
||||
filtered_data: Filtered data.
|
||||
"""
|
||||
new_data = np.zeros(len(data)) # empty vector where data will be put in in the following steps
|
||||
for k in np.arange(0, len(data) - n):
|
||||
kk = int(k)
|
||||
f = np.mean(data[kk:kk+n]) # mean over data over window from kk to kk+n
|
||||
kkk = int(kk+n / 2) # position where mean datapoint will be placed (so to say)
|
||||
if k == 0:
|
||||
new_data[:kkk] = f
|
||||
new_data[kkk] = f # assignment of value to datapoint
|
||||
new_data[kkk:] = f
|
||||
for nd in new_data[0:n-1]: # correction of left boundary effects (boundaries up to length of n were same number)
|
||||
nd_idx = np.argwhere(nd)
|
||||
new_data[nd_idx] = data[nd_idx]
|
||||
for nd in new_data[-1 - (n-1):-1]: # same as above, correction of right boundary effect
|
||||
nd_idx = np.argwhere(nd)
|
||||
new_data[nd_idx] = data[nd_idx]
|
||||
|
||||
return new_data
|
@ -1,258 +0,0 @@
|
||||
from multiprocessing import allow_connection_pickling
|
||||
from turtle import left
|
||||
from xml.dom.expatbuilder import FILTER_ACCEPT
|
||||
from cv2 import MARKER_TRIANGLE_UP, calibrationMatrixValues, mean, threshold
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import cv2
|
||||
import os
|
||||
import sys
|
||||
import glob
|
||||
from IPython import embed
|
||||
from calibration_functions import *
|
||||
|
||||
|
||||
|
||||
class DistanceCalibration():
|
||||
|
||||
def __init__(self, file_name, frame_number, x_0=154, y_0=1318, cam_dist=1.36, tank_width=1.35, tank_height=0.805, width_pixel=1900, height_pixel=200,
|
||||
checkerboard_width=0.24, checkerboard_height=0.18, checkerboard_width_pixel=500, checkerboard_height_pixel=350) -> None:
|
||||
super().__init__()
|
||||
|
||||
self._file_name = file_name
|
||||
self._x_0 = x_0
|
||||
self._y_0 = y_0
|
||||
self._width_pix = width_pixel
|
||||
self._height_pix = height_pixel
|
||||
self._cam_dist = cam_dist
|
||||
self._tank_width = tank_width
|
||||
self._tank_height = tank_height
|
||||
self._cb_width = checkerboard_width
|
||||
self._cb_height = checkerboard_height
|
||||
self._cb_width_pix = checkerboard_width_pixel
|
||||
self._cb_height_pix = checkerboard_height_pixel
|
||||
self._x_factor = tank_width / width_pixel # m/pix
|
||||
self._y_factor = tank_height / height_pixel # m/pix
|
||||
|
||||
self.distance_factor_calculation
|
||||
self.mark_crop_positions
|
||||
|
||||
# if needed include setter: @y_0.setter def y_0(self, value): self._y_0 = value
|
||||
@property
|
||||
def x_0(self):
|
||||
return self._x_0
|
||||
|
||||
@property
|
||||
def y_0(self):
|
||||
return self._y_0
|
||||
|
||||
@property
|
||||
def cam_dist(self):
|
||||
return self._cam_dist
|
||||
|
||||
@property
|
||||
def width(self):
|
||||
return self._width
|
||||
|
||||
@property
|
||||
def height(self):
|
||||
return self._height
|
||||
|
||||
@property
|
||||
def width_pix(self):
|
||||
return self._width_pix
|
||||
|
||||
@property
|
||||
def height_pix(self):
|
||||
return self._height_pix
|
||||
|
||||
@property
|
||||
def cb_width(self):
|
||||
return self._cb_width
|
||||
|
||||
@property
|
||||
def cb_height(self):
|
||||
return self._cb_height
|
||||
|
||||
@property
|
||||
def x_factor(self):
|
||||
return self._x_factor
|
||||
|
||||
@property
|
||||
def y_factor(self):
|
||||
return self._y_factor
|
||||
|
||||
|
||||
def mark_crop_positions(self):
|
||||
task = MarkerTask("crop area", ["bottom left corner", "top left corner", "top right corner", "bottom right corner"], "Mark crop area")
|
||||
im = ImageMarker([task])
|
||||
|
||||
marker_positions = im.mark_movie(file_name, frame_number)
|
||||
print(marker_positions)
|
||||
|
||||
np.save('marker_positions', marker_positions)
|
||||
|
||||
return marker_positions
|
||||
|
||||
|
||||
def detect_checkerboard(self, filename, frame_number, marker_positions):
|
||||
# load frame
|
||||
if not os.path.exists(filename):
|
||||
raise IOError("file %s does not exist!" % filename)
|
||||
video = cv2.VideoCapture()
|
||||
video.open(filename)
|
||||
frame_counter = 0
|
||||
success = True
|
||||
frame = None
|
||||
|
||||
while success and frame_counter <= frame_number: # iterating until frame_counter == frame_number --> success (True)
|
||||
print("Reading frame: %i" % frame_counter, end="\r")
|
||||
success, frame = video.read()
|
||||
frame_counter += 1
|
||||
|
||||
marker_positions = np.load('marker_positions.npy', allow_pickle=True) # load saved numpy marker positions file
|
||||
|
||||
# care: y-axis is inverted, top values are low, bottom values are high
|
||||
|
||||
cropped_frame, frame_width, frame_height, diff_width, diff_height, _, _ = crop_frame(frame, marker_positions) # crop frame to given marker positions
|
||||
|
||||
bottom_left_x = 0
|
||||
bottom_left_y = np.shape(cropped_frame)[0]
|
||||
bottom_right_x = np.shape(cropped_frame)[1]
|
||||
bottom_right_y = np.shape(cropped_frame)[0]
|
||||
top_left_x = 0
|
||||
top_left_y = 0
|
||||
top_right_x = np.shape(cropped_frame)[1]
|
||||
top_right_y = 0
|
||||
|
||||
cropped_marker_positions = [{'bottom left corner': (bottom_left_x, bottom_left_y), 'top left corner': (top_left_x, top_left_y),
|
||||
'top right corner': (top_right_x, top_right_y), 'bottom right corner': (bottom_right_x, bottom_right_y)}]
|
||||
|
||||
thresh_fact = 7 # factor by which the min/max is divided to calculate the upper and lower thresholds
|
||||
|
||||
# filtering/smoothing of data using kernel with n datapoints
|
||||
kernel = 4
|
||||
diff_width = filter_data(diff_width, n=kernel) # for widht (x-axis)
|
||||
diff_height = filter_data(diff_height, n=kernel) # for height (y-axis)
|
||||
|
||||
# input data is derivation of color values of frame
|
||||
lci_width, uci_width = threshold_crossings(diff_width, threshold_factor=thresh_fact) # threshold crossings (=edges of checkerboard) for width (x-axis)
|
||||
lci_height, uci_height = threshold_crossings(diff_height, threshold_factor=thresh_fact) # ..for height (y-axis)
|
||||
|
||||
print('lower crossings:', lci_width)
|
||||
print('upper crossings:', uci_width)
|
||||
|
||||
# position of checkerboard in width
|
||||
print('width..')
|
||||
width_position, left_width_position, right_width_position = checkerboard_position(lci_width, uci_width)
|
||||
|
||||
# position of checkerboard in height
|
||||
print('height..')
|
||||
height_position, left_height_position, right_height_position = checkerboard_position(lci_height, uci_height) # left height refers to top, right height to bottom
|
||||
|
||||
if width_position == 'left' and height_position == 'left':
|
||||
checkerboard_position_tank = 'top left'
|
||||
elif width_position == 'left' and height_position == 'right':
|
||||
checkerboard_position_tank = 'bottom left'
|
||||
elif width_position == 'right' and height_position == 'right':
|
||||
checkerboard_position_tank = 'bottom right'
|
||||
elif width_position == 'right' and height_position == 'left':
|
||||
checkerboard_position_tank = 'top right'
|
||||
else:
|
||||
checkerboard_position_tank = 'middle'
|
||||
|
||||
print(checkerboard_position_tank)
|
||||
|
||||
# final corner positions of checkerboard
|
||||
checkerboard_marker_positions = [{'bottom left corner': (left_width_position, right_height_position), 'top left corner': (left_width_position, left_height_position),
|
||||
'top right corner': (right_width_position, left_height_position), 'bottom right corner': (right_width_position, right_height_position)}]
|
||||
|
||||
print(checkerboard_marker_positions)
|
||||
|
||||
checkerboard_top_right, checkerboard_top_left, checkerboard_bottom_right, checkerboard_bottom_left = assign_checkerboard_positions(checkerboard_marker_positions)
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
ax.imshow(cropped_frame)
|
||||
for p in checkerboard_top_left, checkerboard_top_right, checkerboard_bottom_left, checkerboard_bottom_right:
|
||||
ax.scatter(p[0], p[1])
|
||||
ax.scatter(bottom_left_x, bottom_left_y)
|
||||
ax.scatter(bottom_right_x, bottom_right_y)
|
||||
ax.scatter(top_left_x, top_left_y)
|
||||
ax.scatter(top_right_x, top_right_y)
|
||||
plt.show()
|
||||
|
||||
|
||||
return checkerboard_marker_positions, cropped_marker_positions, checkerboard_position_tank
|
||||
|
||||
|
||||
def distance_factor_calculation(self, checkerboard_marker_positions, marker_positions):
|
||||
|
||||
checkerboard_top_right, checkerboard_top_left, checkerboard_bottom_right, checkerboard_bottom_left = assign_checkerboard_positions(checkerboard_marker_positions)
|
||||
|
||||
checkerboard_width = 0.24
|
||||
checkerboard_height = 0.18
|
||||
|
||||
checkerboard_width_pixel = checkerboard_top_right[0] - checkerboard_top_left[0]
|
||||
checkerboard_height_pixel = checkerboard_bottom_right[1] - checkerboard_top_right[1]
|
||||
|
||||
x_factor = checkerboard_width / checkerboard_width_pixel
|
||||
y_factor = checkerboard_height / checkerboard_height_pixel
|
||||
|
||||
bottom_left_x, bottom_left_y, bottom_right_x, bottom_right_y, top_left_x, top_left_y, top_right_x, top_right_y = assign_marker_positions(marker_positions)
|
||||
|
||||
tank_width_pixel = np.mean([bottom_right_x - bottom_left_x, top_right_x - top_left_x])
|
||||
tank_height_pixel = np.mean([bottom_left_y - top_left_y, bottom_right_y - top_right_y])
|
||||
|
||||
tank_width = tank_width_pixel * x_factor
|
||||
tank_height = tank_height_pixel * y_factor
|
||||
|
||||
print(tank_width, tank_height)
|
||||
|
||||
return x_factor, y_factor
|
||||
|
||||
|
||||
def distance_factor_interpolation(x_factors, y_factors):
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
all_x_factor = []
|
||||
all_y_factor = []
|
||||
all_checkerboard_position_tank = []
|
||||
for file_name in glob.glob("/home/efish/etrack/videos/*"):
|
||||
# file_name = "/home/efish/etrack/videos/2022.03.28_4.mp4"
|
||||
frame_number = 10
|
||||
dc = DistanceCalibration(file_name=file_name, frame_number=frame_number)
|
||||
|
||||
dc.mark_crop_positions()
|
||||
|
||||
checkerboard_marker_positions, cropped_marker_positions, checkerboard_position_tank = dc.detect_checkerboard(file_name, frame_number=frame_number, marker_positions=np.load('marker_positions.npy', allow_pickle=True))
|
||||
|
||||
x_factor, y_factor = dc.distance_factor_calculation(checkerboard_marker_positions, marker_positions=cropped_marker_positions)
|
||||
|
||||
all_x_factor.append(x_factor)
|
||||
all_y_factor.append(y_factor)
|
||||
all_checkerboard_position_tank.append(checkerboard_position_tank)
|
||||
|
||||
x_factors = np.load('x_factors.npy')
|
||||
y_factors = np.load('y_factors.npy')
|
||||
all_checkerboard_position_tank = np.load('all_checkerboard_position_tank.npy')
|
||||
|
||||
|
||||
embed()
|
||||
quit()
|
||||
|
||||
# next up: distance calculation with angle
|
||||
# is this needed or are current videos enough?:
|
||||
# laying checkerboard at position directly above and below / left and right to centered checkerboard near edge of tank
|
||||
# calculating x and y factor for centered checkerboard, then for the ones at the edge
|
||||
# --> afterwards interpolate between them to have continuous factors for whole tank
|
||||
# maybe smaller object in tank to have more accurate factor
|
||||
|
||||
# make function to refine checkerboard detection at edges of tank by saying if no lower color values appears near edge --> checkerboard position then == corner of tank?
|
||||
#
|
||||
# mark_crop_positions why failing plot at end?
|
||||
# with rectangles of checkerboard?
|
||||
|
||||
# embed()
|
@ -1,218 +0,0 @@
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import numbers as nb
|
||||
import os
|
||||
|
||||
"""
|
||||
x_0 = 0
|
||||
width = 1230
|
||||
y_0 = 0
|
||||
height = 1100
|
||||
x_factor = 0.81/width # Einheit m/px
|
||||
y_factor = 0.81/height # Einheit m/px
|
||||
center = (np.round(x_0 + width/2), np.round(y_0 + height/2))
|
||||
center_meter = ((center[0] - x_0) * x_factor, (center[1] - y_0) * y_factor)
|
||||
"""
|
||||
|
||||
class TrackingResult(object):
|
||||
|
||||
def __init__(self, results_file, x_0=0, y_0= 0, width_pixel=1975, height_pixel=1375, width_meter=0.81, height_meter=0.81) -> None:
|
||||
super().__init__()
|
||||
"""Width refers to the "x-axis" of the tank, height to the "y-axis" of it.
|
||||
|
||||
Args:
|
||||
results_file (_type_): Results file of the before done animal tracking.
|
||||
x_0 (int, optional): . Defaults to 95.
|
||||
y_0 (int, optional): _description_. Defaults to 185.
|
||||
width_pixel (int, optional): Width from one lightened corner of the tank to the other. Defaults to 1975.
|
||||
height_pixel (int, optional): Heigth from one lightened corner of the tank to the other. Defaults to 1375.
|
||||
width_meter (float, optional): Width of the tank in meter. Defaults to 0.81.
|
||||
height_meter (float, optional): Height of the tank in meter. Defaults to 0.81.
|
||||
"""
|
||||
if not os.path.exists(results_file):
|
||||
raise ValueError("File %s does not exist!" % results_file)
|
||||
self._file_name = results_file
|
||||
self.x_0 = x_0
|
||||
self.y_0 = y_0
|
||||
self.width_pix = width_pixel
|
||||
self.width_m = width_meter
|
||||
self.height_pix = height_pixel
|
||||
self.height_m = height_meter
|
||||
self.x_factor = self.width_m / self.width_pix # m/pix
|
||||
self.y_factor = self.height_m / self.height_pix # m/pix
|
||||
|
||||
self.center = (np.round(self.x_0 + self.width_pix/2), np.round(self.y_0 + self.height_pix/2)) # middle of width and height --> center
|
||||
self.center_meter = ((self.center[0] - self.x_0) * self.x_factor, (self.center[1] - self.y_0) * self.y_factor) # center in meter by multipling with factor
|
||||
|
||||
self._data_frame = pd.read_hdf(results_file) # read dataframe of scorer
|
||||
self._level_shape = self._data_frame.columns.levshape # shape of dataframe (?)
|
||||
self._scorer = self._data_frame.columns.levels[0].values # scorer of dataset
|
||||
self._bodyparts = self._data_frame.columns.levels[1].values if self._level_shape[1] > 0 else [] # tracked body parts
|
||||
self._positions = self._data_frame.columns.levels[2].values if self._level_shape[2] > 0 else [] # position in x and y values and the likelihood of it
|
||||
|
||||
def angle_to_center(self, bodypart=0, twopi=True, inversed_yaxis=False, min_likelihood=0.95):
|
||||
"""Angel of animal position in relation to the center of the tank.
|
||||
|
||||
Args:
|
||||
bodypart (int, optional): Bodypart of the animal. Defaults to 0.
|
||||
twopi (bool, optional): _description_. Defaults to True.
|
||||
inversed_yaxis (bool, optional): Inversed y-axis = True when 0 is at the top of axis. Defaults to False.
|
||||
min_likelihood (float, optional): The likelihood of the position estimation. Defaults to 0.95.
|
||||
|
||||
Raises:
|
||||
ValueError: No valid x-position values.
|
||||
|
||||
Returns:
|
||||
phi: Angle of animal in relation to center.
|
||||
"""
|
||||
if isinstance(bodypart, nb.Number): # check if the instance bodypart of this class is a number
|
||||
bp = self._bodyparts[bodypart]
|
||||
elif isinstance(bodypart, str) and bodypart in self._bodyparts: # or if bodypart is a string
|
||||
bp = bodypart
|
||||
else:
|
||||
raise ValueError("Bodypart %s is not in dataframe!" % bodypart) # or if it is existing
|
||||
_, x, y, _, _ = self.position_values(bodypart=bp, min_likelihood=min_likelihood) # set x and y values, already in meter from position_values
|
||||
if x is None:
|
||||
print("Error: no valid angles for %s" % self._file_name)
|
||||
return []
|
||||
x_to_center = x - self.center_meter[0] #
|
||||
y_to_center = y - self.center_meter[1]
|
||||
if inversed_yaxis == True:
|
||||
y_to_center *= -1
|
||||
phi = np.arctan2(y_to_center, x_to_center) * 180 / np.pi
|
||||
if twopi:
|
||||
phi[phi < 0] = 360 + phi[phi < 0]
|
||||
|
||||
return phi
|
||||
|
||||
def coordinate_transformation(self, position):
|
||||
x = (position[0] - self.x_0) * self.x_factor
|
||||
y = (position[1] - self.y_0) * self.y_factor
|
||||
return (x, y) #in m
|
||||
|
||||
@property
|
||||
def filename(self):
|
||||
return self._file_name
|
||||
|
||||
@property
|
||||
def dataframe(self):
|
||||
return self._data_frame
|
||||
|
||||
@property
|
||||
def scorer(self):
|
||||
return self._scorer
|
||||
|
||||
@property
|
||||
def bodyparts(self):
|
||||
return self._bodyparts
|
||||
|
||||
@property
|
||||
def positions(self):
|
||||
return self._positions
|
||||
|
||||
def position_values(self, scorer=0, bodypart=0, framerate=25, interpolate=True, min_likelihood=0.95):
|
||||
"""Returns the x and y positions of a bodypart over time and the likelihood of it.
|
||||
|
||||
Args:
|
||||
scorer (int, optional): Scorer of dataset. Defaults to 0.
|
||||
bodypart (int, optional): Bodypart of the animal. Can be seen in etrack.TrackingResults.bodyparts. Defaults to 0.
|
||||
framerate (int, optional): Framerate of the video. Defaults to 25.
|
||||
|
||||
Raises:
|
||||
ValueError: Scorer not existing in dataframe.
|
||||
ValueError: Bodypart not existing in dataframe.
|
||||
|
||||
Returns:
|
||||
time [np.array]: The time axis.
|
||||
x [np.array]: x-position in meter.
|
||||
y [np.array]: y-position in meter.
|
||||
l [np.array]: The likelihood of the position estimation. Originating from animal tracking done before.
|
||||
bp string: The body part of the animal.
|
||||
[type]: [description]
|
||||
"""
|
||||
|
||||
if isinstance(scorer, nb.Number):
|
||||
sc = self._scorer[scorer]
|
||||
elif isinstance(scorer, str) and scorer in self._scorer:
|
||||
sc = scorer
|
||||
else:
|
||||
raise ValueError("Scorer %s is not in dataframe!" % scorer)
|
||||
if isinstance(bodypart, nb.Number):
|
||||
bp = self._bodyparts[bodypart]
|
||||
elif isinstance(bodypart, str) and bodypart in self._bodyparts:
|
||||
bp = bodypart
|
||||
else:
|
||||
raise ValueError("Bodypart %s is not in dataframe!" % bodypart)
|
||||
|
||||
x = self._data_frame[sc][bp]["x"] if "x" in self._positions else []
|
||||
x = (np.asarray(x) - self.x_0) * self.x_factor
|
||||
y = self._data_frame[sc][bp]["y"] if "y" in self._positions else []
|
||||
y = (np.asarray(y) - self.y_0) * self.y_factor
|
||||
l = self._data_frame[sc][bp]["likelihood"] if "likelihood" in self._positions else []
|
||||
|
||||
time = np.arange(len(self._data_frame))/framerate
|
||||
time2 = time[l > min_likelihood]
|
||||
if len(l[l > min_likelihood]) < 100:
|
||||
print("%s has not datapoints with likelihood larger than %.2f" % (self._file_name, min_likelihood) )
|
||||
return None, None, None, None, None
|
||||
x2 = x[l > min_likelihood]
|
||||
y2 = y[l > min_likelihood]
|
||||
x3 = np.interp(time, time2, x2)
|
||||
y3 = np.interp(time, time2, y2)
|
||||
return time, x3, y3, l, bp
|
||||
|
||||
|
||||
def plot(self, scorer=0, bodypart=0, threshold=0.9, framerate=25):
|
||||
"""Plot the position of a bodypart in the tank over time.
|
||||
|
||||
Args:
|
||||
scorer (int, optional): Scorer of dataset. Defaults to 0.
|
||||
bodypart (int, optional): Given bodypart to plot. Defaults to 0.
|
||||
threshold (float, optional): Threshold below which the likelihood has to be. Defaults to 0.9.
|
||||
framerate (int, optional): Framerate of the video. Defaults to 25.
|
||||
"""
|
||||
t, x, y, l, name = self.position_values(scorer=scorer, bodypart=bodypart, framerate=framerate)
|
||||
plt.scatter(x[l > threshold], y[l > threshold], c=t[l > threshold], label=name)
|
||||
plt.scatter(self.center_meter[0], self.center_meter[1], marker="*")
|
||||
plt.plot(x[l > threshold], y[l > threshold])
|
||||
plt.xlabel("x position")
|
||||
plt.ylabel("y position")
|
||||
plt.gca().invert_yaxis()
|
||||
bar = plt.colorbar()
|
||||
bar.set_label("time [s]")
|
||||
plt.legend()
|
||||
plt.show()
|
||||
|
||||
pass
|
||||
|
||||
if __name__ == '__main__':
|
||||
from IPython import embed
|
||||
filename = "/2022.01.12_3DLC_resnet50_efish_tracking3Mar21shuffle1_300000.h5"
|
||||
path = "/home/efish/efish_tracking/efish_tracking3-Xaver-2022-03-21/videos"
|
||||
|
||||
tr = TrackingResult(path+filename) # usage of class with given file
|
||||
time, x, y, l, bp = tr.position_values(bodypart=2) # time, x and y values, likelihood of position estimation, tracked bodypart
|
||||
phi = tr.angle_to_center(0, True, False, 0.95)
|
||||
|
||||
thresh = 0.95
|
||||
time2 = time[l>thresh] # time values where likelihood of position estimation > threshold
|
||||
x2 = x[l>thresh] # x values with likelihood > threshold
|
||||
y2 = y[l>thresh] # y values -"-
|
||||
x3 = np.interp(time, time2, x2) # x value interpolation at points where likelihood has been under threshold
|
||||
y3 = np.interp(time, time2, y2) # y value -"-
|
||||
|
||||
|
||||
fig, axes = plt.subplots(3,1, sharex=True)
|
||||
axes[0].plot(time, x)
|
||||
axes[0].plot(time, x3)
|
||||
axes[0].set_ylabel('x-position')
|
||||
axes[1].plot(time, y)
|
||||
axes[1].plot(time, y3)
|
||||
axes[1].set_ylabel('y-position')
|
||||
axes[2].plot(time, l)
|
||||
axes[2].set_xlabel('time [s]')
|
||||
axes[2].set_ylabel('likelihood')
|
||||
plt.show()
|
||||
|
||||
embed()
|
17
mkdocs.yml
Normal file
17
mkdocs.yml
Normal file
@ -0,0 +1,17 @@
|
||||
site_name: etrack
|
||||
|
||||
repo_url: https://github.com/bendalab/etrack/
|
||||
|
||||
edit_uri: ""
|
||||
|
||||
site_author: Jan Grewe jan.grewe@g-node.org
|
||||
|
||||
theme: readthedocs
|
||||
|
||||
nav:
|
||||
- Home: 'index.md'
|
||||
- 'User guide':
|
||||
- 'etrack': 'etrack.md'
|
||||
- 'TrackingData' : 'trackingdata.md'
|
||||
- 'Code':
|
||||
- API reference: 'api/index.html'
|
48
pyproject.toml
Normal file
48
pyproject.toml
Normal file
@ -0,0 +1,48 @@
|
||||
[build-system]
|
||||
requires = ["setuptools"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "etrack"
|
||||
dynamic = ["version"]
|
||||
dependencies = [
|
||||
"hdf5",
|
||||
"nixtrack",
|
||||
"numpy",
|
||||
"matplotlib",
|
||||
"opencv-python",
|
||||
"pandas",
|
||||
"scikit-image",
|
||||
]
|
||||
requires-python = ">=3.6"
|
||||
authors = [
|
||||
{name = "Jan Grewe", email = "jan.grewe@g-node.org"},
|
||||
]
|
||||
maintainers = [
|
||||
{name = "Jan Grewe", email = "jan.grewe@g-node.org"},
|
||||
]
|
||||
description = "Goodies for working with tracking data of efishes."
|
||||
readme = "README.md"
|
||||
license = {file = "LICENSE"}
|
||||
classifiers = [
|
||||
"Development Status :: 4 - Beta",
|
||||
"Environment :: Console",
|
||||
"Intended Audience :: Science/Research",
|
||||
"License :: OSI Approved :: BSD-2-Clause",
|
||||
"Natural Language :: English",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
"Operating System :: OS Independent",
|
||||
"Topic :: Scientific/Engineering",
|
||||
"Topic :: Software Development :: Libraries :: Python Modules",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
Repository = "https://github.com/bendalab/etrack"
|
||||
|
||||
[tool.setuptools.dynamic]
|
||||
version = {attr = "etrack.info.VERSION"}
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
pythonpath = "src"
|
33
setup.py
33
setup.py
@ -1,33 +0,0 @@
|
||||
from setuptools import setup
|
||||
|
||||
NAME = "etrack"
|
||||
VERSION = 0.5
|
||||
AUTHOR = "Jan Grewe"
|
||||
CONTACT = "jan.grewe@g-node.org"
|
||||
CLASSIFIERS = "science"
|
||||
DESCRIPTION = "helpers for handling deep lab cut tracking results"
|
||||
|
||||
README = "README.md"
|
||||
with open(README) as f:
|
||||
description_text = f.read()
|
||||
|
||||
packages = [
|
||||
"etrack",
|
||||
]
|
||||
|
||||
install_req = ["h5py", "pandas", "matplotlib", "numpy", "opencv-python"]
|
||||
|
||||
setup(
|
||||
name=NAME,
|
||||
version=VERSION,
|
||||
description=DESCRIPTION,
|
||||
author=AUTHOR,
|
||||
author_email=CONTACT,
|
||||
packages=packages,
|
||||
install_requires=install_req,
|
||||
include_package_data=True,
|
||||
long_description=description_text,
|
||||
long_description_content_type="text/markdown",
|
||||
classifiers=CLASSIFIERS,
|
||||
license="BSD"
|
||||
)
|
14
src/etrack/__init__.py
Normal file
14
src/etrack/__init__.py
Normal file
@ -0,0 +1,14 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
""" etrack package for easier reading and handling of efish tracking data.
|
||||
|
||||
Copyright © 2024, Jan Grewe
|
||||
|
||||
Redistribution and use in source and binary forms, with or without modification, are permitted under the terms of the BSD License. See LICENSE file in the root of the Project.
|
||||
"""
|
||||
from .image_marker import ImageMarker, MarkerTask
|
||||
from .tracking_result import TrackingResult, coordinate_transformation
|
||||
from .arena import Arena, Region
|
||||
from .tracking_data import TrackingData
|
||||
from .io.dlc_data import DLCReader
|
||||
from .io.nixtrack_data import NixtrackData
|
||||
from .util import RegionShape, AnalysisType
|
527
src/etrack/arena.py
Normal file
527
src/etrack/arena.py
Normal file
@ -0,0 +1,527 @@
|
||||
"""
|
||||
Classes to construct the arena in which the animals were tracked.
|
||||
"""
|
||||
import logging
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.patches as patches
|
||||
|
||||
from skimage.draw import disk
|
||||
from .util import RegionShape, AnalysisType, Illumination
|
||||
|
||||
|
||||
class Region(object):
|
||||
"""
|
||||
Class representing a region (of interest). Regions can be either circular or rectangular.
|
||||
A Region can have a parent, i.e. it is contained inside a parent region. It can also have children.
|
||||
|
||||
Coordinates are given in absolute coordinates. The extent is treated depending on the shape. In case of a circular
|
||||
shape, it is the radius and the origin is the center of the circle. Otherwise the origin is the bottom, or top-left corner, depending on the y-axis orientation, if inverted, then it is top-left. FIXME: check this
|
||||
|
||||
"""
|
||||
def __init__(self, origin, extent, inverted_y=True, name="", region_shape=RegionShape.Rectangular, parent=None) -> None:
|
||||
"""Region constructor.
|
||||
Parameters
|
||||
----------
|
||||
origin : 2-tuple
|
||||
x, and y coordinates
|
||||
extent : scalar or 2-tuple, scalar only allowed to circular regions, 2-tuple for rectangular.
|
||||
inverted_y : bool, optional
|
||||
_description_, by default True
|
||||
name : str, optional
|
||||
_description_, by default ""
|
||||
region_shape : _type_, optional
|
||||
_description_, by default RegionShape.Rectangular
|
||||
parent : _type_, optional
|
||||
_description_, by default None
|
||||
|
||||
Returns
|
||||
-------
|
||||
_type_
|
||||
_description_
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
Raises Value error when origin or extent are invalid
|
||||
"""
|
||||
logging.debug(
|
||||
f"etrack.Region: Create {str(region_shape)} region {name} with props origin {origin}, extent {extent} and parent {parent}"
|
||||
)
|
||||
if len(origin) != 2:
|
||||
raise ValueError("Region: origin must be 2-tuple!")
|
||||
self._parent = parent
|
||||
self._name = name
|
||||
self._shape_type = region_shape
|
||||
self._origin = origin
|
||||
self._check_extent(extent)
|
||||
self._extent = extent
|
||||
self._inverted_y = inverted_y
|
||||
|
||||
@staticmethod
|
||||
def circular_mask(width, height, center, radius):
|
||||
assert center[1] + radius < width and center[1] - radius > 0
|
||||
assert center[0] + radius < height and center[0] - radius > 0
|
||||
|
||||
mask = np.zeros((height, width), dtype=np.uint8)
|
||||
rr, cc = disk(reversed(center), radius)
|
||||
mask[rr, cc] = 1
|
||||
|
||||
return mask
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def inverted_y(self):
|
||||
return self._inverted_y
|
||||
|
||||
@property
|
||||
def _max_extent(self):
|
||||
if self._shape_type == RegionShape.Rectangular:
|
||||
max_extent = (
|
||||
self._origin[0] + self._extent[0],
|
||||
self._origin[1] + self._extent[1],
|
||||
)
|
||||
else:
|
||||
max_extent = (
|
||||
self._origin[0] + self._extent,
|
||||
self._origin[1] + self._extent,
|
||||
)
|
||||
return np.asarray(max_extent)
|
||||
|
||||
@property
|
||||
def _min_extent(self):
|
||||
if self._shape_type == RegionShape.Rectangular:
|
||||
min_extent = self._origin
|
||||
else:
|
||||
min_extent = (
|
||||
self._origin[0] - self._extent,
|
||||
self._origin[1] - self._extent,
|
||||
)
|
||||
return np.asarray(min_extent)
|
||||
|
||||
@property
|
||||
def xmax(self):
|
||||
return self._max_extent[0]
|
||||
|
||||
@property
|
||||
def xmin(self):
|
||||
return self._min_extent[0]
|
||||
|
||||
@property
|
||||
def ymin(self):
|
||||
return self._min_extent[1]
|
||||
|
||||
@property
|
||||
def ymax(self):
|
||||
return self._max_extent[1]
|
||||
|
||||
@property
|
||||
def position(self):
|
||||
"""
|
||||
Get the position of the arena.
|
||||
|
||||
Returns
|
||||
-------
|
||||
tuple
|
||||
A tuple containing the x-coordinate, y-coordinate, width, and height of the arena.
|
||||
"""
|
||||
x = self._min_extent[0]
|
||||
y = self._min_extent[1]
|
||||
width = self._max_extent[0] - self._min_extent[0]
|
||||
height = self._max_extent[1] - self._min_extent[1]
|
||||
return x, y, width, height
|
||||
|
||||
def _check_extent(self, ext):
|
||||
"""Checks whether the extent matches the shape. i.e. if the shape is Rectangular, extent must be a length 2 list, tuple, otherwise, if the region is circular, extent must be a single numerical value.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ext : tuple, or numeric scalar
|
||||
"""
|
||||
if self._shape_type == RegionShape.Rectangular:
|
||||
if not isinstance(ext, (list, tuple, np.ndarray)) and len(ext) != 2:
|
||||
raise ValueError(
|
||||
"Extent must be a length 2 list or tuple for rectangular regions!"
|
||||
)
|
||||
elif self._shape_type == RegionShape.Circular:
|
||||
if not isinstance(ext, (int, float)):
|
||||
raise ValueError(
|
||||
"Extent must be a numerical scalar for circular regions!"
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid ShapeType, {self._shape_type}!")
|
||||
|
||||
def fits(self, other) -> bool:
|
||||
"""
|
||||
Checks if the given region fits into the current region.
|
||||
|
||||
Args:
|
||||
other (Region): The region to check if it fits.
|
||||
|
||||
Returns:
|
||||
bool: True if the given region fits into the current region, False otherwise.
|
||||
"""
|
||||
assert isinstance(other, Region)
|
||||
does_fit = all(
|
||||
(
|
||||
other._min_extent[0] >= self._min_extent[0],
|
||||
other._min_extent[1] >= self._min_extent[1],
|
||||
other._max_extent[0] <= self._max_extent[0],
|
||||
other._max_extent[1] <= self._max_extent[1],
|
||||
)
|
||||
)
|
||||
if not does_fit:
|
||||
m = (
|
||||
f"Region {other.name} does not fit into {self.name}. "
|
||||
f"min x: {other._min_extent[0] >= self._min_extent[0]},",
|
||||
f"min y: {other._min_extent[1] >= self._min_extent[1]},",
|
||||
f"max x: {other._max_extent[0] <= self._max_extent[0]},",
|
||||
f"max y: {other._max_extent[1] <= self._max_extent[1]}",
|
||||
)
|
||||
logging.debug(m)
|
||||
return does_fit
|
||||
|
||||
@property
|
||||
def is_child(self):
|
||||
"""
|
||||
Check if the current instance is a child.
|
||||
|
||||
Returns:
|
||||
bool: True if the instance has a parent, False otherwise.
|
||||
"""
|
||||
return self._parent is not None
|
||||
|
||||
def points_in_region(self, x, y, analysis_type=AnalysisType.Full):
|
||||
"""Returns the indices of the points specified by 'x' and 'y' that fall into this region.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : np.ndarray
|
||||
the x positions
|
||||
y : np.ndarray
|
||||
the y positions
|
||||
analysis_type : AnalysisType, optional
|
||||
defines how the positions are evaluated, by default AnalysisType.Full
|
||||
FIXME: some of this can probably be solved using linear algebra, what with multiple exact same points?
|
||||
"""
|
||||
if self._shape_type == RegionShape.Rectangular or (
|
||||
self._shape_type == RegionShape.Circular
|
||||
and analysis_type != AnalysisType.Full
|
||||
):
|
||||
if analysis_type == AnalysisType.Full:
|
||||
indices = np.where(
|
||||
((y >= self._min_extent[1]) & (y <= self._max_extent[1]))
|
||||
& ((x >= self._min_extent[0]) & (x <= self._max_extent[0]))
|
||||
)[0]
|
||||
indices = np.array(indices, dtype=int)
|
||||
elif analysis_type == AnalysisType.CollapseX:
|
||||
x_indices = np.where(
|
||||
(x >= self._min_extent[0]) & (x <= self._max_extent[0])
|
||||
)[0]
|
||||
indices = np.asarray(x_indices, dtype=int)
|
||||
else:
|
||||
y_indices = np.where(
|
||||
(y >= self._min_extent[1]) & (y <= self._max_extent[1])
|
||||
)[0]
|
||||
indices = np.asarray(y_indices, dtype=int)
|
||||
else:
|
||||
if self.is_child:
|
||||
mask = self.circular_mask(
|
||||
self._parent.position[2],
|
||||
self._parent.position[3],
|
||||
self._origin,
|
||||
self._extent,
|
||||
)
|
||||
else:
|
||||
mask = self.circular_mask(
|
||||
self.position[2], self.position[3], self._origin, self._extent
|
||||
)
|
||||
img = np.zeros_like(mask)
|
||||
img[np.asarray(y, dtype=int), np.asarray(x, dtype=int)] = 1
|
||||
temp = np.where(img & mask)
|
||||
indices = []
|
||||
for i, j in zip(list(temp[1]), list(temp[0])):
|
||||
matches = np.where((x == i) & (y == j))
|
||||
if len(matches[0]) == 0:
|
||||
continue
|
||||
indices.append(matches[0][0])
|
||||
indices = np.array(indices)
|
||||
return indices
|
||||
|
||||
def time_in_region(self, x, y, time, analysis_type=AnalysisType.Full):
|
||||
"""Returns the entering and leaving times at which the animal entered
|
||||
and left a region. In case the animal was not observed after entering
|
||||
this region (for example when hidden in a tube) the leaving time is
|
||||
the maximum time entry.
|
||||
Whether the full position, or only the x- or y-position should be considered
|
||||
is controlled with the analysis_type parameter.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : np.ndarray
|
||||
The animal's x-positions
|
||||
y : np.ndarray
|
||||
the animal's y-positions
|
||||
time : np.ndarray
|
||||
the time array
|
||||
analysis_type : AnalysisType, optional
|
||||
The type of analysis, by default AnalysisType.Full
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
The entering times
|
||||
np.ndarray
|
||||
The leaving times
|
||||
"""
|
||||
indices = self.points_in_region(x, y, analysis_type)
|
||||
if len(indices) == 0:
|
||||
return np.array([]), np.array([])
|
||||
|
||||
diffs = np.diff(indices)
|
||||
if len(diffs) == sum(diffs):
|
||||
entering = [time[indices[0]]]
|
||||
leaving = [time[indices[-1]]]
|
||||
else:
|
||||
entering = []
|
||||
leaving = []
|
||||
jumps = np.where(diffs > 1)[0]
|
||||
start = time[indices[0]]
|
||||
for i in range(len(jumps)):
|
||||
end = time[indices[jumps[i]]]
|
||||
entering.append(start)
|
||||
leaving.append(end)
|
||||
start = time[indices[jumps[i] + 1]]
|
||||
|
||||
end = time[indices[-1]]
|
||||
entering.append(start)
|
||||
leaving.append(end)
|
||||
return np.array(entering), np.array(leaving)
|
||||
|
||||
def patch(self, **kwargs):
|
||||
"""
|
||||
Create and return a matplotlib patch object based on the shape type of the arena.
|
||||
|
||||
Parameters:
|
||||
- kwargs: Additional keyword arguments to customize the patch object.
|
||||
|
||||
Returns:
|
||||
- A matplotlib patch object representing the arena shape.
|
||||
|
||||
If the 'fc' (facecolor) keyword argument is not provided, it will default to None.
|
||||
If the 'fill' keyword argument is not provided, it will default to False.
|
||||
|
||||
For rectangular arenas, the patch object will be a Rectangle with width and height
|
||||
based on the arena's position.
|
||||
For circular arenas, the patch object will be a Circle with radius based on the
|
||||
arena's extent.
|
||||
|
||||
Example usage:
|
||||
```
|
||||
arena = Arena()
|
||||
patch = arena.patch(fc='blue', fill=True)
|
||||
ax.add_patch(patch)
|
||||
```
|
||||
"""
|
||||
if "fc" not in kwargs:
|
||||
kwargs["fc"] = None
|
||||
kwargs["fill"] = False
|
||||
if self._shape_type == RegionShape.Rectangular:
|
||||
w = self.position[2]
|
||||
h = self.position[3]
|
||||
return patches.Rectangle(self._origin, w, h, **kwargs)
|
||||
else:
|
||||
return patches.Circle(self._origin, self._extent, **kwargs)
|
||||
|
||||
def __repr__(self):
|
||||
return f"Region: '{self._name}' of {self._shape_type} shape."
|
||||
|
||||
|
||||
class Arena(Region):
|
||||
"""
|
||||
Class to represent the experimental arena. Arena is derived from Region and can be either rectangular or circular.
|
||||
An arena can not have a parent.
|
||||
See Region for more details.
|
||||
"""
|
||||
def __init__(self, origin, extent, inverted_y=True, name="", arena_shape=RegionShape.Rectangular,
|
||||
illumination=Illumination.Backlight) -> None:
|
||||
""" Construct a new Area with a given origin and extent.
|
||||
|
||||
|
||||
Returns
|
||||
-------
|
||||
_type_
|
||||
_description_
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
_description_
|
||||
"""
|
||||
super().__init__(origin, extent, inverted_y, name, arena_shape)
|
||||
self._illumination = illumination
|
||||
self.regions = {}
|
||||
|
||||
def add_region(
|
||||
self, name, origin, extent, shape_type=RegionShape.Rectangular, region=None
|
||||
):
|
||||
if name is None or name in self.regions.keys():
|
||||
raise ValueError(
|
||||
"Region name '{name}' is invalid. The name must not be None and must be unique among the regions."
|
||||
)
|
||||
if region is None:
|
||||
region = Region(
|
||||
origin, extent, name=name, region_shape=shape_type, parent=self
|
||||
)
|
||||
else:
|
||||
region._parent = self
|
||||
doesfit = self.fits(region)
|
||||
if not doesfit:
|
||||
logging.warn(
|
||||
f"Warning! Region {region.name} with size {region.position} does fit into {self.name} with size {self.position}!"
|
||||
)
|
||||
self.regions[name] = region
|
||||
|
||||
def remove_region(self, name):
|
||||
"""
|
||||
Remove a region from the arena.
|
||||
|
||||
Parameter:
|
||||
name : str
|
||||
The name of the region to remove.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
if name in self.regions:
|
||||
self.regions.pop(name)
|
||||
|
||||
def __repr__(self):
|
||||
return f"Arena: '{self._name}' of {self._shape_type} shape."
|
||||
|
||||
def plot(self, axis=None):
|
||||
"""
|
||||
Plots the arena on the given axis.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
- axis (matplotlib.axes.Axes, optional): The axis on which to plot the arena. If not provided, a new figure and axis will be created.
|
||||
|
||||
Returns
|
||||
-------
|
||||
- matplotlib.axes.Axes: The axis on which the arena is plotted.
|
||||
"""
|
||||
if axis is None:
|
||||
fig = plt.figure()
|
||||
axis = fig.add_subplot(111)
|
||||
axis.add_patch(self.patch())
|
||||
axis.set_xlim([self._origin[0], self._max_extent[0]])
|
||||
|
||||
if self.inverted_y:
|
||||
axis.set_ylim([self._max_extent[1], self._origin[1]])
|
||||
else:
|
||||
axis.set_ylim([self._origin[1], self._max_extent[1]])
|
||||
for r in self.regions:
|
||||
axis.add_patch(self.regions[r].patch())
|
||||
return axis
|
||||
|
||||
def region_vector(self, x, y):
|
||||
"""Returns a vector that contains the region names within which the agent was found.
|
||||
FIXME: This does not work well with overlapping regions!@!
|
||||
Parameters
|
||||
----------
|
||||
x : np.array
|
||||
the x-positions
|
||||
y : np.ndarray
|
||||
the y-positions
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.array
|
||||
vector of the same size as x and y. Each entry is the region to which the position is assigned to. If the point is not assigned to a region, the entry will be empty.
|
||||
"""
|
||||
if not isinstance(x, np.ndarray):
|
||||
x = np.asarray(x)
|
||||
if not isinstance(y, np.ndarray):
|
||||
y = np.asarray(y)
|
||||
rv = np.empty(x.shape, dtype=str)
|
||||
for r in self.regions:
|
||||
indices = self.regions[r].points_in_region(x, y)
|
||||
rv[indices] = r
|
||||
return rv
|
||||
|
||||
def in_region(self, x, y):
|
||||
"""
|
||||
Determines if the given coordinates (x, y) are within any of the defined regions in the arena.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : float
|
||||
The x-coordinate of the point to check.
|
||||
y : float
|
||||
The y-coordinate of the point to check.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict:
|
||||
A dictionary containing the region names as keys and a list of indices of points within each region as values.
|
||||
"""
|
||||
tmp = {}
|
||||
for r in self.regions:
|
||||
print(r)
|
||||
indices = self.regions[r].points_in_region(x, y)
|
||||
tmp[r] = indices
|
||||
return tmp
|
||||
|
||||
def __getitem__(self, key):
|
||||
if isinstance(key, (str)):
|
||||
return self.regions[key]
|
||||
else:
|
||||
return self.regions[self.regions.keys()[key]]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
a = Arena((0, 0), (1024, 768), name="arena", arena_shape=RegionShape.Rectangular)
|
||||
a.add_region("small rect1", (0, 0), (100, 300))
|
||||
a.add_region("small rect2", (150, 0), (100, 300))
|
||||
a.add_region("small rect3", (300, 0), (100, 300))
|
||||
a.add_region("circ", (600, 400), 150, shape_type=RegionShape.Circular)
|
||||
axis = a.plot()
|
||||
x = np.linspace(a.position[0], a.position[0] + a.position[2] - 1, 100, dtype=int)
|
||||
y = np.asarray(
|
||||
(np.sin(x * 0.01) + 1) * a.position[3] / 2 + a.position[1] - 1, dtype=int
|
||||
)
|
||||
# y = np.linspace(a.position[1], a.position[1] + a.position[3] - 1, 100, dtype=int)
|
||||
axis.scatter(x, y, c="k", s=2)
|
||||
|
||||
ind = a.regions[3].points_in_region(x, y)
|
||||
if len(ind) > 0:
|
||||
axis.scatter(x[ind], y[ind], label="circ full")
|
||||
|
||||
ind = a.regions[3].points_in_region(x, y, AnalysisType.CollapseX)
|
||||
if len(ind) > 0:
|
||||
axis.scatter(x[ind], y[ind] - 10, label="circ collapseX")
|
||||
|
||||
ind = a.regions[3].points_in_region(x, y, AnalysisType.CollapseY)
|
||||
if len(ind) > 0:
|
||||
axis.scatter(x[ind], y[ind] + 10, label="circ collapseY")
|
||||
|
||||
ind = a.regions[0].points_in_region(x, y, AnalysisType.CollapseX)
|
||||
if len(ind) > 0:
|
||||
axis.scatter(x[ind], y[ind] - 10, label="rect collapseX")
|
||||
|
||||
ind = a.regions[1].points_in_region(x, y, AnalysisType.CollapseY)
|
||||
if len(ind) > 0:
|
||||
axis.scatter(x[ind], y[ind] + 10, label="rect collapseY")
|
||||
|
||||
ind = a.regions[2].points_in_region(x, y, AnalysisType.Full)
|
||||
if len(ind) > 0:
|
||||
axis.scatter(x[ind], y[ind] + 20, label="rect full")
|
||||
axis.legend()
|
||||
plt.show()
|
||||
|
||||
a.plot()
|
||||
plt.show()
|
@ -1,10 +1,11 @@
|
||||
from cv2 import calibrationMatrixValues
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import cv2
|
||||
"""
|
||||
Module that defines the ImageMarker and MarkerTask classes to manually mark things in individual images.
|
||||
"""
|
||||
import os
|
||||
import cv2
|
||||
import sys
|
||||
from IPython import embed
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
class ImageMarker:
|
||||
|
||||
@ -19,21 +20,8 @@ class ImageMarker:
|
||||
self._fig.canvas.mpl_connect('button_press_event', self._on_click_event)
|
||||
self._fig.canvas.mpl_connect('close_event', self._fig_close_event)
|
||||
self._fig.canvas.mpl_connect('key_press_event', self._key_press_event)
|
||||
|
||||
def mark_movie(self, filename, frame_number=10):
|
||||
""" Interactive GUI to mark the corners of the tank. A specific frame of the video can be chosen. Returns marker positions.
|
||||
|
||||
Args:
|
||||
filename: Videofile
|
||||
frame_number (int, optional): Number of a frame in the videofile. Defaults to 0.
|
||||
|
||||
Raises:
|
||||
IOError: File does not exist.
|
||||
|
||||
Returns:
|
||||
marker_positions: Marker positions of tank corners.
|
||||
"""
|
||||
|
||||
def mark_movie(self, filename, frame_number=0):
|
||||
if not os.path.exists(filename):
|
||||
raise IOError("file %s does not exist!" % filename)
|
||||
video = cv2.VideoCapture()
|
||||
@ -41,17 +29,18 @@ class ImageMarker:
|
||||
frame_counter = 0
|
||||
success = True
|
||||
frame = None
|
||||
while success and frame_counter <= frame_number: # iterating until frame_counter == frame_number --> success (True)
|
||||
while success and frame_counter <= frame_number:
|
||||
print("Reading frame: %i" % frame_counter, end="\r")
|
||||
success, frame = video.read()
|
||||
frame_counter += 1
|
||||
|
||||
if success:
|
||||
self._fig.gca().imshow(frame) # plot wanted frame of video
|
||||
self._fig.gca().imshow(frame)
|
||||
else:
|
||||
print("Could not read frame number %i either failed to open movie or beyond maximum frame number!" % frame_number)
|
||||
return []
|
||||
plt.ion() # turn on interactive mode
|
||||
plt.show(block=False) # block=False allows to continue interact in terminal while the figure is open
|
||||
plt.ion()
|
||||
plt.show(block=False)
|
||||
|
||||
self._task_index = -1
|
||||
if len(self._tasks) > 0:
|
||||
@ -65,7 +54,6 @@ class ImageMarker:
|
||||
self._fig.gca().set_title("All set and done!\n Window will close in 2s")
|
||||
self._fig.canvas.draw()
|
||||
plt.pause(2.0)
|
||||
plt.close() #self._fig.gca().imshow(frame))
|
||||
return [t.marker_positions for t in self._tasks]
|
||||
|
||||
def _key_press_event(self, event):
|
||||
@ -155,17 +143,14 @@ class MarkerTask():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Hello Jan!")
|
||||
tank_task = MarkerTask("tank limits", ["bottom left corner", "top left corner", "top right corner", "bottom right corner"], "Mark tank corners")
|
||||
feeder_task = MarkerTask("Feeder positions", list(map(str, range(1, 2))), "Mark feeder positions")
|
||||
tasks = [tank_task] # feeder_task]
|
||||
im = ImageMarker(tasks)
|
||||
|
||||
vid1 = "/home/efish/efish_tracking/efish_tracking3-Xaver-2022-03-21/videos/2022.01.12_3DLC_resnet50_efish_tracking3Mar21shuffle1_300000_labeled.mp4"
|
||||
marker_positions = im.mark_movie(vid1, 10)
|
||||
print(marker_positions)
|
||||
|
||||
#feeder_task = MarkerTask("Feeder positions", list(map(str, range(1, 2))), "Mark feeder positions")
|
||||
#tasks = [tank_task, feeder_task]
|
||||
im = ImageMarker([tank_task])
|
||||
vid1 = "/data/personality/secondhome/fischies/lepto_03/position/lepto03_position_2021.06.07_60.mp4"
|
||||
# print(sys.argv[0])
|
||||
# print (sys.argv[1])
|
||||
# vid1 = sys.argv[1]
|
||||
|
||||
embed()
|
||||
marker_positions = im.mark_movie(vid1, 00)
|
||||
print(marker_positions)
|
10
src/etrack/info.json
Normal file
10
src/etrack/info.json
Normal file
@ -0,0 +1,10 @@
|
||||
{
|
||||
"VERSION": "0.5.0",
|
||||
"STATUS": "Release",
|
||||
"RELEASE": "0.5.0 Release",
|
||||
"AUTHOR": "Jan Grewe",
|
||||
"COPYRIGHT": "2024, University of Tuebingen, Neuroethology, Jan Grewe",
|
||||
"CONTACT": "jan.grewe@g-node.org",
|
||||
"BRIEF": "Efish tracking helpers for handling tracking data.",
|
||||
"HOMEPAGE": "https://github.com/G-Node/nixpy"
|
||||
}
|
28
src/etrack/info.py
Normal file
28
src/etrack/info.py
Normal file
@ -0,0 +1,28 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright © 2024, Jan Grewe
|
||||
#
|
||||
# All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted under the terms of the BSD License. See
|
||||
# LICENSE file in the root of the Project.
|
||||
"""
|
||||
Package info.
|
||||
"""
|
||||
import os
|
||||
import json
|
||||
|
||||
here = os.path.dirname(__file__)
|
||||
|
||||
with open(os.path.join(here, "info.json")) as infofile:
|
||||
infodict = json.load(infofile)
|
||||
|
||||
|
||||
VERSION = infodict["VERSION"]
|
||||
STATUS = infodict["STATUS"]
|
||||
RELEASE = infodict["RELEASE"]
|
||||
AUTHOR = infodict["AUTHOR"]
|
||||
COPYRIGHT = infodict["COPYRIGHT"]
|
||||
CONTACT = infodict["CONTACT"]
|
||||
BRIEF = infodict["BRIEF"]
|
||||
HOMEPAGE = infodict["HOMEPAGE"]
|
3
src/etrack/io/__init__.py
Normal file
3
src/etrack/io/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
"""
|
||||
Reader classes for DeepLabCut, or SLEAP written data files.
|
||||
"""
|
78
src/etrack/io/dlc_data.py
Normal file
78
src/etrack/io/dlc_data.py
Normal file
@ -0,0 +1,78 @@
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import numbers as nb
|
||||
|
||||
from ..tracking_data import TrackingData
|
||||
|
||||
|
||||
class DLCReader(object):
|
||||
"""Class that represents the tracking data stored in a DeepLabCut hdf5 file."""
|
||||
def __init__(self, results_file, crop=(0, 0)) -> None:
|
||||
"""
|
||||
If the video data was cropped before tracking and the tracked positions are with respect to the cropped images, we may want to correct for this to convert the data back to absolute positions in the video frame.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
crop : 2-tuple
|
||||
tuple of (xoffset, yoffset)
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError if crop value is not a 2-tuple
|
||||
"""
|
||||
if not os.path.exists(results_file):
|
||||
raise ValueError("File %s does not exist!" % results_file)
|
||||
if not isinstance(crop, tuple) or len(crop) < 2:
|
||||
raise ValueError("Cropping info must be a 2-tuple of (x, y)")
|
||||
self._file_name = results_file
|
||||
self._crop = crop
|
||||
self._data_frame = pd.read_hdf(results_file)
|
||||
self._level_shape = self._data_frame.columns.levshape
|
||||
self._scorer = self._data_frame.columns.levels[0].values
|
||||
self._bodyparts = self._data_frame.columns.levels[1].values if self._level_shape[1] > 0 else []
|
||||
self._positions = self._data_frame.columns.levels[2].values if self._level_shape[2] > 0 else []
|
||||
|
||||
@property
|
||||
def filename(self):
|
||||
return self._file_name
|
||||
|
||||
@property
|
||||
def dataframe(self):
|
||||
return self._data_frame
|
||||
|
||||
@property
|
||||
def scorer(self):
|
||||
return self._scorer
|
||||
|
||||
@property
|
||||
def bodyparts(self):
|
||||
return self._bodyparts
|
||||
|
||||
def _correct_cropping(self, orgx, orgy):
|
||||
x = orgx + self._crop[0]
|
||||
y = orgy + self._crop[1]
|
||||
return x, y
|
||||
|
||||
def track(self, scorer=0, bodypart=0, framerate=30):
|
||||
if isinstance(scorer, nb.Number):
|
||||
sc = self._scorer[scorer]
|
||||
elif isinstance(scorer, str) and scorer in self._scorer:
|
||||
sc = scorer
|
||||
else:
|
||||
raise ValueError(f"Scorer {scorer} is not in dataframe!")
|
||||
if isinstance(bodypart, nb.Number):
|
||||
bp = self._bodyparts[bodypart]
|
||||
elif isinstance(bodypart, str) and bodypart in self._bodyparts:
|
||||
bp = bodypart
|
||||
else:
|
||||
raise ValueError(f"Body part {bodypart} is not in dataframe!")
|
||||
|
||||
x = np.asarray(self._data_frame[sc][bp]["x"] if "x" in self._positions else [])
|
||||
y = np.asarray(self._data_frame[sc][bp]["y"] if "y" in self._positions else [])
|
||||
x, y = self._correct_cropping(x, y)
|
||||
l = np.asarray(self._data_frame[sc][bp]["likelihood"] if "likelihood" in self._positions else [])
|
||||
|
||||
time = np.arange(len(x))/framerate
|
||||
|
||||
return TrackingData(x, y, time, l, bp, fps=framerate)
|
137
src/etrack/io/nixtrack_data.py
Normal file
137
src/etrack/io/nixtrack_data.py
Normal file
@ -0,0 +1,137 @@
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import numbers as nb
|
||||
import nixtrack as nt
|
||||
|
||||
from ..tracking_data import TrackingData
|
||||
from IPython import embed
|
||||
|
||||
|
||||
class NixtrackData(object):
|
||||
"""Wrapper around a nix data file that has been written accorind to the nixtrack model (https://github.com/bendalab/nixtrack)
|
||||
"""
|
||||
def __init__(self, filename, crop=(0, 0)) -> None:
|
||||
"""
|
||||
If the video data was cropped before tracking and the tracked positions are with respect to the cropped images, we may want to correct for this to convert the data back to absolute positions in the video frame.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
filename : str
|
||||
full filename
|
||||
crop : 2-tuple
|
||||
tuple of (xoffset, yoffset)
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError if crop value is not a 2-tuple
|
||||
"""
|
||||
if not os.path.exists(filename):
|
||||
raise ValueError("File %s does not exist!" % filename)
|
||||
if not isinstance(crop, tuple) or len(crop) < 2:
|
||||
raise ValueError("Cropping info must be a 2-tuple of (x, y)")
|
||||
self._file_name = filename
|
||||
self._crop = crop
|
||||
self._dataset = nt.Dataset(self._file_name)
|
||||
if not self._dataset.is_open:
|
||||
raise ValueError(f"An error occurred opening file {self._file_name}! File is not open!")
|
||||
|
||||
@property
|
||||
def filename(self):
|
||||
"""
|
||||
Returns the name of the file associated with the NixtrackData object.
|
||||
|
||||
Returns:
|
||||
str: The name of the file.
|
||||
"""
|
||||
return self._file_name
|
||||
|
||||
@property
|
||||
def bodyparts(self):
|
||||
"""
|
||||
Returns the bodyparts of the dataset.
|
||||
|
||||
Returns:
|
||||
list: A list of bodyparts.
|
||||
"""
|
||||
return self._dataset.nodes
|
||||
|
||||
def _correct_cropping(self, orgx, orgy):
|
||||
"""
|
||||
Corrects the coordinates based on the cropping values, If it cropping was done during tracking.
|
||||
|
||||
Args:
|
||||
orgx (int): The original x-coordinate.
|
||||
orgy (int): The original y-coordinate.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing the corrected x and y coordinates.
|
||||
"""
|
||||
x = orgx + self._crop[0]
|
||||
y = orgy + self._crop[1]
|
||||
return x, y
|
||||
|
||||
@property
|
||||
def fps(self):
|
||||
"""Property that holds frames per second of the original video.
|
||||
Returns
|
||||
-------
|
||||
int : the frames of second
|
||||
"""
|
||||
return self._dataset.fps
|
||||
|
||||
@property
|
||||
def tracks(self):
|
||||
"""
|
||||
Returns a list of tracks from the dataset.
|
||||
|
||||
Returns:
|
||||
list: A list of tracks.
|
||||
"""
|
||||
return [t[0] for t in self._dataset.tracks]
|
||||
|
||||
def track_data(self, bodypart=0, track=-1, fps=None):
|
||||
"""
|
||||
Retrieve tracking data for a specific body part and track.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
bodypart : int or str
|
||||
Index or name of the body part to retrieve tracking data for.
|
||||
track : int or str
|
||||
Index of the track to retrieve tracking data for.
|
||||
fps : float
|
||||
Frames per second of the tracking data. If not provided, it will be retrieved from the dataset.
|
||||
|
||||
Returns
|
||||
-------
|
||||
TrackingData: An object containing the x and y positions, time, score, body part name, and frames per second.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError: If the body part or track is not valid.
|
||||
"""
|
||||
if isinstance(bodypart, nb.Number):
|
||||
bp = self.bodyparts[bodypart]
|
||||
elif isinstance(bodypart, (str)) and bodypart in self.bodyparts:
|
||||
bp = bodypart
|
||||
else:
|
||||
raise ValueError(f"Body part {bodypart} is not a tracked node!")
|
||||
|
||||
if track not in self.tracks:
|
||||
raise ValueError(f"Track {track} is not a valid track name!")
|
||||
if not isinstance(track, (list, tuple)):
|
||||
track = [track]
|
||||
elif isinstance(track, int):
|
||||
track = [self.tracks[track]]
|
||||
|
||||
if fps is None:
|
||||
fps = self._dataset.fps
|
||||
|
||||
positions, time, _, nscore = self._dataset.positions(node=bp, axis_type=nt.AxisType.Time)
|
||||
valid = ~np.isnan(positions[:, 0])
|
||||
positions = positions[valid,:]
|
||||
time = time[valid]
|
||||
score = nscore[valid]
|
||||
|
||||
return TrackingData(positions[:, 0], positions[:, 1], time, score, bp, fps=fps)
|
267
src/etrack/tracking_data.py
Normal file
267
src/etrack/tracking_data.py
Normal file
@ -0,0 +1,267 @@
|
||||
"""
|
||||
Module that defines the TrackingData class that wraps the position data for a given node/bodypart that has been tracked.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class TrackingData(object):
|
||||
"""Class that represents tracking data, i.e. positions of an agent tracked in an environment.
|
||||
These data are the x, and y-positions, the time at which the agent was detected, and the quality associated with the position estimation.
|
||||
TrackingData contains these data and offers a few functions to work with it.
|
||||
Using the 'quality_threshold', 'temporal_limits', or the 'position_limits' data can be filtered (see filter_tracks function).
|
||||
The 'interpolate' function allows to fill up the gaps that may result from filtering with linearly interpolated data points.
|
||||
|
||||
More may follow...
|
||||
"""
|
||||
|
||||
def __init__(self, x, y, time, quality, node="", fps=None,
|
||||
quality_threshold=None, temporal_limits=None, position_limits=None) -> None:
|
||||
"""
|
||||
Initialize a TrackingData object.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : float
|
||||
The x-coordinates of the tracking data.
|
||||
y : float
|
||||
The y-coordinates of the tracking data.
|
||||
time : float
|
||||
The time vector associated with the x-, and y-coordinates.
|
||||
quality : float
|
||||
The quality score associated with the position estimates.
|
||||
node : str, optional
|
||||
The node name associated with the data. Default is an empty string.
|
||||
fps : float, optional
|
||||
The frames per second of the tracking data. Default is None.
|
||||
quality_threshold : float, optional
|
||||
The quality threshold for the tracking data. Default is None.
|
||||
temporal_limits : tuple, optional
|
||||
The temporal limits for the tracking data. Default is None.
|
||||
position_limits : tuple, optional
|
||||
The position limits for the tracking data. Default is None.
|
||||
"""
|
||||
self._orgx = x
|
||||
self._orgy = y
|
||||
self._orgtime = time
|
||||
self._orgquality = quality
|
||||
self._x = x
|
||||
self._y = y
|
||||
self._time = time
|
||||
self._quality = quality
|
||||
self._node = node
|
||||
self._threshold = quality_threshold
|
||||
self._position_limits = position_limits
|
||||
self._time_limits = temporal_limits
|
||||
self._fps = fps
|
||||
|
||||
@property
|
||||
def original_positions(self):
|
||||
return self._orgx, self._orgy
|
||||
|
||||
@property
|
||||
def original_quality(self):
|
||||
return self._orgquality
|
||||
|
||||
def interpolate(self, start_time=None, end_time=None, min_count=5):
|
||||
if len(self._x) < min_count:
|
||||
print(
|
||||
f"{self._node} data has less than {min_count} data points with sufficient quality ({len(self._x)})!"
|
||||
)
|
||||
return None, None, None
|
||||
start = self._time[0] if start_time is None else start_time
|
||||
end = self._time[-1] if end_time is None else end_time
|
||||
time = np.arange(start, end, 1.0 / self._fps)
|
||||
x = np.interp(time, self._time, self._x)
|
||||
y = np.interp(time, self._time, self._y)
|
||||
|
||||
return x, y, time
|
||||
|
||||
@property
|
||||
def quality_threshold(self):
|
||||
"""Property that holds the quality filter setting.
|
||||
|
||||
Returns
|
||||
-------
|
||||
float : the quality threshold
|
||||
"""
|
||||
return self._threshold
|
||||
|
||||
@quality_threshold.setter
|
||||
def quality_threshold(self, new_threshold):
|
||||
"""Setter of the quality threshold that should be applied when filtering the data. Setting this to None removes the quality filter.
|
||||
|
||||
Data points that have a quality score below the given threshold are discarded.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
new_threshold : float
|
||||
|
||||
"""
|
||||
self._threshold = new_threshold
|
||||
|
||||
@property
|
||||
def position_limits(self):
|
||||
"""
|
||||
Get the position limits of the tracking data.
|
||||
|
||||
Returns:
|
||||
tuple: A 4-tuple containing the start x, and y positions, width and height limits.
|
||||
"""
|
||||
return self._position_limits
|
||||
|
||||
@position_limits.setter
|
||||
def position_limits(self, new_limits):
|
||||
"""Sets the limits for the position filter. 'new_limits' must be a 4-tuple of the form (x0, y0, width, height). If None, the limits will be removed.
|
||||
Data points outside the position limits are discarded.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
new_limits: 4-tuple
|
||||
tuple of x-position, y-position, the width and the height. Passing None removes the filter
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError, if new_value is not a 4-tuple
|
||||
"""
|
||||
if new_limits is not None and not (
|
||||
isinstance(new_limits, (tuple, list)) and len(new_limits) == 4
|
||||
):
|
||||
raise ValueError(
|
||||
f"The new_limits vector must be a 4-tuple of the form (x, y, width, height)"
|
||||
)
|
||||
self._position_limits = new_limits
|
||||
|
||||
@property
|
||||
def temporal_limits(self):
|
||||
"""
|
||||
Get the temporal limits of the tracking data.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing the start and end time of the tracking data.
|
||||
"""
|
||||
return self._time_limits
|
||||
|
||||
@temporal_limits.setter
|
||||
def temporal_limits(self, new_limits):
|
||||
"""Limits for temporal filter. The limits must be a 2-tuple of start and end time. Setting this to None removes the filter.
|
||||
Data points the are associated with times outside the limits are discarded.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
new_limits : 2-tuple
|
||||
The new limits in the form (start, end) given in seconds.
|
||||
Returns
|
||||
-------
|
||||
None
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError if the limits are not valid.
|
||||
"""
|
||||
if new_limits is not None and not (
|
||||
isinstance(new_limits, (tuple, list)) and len(new_limits) == 2
|
||||
):
|
||||
raise ValueError(
|
||||
f"The new_limits vector must be a 2-tuple of the form (start, end). "
|
||||
)
|
||||
self._time_limits = new_limits
|
||||
|
||||
def filter_tracks(self, align_time=True):
|
||||
"""Applies the filters to the tracking data. All filters will be applied sequentially, i.e. an AND connection.
|
||||
To change the filter settings use the setters for 'quality_threshold', 'temporal_limits', 'position_limits'. Setting them to None disables the respective filter discarding the setting.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
align_time: bool
|
||||
Controls whether the time vector is aligned to the first time point at which the agent is within the positional_limits. Default = True
|
||||
"""
|
||||
self._x = self._orgx.copy()
|
||||
self._y = self._orgy.copy()
|
||||
self._time = self._orgtime.copy()
|
||||
self._quality = self.original_quality.copy()
|
||||
|
||||
if self.position_limits is not None:
|
||||
x_max = self.position_limits[0] + self.position_limits[2]
|
||||
y_max = self.position_limits[1] + self.position_limits[3]
|
||||
indices = np.where(
|
||||
(self._x >= self.position_limits[0])
|
||||
& (self._x < x_max)
|
||||
& (self._y >= self.position_limits[1])
|
||||
& (self._y < y_max)
|
||||
)
|
||||
self._x = self._x[indices]
|
||||
self._y = self._y[indices]
|
||||
self._time = self._time[indices] - self._time[0] if align_time else 0.0
|
||||
self._quality = self._quality[indices]
|
||||
|
||||
if self.temporal_limits is not None:
|
||||
indices = np.where(
|
||||
(self._time >= self.temporal_limits[0])
|
||||
& (self._time < self.temporal_limits[1])
|
||||
)
|
||||
self._x = self._x[indices]
|
||||
self._y = self._y[indices]
|
||||
self._time = self._time[indices]
|
||||
self._quality = self._quality[indices]
|
||||
|
||||
if self.quality_threshold is not None:
|
||||
indices = np.where((self._quality >= self.quality_threshold))
|
||||
self._x = self._x[indices]
|
||||
self._y = self._y[indices]
|
||||
self._time = self._time[indices]
|
||||
self._quality = self._quality[indices]
|
||||
|
||||
def positions(self):
|
||||
"""Returns the filtered data (if filters have been applied, otherwise the original data).
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
The x-positions
|
||||
np.ndarray
|
||||
The y-positions
|
||||
np.ndarray
|
||||
The time vector
|
||||
np.ndarray
|
||||
The detection quality
|
||||
"""
|
||||
return self._x, self._y, self._time, self._quality
|
||||
|
||||
def speed(self, x=None, y=None, t=None):
|
||||
""" Returns the agent's speed as a function of time and position. The speed estimation is associated to the time/position between two sample points. If any of the arguments is not provided, the function will use the x,y coordinates that are stored within the object, otherwise, if all are provided, the user-provided values will be used.
|
||||
|
||||
Since the velocities are estimated from the difference between two sample points the returned velocities and positions are assigned to positions and times between the respective sampled positions/times.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x: np.ndarray
|
||||
The x-coordinates, defaults to None
|
||||
y: np.ndarray
|
||||
The y-coordinates, defaults to None
|
||||
t: np.ndarray
|
||||
The time vector, defaults to None
|
||||
Returns
|
||||
-------
|
||||
np.ndarray:
|
||||
The time vector.
|
||||
np.ndarray:
|
||||
The speed.
|
||||
tuple of np.ndarray
|
||||
The position
|
||||
"""
|
||||
if x is None or y is None or t is None:
|
||||
x = self._x.copy()
|
||||
y = self._y.copy()
|
||||
t = self._time.copy()
|
||||
dt = np.diff(t)
|
||||
speed = np.sqrt(np.diff(x)**2 + np.diff(y)**2) / dt
|
||||
t = t[:-1] + dt / 2
|
||||
x = x[:-1] + np.diff(x) / 2
|
||||
y = y[:-1] + np.diff(y) / 2
|
||||
|
||||
return t, speed, (x, y)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
s = f"Tracking data of node '{self._node}'!"
|
||||
return s
|
197
src/etrack/tracking_result.py
Normal file
197
src/etrack/tracking_result.py
Normal file
@ -0,0 +1,197 @@
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import numbers as nb
|
||||
import os
|
||||
|
||||
"""
|
||||
x_0 = 0
|
||||
width = 1230
|
||||
y_0 = 0
|
||||
height = 1100
|
||||
x_factor = 0.81/width # Einheit m/px
|
||||
y_factor = 0.81/height # Einheit m/px
|
||||
center = (np.round(x_0 + width/2), np.round(y_0 + height/2))
|
||||
center_meter = ((center[0] - x_0) * x_factor, (center[1] - y_0) * y_factor)
|
||||
"""
|
||||
def coordinate_transformation(position,x_0, y_0, x_factor, y_factor):
|
||||
x = (position[0] - x_0) * x_factor
|
||||
y = (position[1] - y_0) * y_factor
|
||||
return (x, y) #in m
|
||||
|
||||
class TrackingResult(object):
|
||||
|
||||
def __init__(self, results_file, x_0=0, y_0= 0, width_pixel=1230, height_pixel=1100, width_meter=0.81, height_meter=0.81) -> None:
|
||||
super().__init__()
|
||||
if not os.path.exists(results_file):
|
||||
raise ValueError("File %s does not exist!" % results_file)
|
||||
self._file_name = results_file
|
||||
self.x_0 = x_0
|
||||
self.y_0 = y_0
|
||||
self.width_pix = width_pixel
|
||||
self.width_m = width_meter
|
||||
self.height_pix = height_pixel
|
||||
self.height_m = height_meter
|
||||
self.x_factor = self.width_m / self.width_pix # m/pix
|
||||
self.y_factor = self.height_m / self.height_pix # m/pix
|
||||
|
||||
self.center = (np.round(self.x_0 + self.width_pix/2), np.round(self.y_0 + self.height_pix/2))
|
||||
self.center_meter = ((self.center[0] - self.x_0) * self.x_factor, (self.center[1] - self.y_0) * self.y_factor)
|
||||
|
||||
self._data_frame = pd.read_hdf(results_file)
|
||||
self._level_shape = self._data_frame.columns.levshape
|
||||
self._scorer = self._data_frame.columns.levels[0].values
|
||||
self._bodyparts = self._data_frame.columns.levels[1].values if self._level_shape[1] > 0 else []
|
||||
self._positions = self._data_frame.columns.levels[2].values if self._level_shape[2] > 0 else []
|
||||
|
||||
def angle_to_center(self, bodypart=0, twopi=True, origin="topleft", min_likelihood=0.95):
|
||||
if isinstance(bodypart, nb.Number):
|
||||
bp = self._bodyparts[bodypart]
|
||||
elif isinstance(bodypart, str) and bodypart in self._bodyparts:
|
||||
bp = bodypart
|
||||
else:
|
||||
raise ValueError("Bodypart %s is not in dataframe!" % bodypart)
|
||||
_, x, y, _, _ = self.position_values(bodypart=bp, min_likelihood=min_likelihood)
|
||||
if x is None:
|
||||
print("Error: no valid angles for %s" % self._file_name)
|
||||
return []
|
||||
x_meter = x - self.center_meter[0]
|
||||
y_meter = y - self.center_meter[1]
|
||||
if origin.lower() == "topleft":
|
||||
y_meter *= -1
|
||||
phi = np.arctan2(y_meter, x_meter) * 180 / np.pi
|
||||
if twopi:
|
||||
phi[phi < 0] = 360 + phi[phi < 0]
|
||||
return phi
|
||||
|
||||
def coordinate_transformation(self, position):
|
||||
x = (position[0] - self.x_0) * self.x_factor
|
||||
y = (position[1] - self.y_0) * self.y_factor
|
||||
return (x, y) #in m
|
||||
|
||||
@property
|
||||
def filename(self):
|
||||
return self._file_name
|
||||
|
||||
@property
|
||||
def dataframe(self):
|
||||
return self._data_frame
|
||||
|
||||
@property
|
||||
def scorer(self):
|
||||
return self._scorer
|
||||
|
||||
@property
|
||||
def bodyparts(self):
|
||||
return self._bodyparts
|
||||
|
||||
@property
|
||||
def positions(self):
|
||||
return self._positions
|
||||
|
||||
def position_values(self, scorer=0, bodypart=0, framerate=30, interpolate=True, min_likelihood=0.95):
|
||||
"""returns the x and y positions in m and the likelihood of the positions.
|
||||
|
||||
Args:
|
||||
scorer (int, optional): [description]. Defaults to 0.
|
||||
bodypart (int, optional): [description]. Defaults to 0.
|
||||
framerate (int, optional): [description]. Defaults to 30.
|
||||
|
||||
Raises:
|
||||
ValueError: [description]
|
||||
ValueError: [description]
|
||||
|
||||
Returns:
|
||||
time [np.array]: the time axis
|
||||
x [np.array]: the x-position in m
|
||||
y [np.array]: the y-position in m
|
||||
l [np.array]: the likelihood of the position estimation
|
||||
bp string: the body part
|
||||
[type]: [description]
|
||||
"""
|
||||
time, x, y, l, bp = self.pixel_positions(scorer, bodypart, framerate, interpolate, min_likelihood)
|
||||
x, y = self._to_meter(x, y)
|
||||
return time, x, y, l, bp
|
||||
|
||||
def pixel_positions(self, scorer=0, bodypart=0, framerate=30, interpolate=True, min_likelihood=0.95):
|
||||
if isinstance(scorer, nb.Number):
|
||||
sc = self._scorer[scorer]
|
||||
elif isinstance(scorer, str) and scorer in self._scorer:
|
||||
sc = scorer
|
||||
else:
|
||||
raise ValueError("Scorer %s is not in dataframe!" % scorer)
|
||||
if isinstance(bodypart, nb.Number):
|
||||
bp = self._bodyparts[bodypart]
|
||||
elif isinstance(bodypart, str) and bodypart in self._bodyparts:
|
||||
bp = bodypart
|
||||
else:
|
||||
raise ValueError("Bodypart %s is not in dataframe!" % bodypart)
|
||||
|
||||
x = np.asarray(self._data_frame[sc][bp]["x"] if "x" in self._positions else [])
|
||||
y = np.asarray(self._data_frame[sc][bp]["y"] if "y" in self._positions else [])
|
||||
l = np.asarray(self._data_frame[sc][bp]["likelihood"] if "likelihood" in self._positions else [])
|
||||
|
||||
time = np.arange(len(x))/framerate
|
||||
if interpolate:
|
||||
x, y = self.interpolate(time, x, y, l, min_likelihood)
|
||||
return time, x, y, l, bp
|
||||
|
||||
def _to_meter(self, x, y):
|
||||
new_x = (np.asarray(x) - self.x_0) * self.x_factor
|
||||
new_y = (np.asarray(y) - self.y_0) * self.y_factor
|
||||
return new_x, new_y
|
||||
|
||||
def _speed(self, t, x, y):
|
||||
speed = np.sqrt(np.diff(x)**2 + np.diff(y)**2) / np.diff(t)
|
||||
return speed
|
||||
|
||||
def interpolate(self, t, x, y, l, min_likelihood=0.9):
|
||||
time2 = t[l > min_likelihood]
|
||||
if len(l[l > min_likelihood]) < 10:
|
||||
print("%s has less than 10 datapoints with likelihood larger than %.2f" % (self._file_name, min_likelihood) )
|
||||
return None, None
|
||||
x2 = x[l > min_likelihood]
|
||||
y2 = y[l > min_likelihood]
|
||||
x3 = np.interp(t, time2, x2)
|
||||
y3 = np.interp(t, time2, y2)
|
||||
return x3, y3
|
||||
|
||||
def plot(self, scorer=0, bodypart=0, threshold=0.9, framerate=30):
|
||||
t, x, y, l, name = self.position_values(scorer=scorer, bodypart=bodypart, framerate=framerate, min_likelihood=threshold)
|
||||
plt.scatter(x[l > threshold], y[l > threshold], c=t[l > threshold], label=name)
|
||||
plt.scatter(self.center_meter[0], self.center_meter[1], marker="*")
|
||||
plt.plot(x[l > threshold], y[l > threshold])
|
||||
plt.xlabel("x position")
|
||||
plt.ylabel("y position")
|
||||
plt.gca().invert_yaxis()
|
||||
bar = plt.colorbar()
|
||||
bar.set_label("time [s]")
|
||||
plt.legend()
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
from IPython import embed
|
||||
filename = "2020.12.04_lepto48DLC_resnet50_boldnessDec11shuffle1_200000.h5"
|
||||
path = "/mnt/movies/merle_verena/boldness/labeled_videos/day_4/"
|
||||
tr = TrackingResult(path+filename)
|
||||
time, x, y, l, bp = tr.position_values(bodypart=2)
|
||||
|
||||
|
||||
thresh = 0.95
|
||||
time2 = time[l>thresh]
|
||||
x2 = x[l>thresh]
|
||||
y2 = y[l>thresh]
|
||||
x3 = np.interp(time, time2, x2)
|
||||
y3 = np.interp(time, time2, y2)
|
||||
|
||||
fig, axes = plt.subplots(3,1, sharex=True)
|
||||
axes[0].plot(time, x)
|
||||
axes[0].plot(time, x3)
|
||||
axes[1].plot(time, y)
|
||||
axes[1].plot(time, y3)
|
||||
|
||||
axes[2].plot(time, l)
|
||||
plt.show()
|
||||
|
||||
embed()
|
52
src/etrack/util.py
Normal file
52
src/etrack/util.py
Normal file
@ -0,0 +1,52 @@
|
||||
"""
|
||||
Module containing utility functions and enum classes.
|
||||
"""
|
||||
from enum import Enum
|
||||
|
||||
class Illumination(Enum):
|
||||
Backlight = 0
|
||||
Incident = 1
|
||||
|
||||
|
||||
class RegionShape(Enum):
|
||||
"""
|
||||
Enumeration representing the shape of a region.
|
||||
|
||||
Attributes:
|
||||
Circular: Represents a circular region.
|
||||
Rectangular: Represents a rectangular region.
|
||||
"""
|
||||
|
||||
Circular = 0
|
||||
Rectangular = 1
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.name
|
||||
|
||||
|
||||
class AnalysisType(Enum):
|
||||
"""
|
||||
Enumeration representing different types of analysis used when analyzing whether
|
||||
positions fall into a given region.
|
||||
|
||||
Possible types:
|
||||
AnalysisType.Full: considers both, the x- and the y-coordinates
|
||||
AnalysisType.CollapseX: consider only the x-coordinates
|
||||
AnalysisType.CollapseY: consider only the y-coordinates
|
||||
"""
|
||||
Full = 0
|
||||
CollapseX = 1
|
||||
CollapseY = 2
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""
|
||||
Returns the string representation of the analysis type.
|
||||
|
||||
Returns:
|
||||
str: The name of the analysis type.
|
||||
"""
|
||||
return self.name
|
||||
|
||||
class PositionType(Enum):
|
||||
Absolute = 0
|
||||
Cropped = 1
|
BIN
test/2022lepto01_converted_2024.03.27_0.mp4.nix
Normal file
BIN
test/2022lepto01_converted_2024.03.27_0.mp4.nix
Normal file
Binary file not shown.
82
test/test_arena.py
Normal file
82
test/test_arena.py
Normal file
@ -0,0 +1,82 @@
|
||||
import pytest
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.patches as mp
|
||||
|
||||
from etrack import Arena, Region, RegionShape
|
||||
|
||||
|
||||
def test_region():
|
||||
# Create a parent region
|
||||
parent_region = Region((0, 0), (100, 100), name="parent", region_shape=RegionShape.Rectangular)
|
||||
|
||||
# Create a child region
|
||||
child_region = Region((10, 10), (50, 50), name="child", region_shape=RegionShape.Rectangular, parent=parent_region)
|
||||
|
||||
# Test properties
|
||||
assert child_region.name == "child"
|
||||
assert child_region.inverted_y == True
|
||||
assert (child_region._max_extent == np.array((60, 60))).all()
|
||||
assert (child_region._min_extent == np.array((10, 10))).all()
|
||||
assert child_region.xmax == 60
|
||||
assert child_region.xmin == 10
|
||||
assert child_region.ymin == 10
|
||||
assert child_region.ymax == 60
|
||||
assert child_region.position == (10, 10, 50, 50)
|
||||
assert child_region.is_child == True
|
||||
|
||||
# Test fits method
|
||||
assert parent_region.fits(child_region) == True
|
||||
|
||||
# Test points_in_region method
|
||||
x = [15, 20, 25, 30, 35, 5]
|
||||
y = [15, 20, 25, 30, 35, 5]
|
||||
assert (child_region.points_in_region(x, y) == np.array([0, 1, 2, 3, 4])).all()
|
||||
|
||||
# Test time_in_region method
|
||||
x = [5, 15, 20, 25, 30, 35, 35]
|
||||
y = [5, 15, 20, 25, 30, 35, 65]
|
||||
time = np.arange(0, len(x), 1)
|
||||
enter, leave = child_region.time_in_region(x, y, time)
|
||||
assert enter[0] == 1
|
||||
assert leave[0] == 5
|
||||
|
||||
# Test patch method
|
||||
patch = child_region.patch(color='red')
|
||||
assert isinstance(patch, mp.Patch)
|
||||
|
||||
# Test __repr__ method
|
||||
assert repr(child_region) == "Region: 'child' of Rectangular shape."
|
||||
|
||||
|
||||
def test_arena():
|
||||
# Create an arena
|
||||
arena = Arena((0, 0), (100, 100), name="arena", arena_shape=RegionShape.Rectangular)
|
||||
# Test add_region method
|
||||
arena.add_region("small rect1", (0, 0), (50, 50))
|
||||
assert len(arena.regions) == 1
|
||||
assert arena.regions["small rect1"].name == "small rect1"
|
||||
# Test remove_region method
|
||||
arena.remove_region("small rect1")
|
||||
assert len(arena.regions) == 0
|
||||
# Test plot method
|
||||
axis = arena.plot()
|
||||
assert isinstance(axis, plt.Axes)
|
||||
# Test region_vector method
|
||||
x = [10, 20, 30]
|
||||
y = [10, 20, 30]
|
||||
assert (arena.region_vector(x, y) == "").all()
|
||||
|
||||
# Test in_region method
|
||||
# assert len(arena.in_region(10, 10)) > 0
|
||||
# print(arena.in_region(10, 10))
|
||||
|
||||
# print(arena.in_region(110, 110))
|
||||
# assert arena.in_region(110, 110) == False
|
||||
# Test __getitem__ method
|
||||
arena.add_region("small rect2", (0, 0), (50, 50))
|
||||
assert arena["small rect2"].name == "small rect2"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main()
|
32
test/test_nixtrackio.py
Normal file
32
test/test_nixtrackio.py
Normal file
@ -0,0 +1,32 @@
|
||||
import pytest
|
||||
import etrack as et
|
||||
|
||||
from IPython import embed
|
||||
|
||||
dataset = "test/2022lepto01_converted_2024.03.27_0.mp4.nix"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def nixtrack_data():
|
||||
# Create a NixTrackData object with some test data
|
||||
return et.NixtrackData(dataset)
|
||||
|
||||
|
||||
def test_basics(nixtrack_data):
|
||||
assert nixtrack_data.filename == dataset
|
||||
assert len(nixtrack_data.bodyparts) == 5
|
||||
assert len(nixtrack_data.tracks) == 2
|
||||
assert nixtrack_data.fps == 25
|
||||
|
||||
|
||||
def test_trackingdata(nixtrack_data):
|
||||
with pytest.raises(ValueError):
|
||||
nixtrack_data.track_data(bodypart="test")
|
||||
nixtrack_data.track_data(track="fish")
|
||||
|
||||
assert nixtrack_data.track_data("center") is not None
|
||||
assert nixtrack_data.track_data("center", "none") is not None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main()
|
Loading…
Reference in New Issue
Block a user