diff --git a/fixtracks/utils/tablemodels.py b/fixtracks/utils/tablemodels.py new file mode 100644 index 0000000..e2c6bdc --- /dev/null +++ b/fixtracks/utils/tablemodels.py @@ -0,0 +1,74 @@ +import logging +import numpy as np + +from PySide6.QtCore import Qt, QAbstractTableModel, QSortFilterProxyModel + +class PoseTableModel(QAbstractTableModel): + column_header = ["frame", "track"] + columns = ["frame", "track"] + + def __init__(self, dataframe, parent=None): + super().__init__(parent) + self._data = dataframe + self._frames = self._data.frame.values + self._tracks = self._data.track.values + self._indices = self._data.index.values + self._column_data = [self._frames, self._tracks] + + def columnCount(self, parent=None): + return len(self.columns) + + def rowCount(self, parent=None): + if self._data is not None: + return len(self._data) + else: + return 0 + + def data(self, index, role = ...): + value = self._column_data[index.column()][index.row()] + if role == Qt.ItemDataRole.DisplayRole: + return str(value) + elif role == Qt.ItemDataRole.UserRole: + return value + return None + + def headerData(self, section, orientation, role = ...): + if role == Qt.ItemDataRole.DisplayRole: + if orientation == Qt.Orientation.Horizontal: + return self.column_header[section] + else: + return str(self._indices[section]) + else: + return None + + def mapIdToRow(self, id): + row = np.where(self._indices == id)[0] + if len(row) == 0: + return -1 + return row[0] + + +class FilterProxyModel(QSortFilterProxyModel): + def __init__(self, parent=None): + super().__init__(parent) + self._range = None + + def setFilterRange(self, start, stop): + logging.info("FilterProxyModel.setFilterRange set to range %i , %i", start, stop) + self._range = (start, stop) + self.invalidateRowsFilter() + + def all(self): + self._range = None + + def filterAcceptsRow(self, source_row, source_parent): + if self._range is None: + return True + else: + idx = self.sourceModel().index(source_row, 0, source_parent); + val = self.sourceModel().data(idx, Qt.ItemDataRole.UserRole) + print("filteracceptrows: ", val, self._range, val >= self._range[0] and val < self._range[1] ) + return val >= self._range[0] and val < self._range[1] + + def filterAcceptsColumn(self, source_column, source_parent): + return True diff --git a/fixtracks/widgets/skeleton.py b/fixtracks/widgets/skeleton.py index 13c8f44..be360dc 100644 --- a/fixtracks/widgets/skeleton.py +++ b/fixtracks/widgets/skeleton.py @@ -41,8 +41,8 @@ class Skeleton(QGraphicsRectItem): @property def length(self): - bodykps = self._keypoints[self.bodyaxis, :] - dist = np.sum(np.sqrt(np.sum(np.diff(bodykps, axis=0)**2, axis=1)), axis=0) + bodykpts = self._keypoints[self.bodyaxis, :] + dist = np.sum(np.sqrt(np.sum(np.diff(bodykpts, axis=0)**2, axis=1)), axis=0) return dist # def mousePressEvent(self, event): diff --git a/fixtracks/widgets/tracks.py b/fixtracks/widgets/tracks.py index 2e48f58..9eb6f75 100644 --- a/fixtracks/widgets/tracks.py +++ b/fixtracks/widgets/tracks.py @@ -1,10 +1,9 @@ import logging -import pathlib import pickle import numpy as np import pandas as pd -from PySide6.QtCore import Qt, QThreadPool, Signal, QAbstractTableModel, QSortFilterProxyModel, QSize, QObject +from PySide6.QtCore import Qt, QThreadPool, Signal, QSize, QObject from PySide6.QtGui import QImage, QBrush, QColor, QFont from PySide6.QtWidgets import QWidget, QVBoxLayout, QHBoxLayout, QLabel, QPushButton, QSizePolicy, QComboBox from PySide6.QtWidgets import QSpinBox, QSpacerItem, QProgressBar, QSplitter, QGridLayout, QFileDialog, QGridLayout @@ -17,77 +16,6 @@ from fixtracks.widgets.skeleton import SkeletonWidget from fixtracks.widgets.classifier import ClassifierWidget -class PoseTableModel(QAbstractTableModel): - column_header = ["frame", "track"] - columns = ["frame", "track"] - - def __init__(self, dataframe, parent=None): - super().__init__(parent) - self._data = dataframe - self._frames = self._data.frame.values - self._tracks = self._data.track.values - self._indices = self._data.index.values - self._column_data = [self._frames, self._tracks] - - def columnCount(self, parent=None): - return len(self.columns) - - def rowCount(self, parent=None): - if self._data is not None: - return len(self._data) - else: - return 0 - - def data(self, index, role = ...): - value = self._column_data[index.column()][index.row()] - if role == Qt.ItemDataRole.DisplayRole: - return str(value) - elif role == Qt.ItemDataRole.UserRole: - return value - return None - - def headerData(self, section, orientation, role = ...): - if role == Qt.ItemDataRole.DisplayRole: - if orientation == Qt.Orientation.Horizontal: - return self.column_header[section] - else: - return str(self._indices[section]) - else: - return None - - def mapIdToRow(self, id): - row = np.where(self._indices == id)[0] - if len(row) == 0: - return -1 - return row[0] - - -class FilterProxyModel(QSortFilterProxyModel): - def __init__(self, parent=None): - super().__init__(parent) - self._range = None - - def setFilterRange(self, start, stop): - logging.info("FilterProxyModel.setFilterRange set to range %i , %i", start, stop) - self._range = (start, stop) - self.invalidateRowsFilter() - - def all(self): - self._range = None - - def filterAcceptsRow(self, source_row, source_parent): - if self._range is None: - return True - else: - idx = self.sourceModel().index(source_row, 0, source_parent); - val = self.sourceModel().data(idx, Qt.ItemDataRole.UserRole) - print("filteracceptrows: ", val, self._range, val >= self._range[0] and val < self._range[1] ) - return val >= self._range[0] and val < self._range[1] - - def filterAcceptsColumn(self, source_column, source_parent): - return True - - class SelectionControls(QWidget): fwd = Signal(float) back = Signal(float) @@ -230,7 +158,7 @@ class SelectionControls(QWidget): self._total = len(tracks) -class DataController(QObject): +class TrackingData(QObject): def __init__(self, parent=None): super().__init__(parent) self._data = None @@ -314,6 +242,7 @@ class DataController(QObject): def coordinates(self): return np.stack(self._data["keypoints"]).astype(np.float32) + class FixTracks(QWidget): back = Signal() trackone_id = 1 @@ -327,7 +256,7 @@ class FixTracks(QWidget): self._reader = None self._image = None self._clear_detections = True - self._data = DataController() + self._data = TrackingData() self._brushes = {"assigned_left": QBrush(QColor.fromString("orange")), "assigned_right": QBrush(QColor.fromString("green")), "unassigned": QBrush(QColor.fromString("red"))