[cleanup] moving stuff, rename DataController to TrackingData

This commit is contained in:
Jan Grewe 2025-02-09 11:13:33 +01:00
parent ff7e1e85ae
commit b7d4097e73
3 changed files with 80 additions and 77 deletions

View File

@ -0,0 +1,74 @@
import logging
import numpy as np
from PySide6.QtCore import Qt, QAbstractTableModel, QSortFilterProxyModel
class PoseTableModel(QAbstractTableModel):
column_header = ["frame", "track"]
columns = ["frame", "track"]
def __init__(self, dataframe, parent=None):
super().__init__(parent)
self._data = dataframe
self._frames = self._data.frame.values
self._tracks = self._data.track.values
self._indices = self._data.index.values
self._column_data = [self._frames, self._tracks]
def columnCount(self, parent=None):
return len(self.columns)
def rowCount(self, parent=None):
if self._data is not None:
return len(self._data)
else:
return 0
def data(self, index, role = ...):
value = self._column_data[index.column()][index.row()]
if role == Qt.ItemDataRole.DisplayRole:
return str(value)
elif role == Qt.ItemDataRole.UserRole:
return value
return None
def headerData(self, section, orientation, role = ...):
if role == Qt.ItemDataRole.DisplayRole:
if orientation == Qt.Orientation.Horizontal:
return self.column_header[section]
else:
return str(self._indices[section])
else:
return None
def mapIdToRow(self, id):
row = np.where(self._indices == id)[0]
if len(row) == 0:
return -1
return row[0]
class FilterProxyModel(QSortFilterProxyModel):
def __init__(self, parent=None):
super().__init__(parent)
self._range = None
def setFilterRange(self, start, stop):
logging.info("FilterProxyModel.setFilterRange set to range %i , %i", start, stop)
self._range = (start, stop)
self.invalidateRowsFilter()
def all(self):
self._range = None
def filterAcceptsRow(self, source_row, source_parent):
if self._range is None:
return True
else:
idx = self.sourceModel().index(source_row, 0, source_parent);
val = self.sourceModel().data(idx, Qt.ItemDataRole.UserRole)
print("filteracceptrows: ", val, self._range, val >= self._range[0] and val < self._range[1] )
return val >= self._range[0] and val < self._range[1]
def filterAcceptsColumn(self, source_column, source_parent):
return True

View File

@ -41,8 +41,8 @@ class Skeleton(QGraphicsRectItem):
@property
def length(self):
bodykps = self._keypoints[self.bodyaxis, :]
dist = np.sum(np.sqrt(np.sum(np.diff(bodykps, axis=0)**2, axis=1)), axis=0)
bodykpts = self._keypoints[self.bodyaxis, :]
dist = np.sum(np.sqrt(np.sum(np.diff(bodykpts, axis=0)**2, axis=1)), axis=0)
return dist
# def mousePressEvent(self, event):

View File

@ -1,10 +1,9 @@
import logging
import pathlib
import pickle
import numpy as np
import pandas as pd
from PySide6.QtCore import Qt, QThreadPool, Signal, QAbstractTableModel, QSortFilterProxyModel, QSize, QObject
from PySide6.QtCore import Qt, QThreadPool, Signal, QSize, QObject
from PySide6.QtGui import QImage, QBrush, QColor, QFont
from PySide6.QtWidgets import QWidget, QVBoxLayout, QHBoxLayout, QLabel, QPushButton, QSizePolicy, QComboBox
from PySide6.QtWidgets import QSpinBox, QSpacerItem, QProgressBar, QSplitter, QGridLayout, QFileDialog, QGridLayout
@ -17,77 +16,6 @@ from fixtracks.widgets.skeleton import SkeletonWidget
from fixtracks.widgets.classifier import ClassifierWidget
class PoseTableModel(QAbstractTableModel):
column_header = ["frame", "track"]
columns = ["frame", "track"]
def __init__(self, dataframe, parent=None):
super().__init__(parent)
self._data = dataframe
self._frames = self._data.frame.values
self._tracks = self._data.track.values
self._indices = self._data.index.values
self._column_data = [self._frames, self._tracks]
def columnCount(self, parent=None):
return len(self.columns)
def rowCount(self, parent=None):
if self._data is not None:
return len(self._data)
else:
return 0
def data(self, index, role = ...):
value = self._column_data[index.column()][index.row()]
if role == Qt.ItemDataRole.DisplayRole:
return str(value)
elif role == Qt.ItemDataRole.UserRole:
return value
return None
def headerData(self, section, orientation, role = ...):
if role == Qt.ItemDataRole.DisplayRole:
if orientation == Qt.Orientation.Horizontal:
return self.column_header[section]
else:
return str(self._indices[section])
else:
return None
def mapIdToRow(self, id):
row = np.where(self._indices == id)[0]
if len(row) == 0:
return -1
return row[0]
class FilterProxyModel(QSortFilterProxyModel):
def __init__(self, parent=None):
super().__init__(parent)
self._range = None
def setFilterRange(self, start, stop):
logging.info("FilterProxyModel.setFilterRange set to range %i , %i", start, stop)
self._range = (start, stop)
self.invalidateRowsFilter()
def all(self):
self._range = None
def filterAcceptsRow(self, source_row, source_parent):
if self._range is None:
return True
else:
idx = self.sourceModel().index(source_row, 0, source_parent);
val = self.sourceModel().data(idx, Qt.ItemDataRole.UserRole)
print("filteracceptrows: ", val, self._range, val >= self._range[0] and val < self._range[1] )
return val >= self._range[0] and val < self._range[1]
def filterAcceptsColumn(self, source_column, source_parent):
return True
class SelectionControls(QWidget):
fwd = Signal(float)
back = Signal(float)
@ -230,7 +158,7 @@ class SelectionControls(QWidget):
self._total = len(tracks)
class DataController(QObject):
class TrackingData(QObject):
def __init__(self, parent=None):
super().__init__(parent)
self._data = None
@ -314,6 +242,7 @@ class DataController(QObject):
def coordinates(self):
return np.stack(self._data["keypoints"]).astype(np.float32)
class FixTracks(QWidget):
back = Signal()
trackone_id = 1
@ -327,7 +256,7 @@ class FixTracks(QWidget):
self._reader = None
self._image = None
self._clear_detections = True
self._data = DataController()
self._data = TrackingData()
self._brushes = {"assigned_left": QBrush(QColor.fromString("orange")),
"assigned_right": QBrush(QColor.fromString("green")),
"unassigned": QBrush(QColor.fromString("red"))