diff --git a/fixtracks/widgets/tracks.py b/fixtracks/widgets/tracks.py index 4c262d3..a9e8244 100644 --- a/fixtracks/widgets/tracks.py +++ b/fixtracks/widgets/tracks.py @@ -4,19 +4,18 @@ import numpy as np import pandas as pd from PySide6.QtCore import Qt, QThreadPool, Signal, QAbstractTableModel, QSortFilterProxyModel, QItemSelectionModel -from PySide6.QtGui import QImage, QBrush, QColor -from PySide6.QtWidgets import QWidget, QVBoxLayout, QHBoxLayout, QLabel, QPushButton, QSizePolicy -from PySide6.QtWidgets import QSpinBox, QSpacerItem, QFileDialog, QProgressBar, QTableView, QSplitter +from PySide6.QtGui import QImage, QBrush, QColor, QFont +from PySide6.QtWidgets import QWidget, QVBoxLayout, QHBoxLayout, QLabel, QPushButton, QSizePolicy, QComboBox +from PySide6.QtWidgets import QSpinBox, QSpacerItem, QProgressBar, QTableView, QSplitter from fixtracks.utils.reader import PickleLoader from fixtracks.widgets.detectionview import DetectionView -from fixtracks.widgets.timeline import Timeline from fixtracks.widgets.detectiontimeline import DetectionTimeline class PoseTableModel(QAbstractTableModel): column_header = ["frame", "track"] columns = ["frame", "track"] - + def __init__(self, dataframe, parent=None): super().__init__(parent) self._data = dataframe @@ -43,14 +42,14 @@ class PoseTableModel(QAbstractTableModel): return None def headerData(self, section, orientation, role = ...): - if role == Qt.ItemDataRole.DisplayRole: + if role == Qt.ItemDataRole.DisplayRole: if orientation == Qt.Orientation.Horizontal: return self.column_header[section] else: return str(self._indices[section]) else: return None - + def mapIdToRow(self, id): row = np.where(self._indices == id)[0] if len(row) == 0: @@ -91,25 +90,25 @@ class FixTracks(QWidget): self._threadpool = QThreadPool() self._reader = None self._image = None - self._dataframe = None + self._data = None self._unassignedmodel = None self._leftmodel = None self._rightmodel = None self._proxymodel = None - self._brushes = {"assigned_left": QBrush(QColor.fromString("red")), - "assigned_right": QBrush(QColor.fromString("blue")), - "unassigned": QBrush(QColor.fromString("white")) + self._brushes = {"assigned_left": QBrush(QColor.fromString("orange")), + "assigned_right": QBrush(QColor.fromString("green")), + "unassigned": QBrush(QColor.fromString("red")) } self._detectionView = DetectionView() self._detectionView.signals.itemsSelected.connect(self.on_detectionsSelected) self._progress_bar = QProgressBar(self) self._progress_bar.setMaximumHeight(20) - # self._progress_bar.setRange(0, 0) # Set the progress bar to be indeterminate self._progress_bar.setValue(0) self._tasklabel = QLabel() - + self._timeline = DetectionTimeline() self._timeline.signals.windowMoved.connect(self.on_windowChanged) + self._windowspinner = QSpinBox() self._windowspinner.setRange(100, 10000) self._windowspinner.setSingleStep(100) @@ -121,38 +120,74 @@ class FixTracks(QWidget): timelinebox.addWidget(QLabel("Window")) timelinebox.addWidget(self._windowspinner) - self._left_table = QTableView() + self._trackone_table = QTableView() + font = QFont() + font.setBold(True) + font.setPointSize(8) + self._trackone_table.setFont(font) assign1 = QPushButton("<<") assign1.clicked.connect(self.on_assignLeft) assign2 = QPushButton(">>") assign2.clicked.connect(self.on_assignRight) self._unassigned_table = QTableView() + self._unassigned_table.setFont(font) self._unassigned_table.setSelectionMode(QTableView.SelectionMode.ExtendedSelection) self._unassigned_table.setSelectionBehavior(QTableView.SelectionBehavior.SelectRows) - self._right_table = QTableView() + self._tracktwo_table = QTableView() + self._tracktwo_table.setFont(font) + + trackone_label = QLabel("Track 1") + trackone_label.setStyleSheet("QLabel { color : orange; }") + track1_box = QVBoxLayout() + track1_box.addWidget(trackone_label) + track1_box.addWidget(self._trackone_table) + + tracktwo_label = QLabel("Track 2") + tracktwo_label.setStyleSheet("QLabel { color : green; }") + tracktwo_box = QVBoxLayout() + tracktwo_box.addWidget(tracktwo_label) + tracktwo_box.addWidget(self._tracktwo_table) + + trackother_label = QLabel("Unassigned") + trackother_label.setStyleSheet("QLabel { color : red; }") + trackother_box = QVBoxLayout() + trackother_box.addWidget(trackother_label) + trackother_box.addWidget(self._unassigned_table) + tablebox = QHBoxLayout() - tablebox.addWidget(self._left_table) + tablebox.addLayout(track1_box) tablebox.addWidget(assign1) - tablebox.addWidget(self._unassigned_table) + tablebox.addLayout(trackother_box) tablebox.addWidget(assign2) - tablebox.addWidget(self._right_table) + tablebox.addLayout(tracktwo_box) - self._openBtn = QPushButton("Open") - self._openBtn.setEnabled(True) - self._openBtn.clicked.connect(self._on_open) self._saveBtn = QPushButton("Save") self._saveBtn.setEnabled(False) self._saveBtn.clicked.connect(self.on_save) self._backBtn = QPushButton("Back") self._backBtn.clicked.connect(self.on_back) + self._data_combo = QComboBox() + self._data_combo.addItems(self._files) + self._data_combo.currentIndexChanged.connect(self.on_dataSelection) + self._image_combo = QComboBox() + self._image_combo.addItems(self._files) + self._image_combo.currentIndexChanged.connect(self.on_imageSelection) + + data_selection_box = QHBoxLayout() + data_selection_box.addWidget(QLabel("Select image file")) + data_selection_box.addWidget(self._image_combo) + data_selection_box.addWidget(QLabel("Select data file")) + data_selection_box.addWidget(self._data_combo) + data_selection_box.addItem(QSpacerItem(100, 10, QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Fixed)) + btnBox = QHBoxLayout() btnBox.setAlignment(Qt.AlignmentFlag.AlignLeft) btnBox.addWidget(self._backBtn) btnBox.addItem(QSpacerItem(100, 10, QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Fixed)) btnBox.addWidget(self._tasklabel) btnBox.addWidget(self._progress_bar) - btnBox.addWidget(self._openBtn) + # btnBox.addWidget(self._openBtn) btnBox.addWidget(self._saveBtn) vbox = QVBoxLayout() @@ -167,78 +202,105 @@ class FixTracks(QWidget): splitter.addWidget(container) splitter.setStretchFactor(0, 3) splitter.setStretchFactor(1, 1) - layout = QHBoxLayout() + layout = QVBoxLayout() + layout.addLayout(data_selection_box) layout.addWidget(splitter) self.setLayout(layout) - def _on_open(self): - infile = None - imgfile = None - - self._tasklabel.setText( "Select merged image") - file_dialog = QFileDialog(self, "Select merged image") - file_dialog.setFileMode(QFileDialog.ExistingFile) - file_dialog.setNameFilters([ - "Image Files (*.png *.jpg *.jpeg)", - "All Files (*)" - ]) - if file_dialog.exec(): - imgfile = file_dialog.selectedFiles()[0] - if imgfile is not None: - img = QImage(imgfile) - self._detectionView.setImage(img) - self._tasklabel.setText( "Open data") - file_dialog = QFileDialog(self, "Select pickled DataFrame", "", "Pandas DataFrame (*.pkl)") - if file_dialog.exec(): - infile = file_dialog.selectedFiles()[0] - if infile is not None: - self._progress_bar.setRange(0,0) - self._reader = PickleLoader(infile) - self._reader.signals.finished.connect(self._on_dataOpenend) - self._threadpool.start(self._reader) + def on_dataSelection(self): + filename = self._data_combo.currentText() + if "please select" in filename.lower(): + return + self._progress_bar.setRange(0,0) + self._reader = PickleLoader(filename) + self._reader.signals.finished.connect(self._on_dataOpenend) + self._threadpool.start(self._reader) + + def on_imageSelection(self): + filename = self._image_combo.currentText() + if "please select" in filename.lower(): + return + img = QImage(filename) + self._detectionView.setImage(img) def populateTables(self): def update_detectionView(df, name): if len(df) == 0: return - coords = np.stack(df.keypoints.values).astype(np.float32)[:,0,:] - tracks = df.track.values.astype(int) + coords = np.stack(df["keypoints"].values).astype(np.float32)[:,0,:] + tracks = df["track"].values.astype(int) ids = df.index.values.astype(int) self._detectionView.addDetections(coords, tracks, ids, self._brushes[name]) - left_trackid = 1 - right_trackid = 2 - max_frames = np.max(self._dataframe.frame.values) - start = self._timeline.rangeStart - stop = self._timeline._rangeStop - start_frame = np.floor(start * max_frames) - stop_frame = np.ceil(stop * max_frames) + trackone_id = 1 + tracktwo_id = 2 - df = self._dataframe[(self._dataframe.frame >= start_frame) & (self._dataframe.frame < stop_frame)] - assigned_left = df[(df.track == left_trackid)] - assigned_right = df[(df.track == right_trackid)] - unassigned = df[(df.track != left_trackid) & (df.track != right_trackid)] + max_frames = np.max(self._data["frame"]) + start = self._timeline.rangeStart + stop = self._timeline.rangeStop + print(start, stop, max_frames) + start_frame = int(np.floor(start * max_frames)) + stop_frame = int(np.ceil(stop * max_frames)) + logging.debug("Updating TableModel for range %i, %i", start_frame, stop_frame) + indices = np.where((self._data["frame"] >= start_frame) & (self._data["frame"] < stop_frame))[0] + # from IPython import embed + # embed() + # exit() + df = pd.DataFrame({"frame": self._data["index"][indices], + "track": self._data["track"][indices], + "keypoints": self._data["keypoints"][indices]}, + index= self._data["index"][indices]) + assigned_left = df[(df.track == trackone_id)] + assigned_right = df[(df.track == tracktwo_id)] + unassigned = df[(df.track != trackone_id) & (df.track != tracktwo_id)] + logging.debug("Updating TableModel: %i track one, %i unassigned, %i tracktwo", len(assigned_left), + len(unassigned), len(assigned_right)) self._unassignedmodel = PoseTableModel(unassigned) self._unassigned_table.setModel(self._unassignedmodel) self._leftmodel = PoseTableModel(assigned_left) - self._left_table.setModel(self._leftmodel) + self._trackone_table.setModel(self._leftmodel) self._rightmodel = PoseTableModel(assigned_right) - self._right_table.setModel(self._rightmodel) + self._tracktwo_table.setModel(self._rightmodel) self._detectionView.clearDetections() update_detectionView(unassigned, "unassigned") update_detectionView(assigned_left, "assigned_left") update_detectionView(assigned_right, "assigned_right") + @property + def fileList(self): + return self._files + + @fileList.setter + def fileList(self, file_list): + logging.debug("FixTracks.fileList: set new file list") + + logging.debug("FixTracks.fileList: setting image combo box") + img_formats = [".jpg", ".png"] + self._files = [str(f) for f in file_list if f.suffix in img_formats] + self._image_combo.addItem("Please select") + self._image_combo.addItems(self.fileList) + self._image_combo.setCurrentIndex(0) + + logging.debug("FixTracks.fileList: setting data combo box") + dataformats = [".pkl"] + self._files = [str(f) for f in file_list if f.suffix in dataformats] + self._data_combo.addItem("Please select") + self._data_combo.addItems(self.fileList) + self._data_combo.setCurrentIndex(0) + + self._data_combo.currentIndexChanged.connect(self.on_rightDataSelection) + self._image_combo.currentIndexChanged.connect(self.on_rightvideoSelection) + def _on_dataOpenend(self, state): logging.info("Finished loading data with state %s", state) self._tasklabel.setText("") self._progress_bar.setRange(0, 100) self._progress_bar.setValue(0) if state and self._reader is not None: - self._dataframe = self._reader.data - self._timeline.setDetectionData(self._dataframe) + self._data = self._reader.asdict + self._timeline.setDetectionData(self._data) self.populateTables() def on_save(self): @@ -248,15 +310,30 @@ class FixTracks(QWidget): logging.debug("Back button pressed!") self.back.emit() + def assignTrack(self, rows, trackid): + logging.debug("Assign %i detections to Track One", len(rows)) + ids = np.zeros(len(rows), dtype=int) + for i,r in enumerate(rows): + ids[i] = self._unassignedmodel.headerData(r.row(), Qt.Orientation.Vertical, Qt.ItemDataRole.DisplayRole) + + self._data["track"][ids] = np.zeros_like(ids, dtype=int) + trackid + self.populateTables() + self._timeline.setDetectionData(self._data) + def on_assignLeft(self): - pass + selection = self._unassigned_table.selectionModel() + rows = selection.selectedRows() + logging.debug("Assign %i detections to Track One", len(rows)) + self.assignTrack(rows, 1) def on_assignRight(self): - pass + selection = self._unassigned_table.selectionModel() + rows = selection.selectedRows() + logging.debug("Assign %i detections to Track Two", len(rows)) + self.assignTrack(rows, 2) def on_windowChanged(self, start, stop): logging.info("Timeline reports window change to range %f %f percent of data", start, stop) - self.populateTables() def on_windowSizeChanged(self, value):