[cleanup] moving stuff, rename DataController to TrackingData
This commit is contained in:
parent
ff7e1e85ae
commit
b7d4097e73
74
fixtracks/utils/tablemodels.py
Normal file
74
fixtracks/utils/tablemodels.py
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
import logging
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from PySide6.QtCore import Qt, QAbstractTableModel, QSortFilterProxyModel
|
||||||
|
|
||||||
|
class PoseTableModel(QAbstractTableModel):
|
||||||
|
column_header = ["frame", "track"]
|
||||||
|
columns = ["frame", "track"]
|
||||||
|
|
||||||
|
def __init__(self, dataframe, parent=None):
|
||||||
|
super().__init__(parent)
|
||||||
|
self._data = dataframe
|
||||||
|
self._frames = self._data.frame.values
|
||||||
|
self._tracks = self._data.track.values
|
||||||
|
self._indices = self._data.index.values
|
||||||
|
self._column_data = [self._frames, self._tracks]
|
||||||
|
|
||||||
|
def columnCount(self, parent=None):
|
||||||
|
return len(self.columns)
|
||||||
|
|
||||||
|
def rowCount(self, parent=None):
|
||||||
|
if self._data is not None:
|
||||||
|
return len(self._data)
|
||||||
|
else:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def data(self, index, role = ...):
|
||||||
|
value = self._column_data[index.column()][index.row()]
|
||||||
|
if role == Qt.ItemDataRole.DisplayRole:
|
||||||
|
return str(value)
|
||||||
|
elif role == Qt.ItemDataRole.UserRole:
|
||||||
|
return value
|
||||||
|
return None
|
||||||
|
|
||||||
|
def headerData(self, section, orientation, role = ...):
|
||||||
|
if role == Qt.ItemDataRole.DisplayRole:
|
||||||
|
if orientation == Qt.Orientation.Horizontal:
|
||||||
|
return self.column_header[section]
|
||||||
|
else:
|
||||||
|
return str(self._indices[section])
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def mapIdToRow(self, id):
|
||||||
|
row = np.where(self._indices == id)[0]
|
||||||
|
if len(row) == 0:
|
||||||
|
return -1
|
||||||
|
return row[0]
|
||||||
|
|
||||||
|
|
||||||
|
class FilterProxyModel(QSortFilterProxyModel):
|
||||||
|
def __init__(self, parent=None):
|
||||||
|
super().__init__(parent)
|
||||||
|
self._range = None
|
||||||
|
|
||||||
|
def setFilterRange(self, start, stop):
|
||||||
|
logging.info("FilterProxyModel.setFilterRange set to range %i , %i", start, stop)
|
||||||
|
self._range = (start, stop)
|
||||||
|
self.invalidateRowsFilter()
|
||||||
|
|
||||||
|
def all(self):
|
||||||
|
self._range = None
|
||||||
|
|
||||||
|
def filterAcceptsRow(self, source_row, source_parent):
|
||||||
|
if self._range is None:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
idx = self.sourceModel().index(source_row, 0, source_parent);
|
||||||
|
val = self.sourceModel().data(idx, Qt.ItemDataRole.UserRole)
|
||||||
|
print("filteracceptrows: ", val, self._range, val >= self._range[0] and val < self._range[1] )
|
||||||
|
return val >= self._range[0] and val < self._range[1]
|
||||||
|
|
||||||
|
def filterAcceptsColumn(self, source_column, source_parent):
|
||||||
|
return True
|
@ -41,8 +41,8 @@ class Skeleton(QGraphicsRectItem):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def length(self):
|
def length(self):
|
||||||
bodykps = self._keypoints[self.bodyaxis, :]
|
bodykpts = self._keypoints[self.bodyaxis, :]
|
||||||
dist = np.sum(np.sqrt(np.sum(np.diff(bodykps, axis=0)**2, axis=1)), axis=0)
|
dist = np.sum(np.sqrt(np.sum(np.diff(bodykpts, axis=0)**2, axis=1)), axis=0)
|
||||||
return dist
|
return dist
|
||||||
|
|
||||||
# def mousePressEvent(self, event):
|
# def mousePressEvent(self, event):
|
||||||
|
@ -1,10 +1,9 @@
|
|||||||
import logging
|
import logging
|
||||||
import pathlib
|
|
||||||
import pickle
|
import pickle
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
from PySide6.QtCore import Qt, QThreadPool, Signal, QAbstractTableModel, QSortFilterProxyModel, QSize, QObject
|
from PySide6.QtCore import Qt, QThreadPool, Signal, QSize, QObject
|
||||||
from PySide6.QtGui import QImage, QBrush, QColor, QFont
|
from PySide6.QtGui import QImage, QBrush, QColor, QFont
|
||||||
from PySide6.QtWidgets import QWidget, QVBoxLayout, QHBoxLayout, QLabel, QPushButton, QSizePolicy, QComboBox
|
from PySide6.QtWidgets import QWidget, QVBoxLayout, QHBoxLayout, QLabel, QPushButton, QSizePolicy, QComboBox
|
||||||
from PySide6.QtWidgets import QSpinBox, QSpacerItem, QProgressBar, QSplitter, QGridLayout, QFileDialog, QGridLayout
|
from PySide6.QtWidgets import QSpinBox, QSpacerItem, QProgressBar, QSplitter, QGridLayout, QFileDialog, QGridLayout
|
||||||
@ -17,77 +16,6 @@ from fixtracks.widgets.skeleton import SkeletonWidget
|
|||||||
from fixtracks.widgets.classifier import ClassifierWidget
|
from fixtracks.widgets.classifier import ClassifierWidget
|
||||||
|
|
||||||
|
|
||||||
class PoseTableModel(QAbstractTableModel):
|
|
||||||
column_header = ["frame", "track"]
|
|
||||||
columns = ["frame", "track"]
|
|
||||||
|
|
||||||
def __init__(self, dataframe, parent=None):
|
|
||||||
super().__init__(parent)
|
|
||||||
self._data = dataframe
|
|
||||||
self._frames = self._data.frame.values
|
|
||||||
self._tracks = self._data.track.values
|
|
||||||
self._indices = self._data.index.values
|
|
||||||
self._column_data = [self._frames, self._tracks]
|
|
||||||
|
|
||||||
def columnCount(self, parent=None):
|
|
||||||
return len(self.columns)
|
|
||||||
|
|
||||||
def rowCount(self, parent=None):
|
|
||||||
if self._data is not None:
|
|
||||||
return len(self._data)
|
|
||||||
else:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
def data(self, index, role = ...):
|
|
||||||
value = self._column_data[index.column()][index.row()]
|
|
||||||
if role == Qt.ItemDataRole.DisplayRole:
|
|
||||||
return str(value)
|
|
||||||
elif role == Qt.ItemDataRole.UserRole:
|
|
||||||
return value
|
|
||||||
return None
|
|
||||||
|
|
||||||
def headerData(self, section, orientation, role = ...):
|
|
||||||
if role == Qt.ItemDataRole.DisplayRole:
|
|
||||||
if orientation == Qt.Orientation.Horizontal:
|
|
||||||
return self.column_header[section]
|
|
||||||
else:
|
|
||||||
return str(self._indices[section])
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def mapIdToRow(self, id):
|
|
||||||
row = np.where(self._indices == id)[0]
|
|
||||||
if len(row) == 0:
|
|
||||||
return -1
|
|
||||||
return row[0]
|
|
||||||
|
|
||||||
|
|
||||||
class FilterProxyModel(QSortFilterProxyModel):
|
|
||||||
def __init__(self, parent=None):
|
|
||||||
super().__init__(parent)
|
|
||||||
self._range = None
|
|
||||||
|
|
||||||
def setFilterRange(self, start, stop):
|
|
||||||
logging.info("FilterProxyModel.setFilterRange set to range %i , %i", start, stop)
|
|
||||||
self._range = (start, stop)
|
|
||||||
self.invalidateRowsFilter()
|
|
||||||
|
|
||||||
def all(self):
|
|
||||||
self._range = None
|
|
||||||
|
|
||||||
def filterAcceptsRow(self, source_row, source_parent):
|
|
||||||
if self._range is None:
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
idx = self.sourceModel().index(source_row, 0, source_parent);
|
|
||||||
val = self.sourceModel().data(idx, Qt.ItemDataRole.UserRole)
|
|
||||||
print("filteracceptrows: ", val, self._range, val >= self._range[0] and val < self._range[1] )
|
|
||||||
return val >= self._range[0] and val < self._range[1]
|
|
||||||
|
|
||||||
def filterAcceptsColumn(self, source_column, source_parent):
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
class SelectionControls(QWidget):
|
class SelectionControls(QWidget):
|
||||||
fwd = Signal(float)
|
fwd = Signal(float)
|
||||||
back = Signal(float)
|
back = Signal(float)
|
||||||
@ -230,7 +158,7 @@ class SelectionControls(QWidget):
|
|||||||
self._total = len(tracks)
|
self._total = len(tracks)
|
||||||
|
|
||||||
|
|
||||||
class DataController(QObject):
|
class TrackingData(QObject):
|
||||||
def __init__(self, parent=None):
|
def __init__(self, parent=None):
|
||||||
super().__init__(parent)
|
super().__init__(parent)
|
||||||
self._data = None
|
self._data = None
|
||||||
@ -314,6 +242,7 @@ class DataController(QObject):
|
|||||||
def coordinates(self):
|
def coordinates(self):
|
||||||
return np.stack(self._data["keypoints"]).astype(np.float32)
|
return np.stack(self._data["keypoints"]).astype(np.float32)
|
||||||
|
|
||||||
|
|
||||||
class FixTracks(QWidget):
|
class FixTracks(QWidget):
|
||||||
back = Signal()
|
back = Signal()
|
||||||
trackone_id = 1
|
trackone_id = 1
|
||||||
@ -327,7 +256,7 @@ class FixTracks(QWidget):
|
|||||||
self._reader = None
|
self._reader = None
|
||||||
self._image = None
|
self._image = None
|
||||||
self._clear_detections = True
|
self._clear_detections = True
|
||||||
self._data = DataController()
|
self._data = TrackingData()
|
||||||
self._brushes = {"assigned_left": QBrush(QColor.fromString("orange")),
|
self._brushes = {"assigned_left": QBrush(QColor.fromString("orange")),
|
||||||
"assigned_right": QBrush(QColor.fromString("green")),
|
"assigned_right": QBrush(QColor.fromString("green")),
|
||||||
"unassigned": QBrush(QColor.fromString("red"))
|
"unassigned": QBrush(QColor.fromString("red"))
|
||||||
|
Loading…
Reference in New Issue
Block a user