Compare commits
5 Commits
ff7e1e85ae
...
5a128cf28e
Author | SHA1 | Date | |
---|---|---|---|
5a128cf28e | |||
96e4b0b2c5 | |||
6244f7fdbe | |||
1e86a74549 | |||
b7d4097e73 |
74
fixtracks/utils/tablemodels.py
Normal file
74
fixtracks/utils/tablemodels.py
Normal file
@ -0,0 +1,74 @@
|
||||
import logging
|
||||
import numpy as np
|
||||
|
||||
from PySide6.QtCore import Qt, QAbstractTableModel, QSortFilterProxyModel
|
||||
|
||||
class PoseTableModel(QAbstractTableModel):
|
||||
column_header = ["frame", "track"]
|
||||
columns = ["frame", "track"]
|
||||
|
||||
def __init__(self, dataframe, parent=None):
|
||||
super().__init__(parent)
|
||||
self._data = dataframe
|
||||
self._frames = self._data.frame.values
|
||||
self._tracks = self._data.track.values
|
||||
self._indices = self._data.index.values
|
||||
self._column_data = [self._frames, self._tracks]
|
||||
|
||||
def columnCount(self, parent=None):
|
||||
return len(self.columns)
|
||||
|
||||
def rowCount(self, parent=None):
|
||||
if self._data is not None:
|
||||
return len(self._data)
|
||||
else:
|
||||
return 0
|
||||
|
||||
def data(self, index, role = ...):
|
||||
value = self._column_data[index.column()][index.row()]
|
||||
if role == Qt.ItemDataRole.DisplayRole:
|
||||
return str(value)
|
||||
elif role == Qt.ItemDataRole.UserRole:
|
||||
return value
|
||||
return None
|
||||
|
||||
def headerData(self, section, orientation, role = ...):
|
||||
if role == Qt.ItemDataRole.DisplayRole:
|
||||
if orientation == Qt.Orientation.Horizontal:
|
||||
return self.column_header[section]
|
||||
else:
|
||||
return str(self._indices[section])
|
||||
else:
|
||||
return None
|
||||
|
||||
def mapIdToRow(self, id):
|
||||
row = np.where(self._indices == id)[0]
|
||||
if len(row) == 0:
|
||||
return -1
|
||||
return row[0]
|
||||
|
||||
|
||||
class FilterProxyModel(QSortFilterProxyModel):
|
||||
def __init__(self, parent=None):
|
||||
super().__init__(parent)
|
||||
self._range = None
|
||||
|
||||
def setFilterRange(self, start, stop):
|
||||
logging.info("FilterProxyModel.setFilterRange set to range %i , %i", start, stop)
|
||||
self._range = (start, stop)
|
||||
self.invalidateRowsFilter()
|
||||
|
||||
def all(self):
|
||||
self._range = None
|
||||
|
||||
def filterAcceptsRow(self, source_row, source_parent):
|
||||
if self._range is None:
|
||||
return True
|
||||
else:
|
||||
idx = self.sourceModel().index(source_row, 0, source_parent);
|
||||
val = self.sourceModel().data(idx, Qt.ItemDataRole.UserRole)
|
||||
print("filteracceptrows: ", val, self._range, val >= self._range[0] and val < self._range[1] )
|
||||
return val >= self._range[0] and val < self._range[1]
|
||||
|
||||
def filterAcceptsColumn(self, source_column, source_parent):
|
||||
return True
|
237
fixtracks/utils/trackingdata.py
Normal file
237
fixtracks/utils/trackingdata.py
Normal file
@ -0,0 +1,237 @@
|
||||
import pickle
|
||||
import logging
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from PySide6.QtCore import QObject
|
||||
|
||||
class TrackingData(QObject):
|
||||
def __init__(self, parent=None):
|
||||
super().__init__(parent)
|
||||
self._data = None
|
||||
self._columns = []
|
||||
self._start = 0
|
||||
self._stop = 0
|
||||
self._indices = None
|
||||
self._selection_column = None
|
||||
self._user_selections = None
|
||||
|
||||
def setData(self, datadict):
|
||||
assert isinstance(datadict, dict)
|
||||
self._data = datadict
|
||||
self._columns = [k for k in self._data.keys()]
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
return self._data
|
||||
|
||||
@property
|
||||
def columns(self):
|
||||
return self._columns
|
||||
|
||||
def max(self, col):
|
||||
if col in self.columns:
|
||||
return np.max(self._data[col])
|
||||
else:
|
||||
logging.error("Column %s not in dictionary", col)
|
||||
return np.nan
|
||||
|
||||
@property
|
||||
def numDetections(self):
|
||||
return self._data["track"].shape[0]
|
||||
|
||||
@property
|
||||
def selectionRange(self):
|
||||
return self._start, self._stop
|
||||
|
||||
@property
|
||||
def selectionRangeColumn(self):
|
||||
return self._selection_column
|
||||
|
||||
@property
|
||||
def selectionIndices(self):
|
||||
return self._indices
|
||||
|
||||
def setSelectionRange(self, col, start, stop):
|
||||
self._start = start
|
||||
self._stop = stop
|
||||
self._selection_column = col
|
||||
self._indices = np.where((self._data[col] >= self._start) & (self._data[col] < self._stop))[0]
|
||||
|
||||
def selectedData(self, col):
|
||||
return self._data[col][self._indices]
|
||||
|
||||
def setUserSelection(self, ids):
|
||||
"""
|
||||
Set the user selections. That is, e.g. when the user selected a number of ids.
|
||||
Parameters
|
||||
----------
|
||||
ids : array-like
|
||||
An array-like object containing the IDs to be set as user selections.
|
||||
The IDs will be converted to integers.
|
||||
"""
|
||||
self._user_selections = ids.astype(int)
|
||||
|
||||
def assignUserSelection(self, track_id:int)-> None:
|
||||
"""Assign a new track_id to the user-selected detections
|
||||
|
||||
Parameters
|
||||
----------
|
||||
track_id : int
|
||||
The new track id for the user-selected detections
|
||||
"""
|
||||
self._data["track"][self._user_selections] = track_id
|
||||
|
||||
def assignTracks(self, tracks):
|
||||
"""assignTracks _summary_
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tracks : _type_
|
||||
_description_
|
||||
|
||||
Returns
|
||||
-------
|
||||
_type_
|
||||
_description_
|
||||
"""
|
||||
if len(tracks) != self.numDetections:
|
||||
logging.error("DataController: Size of passed tracks does not match data!")
|
||||
return
|
||||
self._data["track"] = tracks
|
||||
|
||||
def save(self, filename):
|
||||
export_columns = self._columns.copy()
|
||||
export_columns.remove("index")
|
||||
dictionary = {c: self._data[c] for c in export_columns}
|
||||
df = pd.DataFrame(dictionary, index=self._data["index"])
|
||||
with open(filename, 'wb') as f:
|
||||
pickle.dump(df, f)
|
||||
|
||||
def numKeypoints(self):
|
||||
if len(self._data["keypoints"]) == 0:
|
||||
return 0
|
||||
return self._data["keypoints"][0].shape[0]
|
||||
|
||||
def coordinates(self):
|
||||
"""
|
||||
Returns the coordinates of all keypoints as a NumPy array.
|
||||
|
||||
Returns:
|
||||
np.ndarray: A NumPy array of shape (N, M, 2) where N is the number of detections,
|
||||
and M is number of keypoints
|
||||
"""
|
||||
return np.stack(self._data["keypoints"]).astype(np.float32)
|
||||
|
||||
def keypointScores(self):
|
||||
"""
|
||||
Returns the keypoint scores as a NumPy array of type float32.
|
||||
|
||||
Returns
|
||||
-------
|
||||
numpy.ndarray
|
||||
A NumPy array of type float32 containing the keypoint scores of the shape (N, M)
|
||||
with N the number of detections and M the number of keypoints.
|
||||
"""
|
||||
return np.stack(self._data["keypoint_score"]).astype(np.float32)
|
||||
|
||||
def centerOfGravity(self, threshold=0.8):
|
||||
"""
|
||||
Calculate the center of gravity of keypoints weighted by their scores. Ignores keypoints that have a score
|
||||
less than threshold.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
threshold: float
|
||||
keypoints with a score less than threshold are ignored
|
||||
|
||||
Returns:
|
||||
--------
|
||||
np.ndarray:
|
||||
A NumPy array of shape (N, 2) containing the center of gravity for each detection.
|
||||
"""
|
||||
scores = self.keypointScores()
|
||||
scores[scores < threshold] = 0.0
|
||||
weighted_coords = self.coordinates() * scores[:, :, np.newaxis]
|
||||
sum_scores = np.sum(scores, axis=1, keepdims=True)
|
||||
center_of_gravity = np.sum(weighted_coords, axis=1) / sum_scores
|
||||
return center_of_gravity
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self._data[key]
|
||||
|
||||
# def __setitem__(self, key, value):
|
||||
# self._data[key] = value
|
||||
"""
|
||||
self._data.setSelectionRange("index", 0, self._data.numDetections)
|
||||
self._data.assignTracks(tracks)
|
||||
self._timeline.setDetectionData(self._data.data)
|
||||
self.update()
|
||||
"""
|
||||
|
||||
def main():
|
||||
import pandas as pd
|
||||
from IPython import embed
|
||||
import matplotlib.pyplot as plt
|
||||
from fixtracks.info import PACKAGE_ROOT
|
||||
from scipy.spatial.distance import cdist
|
||||
|
||||
def as_dict(df:pd.DataFrame):
|
||||
d = {c: df[c].values for c in df.columns}
|
||||
d["index"] = df.index.values
|
||||
return d
|
||||
|
||||
|
||||
def neighborDistances(x, n=5, symmetric=True):
|
||||
pad_shape = list(x.shape)
|
||||
pad_shape[0] = 5
|
||||
pad = np.zeros(pad_shape)
|
||||
if symmetric:
|
||||
padded_x = np.vstack((pad, x, pad))
|
||||
else:
|
||||
padded_x = np.vstack((pad, x))
|
||||
dists = np.zeros((padded_x.shape[0], 2*n))
|
||||
count = 0
|
||||
r = range(-n, n+1) if symmetric else range(-n, 0)
|
||||
for i in r:
|
||||
if i == 0:
|
||||
continue
|
||||
shifted_x = np.roll(padded_x, i)
|
||||
dists[:, count] = np.sqrt(np.sum((padded_x - shifted_x)**2, axis=1))
|
||||
count += 1
|
||||
return dists
|
||||
|
||||
datafile = PACKAGE_ROOT / "data/merged_small.pkl"
|
||||
with open(datafile, "rb") as f:
|
||||
df = pickle.load(f)
|
||||
|
||||
data = TrackingData()
|
||||
data.setData(as_dict(df))
|
||||
all_cogs = data.centerOfGravity()
|
||||
tracks = data["track"]
|
||||
cogs = all_cogs[tracks==1]
|
||||
all_dists = neighborDistances(cogs, 2, False)
|
||||
plt.hist(all_dists[1:, 0], bins=1000)
|
||||
print(np.percentile(all_dists[1:, 0], 99))
|
||||
print(np.percentile(all_dists[1:, 0], 1))
|
||||
plt.gca().set_xscale("log")
|
||||
plt.gca().set_yscale("log")
|
||||
# plt.hist(all_dists[1:, 1], bins=100)
|
||||
plt.show()
|
||||
# def compute_neighbor_distances(cogs, window=10):
|
||||
# distances = []
|
||||
# for i in range(len(cogs)):
|
||||
# start = max(0, i - window)
|
||||
# stop = min(len(cogs), i + window + 1)
|
||||
# neighbors = cogs[start:stop]
|
||||
# dists = cdist([cogs[i]], neighbors)[0]
|
||||
# distances.append(dists)
|
||||
# return distances
|
||||
# print("estimating neighorhood distances")
|
||||
# neighbor_distances = compute_neighbor_distances(cogs)
|
||||
embed()
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,15 +1,17 @@
|
||||
import logging
|
||||
import numpy as np
|
||||
import pyqtgraph as pg
|
||||
|
||||
from PySide6.QtWidgets import QWidget, QVBoxLayout, QTabWidget,QPushButton
|
||||
from PySide6.QtWidgets import QWidget, QVBoxLayout, QTabWidget,QPushButton, QGraphicsView
|
||||
from PySide6.QtCore import Signal
|
||||
from PySide6.QtGui import QBrush, QColor
|
||||
|
||||
import pyqtgraph as pg
|
||||
from fixtracks.utils.trackingdata import TrackingData
|
||||
|
||||
|
||||
class SizeClassifier(QWidget):
|
||||
apply = Signal()
|
||||
name = "SizeClassifier"
|
||||
|
||||
def __init__(self, parent=None):
|
||||
super().__init__(parent)
|
||||
@ -29,17 +31,16 @@ class SizeClassifier(QWidget):
|
||||
|
||||
def setupGraph(self):
|
||||
track1_brush = QBrush(QColor.fromString("orange"))
|
||||
track1_brush.color().setAlphaF(0.5)
|
||||
track2_brush = QBrush(QColor.fromString("green"))
|
||||
|
||||
pg.setConfigOptions(antialias=True)
|
||||
plot_widget = pg.GraphicsLayoutWidget(show=False)
|
||||
|
||||
self._t1_selection = pg.LinearRegionItem([100, 200])
|
||||
self._t1_selection.setZValue(-10) # what is that?
|
||||
self._t1_selection.setZValue(-10)
|
||||
self._t1_selection.setBrush(track1_brush)
|
||||
self._t2_selection = pg.LinearRegionItem([300,400])
|
||||
self._t2_selection.setZValue(-10) # what is that?
|
||||
self._t2_selection.setZValue(-10)
|
||||
self._t2_selection.setBrush(track2_brush)
|
||||
return plot_widget
|
||||
|
||||
@ -47,8 +48,8 @@ class SizeClassifier(QWidget):
|
||||
if bodyaxis is None:
|
||||
bodyaxis = [0, 1, 2, 5]
|
||||
bodycoords = coords[:, bodyaxis, :]
|
||||
dists = np.sum(np.sqrt(np.sum(np.diff(bodycoords, axis=1)**2, axis=2)), axis=1)
|
||||
return dists
|
||||
lengths = np.sum(np.sqrt(np.sum(np.diff(bodycoords, axis=1)**2, axis=2)), axis=1)
|
||||
return lengths
|
||||
|
||||
def estimate_histogram(self, dists, min_threshold=1., max_threshold=99.):
|
||||
min_length = np.percentile(dists, min_threshold)
|
||||
@ -83,6 +84,112 @@ class SizeClassifier(QWidget):
|
||||
tracks[(self._sizes >= t2lower) & (self._sizes < t2upper)] = 2
|
||||
return tracks
|
||||
|
||||
class NeighborhoodValidator(QWidget):
|
||||
apply = Signal()
|
||||
name = "Neighborhood Validator"
|
||||
|
||||
def __init__(self, parent = None):
|
||||
super().__init__(parent)
|
||||
self._threshold = None
|
||||
self._positions = None
|
||||
self._distances = None
|
||||
self._tracks = None
|
||||
self._plot = None
|
||||
|
||||
self._plot_widget = self.setupGraph()
|
||||
self._apply_btn = QPushButton("apply")
|
||||
self._apply_btn.clicked.connect(lambda: self.apply.emit())
|
||||
|
||||
layout = QVBoxLayout()
|
||||
print(isinstance(self._plot_widget, QGraphicsView))
|
||||
layout.addWidget(self._plot_widget)
|
||||
layout.addWidget(self._apply_btn)
|
||||
self.setLayout(layout)
|
||||
|
||||
def setupGraph(self):
|
||||
pg.setConfigOptions(antialias=True)
|
||||
plot_widget = pg.GraphicsLayoutWidget(show=False)
|
||||
self._threshold = pg.LineSegmentROI([[10, 64], [120,64]], pen='r')
|
||||
self._threshold.setZValue(-10) # what is that?
|
||||
return plot_widget
|
||||
|
||||
def estimate_histogram(self, dists, min_threshold=1., max_threshold=99., bin_count=100, log=False):
|
||||
min_dist = np.percentile(dists, min_threshold)
|
||||
max_dist = np.percentile(dists, max_threshold)
|
||||
if log:
|
||||
bins = np.logspace(min_dist, max_dist, bin_count, base=10)
|
||||
bins = np.linspace(min_dist, max_dist, bin_count)
|
||||
hist, edges = np.histogram(dists, bins=bins, density=True)
|
||||
return hist, edges
|
||||
|
||||
def neighborDistances(self, x, frames, n=5, symmetric=True):
|
||||
logging.debug("classifier:NeighborhoodValidator neighborDistance")
|
||||
pad_shape = list(x.shape)
|
||||
pad_shape[0] = n
|
||||
pad = np.atleast_2d(np.zeros(pad_shape))
|
||||
if symmetric:
|
||||
padded_x = np.vstack((pad, x, pad))
|
||||
dists = np.zeros((x.shape[0]-1, 2*n))
|
||||
else:
|
||||
padded_x = np.vstack((pad, x))
|
||||
dists = np.zeros((x.shape[0]-1, n))
|
||||
|
||||
count = 0
|
||||
r = range(-n, n+1) if symmetric else range(-n, 0)
|
||||
for i in r:
|
||||
if i == 0:
|
||||
continue
|
||||
shifted_x = np.roll(padded_x, i, axis=0)
|
||||
dist = np.sqrt(np.sum((padded_x - shifted_x)**2, axis=1))
|
||||
dists[:, count] = dist[n+1:]/np.diff(frames)
|
||||
count += 1
|
||||
return dists
|
||||
|
||||
def setData(self, positions, tracks, frames):
|
||||
"""Set the data, the classifier/should be working on.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
positions : np.ndarray
|
||||
The position estimates, e.g. the center of gravity for each detection
|
||||
tracks : np.ndarray
|
||||
The current track assignment.
|
||||
frames : np.ndarray
|
||||
respective frame.
|
||||
"""
|
||||
def mouseClicked(self, event):
|
||||
print("mouse clicked at", event.pos())
|
||||
|
||||
track2_brush = QBrush(QColor.fromString("green"))
|
||||
track1_brush = QBrush(QColor.fromString("orange"))
|
||||
self._positions = positions
|
||||
self._tracks = tracks
|
||||
self._frames = frames
|
||||
t1_positions = self._positions[self._tracks == 1]
|
||||
t1_frames = self._frames[self._tracks == 1]
|
||||
t1_distances = self.neighborDistances(t1_positions, t1_frames, 1, False)
|
||||
t2_positions = self._positions[self._tracks == 2]
|
||||
t2_frames = self._frames[self._tracks == 2]
|
||||
t2_distances = self.neighborDistances(t2_positions, t2_frames, 1, False)
|
||||
|
||||
n, e = self.estimate_histogram(t1_distances[1:], bin_count=100, log=False)
|
||||
self._plot = self._plot_widget.addPlot()
|
||||
bgi1 = pg.BarGraphItem(x0=e[:-1], x1=e[1:], height=n, pen='w', brush=track1_brush)
|
||||
self._plot.addItem(bgi1)
|
||||
n, e = self.estimate_histogram(t2_distances[1:], bin_count=100, log=False)
|
||||
bgi2 = pg.BarGraphItem(x0=e[:-1], x1=e[1:], height=n, pen='w', brush=track2_brush)
|
||||
self._plot.addItem(bgi2)
|
||||
self._plot.scene().sigMouseClicked.connect(mouseClicked)
|
||||
self._plot.setLogMode(x=False, y=True)
|
||||
# plot.setXRange(np.min(t1_distances), np.max(t1_distances))
|
||||
self._plot.setLabel('left', "prob. density")
|
||||
self._plot.setLabel('bottom', "distance", units="px/frame")
|
||||
# plot.addItem(self._threshold)
|
||||
vLine = pg.InfiniteLine(pos=10, angle=90, movable=False)
|
||||
self._plot.addItem(vLine, ignoreBounds=True)
|
||||
vb = self._plot.vb
|
||||
|
||||
|
||||
|
||||
class ClassifierWidget(QTabWidget):
|
||||
apply_sizeclassifier = Signal(np.ndarray)
|
||||
@ -90,7 +197,9 @@ class ClassifierWidget(QTabWidget):
|
||||
def __init__(self, parent=None):
|
||||
super().__init__(parent)
|
||||
self._size_classifier = SizeClassifier()
|
||||
self.addTab(self._size_classifier, "Size classifier")
|
||||
self._neigborhood_validator = NeighborhoodValidator()
|
||||
self.addTab(self._size_classifier, SizeClassifier.name)
|
||||
self.addTab(self._neigborhood_validator, NeighborhoodValidator.name)
|
||||
self._size_classifier.apply.connect(self._on_applySizeClassifier)
|
||||
|
||||
def _on_applySizeClassifier(self):
|
||||
@ -101,48 +210,51 @@ class ClassifierWidget(QTabWidget):
|
||||
def size_classifier(self):
|
||||
return self._size_classifier
|
||||
|
||||
@property
|
||||
def neighborhood_validator(self):
|
||||
return self._neigborhood_validator
|
||||
|
||||
|
||||
def test_sizeClassifier(coords):
|
||||
app = QApplication([])
|
||||
window = QWidget()
|
||||
window.setMinimumSize(200, 200)
|
||||
layout = QVBoxLayout()
|
||||
win = SizeClassifier()
|
||||
win.setCoordinates(coords)
|
||||
def as_dict(df):
|
||||
d = {c: df[c].values for c in df.columns}
|
||||
d["index"] = df.index.values
|
||||
return d
|
||||
|
||||
layout.addWidget(win)
|
||||
window.setLayout(layout)
|
||||
window.show()
|
||||
app.exec()
|
||||
|
||||
def main():
|
||||
test_size = False
|
||||
import pickle
|
||||
from fixtracks.info import PACKAGE_ROOT
|
||||
|
||||
datafile = PACKAGE_ROOT / "data/merged_small.pkl"
|
||||
|
||||
with open(datafile, "rb") as f:
|
||||
df = pickle.load(f)
|
||||
data = TrackingData()
|
||||
data.setData(as_dict(df))
|
||||
|
||||
positions = data.centerOfGravity()
|
||||
tracks = data["track"]
|
||||
frames = data["frame"]
|
||||
coords = data.coordinates()
|
||||
|
||||
def test_neighborhoodClassifier(coords):
|
||||
app = QApplication([])
|
||||
window = QWidget()
|
||||
window.setMinimumSize(200, 200)
|
||||
# if test_size:
|
||||
# win = SizeClassifier()
|
||||
# win.setCoordinates(coords)
|
||||
# else:
|
||||
w = ClassifierWidget()
|
||||
w.neighborhood_validator.setData(positions, tracks, frames)
|
||||
|
||||
layout = QVBoxLayout()
|
||||
win = SizeClassifier()
|
||||
win.setCoordinates(coords)
|
||||
layout.addWidget(win)
|
||||
layout.addWidget(w)
|
||||
window.setLayout(layout)
|
||||
window.show()
|
||||
app.exec()
|
||||
|
||||
|
||||
def main():
|
||||
import pickle
|
||||
from fixtracks.info import PACKAGE_ROOT
|
||||
from PySide6.QtWidgets import QApplication
|
||||
datafile = PACKAGE_ROOT / "data/merged_small.pkl"
|
||||
print(datafile)
|
||||
with open(datafile, "rb") as f:
|
||||
df = pickle.load(f)
|
||||
|
||||
coords = np.stack(df.keypoints.values,).astype(np.float32)
|
||||
frames = df.frame.values
|
||||
test_sizeClassifier(coords)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from PySide6.QtWidgets import QApplication
|
||||
main()
|
||||
|
@ -92,6 +92,7 @@ class DetectionView(QWidget):
|
||||
self._view = QGraphicsView()
|
||||
self._view.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding)
|
||||
self._view.setMouseTracking(True)
|
||||
self._mouseEnabled = True
|
||||
self._zoomFactor = 1.15
|
||||
self._minZoom = 0.1
|
||||
self._maxZoom = 10
|
||||
@ -102,16 +103,23 @@ class DetectionView(QWidget):
|
||||
self.setLayout(lyt)
|
||||
|
||||
def wheelEvent(self, event):
|
||||
if event.angleDelta().y() > 0: # Zoom in
|
||||
factor = self._zoomFactor
|
||||
else: # Zoom out
|
||||
factor = 1 / self._zoomFactor
|
||||
|
||||
newZoom = self._currentZoom * factor
|
||||
if not self._mouseEnabled:
|
||||
super().wheelEvent(event)
|
||||
return
|
||||
modifiers = event.modifiers()
|
||||
if modifiers == Qt.ControlModifier:
|
||||
delta = event.angleDelta().x()
|
||||
if delta == 0:
|
||||
delta = event.angleDelta().y()
|
||||
sc = 1.001 ** delta
|
||||
self._view.scale(sc, sc)
|
||||
else:
|
||||
super().wheelEvent(event)
|
||||
|
||||
if self._minZoom < newZoom < self._maxZoom:
|
||||
self._view.scale(factor, factor)
|
||||
self._currentZoom = newZoom
|
||||
# elif modifiers == Qt.ShiftModifier:
|
||||
# print("Shift key pressed")
|
||||
# elif modifiers == Qt.AltModifier:
|
||||
# print("Alt key pressed")
|
||||
|
||||
def setImage(self, image: QImage):
|
||||
self._img = image
|
||||
|
@ -41,8 +41,8 @@ class Skeleton(QGraphicsRectItem):
|
||||
|
||||
@property
|
||||
def length(self):
|
||||
bodykps = self._keypoints[self.bodyaxis, :]
|
||||
dist = np.sum(np.sqrt(np.sum(np.diff(bodykps, axis=0)**2, axis=1)), axis=0)
|
||||
bodykpts = self._keypoints[self.bodyaxis, :]
|
||||
dist = np.sum(np.sqrt(np.sum(np.diff(bodykpts, axis=0)**2, axis=1)), axis=0)
|
||||
return dist
|
||||
|
||||
# def mousePressEvent(self, event):
|
||||
|
@ -1,93 +1,21 @@
|
||||
import logging
|
||||
import pathlib
|
||||
import pickle
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from PySide6.QtCore import Qt, QThreadPool, Signal, QAbstractTableModel, QSortFilterProxyModel, QSize, QObject
|
||||
from PySide6.QtCore import Qt, QThreadPool, Signal, QSize, QObject
|
||||
from PySide6.QtGui import QImage, QBrush, QColor, QFont
|
||||
from PySide6.QtWidgets import QWidget, QVBoxLayout, QHBoxLayout, QLabel, QPushButton, QSizePolicy, QComboBox
|
||||
from PySide6.QtWidgets import QSpinBox, QSpacerItem, QProgressBar, QSplitter, QGridLayout, QFileDialog, QGridLayout
|
||||
|
||||
from fixtracks.utils.reader import PickleLoader
|
||||
from fixtracks.utils.writer import PickleWriter
|
||||
from fixtracks.utils.trackingdata import TrackingData
|
||||
from fixtracks.widgets.detectionview import DetectionView, DetectionData
|
||||
from fixtracks.widgets.detectiontimeline import DetectionTimeline
|
||||
from fixtracks.widgets.skeleton import SkeletonWidget
|
||||
from fixtracks.widgets.classifier import ClassifierWidget
|
||||
|
||||
|
||||
class PoseTableModel(QAbstractTableModel):
|
||||
column_header = ["frame", "track"]
|
||||
columns = ["frame", "track"]
|
||||
|
||||
def __init__(self, dataframe, parent=None):
|
||||
super().__init__(parent)
|
||||
self._data = dataframe
|
||||
self._frames = self._data.frame.values
|
||||
self._tracks = self._data.track.values
|
||||
self._indices = self._data.index.values
|
||||
self._column_data = [self._frames, self._tracks]
|
||||
|
||||
def columnCount(self, parent=None):
|
||||
return len(self.columns)
|
||||
|
||||
def rowCount(self, parent=None):
|
||||
if self._data is not None:
|
||||
return len(self._data)
|
||||
else:
|
||||
return 0
|
||||
|
||||
def data(self, index, role = ...):
|
||||
value = self._column_data[index.column()][index.row()]
|
||||
if role == Qt.ItemDataRole.DisplayRole:
|
||||
return str(value)
|
||||
elif role == Qt.ItemDataRole.UserRole:
|
||||
return value
|
||||
return None
|
||||
|
||||
def headerData(self, section, orientation, role = ...):
|
||||
if role == Qt.ItemDataRole.DisplayRole:
|
||||
if orientation == Qt.Orientation.Horizontal:
|
||||
return self.column_header[section]
|
||||
else:
|
||||
return str(self._indices[section])
|
||||
else:
|
||||
return None
|
||||
|
||||
def mapIdToRow(self, id):
|
||||
row = np.where(self._indices == id)[0]
|
||||
if len(row) == 0:
|
||||
return -1
|
||||
return row[0]
|
||||
|
||||
|
||||
class FilterProxyModel(QSortFilterProxyModel):
|
||||
def __init__(self, parent=None):
|
||||
super().__init__(parent)
|
||||
self._range = None
|
||||
|
||||
def setFilterRange(self, start, stop):
|
||||
logging.info("FilterProxyModel.setFilterRange set to range %i , %i", start, stop)
|
||||
self._range = (start, stop)
|
||||
self.invalidateRowsFilter()
|
||||
|
||||
def all(self):
|
||||
self._range = None
|
||||
|
||||
def filterAcceptsRow(self, source_row, source_parent):
|
||||
if self._range is None:
|
||||
return True
|
||||
else:
|
||||
idx = self.sourceModel().index(source_row, 0, source_parent);
|
||||
val = self.sourceModel().data(idx, Qt.ItemDataRole.UserRole)
|
||||
print("filteracceptrows: ", val, self._range, val >= self._range[0] and val < self._range[1] )
|
||||
return val >= self._range[0] and val < self._range[1]
|
||||
|
||||
def filterAcceptsColumn(self, source_column, source_parent):
|
||||
return True
|
||||
|
||||
|
||||
class SelectionControls(QWidget):
|
||||
fwd = Signal(float)
|
||||
back = Signal(float)
|
||||
@ -230,90 +158,6 @@ class SelectionControls(QWidget):
|
||||
self._total = len(tracks)
|
||||
|
||||
|
||||
class DataController(QObject):
|
||||
def __init__(self, parent=None):
|
||||
super().__init__(parent)
|
||||
self._data = None
|
||||
self._columns = []
|
||||
self._start = 0
|
||||
self._stop = 0
|
||||
self._indices = None
|
||||
self._selection_column = None
|
||||
self._user_selections = None
|
||||
|
||||
def setData(self, datadict):
|
||||
assert isinstance(datadict, dict)
|
||||
self._data = datadict
|
||||
self._columns = [k for k in self._data.keys()]
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
return self._data
|
||||
|
||||
@property
|
||||
def columns(self):
|
||||
return self._columns
|
||||
|
||||
def max(self, col):
|
||||
if col in self.columns:
|
||||
return np.max(self._data[col])
|
||||
else:
|
||||
logging.error("Column %s not in dictionary", col)
|
||||
return np.nan
|
||||
|
||||
@property
|
||||
def numDetections(self):
|
||||
return self._data["track"].shape[0]
|
||||
|
||||
@property
|
||||
def selectionRange(self):
|
||||
return self._start, self._stop
|
||||
|
||||
@property
|
||||
def selectionRangeColumn(self):
|
||||
return self._selection_column
|
||||
|
||||
@property
|
||||
def selectionIndices(self):
|
||||
return self._indices
|
||||
|
||||
def setSelectionRange(self, col, start, stop):
|
||||
self._start = start
|
||||
self._stop = stop
|
||||
self._selection_column = col
|
||||
self._indices = np.where((self._data[col] >= self._start) & (self._data[col] < self._stop))[0]
|
||||
|
||||
def selectedData(self, col):
|
||||
return self._data[col][self._indices]
|
||||
|
||||
def setUserSelection(self, ids):
|
||||
self._user_selections = ids.astype(int)
|
||||
|
||||
def assignUserSelection(self, track_id):
|
||||
self._data["track"][self._user_selections] = track_id
|
||||
|
||||
def assignTracks(self, tracks):
|
||||
if len(tracks) != self.numDetections:
|
||||
logging.error("DataController: Size of passed tracks does not match data!")
|
||||
return
|
||||
self._data["track"] = tracks
|
||||
|
||||
def save(self, filename):
|
||||
export_columns = self._columns.copy()
|
||||
export_columns.remove("index")
|
||||
dictionary = {c: self._data[c] for c in export_columns}
|
||||
df = pd.DataFrame(dictionary, index=self._data["index"])
|
||||
with open(filename, 'wb') as f:
|
||||
pickle.dump(df, f)
|
||||
|
||||
def numKeypoints(self):
|
||||
if len(self._data["keypoints"]) == 0:
|
||||
return 0
|
||||
return self._data["keypoints"][0].shape[0]
|
||||
|
||||
def coordinates(self):
|
||||
return np.stack(self._data["keypoints"]).astype(np.float32)
|
||||
|
||||
class FixTracks(QWidget):
|
||||
back = Signal()
|
||||
trackone_id = 1
|
||||
@ -327,7 +171,7 @@ class FixTracks(QWidget):
|
||||
self._reader = None
|
||||
self._image = None
|
||||
self._clear_detections = True
|
||||
self._data = DataController()
|
||||
self._data = TrackingData()
|
||||
self._brushes = {"assigned_left": QBrush(QColor.fromString("orange")),
|
||||
"assigned_right": QBrush(QColor.fromString("green")),
|
||||
"unassigned": QBrush(QColor.fromString("red"))
|
||||
@ -534,7 +378,11 @@ class FixTracks(QWidget):
|
||||
rel_width = self._windowspinner.value() / maxframes
|
||||
self._timeline.setWindowWidth(rel_width)
|
||||
coordinates = self._data.coordinates()
|
||||
positions = self._data.centerOfGravity()
|
||||
tracks = self._data["track"]
|
||||
frames = self._data["frame"]
|
||||
self._classifier.size_classifier.setCoordinates(coordinates)
|
||||
self._classifier.neighborhood_validator.setData(positions, tracks, frames)
|
||||
self.update()
|
||||
self._saveBtn.setEnabled(True)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user