add tracking result tool
This commit is contained in:
parent
f3c0d48557
commit
b7393dff04
72
tracking_result.py
Normal file
72
tracking_result.py
Normal file
@ -0,0 +1,72 @@
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import numbers as nb
|
||||
import os
|
||||
|
||||
class TrackingResult():
|
||||
|
||||
def __init__(self, results_file) -> None:
|
||||
super().__init__()
|
||||
if not os.path.exists(results_file):
|
||||
raise ValueError("File %s does not exist!" % results_file)
|
||||
self._file_name = results_file
|
||||
|
||||
self._data_frame = pd.read_hdf(results_file)
|
||||
self._level_shape = self._data_frame.columns.levshape
|
||||
self._scorer = self._data_frame.columns.levels[0].values
|
||||
self._bodyparts = self._data_frame.columns.levels[1].values if self._level_shape[1] > 0 else []
|
||||
self._positions = self._data_frame.columns.levels[2].values if self._level_shape[2] > 0 else []
|
||||
|
||||
|
||||
@property
|
||||
def filename(self):
|
||||
return self._file_name
|
||||
|
||||
@property
|
||||
def dataframe(self):
|
||||
return self._data_frame
|
||||
|
||||
@property
|
||||
def scorer(self):
|
||||
return self._scorer
|
||||
|
||||
@property
|
||||
def bodyparts(self):
|
||||
return self._bodyparts
|
||||
|
||||
@property
|
||||
def positions(self):
|
||||
return self._positions
|
||||
|
||||
def position_values(self, scorer=0, bodypart=0, framerate=30):
|
||||
if isinstance(scorer, nb.Number):
|
||||
sc = self._scorer[scorer]
|
||||
elif isinstance(scorer, str) and scorer in self._scorer:
|
||||
sc = scorer
|
||||
else:
|
||||
raise ValueError("Scorer %s is not in dataframe!" % scorer)
|
||||
if isinstance(bodypart, nb.Number):
|
||||
bp = self._bodyparts[bodypart]
|
||||
elif isinstance(bodypart, str) and bodypart in self._bodyparts:
|
||||
bp = bodypart
|
||||
else:
|
||||
raise ValueError("Bodypart %s is not in dataframe!" % bodypart)
|
||||
|
||||
x = self._data_frame[sc][bp]["x"] if "x" in self._positions else []
|
||||
y = self._data_frame[sc][bp]["y"] if "y" in self._positions else []
|
||||
l = self._data_frame[sc][bp]["likelihood"] if "likelihood" in self._positions else []
|
||||
time = np.arange(len(self._data_frame))/framerate
|
||||
return time, x, y, l, bp
|
||||
|
||||
def plot(self, scorer=0, bodypart=0, threshold=0.9, framerate=30):
|
||||
t, x, y, l, name = self.position_values(scorer=scorer, bodypart=bodypart, framerate=framerate)
|
||||
plt.scatter(x[l > threshold], y[l > threshold], c=t[l > threshold], label=name)
|
||||
plt.plot(x[l > threshold], y[l > threshold])
|
||||
plt.xlabel("x position")
|
||||
plt.ylabel("y position")
|
||||
plt.gca().invert_yaxis()
|
||||
bar = plt.colorbar()
|
||||
bar.set_label("time [s]")
|
||||
plt.legend()
|
||||
plt.show()
|
Loading…
Reference in New Issue
Block a user