From a7f6d65e62440c8c18427a212934d33e96e80b6a Mon Sep 17 00:00:00 2001 From: Jan Grewe Date: Sat, 25 Jan 2025 14:48:05 +0100 Subject: [PATCH] [detectionview] rubberband selection handling --- fixtracks/widgets/detectionview.py | 104 ++++++++++++++++++++++------- 1 file changed, 81 insertions(+), 23 deletions(-) diff --git a/fixtracks/widgets/detectionview.py b/fixtracks/widgets/detectionview.py index 5800cbd..7eb5a90 100644 --- a/fixtracks/widgets/detectionview.py +++ b/fixtracks/widgets/detectionview.py @@ -1,15 +1,12 @@ import logging import numpy as np -import pandas as pd -from PySide6.QtWidgets import QWidget, QVBoxLayout, QSizePolicy, QGraphicsView, QGraphicsScene, QGraphicsEllipseItem -from PySide6.QtCore import Qt, QPointF +from PySide6.QtWidgets import QWidget, QVBoxLayout, QSizePolicy, QGraphicsView, QGraphicsScene, QGraphicsEllipseItem, QGraphicsRectItem +from PySide6.QtCore import Qt, QPointF, QRectF, QPointF from PySide6.QtGui import QPixmap, QBrush, QColor, QImage - from fixtracks.info import PACKAGE_ROOT -from fixtracks.utils.reader import PickleLoader -from fixtracks.utils.signals import DetectionSignals +from fixtracks.utils.signals import DetectionSignals, DetectionViewSignals, DetectionSceneSignals class Detection(QGraphicsEllipseItem): signals = DetectionSignals() @@ -28,7 +25,54 @@ class Detection(QGraphicsEllipseItem): super().hoverEnterEvent(event) +class DetectionScene(QGraphicsScene): + signals = DetectionSceneSignals() + + def __init__(self): + super().__init__() + self.start_point = QPointF() + self.selection_rect = None + + def mousePressEvent(self, event): + if event.button() == Qt.LeftButton: + # Record the start point for the selection rectangle + self.start_point = event.scenePos() + + # Create a temporary rectangle to visually show the selection + self.selection_rect = QGraphicsRectItem() + self.selection_rect.setPen(Qt.DashLine) # Dashed outline + self.selection_rect.setBrush(Qt.transparent) + self.addItem(self.selection_rect) + super().mousePressEvent(event) + + def mouseMoveEvent(self, event): + if self.selection_rect is not None: + # Update the selection rectangle as the mouse moves + rect = QRectF(self.start_point, event.scenePos()).normalized() + self.selection_rect.setRect(rect) + super().mouseMoveEvent(event) + + def mouseReleaseEvent(self, event): + if event.button() == Qt.LeftButton and self.selection_rect is not None: + # Perform selection of items within the selection rectangle + rect = self.selection_rect.rect() + + # Remove the temporary selection rectangle + self.removeItem(self.selection_rect) + self.selection_rect = None + + # Find all items that intersect with the selection rectangle + selected_items = self.items(rect, Qt.IntersectsItemShape) + for item in selected_items: + if not isinstance(item, Detection): + selected_items.remove(item) + item.setSelected(True) # Mark the item as selected + self.signals.itemsSelected.emit(selected_items) + super().mouseReleaseEvent(event) + + class DetectionView(QWidget): + signals = DetectionViewSignals() def __init__(self, parent=None): super().__init__(parent) @@ -38,7 +82,6 @@ class DetectionView(QWidget): self._view = QGraphicsView() self._view.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding) self._view.setMouseTracking(True) - self._items = [] lyt = QVBoxLayout() lyt.addWidget(self._view) @@ -46,25 +89,29 @@ class DetectionView(QWidget): def setImage(self, image: QImage): self._img = image - self._scene = QGraphicsScene() + self._scene = DetectionScene() + self._scene.signals.itemsSelected.connect(self.on_itemSelection) self._pixmapitem = self._scene.addPixmap(QPixmap.fromImage(self._img)) self._view.setScene(self._scene) self._view.fitInView(self._scene.sceneRect(), aspectRadioMode=Qt.AspectRatioMode.KeepAspectRatio) - # self._view.show() def clearDetections(self): - for item in self._items: - self._scene.removeItem(item) + for it in self._scene.items(): + if isinstance(it, Detection): + self._scene.removeItem(it) + del it - def addDetections(self, coordinates:np.array, ids:np.array, brush:QBrush): + def addDetections(self, coordinates:np.array, track_ids:np.array, detection_ids:np.array, brush:QBrush): + logging.info("DetectionView: add %i detections with color %s", coordinates.shape[0], brush.color.__str__()) image_rect = self._pixmapitem.boundingRect() for i in range(coordinates.shape[0]): x = coordinates[i, 0] y = coordinates[i, 1] - item = Detection(image_rect.left() + x, image_rect.top() + y, 10, 10, brush=brush) - item.setData(0, ids[i]) + item = Detection(image_rect.left() + x, image_rect.top() + y, 20, 20, brush=brush) + item.setData(0, track_ids[i]) + item.setData(1, detection_ids[i]) item = self._scene.addItem(item) - self._items.append(item) + logging.info("View contains %i items", len(self._scene.items())) def fit_image_to_view(self): """Scale the image to fit the QGraphicsView.""" @@ -75,8 +122,15 @@ class DetectionView(QWidget): """Handle window resizing to fit the image.""" super().resizeEvent(event) self.fit_image_to_view() + + def on_itemSelection(self, selected_items): + self.signals.itemsSelected.emit(selected_items) + def main(): + def items_selected(items): + print(items) + import pickle import numpy as np from IPython import embed @@ -93,13 +147,16 @@ def main(): background_brush = QBrush(QColor.fromString("white")) bg_coords = np.stack(df.keypoints[(df.track != 1) & (df.track != 2)].values,).astype(np.float32)[:,0,:] - bg_ids = df.track[(df.track != 1) & (df.track != 2)].values - + bg_tracks = df.track[(df.track != 1) & (df.track != 2)].values + bg_ids = df.track[(df.track != 1) & (df.track != 2)].index.values + scnd_coords = np.stack(df.keypoints[(df.track == 2)].values,).astype(np.float32)[:,0,:] - scnd_ids = df.track[df.track == 2].values + scnd_tracks = df.track[df.track == 2].values + scnd_ids = df.track[(df.track == 2)].index.values focus_coords = np.stack(df.keypoints[df.track == 1].values,).astype(np.float32)[:,0,:] - focus_ids = df.track[df.track == 1].values + focus_tracks = df.track[df.track == 1].values + focus_ids = df.track[(df.track == 2)].index.values app = QApplication([]) window = QWidget() @@ -107,13 +164,14 @@ def main(): layout = QVBoxLayout() view = DetectionView() + view.signals.itemsSelected.connect(items_selected) layout.addWidget(view) view.setImage(img) - view.addDetections(bg_coords, bg_ids, background_brush) - view.addDetections(focus_coords, focus_ids, focus_brush) - view.addDetections(scnd_coords, scnd_ids, second_brush) + view.addDetections(bg_coords, bg_tracks, bg_ids, background_brush) + view.addDetections(focus_coords, focus_tracks, focus_ids, focus_brush) + view.addDetections(scnd_coords, scnd_tracks, scnd_ids, second_brush) window.setLayout(layout) - # window.show() + window.show() app.exec() if __name__ == "__main__":