forked from jgrewe/efish_tracking
initials
This commit is contained in:
commit
a6e5033597
35
README.md
Normal file
35
README.md
Normal file
@ -0,0 +1,35 @@
|
||||
# E-Fish tracking
|
||||
|
||||
Tool for easier handling of tracking results.
|
||||
|
||||
## Installation
|
||||
|
||||
### 1. Clone git repository
|
||||
|
||||
```shell
|
||||
git clone https://whale.am28.uni-tuebingen.de/git/jgrewe/efish_tracking.git
|
||||
```
|
||||
|
||||
### 2. Change into directory
|
||||
|
||||
```shell
|
||||
cd efish_tracking
|
||||
````
|
||||
|
||||
### 3. Install with pip
|
||||
|
||||
```shell
|
||||
pip3 install -e . --user
|
||||
```
|
||||
|
||||
The ```-e``` installs the package in an *editable* model that you do not need to reinstall whenever you pull upstream changes.
|
||||
|
||||
If you leave away the ```--user``` the package will be installed system-wide.
|
||||
|
||||
## TrackingResults
|
||||
|
||||
Is a class that wraps around the *.h5 files written by DeppLabCut
|
||||
|
||||
## ImageMarker
|
||||
|
||||
Class that allows for creating MarkerTasks to get specific positions in a video.
|
2
etrack/__init__.py
Normal file
2
etrack/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
from .image_marker import ImageMarker, MarkerTask
|
||||
from .tracking_result import TrackingResult
|
151
etrack/image_marker.py
Normal file
151
etrack/image_marker.py
Normal file
@ -0,0 +1,151 @@
|
||||
import matplotlib.pyplot as plt
|
||||
import cv2
|
||||
import os
|
||||
import sys
|
||||
from IPython import embed
|
||||
|
||||
class ImageMarker:
|
||||
|
||||
def __init__(self, tasks=[]) -> None:
|
||||
super().__init__()
|
||||
self._fig = plt.figure()
|
||||
self._tasks = tasks
|
||||
self._task_index = -1
|
||||
self._current_task = None
|
||||
self._marker_set = False
|
||||
self._interrupt = False
|
||||
self._fig.canvas.mpl_connect('button_press_event', self._on_click_event)
|
||||
self._fig.canvas.mpl_connect('close_event', self._fig_close_event)
|
||||
self._fig.canvas.mpl_connect('key_press_event', self._key_press_event)
|
||||
|
||||
def mark_movie(self, filename, frame_number=0):
|
||||
if not os.path.exists(filename):
|
||||
raise IOError("file %s does not exist!" % filename)
|
||||
video = cv2.VideoCapture()
|
||||
video.open(filename)
|
||||
frame_counter = 0
|
||||
success = True
|
||||
frame = None
|
||||
while success and frame_counter <= frame_number:
|
||||
print("Reading frame: %i" % frame_counter, end="\r")
|
||||
success, frame = video.read()
|
||||
frame_counter += 1
|
||||
if success:
|
||||
self._fig.gca().imshow(frame)
|
||||
else:
|
||||
print("Could not read frame number %i either failed to open movie or beyond maximum frame number!" % frame_number)
|
||||
return []
|
||||
plt.ion()
|
||||
plt.show(block=False)
|
||||
|
||||
self._task_index = -1
|
||||
if len(self._tasks) > 0:
|
||||
self._next_task()
|
||||
|
||||
while not self._tasks_done:
|
||||
plt.pause(0.250)
|
||||
if self._interrupt:
|
||||
return []
|
||||
|
||||
self._fig.gca().set_title("All set and done!\n Window will close in 2s")
|
||||
self._fig.canvas.draw()
|
||||
plt.pause(2.0)
|
||||
return [t.marker_positions for t in self._tasks]
|
||||
|
||||
def _key_press_event(self, event):
|
||||
print("Key pressed: %s!" % event.key)
|
||||
|
||||
@property
|
||||
def _tasks_done(self):
|
||||
done = self._task_index == len(self._tasks) and self._current_task is not None and self._current_task.task_done
|
||||
return done
|
||||
|
||||
def _next_task(self):
|
||||
if self._current_task is None:
|
||||
self._task_index += 1
|
||||
self._current_task = self._tasks[self._task_index]
|
||||
|
||||
if self._current_task is not None and not self._current_task.task_done:
|
||||
self._fig.gca().set_title("%s: \n%s: %s" % (self._current_task.name, self._current_task.message, self._current_task.current_marker))
|
||||
self._fig.canvas.draw()
|
||||
elif self._current_task is not None and self._current_task.task_done:
|
||||
self._task_index += 1
|
||||
if self._task_index < len(self._tasks):
|
||||
self._current_task = self._tasks[self._task_index]
|
||||
self._fig.gca().set_title("%s: \n%s: %s" % (self._current_task.name, self._current_task.message, self._current_task.current_marker))
|
||||
self._fig.canvas.draw()
|
||||
|
||||
def _on_click_event(self, event):
|
||||
self._fig.gca().scatter(event.xdata, event.ydata, marker=self._current_task.marker_symbol, color=self._current_task.marker_color, s=20)
|
||||
event.canvas.draw()
|
||||
self._current_task.set_position(self._current_task.current_marker, event.xdata, event.ydata)
|
||||
self._next_task()
|
||||
|
||||
def _fig_close_event(self, even):
|
||||
self._interrupt = True
|
||||
|
||||
class MarkerTask():
|
||||
def __init__(self, name:str, marker_names=[], message="", marker="o", color="tab:blue") -> None:
|
||||
super().__init__()
|
||||
self._positions = {}
|
||||
self._marker_names = marker_names
|
||||
self._name = name
|
||||
self._message = message
|
||||
self._current_marker = marker_names[0] if len(marker_names) > 0 else None
|
||||
self._current_index = 0
|
||||
self._marker = marker
|
||||
self._marker_color = color
|
||||
|
||||
@property
|
||||
def positions(self):
|
||||
return self._positions
|
||||
|
||||
@property
|
||||
def name(self)->str:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def message(self)->str:
|
||||
return self._message
|
||||
|
||||
def set_position(self, marker_name, x, y):
|
||||
self._positions[marker_name] = (x, y)
|
||||
if not self.task_done:
|
||||
self._current_index += 1
|
||||
self._current_marker = self._marker_names[self._current_index]
|
||||
|
||||
@property
|
||||
def marker_positions(self):
|
||||
return self._positions
|
||||
|
||||
@property
|
||||
def task_done(self):
|
||||
return len(self._positions) == len(self._marker_names)
|
||||
|
||||
@property
|
||||
def current_marker(self):
|
||||
return self._current_marker
|
||||
|
||||
@property
|
||||
def marker_symbol(self):
|
||||
return self._marker
|
||||
|
||||
@property
|
||||
def marker_color(self):
|
||||
return self._marker_color
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "MarkerTask %s with markers: %s" % (self.name, [mn for mn in self._marker_names])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tank_task = MarkerTask("tank limits", ["bottom left corner", "top left corner", "top right corner", "bottom right corner"], "Mark tank corners")
|
||||
feeder_task = MarkerTask("Feeder positions", list(map(str, range(1, 2))), "Mark feeder positions")
|
||||
tasks = [tank_task, feeder_task]
|
||||
im = ImageMarker(tasks)
|
||||
# vid1 = "2020.12.11_lepto48DLC_resnet50_boldnessDec11shuffle1_200000_labeled.mp4"
|
||||
print(sys.argv[0])
|
||||
print (sys.argv[1])
|
||||
vid1 = sys.argv[1]
|
||||
marker_positions = im.mark_movie(vid1, 10)
|
||||
print(marker_positions)
|
178
etrack/tracking_result.py
Normal file
178
etrack/tracking_result.py
Normal file
@ -0,0 +1,178 @@
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import numbers as nb
|
||||
import os
|
||||
|
||||
"""
|
||||
x_0 = 0
|
||||
width = 1230
|
||||
y_0 = 0
|
||||
height = 1100
|
||||
x_factor = 0.81/width # Einheit m/px
|
||||
y_factor = 0.81/height # Einheit m/px
|
||||
center = (np.round(x_0 + width/2), np.round(y_0 + height/2))
|
||||
center_meter = ((center[0] - x_0) * x_factor, (center[1] - y_0) * y_factor)
|
||||
"""
|
||||
|
||||
class TrackingResult(object):
|
||||
|
||||
def __init__(self, results_file, x_0=0, y_0= 0, width_pixel=1230, height_pixel=1100, width_meter=0.81, height_meter=0.81) -> None:
|
||||
super().__init__()
|
||||
if not os.path.exists(results_file):
|
||||
raise ValueError("File %s does not exist!" % results_file)
|
||||
self._file_name = results_file
|
||||
self.x_0 = x_0
|
||||
self.y_0 = y_0
|
||||
self.width_pix = width_pixel
|
||||
self.width_m = width_meter
|
||||
self.height_pix = height_pixel
|
||||
self.height_m = height_meter
|
||||
self.x_factor = self.width_m / self.width_pix # m/pix
|
||||
self.x_factor = self.height_m / self.height_pix # m/pix
|
||||
|
||||
self.center = (np.round(self.x_0 + self.width_pix/2), np.round(self.y_0 + self.height_pix/2))
|
||||
self.center_meter = ((self.center[0] - x_0) * self.x_factor, (self.centerenter[1] - y_0) * self.y_factor)
|
||||
|
||||
self._data_frame = pd.read_hdf(results_file)
|
||||
self._level_shape = self._data_frame.columns.levshape
|
||||
self._scorer = self._data_frame.columns.levels[0].values
|
||||
self._bodyparts = self._data_frame.columns.levels[1].values if self._level_shape[1] > 0 else []
|
||||
self._positions = self._data_frame.columns.levels[2].values if self._level_shape[2] > 0 else []
|
||||
|
||||
def angle_to_center(self, bodypart=0, twopi=True, origin="topleft", min_likelihood=0.95):
|
||||
if isinstance(bodypart, nb.Number):
|
||||
bp = self._bodyparts[bodypart]
|
||||
elif isinstance(bodypart, str) and bodypart in self._bodyparts:
|
||||
bp = bodypart
|
||||
else:
|
||||
raise ValueError("Bodypart %s is not in dataframe!" % bodypart)
|
||||
_, x, y, _, _ = self.position_values(bodypart=bp, min_likelihood=min_likelihood)
|
||||
if x is None:
|
||||
print("Error: no valid angles for %s" % self._file_name)
|
||||
return []
|
||||
x_meter = x - self.center_meter[0]
|
||||
y_meter = y - self.center_meter[1]
|
||||
if origin.lower() == "topleft":
|
||||
y_meter *= -1
|
||||
phi = np.arctan2(y_meter, x_meter) * 180 / np.pi
|
||||
if twopi:
|
||||
phi[phi < 0] = 360 + phi[phi < 0]
|
||||
return phi
|
||||
|
||||
def coordinate_transformation(self, position):
|
||||
x = (position[0] - self.x_0) * self.x_factor
|
||||
y = (position[1] - self.y_0) * self.y_factor
|
||||
return (x, y) #in m
|
||||
|
||||
@property
|
||||
def filename(self):
|
||||
return self._file_name
|
||||
|
||||
@property
|
||||
def dataframe(self):
|
||||
return self._data_frame
|
||||
|
||||
@property
|
||||
def scorer(self):
|
||||
return self._scorer
|
||||
|
||||
@property
|
||||
def bodyparts(self):
|
||||
return self._bodyparts
|
||||
|
||||
@property
|
||||
def positions(self):
|
||||
return self._positions
|
||||
|
||||
def position_values(self, scorer=0, bodypart=0, framerate=30, interpolate=True, min_likelihood=0.95):
|
||||
"""returns the x and y positions in m and the likelihood of the positions.
|
||||
|
||||
Args:
|
||||
scorer (int, optional): [description]. Defaults to 0.
|
||||
bodypart (int, optional): [description]. Defaults to 0.
|
||||
framerate (int, optional): [description]. Defaults to 30.
|
||||
|
||||
Raises:
|
||||
ValueError: [description]
|
||||
ValueError: [description]
|
||||
|
||||
Returns:
|
||||
time [np.array]: the time axis
|
||||
x [np.array]: the x-position in m
|
||||
y [np.array]: the y-position in m
|
||||
l [np.array]: the likelihood of the position estimation
|
||||
bp string: the body part
|
||||
[type]: [description]
|
||||
"""
|
||||
|
||||
if isinstance(scorer, nb.Number):
|
||||
sc = self._scorer[scorer]
|
||||
elif isinstance(scorer, str) and scorer in self._scorer:
|
||||
sc = scorer
|
||||
else:
|
||||
raise ValueError("Scorer %s is not in dataframe!" % scorer)
|
||||
if isinstance(bodypart, nb.Number):
|
||||
bp = self._bodyparts[bodypart]
|
||||
elif isinstance(bodypart, str) and bodypart in self._bodyparts:
|
||||
bp = bodypart
|
||||
else:
|
||||
raise ValueError("Bodypart %s is not in dataframe!" % bodypart)
|
||||
|
||||
x = self._data_frame[sc][bp]["x"] if "x" in self._positions else []
|
||||
x = (np.asarray(x) - self.x_0) * self.x_factor
|
||||
y = self._data_frame[sc][bp]["y"] if "y" in self._positions else []
|
||||
y = (np.asarray(y) - self.y_0) * self.y_factor
|
||||
l = self._data_frame[sc][bp]["likelihood"] if "likelihood" in self._positions else []
|
||||
|
||||
time = np.arange(len(self._data_frame))/framerate
|
||||
time2 = time[l > min_likelihood]
|
||||
if len(l[l > min_likelihood]) < 100:
|
||||
print("%s has not datapoints with likelihood larger than %.2f" % (self._file_name, min_likelihood) )
|
||||
return None, None, None, None, None
|
||||
x2 = x[l > min_likelihood]
|
||||
y2 = y[l > min_likelihood]
|
||||
x3 = np.interp(time, time2, x2)
|
||||
y3 = np.interp(time, time2, y2)
|
||||
return time, x3, y3, l, bp
|
||||
|
||||
def plot(self, scorer=0, bodypart=0, threshold=0.9, framerate=30):
|
||||
t, x, y, l, name = self.position_values(scorer=scorer, bodypart=bodypart, framerate=framerate)
|
||||
plt.scatter(x[l > threshold], y[l > threshold], c=t[l > threshold], label=name)
|
||||
plt.scatter(self.center_meter[0], self.center_meter[1], marker="*")
|
||||
plt.plot(x[l > threshold], y[l > threshold])
|
||||
plt.xlabel("x position")
|
||||
plt.ylabel("y position")
|
||||
plt.gca().invert_yaxis()
|
||||
bar = plt.colorbar()
|
||||
bar.set_label("time [s]")
|
||||
plt.legend()
|
||||
plt.show()
|
||||
from IPython import embed
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
from IPython import embed
|
||||
filename = "2020.12.04_lepto48DLC_resnet50_boldnessDec11shuffle1_200000.h5"
|
||||
path = "/mnt/movies/merle_verena/boldness/labeled_videos/day_4/"
|
||||
tr = TrackingResult(path+filename)
|
||||
time, x, y, l, bp = tr.position_values(bodypart=2)
|
||||
|
||||
|
||||
thresh = 0.95
|
||||
time2 = time[l>thresh]
|
||||
x2 = x[l>thresh]
|
||||
y2 = y[l>thresh]
|
||||
x3 = np.interp(time, time2, x2)
|
||||
y3 = np.interp(time, time2, y2)
|
||||
|
||||
|
||||
fig, axes = plt.subplots(3,1, sharex=True)
|
||||
axes[0].plot(time, x)
|
||||
axes[0].plot(time, x3)
|
||||
axes[1].plot(time, y)
|
||||
axes[1].plot(time, y3)
|
||||
axes[2].plot(time, l)
|
||||
plt.show()
|
||||
|
||||
embed()
|
33
setup.py
Normal file
33
setup.py
Normal file
@ -0,0 +1,33 @@
|
||||
from setuptools import setup
|
||||
|
||||
NAME = "etrack"
|
||||
VERSION = 0.5
|
||||
AUTHOR = "Jan Grewe"
|
||||
CONTACT = "jan.grewe@g-node.org"
|
||||
CLASSIFIERS = "science"
|
||||
DESCRIPTION = "helpers for handling depp lab cut tracking results"
|
||||
|
||||
README = "README.md"
|
||||
with open(README) as f:
|
||||
description_text = f.read()
|
||||
|
||||
packages = [
|
||||
"etrack",
|
||||
]
|
||||
|
||||
install_req = ["h5py", "pandas", "matplotlib", "numpy", "opencv-python"]
|
||||
|
||||
setup(
|
||||
name=NAME,
|
||||
version=VERSION,
|
||||
description=DESCRIPTION,
|
||||
author=AUTHOR,
|
||||
author_email=CONTACT,
|
||||
packages=packages,
|
||||
install_requires=install_req,
|
||||
include_package_data=True,
|
||||
long_description=description_text,
|
||||
long_description_content_type="text/markdown",
|
||||
classifiers=CLASSIFIERS,
|
||||
license="BSD"
|
||||
)
|
Loading…
Reference in New Issue
Block a user