[tracks] works nicely for not too large datasets
This commit is contained in:
parent
9a270f4d97
commit
a98770610b
@ -4,19 +4,18 @@ import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from PySide6.QtCore import Qt, QThreadPool, Signal, QAbstractTableModel, QSortFilterProxyModel, QItemSelectionModel
|
||||
from PySide6.QtGui import QImage, QBrush, QColor
|
||||
from PySide6.QtWidgets import QWidget, QVBoxLayout, QHBoxLayout, QLabel, QPushButton, QSizePolicy
|
||||
from PySide6.QtWidgets import QSpinBox, QSpacerItem, QFileDialog, QProgressBar, QTableView, QSplitter
|
||||
from PySide6.QtGui import QImage, QBrush, QColor, QFont
|
||||
from PySide6.QtWidgets import QWidget, QVBoxLayout, QHBoxLayout, QLabel, QPushButton, QSizePolicy, QComboBox
|
||||
from PySide6.QtWidgets import QSpinBox, QSpacerItem, QProgressBar, QTableView, QSplitter
|
||||
|
||||
from fixtracks.utils.reader import PickleLoader
|
||||
from fixtracks.widgets.detectionview import DetectionView
|
||||
from fixtracks.widgets.timeline import Timeline
|
||||
from fixtracks.widgets.detectiontimeline import DetectionTimeline
|
||||
|
||||
class PoseTableModel(QAbstractTableModel):
|
||||
column_header = ["frame", "track"]
|
||||
columns = ["frame", "track"]
|
||||
|
||||
|
||||
def __init__(self, dataframe, parent=None):
|
||||
super().__init__(parent)
|
||||
self._data = dataframe
|
||||
@ -43,14 +42,14 @@ class PoseTableModel(QAbstractTableModel):
|
||||
return None
|
||||
|
||||
def headerData(self, section, orientation, role = ...):
|
||||
if role == Qt.ItemDataRole.DisplayRole:
|
||||
if role == Qt.ItemDataRole.DisplayRole:
|
||||
if orientation == Qt.Orientation.Horizontal:
|
||||
return self.column_header[section]
|
||||
else:
|
||||
return str(self._indices[section])
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def mapIdToRow(self, id):
|
||||
row = np.where(self._indices == id)[0]
|
||||
if len(row) == 0:
|
||||
@ -91,25 +90,25 @@ class FixTracks(QWidget):
|
||||
self._threadpool = QThreadPool()
|
||||
self._reader = None
|
||||
self._image = None
|
||||
self._dataframe = None
|
||||
self._data = None
|
||||
self._unassignedmodel = None
|
||||
self._leftmodel = None
|
||||
self._rightmodel = None
|
||||
self._proxymodel = None
|
||||
self._brushes = {"assigned_left": QBrush(QColor.fromString("red")),
|
||||
"assigned_right": QBrush(QColor.fromString("blue")),
|
||||
"unassigned": QBrush(QColor.fromString("white"))
|
||||
self._brushes = {"assigned_left": QBrush(QColor.fromString("orange")),
|
||||
"assigned_right": QBrush(QColor.fromString("green")),
|
||||
"unassigned": QBrush(QColor.fromString("red"))
|
||||
}
|
||||
self._detectionView = DetectionView()
|
||||
self._detectionView.signals.itemsSelected.connect(self.on_detectionsSelected)
|
||||
self._progress_bar = QProgressBar(self)
|
||||
self._progress_bar.setMaximumHeight(20)
|
||||
# self._progress_bar.setRange(0, 0) # Set the progress bar to be indeterminate
|
||||
self._progress_bar.setValue(0)
|
||||
self._tasklabel = QLabel()
|
||||
|
||||
|
||||
self._timeline = DetectionTimeline()
|
||||
self._timeline.signals.windowMoved.connect(self.on_windowChanged)
|
||||
|
||||
self._windowspinner = QSpinBox()
|
||||
self._windowspinner.setRange(100, 10000)
|
||||
self._windowspinner.setSingleStep(100)
|
||||
@ -121,38 +120,74 @@ class FixTracks(QWidget):
|
||||
timelinebox.addWidget(QLabel("Window"))
|
||||
timelinebox.addWidget(self._windowspinner)
|
||||
|
||||
self._left_table = QTableView()
|
||||
self._trackone_table = QTableView()
|
||||
font = QFont()
|
||||
font.setBold(True)
|
||||
font.setPointSize(8)
|
||||
self._trackone_table.setFont(font)
|
||||
assign1 = QPushButton("<<")
|
||||
assign1.clicked.connect(self.on_assignLeft)
|
||||
assign2 = QPushButton(">>")
|
||||
assign2.clicked.connect(self.on_assignRight)
|
||||
self._unassigned_table = QTableView()
|
||||
self._unassigned_table.setFont(font)
|
||||
self._unassigned_table.setSelectionMode(QTableView.SelectionMode.ExtendedSelection)
|
||||
self._unassigned_table.setSelectionBehavior(QTableView.SelectionBehavior.SelectRows)
|
||||
self._right_table = QTableView()
|
||||
self._tracktwo_table = QTableView()
|
||||
self._tracktwo_table.setFont(font)
|
||||
|
||||
trackone_label = QLabel("Track 1")
|
||||
trackone_label.setStyleSheet("QLabel { color : orange; }")
|
||||
track1_box = QVBoxLayout()
|
||||
track1_box.addWidget(trackone_label)
|
||||
track1_box.addWidget(self._trackone_table)
|
||||
|
||||
tracktwo_label = QLabel("Track 2")
|
||||
tracktwo_label.setStyleSheet("QLabel { color : green; }")
|
||||
tracktwo_box = QVBoxLayout()
|
||||
tracktwo_box.addWidget(tracktwo_label)
|
||||
tracktwo_box.addWidget(self._tracktwo_table)
|
||||
|
||||
trackother_label = QLabel("Unassigned")
|
||||
trackother_label.setStyleSheet("QLabel { color : red; }")
|
||||
trackother_box = QVBoxLayout()
|
||||
trackother_box.addWidget(trackother_label)
|
||||
trackother_box.addWidget(self._unassigned_table)
|
||||
|
||||
tablebox = QHBoxLayout()
|
||||
tablebox.addWidget(self._left_table)
|
||||
tablebox.addLayout(track1_box)
|
||||
tablebox.addWidget(assign1)
|
||||
tablebox.addWidget(self._unassigned_table)
|
||||
tablebox.addLayout(trackother_box)
|
||||
tablebox.addWidget(assign2)
|
||||
tablebox.addWidget(self._right_table)
|
||||
tablebox.addLayout(tracktwo_box)
|
||||
|
||||
self._openBtn = QPushButton("Open")
|
||||
self._openBtn.setEnabled(True)
|
||||
self._openBtn.clicked.connect(self._on_open)
|
||||
self._saveBtn = QPushButton("Save")
|
||||
self._saveBtn.setEnabled(False)
|
||||
self._saveBtn.clicked.connect(self.on_save)
|
||||
self._backBtn = QPushButton("Back")
|
||||
self._backBtn.clicked.connect(self.on_back)
|
||||
|
||||
self._data_combo = QComboBox()
|
||||
self._data_combo.addItems(self._files)
|
||||
self._data_combo.currentIndexChanged.connect(self.on_dataSelection)
|
||||
self._image_combo = QComboBox()
|
||||
self._image_combo.addItems(self._files)
|
||||
self._image_combo.currentIndexChanged.connect(self.on_imageSelection)
|
||||
|
||||
data_selection_box = QHBoxLayout()
|
||||
data_selection_box.addWidget(QLabel("Select image file"))
|
||||
data_selection_box.addWidget(self._image_combo)
|
||||
data_selection_box.addWidget(QLabel("Select data file"))
|
||||
data_selection_box.addWidget(self._data_combo)
|
||||
data_selection_box.addItem(QSpacerItem(100, 10, QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Fixed))
|
||||
|
||||
btnBox = QHBoxLayout()
|
||||
btnBox.setAlignment(Qt.AlignmentFlag.AlignLeft)
|
||||
btnBox.addWidget(self._backBtn)
|
||||
btnBox.addItem(QSpacerItem(100, 10, QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Fixed))
|
||||
btnBox.addWidget(self._tasklabel)
|
||||
btnBox.addWidget(self._progress_bar)
|
||||
btnBox.addWidget(self._openBtn)
|
||||
# btnBox.addWidget(self._openBtn)
|
||||
btnBox.addWidget(self._saveBtn)
|
||||
|
||||
vbox = QVBoxLayout()
|
||||
@ -167,78 +202,105 @@ class FixTracks(QWidget):
|
||||
splitter.addWidget(container)
|
||||
splitter.setStretchFactor(0, 3)
|
||||
splitter.setStretchFactor(1, 1)
|
||||
layout = QHBoxLayout()
|
||||
layout = QVBoxLayout()
|
||||
layout.addLayout(data_selection_box)
|
||||
layout.addWidget(splitter)
|
||||
self.setLayout(layout)
|
||||
|
||||
def _on_open(self):
|
||||
infile = None
|
||||
imgfile = None
|
||||
|
||||
self._tasklabel.setText( "Select merged image")
|
||||
file_dialog = QFileDialog(self, "Select merged image")
|
||||
file_dialog.setFileMode(QFileDialog.ExistingFile)
|
||||
file_dialog.setNameFilters([
|
||||
"Image Files (*.png *.jpg *.jpeg)",
|
||||
"All Files (*)"
|
||||
])
|
||||
if file_dialog.exec():
|
||||
imgfile = file_dialog.selectedFiles()[0]
|
||||
if imgfile is not None:
|
||||
img = QImage(imgfile)
|
||||
self._detectionView.setImage(img)
|
||||
self._tasklabel.setText( "Open data")
|
||||
file_dialog = QFileDialog(self, "Select pickled DataFrame", "", "Pandas DataFrame (*.pkl)")
|
||||
if file_dialog.exec():
|
||||
infile = file_dialog.selectedFiles()[0]
|
||||
if infile is not None:
|
||||
self._progress_bar.setRange(0,0)
|
||||
self._reader = PickleLoader(infile)
|
||||
self._reader.signals.finished.connect(self._on_dataOpenend)
|
||||
self._threadpool.start(self._reader)
|
||||
def on_dataSelection(self):
|
||||
filename = self._data_combo.currentText()
|
||||
if "please select" in filename.lower():
|
||||
return
|
||||
self._progress_bar.setRange(0,0)
|
||||
self._reader = PickleLoader(filename)
|
||||
self._reader.signals.finished.connect(self._on_dataOpenend)
|
||||
self._threadpool.start(self._reader)
|
||||
|
||||
def on_imageSelection(self):
|
||||
filename = self._image_combo.currentText()
|
||||
if "please select" in filename.lower():
|
||||
return
|
||||
img = QImage(filename)
|
||||
self._detectionView.setImage(img)
|
||||
|
||||
def populateTables(self):
|
||||
def update_detectionView(df, name):
|
||||
if len(df) == 0:
|
||||
return
|
||||
coords = np.stack(df.keypoints.values).astype(np.float32)[:,0,:]
|
||||
tracks = df.track.values.astype(int)
|
||||
coords = np.stack(df["keypoints"].values).astype(np.float32)[:,0,:]
|
||||
tracks = df["track"].values.astype(int)
|
||||
ids = df.index.values.astype(int)
|
||||
self._detectionView.addDetections(coords, tracks, ids, self._brushes[name])
|
||||
|
||||
left_trackid = 1
|
||||
right_trackid = 2
|
||||
max_frames = np.max(self._dataframe.frame.values)
|
||||
start = self._timeline.rangeStart
|
||||
stop = self._timeline._rangeStop
|
||||
start_frame = np.floor(start * max_frames)
|
||||
stop_frame = np.ceil(stop * max_frames)
|
||||
trackone_id = 1
|
||||
tracktwo_id = 2
|
||||
|
||||
df = self._dataframe[(self._dataframe.frame >= start_frame) & (self._dataframe.frame < stop_frame)]
|
||||
assigned_left = df[(df.track == left_trackid)]
|
||||
assigned_right = df[(df.track == right_trackid)]
|
||||
unassigned = df[(df.track != left_trackid) & (df.track != right_trackid)]
|
||||
max_frames = np.max(self._data["frame"])
|
||||
start = self._timeline.rangeStart
|
||||
stop = self._timeline.rangeStop
|
||||
print(start, stop, max_frames)
|
||||
start_frame = int(np.floor(start * max_frames))
|
||||
stop_frame = int(np.ceil(stop * max_frames))
|
||||
logging.debug("Updating TableModel for range %i, %i", start_frame, stop_frame)
|
||||
indices = np.where((self._data["frame"] >= start_frame) & (self._data["frame"] < stop_frame))[0]
|
||||
# from IPython import embed
|
||||
# embed()
|
||||
# exit()
|
||||
df = pd.DataFrame({"frame": self._data["index"][indices],
|
||||
"track": self._data["track"][indices],
|
||||
"keypoints": self._data["keypoints"][indices]},
|
||||
index= self._data["index"][indices])
|
||||
assigned_left = df[(df.track == trackone_id)]
|
||||
assigned_right = df[(df.track == tracktwo_id)]
|
||||
unassigned = df[(df.track != trackone_id) & (df.track != tracktwo_id)]
|
||||
logging.debug("Updating TableModel: %i track one, %i unassigned, %i tracktwo", len(assigned_left),
|
||||
len(unassigned), len(assigned_right))
|
||||
|
||||
self._unassignedmodel = PoseTableModel(unassigned)
|
||||
self._unassigned_table.setModel(self._unassignedmodel)
|
||||
self._leftmodel = PoseTableModel(assigned_left)
|
||||
self._left_table.setModel(self._leftmodel)
|
||||
self._trackone_table.setModel(self._leftmodel)
|
||||
self._rightmodel = PoseTableModel(assigned_right)
|
||||
self._right_table.setModel(self._rightmodel)
|
||||
self._tracktwo_table.setModel(self._rightmodel)
|
||||
|
||||
self._detectionView.clearDetections()
|
||||
update_detectionView(unassigned, "unassigned")
|
||||
update_detectionView(assigned_left, "assigned_left")
|
||||
update_detectionView(assigned_right, "assigned_right")
|
||||
|
||||
@property
|
||||
def fileList(self):
|
||||
return self._files
|
||||
|
||||
@fileList.setter
|
||||
def fileList(self, file_list):
|
||||
logging.debug("FixTracks.fileList: set new file list")
|
||||
|
||||
logging.debug("FixTracks.fileList: setting image combo box")
|
||||
img_formats = [".jpg", ".png"]
|
||||
self._files = [str(f) for f in file_list if f.suffix in img_formats]
|
||||
self._image_combo.addItem("Please select")
|
||||
self._image_combo.addItems(self.fileList)
|
||||
self._image_combo.setCurrentIndex(0)
|
||||
|
||||
logging.debug("FixTracks.fileList: setting data combo box")
|
||||
dataformats = [".pkl"]
|
||||
self._files = [str(f) for f in file_list if f.suffix in dataformats]
|
||||
self._data_combo.addItem("Please select")
|
||||
self._data_combo.addItems(self.fileList)
|
||||
self._data_combo.setCurrentIndex(0)
|
||||
|
||||
self._data_combo.currentIndexChanged.connect(self.on_rightDataSelection)
|
||||
self._image_combo.currentIndexChanged.connect(self.on_rightvideoSelection)
|
||||
|
||||
def _on_dataOpenend(self, state):
|
||||
logging.info("Finished loading data with state %s", state)
|
||||
self._tasklabel.setText("")
|
||||
self._progress_bar.setRange(0, 100)
|
||||
self._progress_bar.setValue(0)
|
||||
if state and self._reader is not None:
|
||||
self._dataframe = self._reader.data
|
||||
self._timeline.setDetectionData(self._dataframe)
|
||||
self._data = self._reader.asdict
|
||||
self._timeline.setDetectionData(self._data)
|
||||
self.populateTables()
|
||||
|
||||
def on_save(self):
|
||||
@ -248,15 +310,30 @@ class FixTracks(QWidget):
|
||||
logging.debug("Back button pressed!")
|
||||
self.back.emit()
|
||||
|
||||
def assignTrack(self, rows, trackid):
|
||||
logging.debug("Assign %i detections to Track One", len(rows))
|
||||
ids = np.zeros(len(rows), dtype=int)
|
||||
for i,r in enumerate(rows):
|
||||
ids[i] = self._unassignedmodel.headerData(r.row(), Qt.Orientation.Vertical, Qt.ItemDataRole.DisplayRole)
|
||||
|
||||
self._data["track"][ids] = np.zeros_like(ids, dtype=int) + trackid
|
||||
self.populateTables()
|
||||
self._timeline.setDetectionData(self._data)
|
||||
|
||||
def on_assignLeft(self):
|
||||
pass
|
||||
selection = self._unassigned_table.selectionModel()
|
||||
rows = selection.selectedRows()
|
||||
logging.debug("Assign %i detections to Track One", len(rows))
|
||||
self.assignTrack(rows, 1)
|
||||
|
||||
def on_assignRight(self):
|
||||
pass
|
||||
selection = self._unassigned_table.selectionModel()
|
||||
rows = selection.selectedRows()
|
||||
logging.debug("Assign %i detections to Track Two", len(rows))
|
||||
self.assignTrack(rows, 2)
|
||||
|
||||
def on_windowChanged(self, start, stop):
|
||||
logging.info("Timeline reports window change to range %f %f percent of data", start, stop)
|
||||
|
||||
self.populateTables()
|
||||
|
||||
def on_windowSizeChanged(self, value):
|
||||
|
Loading…
Reference in New Issue
Block a user