[cleanup] moving stuff, rename DataController to TrackingData

This commit is contained in:
Jan Grewe 2025-02-09 11:13:33 +01:00
parent ff7e1e85ae
commit b7d4097e73
3 changed files with 80 additions and 77 deletions

View File

@ -0,0 +1,74 @@
import logging
import numpy as np
from PySide6.QtCore import Qt, QAbstractTableModel, QSortFilterProxyModel
class PoseTableModel(QAbstractTableModel):
column_header = ["frame", "track"]
columns = ["frame", "track"]
def __init__(self, dataframe, parent=None):
super().__init__(parent)
self._data = dataframe
self._frames = self._data.frame.values
self._tracks = self._data.track.values
self._indices = self._data.index.values
self._column_data = [self._frames, self._tracks]
def columnCount(self, parent=None):
return len(self.columns)
def rowCount(self, parent=None):
if self._data is not None:
return len(self._data)
else:
return 0
def data(self, index, role = ...):
value = self._column_data[index.column()][index.row()]
if role == Qt.ItemDataRole.DisplayRole:
return str(value)
elif role == Qt.ItemDataRole.UserRole:
return value
return None
def headerData(self, section, orientation, role = ...):
if role == Qt.ItemDataRole.DisplayRole:
if orientation == Qt.Orientation.Horizontal:
return self.column_header[section]
else:
return str(self._indices[section])
else:
return None
def mapIdToRow(self, id):
row = np.where(self._indices == id)[0]
if len(row) == 0:
return -1
return row[0]
class FilterProxyModel(QSortFilterProxyModel):
def __init__(self, parent=None):
super().__init__(parent)
self._range = None
def setFilterRange(self, start, stop):
logging.info("FilterProxyModel.setFilterRange set to range %i , %i", start, stop)
self._range = (start, stop)
self.invalidateRowsFilter()
def all(self):
self._range = None
def filterAcceptsRow(self, source_row, source_parent):
if self._range is None:
return True
else:
idx = self.sourceModel().index(source_row, 0, source_parent);
val = self.sourceModel().data(idx, Qt.ItemDataRole.UserRole)
print("filteracceptrows: ", val, self._range, val >= self._range[0] and val < self._range[1] )
return val >= self._range[0] and val < self._range[1]
def filterAcceptsColumn(self, source_column, source_parent):
return True

View File

@ -41,8 +41,8 @@ class Skeleton(QGraphicsRectItem):
@property @property
def length(self): def length(self):
bodykps = self._keypoints[self.bodyaxis, :] bodykpts = self._keypoints[self.bodyaxis, :]
dist = np.sum(np.sqrt(np.sum(np.diff(bodykps, axis=0)**2, axis=1)), axis=0) dist = np.sum(np.sqrt(np.sum(np.diff(bodykpts, axis=0)**2, axis=1)), axis=0)
return dist return dist
# def mousePressEvent(self, event): # def mousePressEvent(self, event):

View File

@ -1,10 +1,9 @@
import logging import logging
import pathlib
import pickle import pickle
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from PySide6.QtCore import Qt, QThreadPool, Signal, QAbstractTableModel, QSortFilterProxyModel, QSize, QObject from PySide6.QtCore import Qt, QThreadPool, Signal, QSize, QObject
from PySide6.QtGui import QImage, QBrush, QColor, QFont from PySide6.QtGui import QImage, QBrush, QColor, QFont
from PySide6.QtWidgets import QWidget, QVBoxLayout, QHBoxLayout, QLabel, QPushButton, QSizePolicy, QComboBox from PySide6.QtWidgets import QWidget, QVBoxLayout, QHBoxLayout, QLabel, QPushButton, QSizePolicy, QComboBox
from PySide6.QtWidgets import QSpinBox, QSpacerItem, QProgressBar, QSplitter, QGridLayout, QFileDialog, QGridLayout from PySide6.QtWidgets import QSpinBox, QSpacerItem, QProgressBar, QSplitter, QGridLayout, QFileDialog, QGridLayout
@ -17,77 +16,6 @@ from fixtracks.widgets.skeleton import SkeletonWidget
from fixtracks.widgets.classifier import ClassifierWidget from fixtracks.widgets.classifier import ClassifierWidget
class PoseTableModel(QAbstractTableModel):
column_header = ["frame", "track"]
columns = ["frame", "track"]
def __init__(self, dataframe, parent=None):
super().__init__(parent)
self._data = dataframe
self._frames = self._data.frame.values
self._tracks = self._data.track.values
self._indices = self._data.index.values
self._column_data = [self._frames, self._tracks]
def columnCount(self, parent=None):
return len(self.columns)
def rowCount(self, parent=None):
if self._data is not None:
return len(self._data)
else:
return 0
def data(self, index, role = ...):
value = self._column_data[index.column()][index.row()]
if role == Qt.ItemDataRole.DisplayRole:
return str(value)
elif role == Qt.ItemDataRole.UserRole:
return value
return None
def headerData(self, section, orientation, role = ...):
if role == Qt.ItemDataRole.DisplayRole:
if orientation == Qt.Orientation.Horizontal:
return self.column_header[section]
else:
return str(self._indices[section])
else:
return None
def mapIdToRow(self, id):
row = np.where(self._indices == id)[0]
if len(row) == 0:
return -1
return row[0]
class FilterProxyModel(QSortFilterProxyModel):
def __init__(self, parent=None):
super().__init__(parent)
self._range = None
def setFilterRange(self, start, stop):
logging.info("FilterProxyModel.setFilterRange set to range %i , %i", start, stop)
self._range = (start, stop)
self.invalidateRowsFilter()
def all(self):
self._range = None
def filterAcceptsRow(self, source_row, source_parent):
if self._range is None:
return True
else:
idx = self.sourceModel().index(source_row, 0, source_parent);
val = self.sourceModel().data(idx, Qt.ItemDataRole.UserRole)
print("filteracceptrows: ", val, self._range, val >= self._range[0] and val < self._range[1] )
return val >= self._range[0] and val < self._range[1]
def filterAcceptsColumn(self, source_column, source_parent):
return True
class SelectionControls(QWidget): class SelectionControls(QWidget):
fwd = Signal(float) fwd = Signal(float)
back = Signal(float) back = Signal(float)
@ -230,7 +158,7 @@ class SelectionControls(QWidget):
self._total = len(tracks) self._total = len(tracks)
class DataController(QObject): class TrackingData(QObject):
def __init__(self, parent=None): def __init__(self, parent=None):
super().__init__(parent) super().__init__(parent)
self._data = None self._data = None
@ -314,6 +242,7 @@ class DataController(QObject):
def coordinates(self): def coordinates(self):
return np.stack(self._data["keypoints"]).astype(np.float32) return np.stack(self._data["keypoints"]).astype(np.float32)
class FixTracks(QWidget): class FixTracks(QWidget):
back = Signal() back = Signal()
trackone_id = 1 trackone_id = 1
@ -327,7 +256,7 @@ class FixTracks(QWidget):
self._reader = None self._reader = None
self._image = None self._image = None
self._clear_detections = True self._clear_detections = True
self._data = DataController() self._data = TrackingData()
self._brushes = {"assigned_left": QBrush(QColor.fromString("orange")), self._brushes = {"assigned_left": QBrush(QColor.fromString("orange")),
"assigned_right": QBrush(QColor.fromString("green")), "assigned_right": QBrush(QColor.fromString("green")),
"unassigned": QBrush(QColor.fromString("red")) "unassigned": QBrush(QColor.fromString("red"))