Compare commits

..

No commits in common. "5a128cf28e2c5bbdd6b66aac128016158a1f9a28" and "ff7e1e85ae4d24dd7b637d076a3fc9b34af8c5df" have entirely different histories.

6 changed files with 208 additions and 487 deletions

View File

@ -1,74 +0,0 @@
import logging
import numpy as np
from PySide6.QtCore import Qt, QAbstractTableModel, QSortFilterProxyModel
class PoseTableModel(QAbstractTableModel):
column_header = ["frame", "track"]
columns = ["frame", "track"]
def __init__(self, dataframe, parent=None):
super().__init__(parent)
self._data = dataframe
self._frames = self._data.frame.values
self._tracks = self._data.track.values
self._indices = self._data.index.values
self._column_data = [self._frames, self._tracks]
def columnCount(self, parent=None):
return len(self.columns)
def rowCount(self, parent=None):
if self._data is not None:
return len(self._data)
else:
return 0
def data(self, index, role = ...):
value = self._column_data[index.column()][index.row()]
if role == Qt.ItemDataRole.DisplayRole:
return str(value)
elif role == Qt.ItemDataRole.UserRole:
return value
return None
def headerData(self, section, orientation, role = ...):
if role == Qt.ItemDataRole.DisplayRole:
if orientation == Qt.Orientation.Horizontal:
return self.column_header[section]
else:
return str(self._indices[section])
else:
return None
def mapIdToRow(self, id):
row = np.where(self._indices == id)[0]
if len(row) == 0:
return -1
return row[0]
class FilterProxyModel(QSortFilterProxyModel):
def __init__(self, parent=None):
super().__init__(parent)
self._range = None
def setFilterRange(self, start, stop):
logging.info("FilterProxyModel.setFilterRange set to range %i , %i", start, stop)
self._range = (start, stop)
self.invalidateRowsFilter()
def all(self):
self._range = None
def filterAcceptsRow(self, source_row, source_parent):
if self._range is None:
return True
else:
idx = self.sourceModel().index(source_row, 0, source_parent);
val = self.sourceModel().data(idx, Qt.ItemDataRole.UserRole)
print("filteracceptrows: ", val, self._range, val >= self._range[0] and val < self._range[1] )
return val >= self._range[0] and val < self._range[1]
def filterAcceptsColumn(self, source_column, source_parent):
return True

View File

@ -1,237 +0,0 @@
import pickle
import logging
import numpy as np
import pandas as pd
from PySide6.QtCore import QObject
class TrackingData(QObject):
def __init__(self, parent=None):
super().__init__(parent)
self._data = None
self._columns = []
self._start = 0
self._stop = 0
self._indices = None
self._selection_column = None
self._user_selections = None
def setData(self, datadict):
assert isinstance(datadict, dict)
self._data = datadict
self._columns = [k for k in self._data.keys()]
@property
def data(self):
return self._data
@property
def columns(self):
return self._columns
def max(self, col):
if col in self.columns:
return np.max(self._data[col])
else:
logging.error("Column %s not in dictionary", col)
return np.nan
@property
def numDetections(self):
return self._data["track"].shape[0]
@property
def selectionRange(self):
return self._start, self._stop
@property
def selectionRangeColumn(self):
return self._selection_column
@property
def selectionIndices(self):
return self._indices
def setSelectionRange(self, col, start, stop):
self._start = start
self._stop = stop
self._selection_column = col
self._indices = np.where((self._data[col] >= self._start) & (self._data[col] < self._stop))[0]
def selectedData(self, col):
return self._data[col][self._indices]
def setUserSelection(self, ids):
"""
Set the user selections. That is, e.g. when the user selected a number of ids.
Parameters
----------
ids : array-like
An array-like object containing the IDs to be set as user selections.
The IDs will be converted to integers.
"""
self._user_selections = ids.astype(int)
def assignUserSelection(self, track_id:int)-> None:
"""Assign a new track_id to the user-selected detections
Parameters
----------
track_id : int
The new track id for the user-selected detections
"""
self._data["track"][self._user_selections] = track_id
def assignTracks(self, tracks):
"""assignTracks _summary_
Parameters
----------
tracks : _type_
_description_
Returns
-------
_type_
_description_
"""
if len(tracks) != self.numDetections:
logging.error("DataController: Size of passed tracks does not match data!")
return
self._data["track"] = tracks
def save(self, filename):
export_columns = self._columns.copy()
export_columns.remove("index")
dictionary = {c: self._data[c] for c in export_columns}
df = pd.DataFrame(dictionary, index=self._data["index"])
with open(filename, 'wb') as f:
pickle.dump(df, f)
def numKeypoints(self):
if len(self._data["keypoints"]) == 0:
return 0
return self._data["keypoints"][0].shape[0]
def coordinates(self):
"""
Returns the coordinates of all keypoints as a NumPy array.
Returns:
np.ndarray: A NumPy array of shape (N, M, 2) where N is the number of detections,
and M is number of keypoints
"""
return np.stack(self._data["keypoints"]).astype(np.float32)
def keypointScores(self):
"""
Returns the keypoint scores as a NumPy array of type float32.
Returns
-------
numpy.ndarray
A NumPy array of type float32 containing the keypoint scores of the shape (N, M)
with N the number of detections and M the number of keypoints.
"""
return np.stack(self._data["keypoint_score"]).astype(np.float32)
def centerOfGravity(self, threshold=0.8):
"""
Calculate the center of gravity of keypoints weighted by their scores. Ignores keypoints that have a score
less than threshold.
Parameters:
-----------
threshold: float
keypoints with a score less than threshold are ignored
Returns:
--------
np.ndarray:
A NumPy array of shape (N, 2) containing the center of gravity for each detection.
"""
scores = self.keypointScores()
scores[scores < threshold] = 0.0
weighted_coords = self.coordinates() * scores[:, :, np.newaxis]
sum_scores = np.sum(scores, axis=1, keepdims=True)
center_of_gravity = np.sum(weighted_coords, axis=1) / sum_scores
return center_of_gravity
def __getitem__(self, key):
return self._data[key]
# def __setitem__(self, key, value):
# self._data[key] = value
"""
self._data.setSelectionRange("index", 0, self._data.numDetections)
self._data.assignTracks(tracks)
self._timeline.setDetectionData(self._data.data)
self.update()
"""
def main():
import pandas as pd
from IPython import embed
import matplotlib.pyplot as plt
from fixtracks.info import PACKAGE_ROOT
from scipy.spatial.distance import cdist
def as_dict(df:pd.DataFrame):
d = {c: df[c].values for c in df.columns}
d["index"] = df.index.values
return d
def neighborDistances(x, n=5, symmetric=True):
pad_shape = list(x.shape)
pad_shape[0] = 5
pad = np.zeros(pad_shape)
if symmetric:
padded_x = np.vstack((pad, x, pad))
else:
padded_x = np.vstack((pad, x))
dists = np.zeros((padded_x.shape[0], 2*n))
count = 0
r = range(-n, n+1) if symmetric else range(-n, 0)
for i in r:
if i == 0:
continue
shifted_x = np.roll(padded_x, i)
dists[:, count] = np.sqrt(np.sum((padded_x - shifted_x)**2, axis=1))
count += 1
return dists
datafile = PACKAGE_ROOT / "data/merged_small.pkl"
with open(datafile, "rb") as f:
df = pickle.load(f)
data = TrackingData()
data.setData(as_dict(df))
all_cogs = data.centerOfGravity()
tracks = data["track"]
cogs = all_cogs[tracks==1]
all_dists = neighborDistances(cogs, 2, False)
plt.hist(all_dists[1:, 0], bins=1000)
print(np.percentile(all_dists[1:, 0], 99))
print(np.percentile(all_dists[1:, 0], 1))
plt.gca().set_xscale("log")
plt.gca().set_yscale("log")
# plt.hist(all_dists[1:, 1], bins=100)
plt.show()
# def compute_neighbor_distances(cogs, window=10):
# distances = []
# for i in range(len(cogs)):
# start = max(0, i - window)
# stop = min(len(cogs), i + window + 1)
# neighbors = cogs[start:stop]
# dists = cdist([cogs[i]], neighbors)[0]
# distances.append(dists)
# return distances
# print("estimating neighorhood distances")
# neighbor_distances = compute_neighbor_distances(cogs)
embed()
if __name__ == "__main__":
main()

View File

@ -1,17 +1,15 @@
import logging
import numpy as np
import pyqtgraph as pg
from PySide6.QtWidgets import QWidget, QVBoxLayout, QTabWidget,QPushButton, QGraphicsView
from PySide6.QtWidgets import QWidget, QVBoxLayout, QTabWidget,QPushButton
from PySide6.QtCore import Signal
from PySide6.QtGui import QBrush, QColor
from fixtracks.utils.trackingdata import TrackingData
import pyqtgraph as pg
class SizeClassifier(QWidget):
apply = Signal()
name = "SizeClassifier"
def __init__(self, parent=None):
super().__init__(parent)
@ -31,16 +29,17 @@ class SizeClassifier(QWidget):
def setupGraph(self):
track1_brush = QBrush(QColor.fromString("orange"))
track1_brush.color().setAlphaF(0.5)
track2_brush = QBrush(QColor.fromString("green"))
pg.setConfigOptions(antialias=True)
plot_widget = pg.GraphicsLayoutWidget(show=False)
self._t1_selection = pg.LinearRegionItem([100, 200])
self._t1_selection.setZValue(-10)
self._t1_selection.setZValue(-10) # what is that?
self._t1_selection.setBrush(track1_brush)
self._t2_selection = pg.LinearRegionItem([300,400])
self._t2_selection.setZValue(-10)
self._t2_selection.setZValue(-10) # what is that?
self._t2_selection.setBrush(track2_brush)
return plot_widget
@ -48,8 +47,8 @@ class SizeClassifier(QWidget):
if bodyaxis is None:
bodyaxis = [0, 1, 2, 5]
bodycoords = coords[:, bodyaxis, :]
lengths = np.sum(np.sqrt(np.sum(np.diff(bodycoords, axis=1)**2, axis=2)), axis=1)
return lengths
dists = np.sum(np.sqrt(np.sum(np.diff(bodycoords, axis=1)**2, axis=2)), axis=1)
return dists
def estimate_histogram(self, dists, min_threshold=1., max_threshold=99.):
min_length = np.percentile(dists, min_threshold)
@ -84,112 +83,6 @@ class SizeClassifier(QWidget):
tracks[(self._sizes >= t2lower) & (self._sizes < t2upper)] = 2
return tracks
class NeighborhoodValidator(QWidget):
apply = Signal()
name = "Neighborhood Validator"
def __init__(self, parent = None):
super().__init__(parent)
self._threshold = None
self._positions = None
self._distances = None
self._tracks = None
self._plot = None
self._plot_widget = self.setupGraph()
self._apply_btn = QPushButton("apply")
self._apply_btn.clicked.connect(lambda: self.apply.emit())
layout = QVBoxLayout()
print(isinstance(self._plot_widget, QGraphicsView))
layout.addWidget(self._plot_widget)
layout.addWidget(self._apply_btn)
self.setLayout(layout)
def setupGraph(self):
pg.setConfigOptions(antialias=True)
plot_widget = pg.GraphicsLayoutWidget(show=False)
self._threshold = pg.LineSegmentROI([[10, 64], [120,64]], pen='r')
self._threshold.setZValue(-10) # what is that?
return plot_widget
def estimate_histogram(self, dists, min_threshold=1., max_threshold=99., bin_count=100, log=False):
min_dist = np.percentile(dists, min_threshold)
max_dist = np.percentile(dists, max_threshold)
if log:
bins = np.logspace(min_dist, max_dist, bin_count, base=10)
bins = np.linspace(min_dist, max_dist, bin_count)
hist, edges = np.histogram(dists, bins=bins, density=True)
return hist, edges
def neighborDistances(self, x, frames, n=5, symmetric=True):
logging.debug("classifier:NeighborhoodValidator neighborDistance")
pad_shape = list(x.shape)
pad_shape[0] = n
pad = np.atleast_2d(np.zeros(pad_shape))
if symmetric:
padded_x = np.vstack((pad, x, pad))
dists = np.zeros((x.shape[0]-1, 2*n))
else:
padded_x = np.vstack((pad, x))
dists = np.zeros((x.shape[0]-1, n))
count = 0
r = range(-n, n+1) if symmetric else range(-n, 0)
for i in r:
if i == 0:
continue
shifted_x = np.roll(padded_x, i, axis=0)
dist = np.sqrt(np.sum((padded_x - shifted_x)**2, axis=1))
dists[:, count] = dist[n+1:]/np.diff(frames)
count += 1
return dists
def setData(self, positions, tracks, frames):
"""Set the data, the classifier/should be working on.
Parameters
----------
positions : np.ndarray
The position estimates, e.g. the center of gravity for each detection
tracks : np.ndarray
The current track assignment.
frames : np.ndarray
respective frame.
"""
def mouseClicked(self, event):
print("mouse clicked at", event.pos())
track2_brush = QBrush(QColor.fromString("green"))
track1_brush = QBrush(QColor.fromString("orange"))
self._positions = positions
self._tracks = tracks
self._frames = frames
t1_positions = self._positions[self._tracks == 1]
t1_frames = self._frames[self._tracks == 1]
t1_distances = self.neighborDistances(t1_positions, t1_frames, 1, False)
t2_positions = self._positions[self._tracks == 2]
t2_frames = self._frames[self._tracks == 2]
t2_distances = self.neighborDistances(t2_positions, t2_frames, 1, False)
n, e = self.estimate_histogram(t1_distances[1:], bin_count=100, log=False)
self._plot = self._plot_widget.addPlot()
bgi1 = pg.BarGraphItem(x0=e[:-1], x1=e[1:], height=n, pen='w', brush=track1_brush)
self._plot.addItem(bgi1)
n, e = self.estimate_histogram(t2_distances[1:], bin_count=100, log=False)
bgi2 = pg.BarGraphItem(x0=e[:-1], x1=e[1:], height=n, pen='w', brush=track2_brush)
self._plot.addItem(bgi2)
self._plot.scene().sigMouseClicked.connect(mouseClicked)
self._plot.setLogMode(x=False, y=True)
# plot.setXRange(np.min(t1_distances), np.max(t1_distances))
self._plot.setLabel('left', "prob. density")
self._plot.setLabel('bottom', "distance", units="px/frame")
# plot.addItem(self._threshold)
vLine = pg.InfiniteLine(pos=10, angle=90, movable=False)
self._plot.addItem(vLine, ignoreBounds=True)
vb = self._plot.vb
class ClassifierWidget(QTabWidget):
apply_sizeclassifier = Signal(np.ndarray)
@ -197,9 +90,7 @@ class ClassifierWidget(QTabWidget):
def __init__(self, parent=None):
super().__init__(parent)
self._size_classifier = SizeClassifier()
self._neigborhood_validator = NeighborhoodValidator()
self.addTab(self._size_classifier, SizeClassifier.name)
self.addTab(self._neigborhood_validator, NeighborhoodValidator.name)
self.addTab(self._size_classifier, "Size classifier")
self._size_classifier.apply.connect(self._on_applySizeClassifier)
def _on_applySizeClassifier(self):
@ -210,51 +101,48 @@ class ClassifierWidget(QTabWidget):
def size_classifier(self):
return self._size_classifier
@property
def neighborhood_validator(self):
return self._neigborhood_validator
def as_dict(df):
d = {c: df[c].values for c in df.columns}
d["index"] = df.index.values
return d
def main():
test_size = False
import pickle
from fixtracks.info import PACKAGE_ROOT
datafile = PACKAGE_ROOT / "data/merged_small.pkl"
def test_sizeClassifier(coords):
app = QApplication([])
window = QWidget()
window.setMinimumSize(200, 200)
layout = QVBoxLayout()
win = SizeClassifier()
win.setCoordinates(coords)
with open(datafile, "rb") as f:
df = pickle.load(f)
data = TrackingData()
data.setData(as_dict(df))
layout.addWidget(win)
window.setLayout(layout)
window.show()
app.exec()
positions = data.centerOfGravity()
tracks = data["track"]
frames = data["frame"]
coords = data.coordinates()
def test_neighborhoodClassifier(coords):
app = QApplication([])
window = QWidget()
window.setMinimumSize(200, 200)
# if test_size:
# win = SizeClassifier()
# win.setCoordinates(coords)
# else:
w = ClassifierWidget()
w.neighborhood_validator.setData(positions, tracks, frames)
layout = QVBoxLayout()
layout.addWidget(w)
win = SizeClassifier()
win.setCoordinates(coords)
layout.addWidget(win)
window.setLayout(layout)
window.show()
app.exec()
if __name__ == "__main__":
def main():
import pickle
from fixtracks.info import PACKAGE_ROOT
from PySide6.QtWidgets import QApplication
datafile = PACKAGE_ROOT / "data/merged_small.pkl"
print(datafile)
with open(datafile, "rb") as f:
df = pickle.load(f)
coords = np.stack(df.keypoints.values,).astype(np.float32)
frames = df.frame.values
test_sizeClassifier(coords)
if __name__ == "__main__":
main()

View File

@ -92,7 +92,6 @@ class DetectionView(QWidget):
self._view = QGraphicsView()
self._view.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding)
self._view.setMouseTracking(True)
self._mouseEnabled = True
self._zoomFactor = 1.15
self._minZoom = 0.1
self._maxZoom = 10
@ -103,23 +102,16 @@ class DetectionView(QWidget):
self.setLayout(lyt)
def wheelEvent(self, event):
if not self._mouseEnabled:
super().wheelEvent(event)
return
modifiers = event.modifiers()
if modifiers == Qt.ControlModifier:
delta = event.angleDelta().x()
if delta == 0:
delta = event.angleDelta().y()
sc = 1.001 ** delta
self._view.scale(sc, sc)
else:
super().wheelEvent(event)
# elif modifiers == Qt.ShiftModifier:
# print("Shift key pressed")
# elif modifiers == Qt.AltModifier:
# print("Alt key pressed")
if event.angleDelta().y() > 0: # Zoom in
factor = self._zoomFactor
else: # Zoom out
factor = 1 / self._zoomFactor
newZoom = self._currentZoom * factor
if self._minZoom < newZoom < self._maxZoom:
self._view.scale(factor, factor)
self._currentZoom = newZoom
def setImage(self, image: QImage):
self._img = image

View File

@ -41,8 +41,8 @@ class Skeleton(QGraphicsRectItem):
@property
def length(self):
bodykpts = self._keypoints[self.bodyaxis, :]
dist = np.sum(np.sqrt(np.sum(np.diff(bodykpts, axis=0)**2, axis=1)), axis=0)
bodykps = self._keypoints[self.bodyaxis, :]
dist = np.sum(np.sqrt(np.sum(np.diff(bodykps, axis=0)**2, axis=1)), axis=0)
return dist
# def mousePressEvent(self, event):

View File

@ -1,21 +1,93 @@
import logging
import pathlib
import pickle
import numpy as np
import pandas as pd
from PySide6.QtCore import Qt, QThreadPool, Signal, QSize, QObject
from PySide6.QtCore import Qt, QThreadPool, Signal, QAbstractTableModel, QSortFilterProxyModel, QSize, QObject
from PySide6.QtGui import QImage, QBrush, QColor, QFont
from PySide6.QtWidgets import QWidget, QVBoxLayout, QHBoxLayout, QLabel, QPushButton, QSizePolicy, QComboBox
from PySide6.QtWidgets import QSpinBox, QSpacerItem, QProgressBar, QSplitter, QGridLayout, QFileDialog, QGridLayout
from fixtracks.utils.reader import PickleLoader
from fixtracks.utils.writer import PickleWriter
from fixtracks.utils.trackingdata import TrackingData
from fixtracks.widgets.detectionview import DetectionView, DetectionData
from fixtracks.widgets.detectiontimeline import DetectionTimeline
from fixtracks.widgets.skeleton import SkeletonWidget
from fixtracks.widgets.classifier import ClassifierWidget
class PoseTableModel(QAbstractTableModel):
column_header = ["frame", "track"]
columns = ["frame", "track"]
def __init__(self, dataframe, parent=None):
super().__init__(parent)
self._data = dataframe
self._frames = self._data.frame.values
self._tracks = self._data.track.values
self._indices = self._data.index.values
self._column_data = [self._frames, self._tracks]
def columnCount(self, parent=None):
return len(self.columns)
def rowCount(self, parent=None):
if self._data is not None:
return len(self._data)
else:
return 0
def data(self, index, role = ...):
value = self._column_data[index.column()][index.row()]
if role == Qt.ItemDataRole.DisplayRole:
return str(value)
elif role == Qt.ItemDataRole.UserRole:
return value
return None
def headerData(self, section, orientation, role = ...):
if role == Qt.ItemDataRole.DisplayRole:
if orientation == Qt.Orientation.Horizontal:
return self.column_header[section]
else:
return str(self._indices[section])
else:
return None
def mapIdToRow(self, id):
row = np.where(self._indices == id)[0]
if len(row) == 0:
return -1
return row[0]
class FilterProxyModel(QSortFilterProxyModel):
def __init__(self, parent=None):
super().__init__(parent)
self._range = None
def setFilterRange(self, start, stop):
logging.info("FilterProxyModel.setFilterRange set to range %i , %i", start, stop)
self._range = (start, stop)
self.invalidateRowsFilter()
def all(self):
self._range = None
def filterAcceptsRow(self, source_row, source_parent):
if self._range is None:
return True
else:
idx = self.sourceModel().index(source_row, 0, source_parent);
val = self.sourceModel().data(idx, Qt.ItemDataRole.UserRole)
print("filteracceptrows: ", val, self._range, val >= self._range[0] and val < self._range[1] )
return val >= self._range[0] and val < self._range[1]
def filterAcceptsColumn(self, source_column, source_parent):
return True
class SelectionControls(QWidget):
fwd = Signal(float)
back = Signal(float)
@ -158,6 +230,90 @@ class SelectionControls(QWidget):
self._total = len(tracks)
class DataController(QObject):
def __init__(self, parent=None):
super().__init__(parent)
self._data = None
self._columns = []
self._start = 0
self._stop = 0
self._indices = None
self._selection_column = None
self._user_selections = None
def setData(self, datadict):
assert isinstance(datadict, dict)
self._data = datadict
self._columns = [k for k in self._data.keys()]
@property
def data(self):
return self._data
@property
def columns(self):
return self._columns
def max(self, col):
if col in self.columns:
return np.max(self._data[col])
else:
logging.error("Column %s not in dictionary", col)
return np.nan
@property
def numDetections(self):
return self._data["track"].shape[0]
@property
def selectionRange(self):
return self._start, self._stop
@property
def selectionRangeColumn(self):
return self._selection_column
@property
def selectionIndices(self):
return self._indices
def setSelectionRange(self, col, start, stop):
self._start = start
self._stop = stop
self._selection_column = col
self._indices = np.where((self._data[col] >= self._start) & (self._data[col] < self._stop))[0]
def selectedData(self, col):
return self._data[col][self._indices]
def setUserSelection(self, ids):
self._user_selections = ids.astype(int)
def assignUserSelection(self, track_id):
self._data["track"][self._user_selections] = track_id
def assignTracks(self, tracks):
if len(tracks) != self.numDetections:
logging.error("DataController: Size of passed tracks does not match data!")
return
self._data["track"] = tracks
def save(self, filename):
export_columns = self._columns.copy()
export_columns.remove("index")
dictionary = {c: self._data[c] for c in export_columns}
df = pd.DataFrame(dictionary, index=self._data["index"])
with open(filename, 'wb') as f:
pickle.dump(df, f)
def numKeypoints(self):
if len(self._data["keypoints"]) == 0:
return 0
return self._data["keypoints"][0].shape[0]
def coordinates(self):
return np.stack(self._data["keypoints"]).astype(np.float32)
class FixTracks(QWidget):
back = Signal()
trackone_id = 1
@ -171,7 +327,7 @@ class FixTracks(QWidget):
self._reader = None
self._image = None
self._clear_detections = True
self._data = TrackingData()
self._data = DataController()
self._brushes = {"assigned_left": QBrush(QColor.fromString("orange")),
"assigned_right": QBrush(QColor.fromString("green")),
"unassigned": QBrush(QColor.fromString("red"))
@ -378,11 +534,7 @@ class FixTracks(QWidget):
rel_width = self._windowspinner.value() / maxframes
self._timeline.setWindowWidth(rel_width)
coordinates = self._data.coordinates()
positions = self._data.centerOfGravity()
tracks = self._data["track"]
frames = self._data["frame"]
self._classifier.size_classifier.setCoordinates(coordinates)
self._classifier.neighborhood_validator.setData(positions, tracks, frames)
self.update()
self._saveBtn.setEnabled(True)