[detectionview] rubberband selection handling

This commit is contained in:
Jan Grewe 2025-01-25 14:48:05 +01:00
parent 24c8584105
commit a7f6d65e62

View File

@ -1,15 +1,12 @@
import logging
import numpy as np
import pandas as pd
from PySide6.QtWidgets import QWidget, QVBoxLayout, QSizePolicy, QGraphicsView, QGraphicsScene, QGraphicsEllipseItem
from PySide6.QtCore import Qt, QPointF
from PySide6.QtWidgets import QWidget, QVBoxLayout, QSizePolicy, QGraphicsView, QGraphicsScene, QGraphicsEllipseItem, QGraphicsRectItem
from PySide6.QtCore import Qt, QPointF, QRectF, QPointF
from PySide6.QtGui import QPixmap, QBrush, QColor, QImage
from fixtracks.info import PACKAGE_ROOT
from fixtracks.utils.reader import PickleLoader
from fixtracks.utils.signals import DetectionSignals
from fixtracks.utils.signals import DetectionSignals, DetectionViewSignals, DetectionSceneSignals
class Detection(QGraphicsEllipseItem):
signals = DetectionSignals()
@ -28,7 +25,54 @@ class Detection(QGraphicsEllipseItem):
super().hoverEnterEvent(event)
class DetectionScene(QGraphicsScene):
signals = DetectionSceneSignals()
def __init__(self):
super().__init__()
self.start_point = QPointF()
self.selection_rect = None
def mousePressEvent(self, event):
if event.button() == Qt.LeftButton:
# Record the start point for the selection rectangle
self.start_point = event.scenePos()
# Create a temporary rectangle to visually show the selection
self.selection_rect = QGraphicsRectItem()
self.selection_rect.setPen(Qt.DashLine) # Dashed outline
self.selection_rect.setBrush(Qt.transparent)
self.addItem(self.selection_rect)
super().mousePressEvent(event)
def mouseMoveEvent(self, event):
if self.selection_rect is not None:
# Update the selection rectangle as the mouse moves
rect = QRectF(self.start_point, event.scenePos()).normalized()
self.selection_rect.setRect(rect)
super().mouseMoveEvent(event)
def mouseReleaseEvent(self, event):
if event.button() == Qt.LeftButton and self.selection_rect is not None:
# Perform selection of items within the selection rectangle
rect = self.selection_rect.rect()
# Remove the temporary selection rectangle
self.removeItem(self.selection_rect)
self.selection_rect = None
# Find all items that intersect with the selection rectangle
selected_items = self.items(rect, Qt.IntersectsItemShape)
for item in selected_items:
if not isinstance(item, Detection):
selected_items.remove(item)
item.setSelected(True) # Mark the item as selected
self.signals.itemsSelected.emit(selected_items)
super().mouseReleaseEvent(event)
class DetectionView(QWidget):
signals = DetectionViewSignals()
def __init__(self, parent=None):
super().__init__(parent)
@ -38,7 +82,6 @@ class DetectionView(QWidget):
self._view = QGraphicsView()
self._view.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding)
self._view.setMouseTracking(True)
self._items = []
lyt = QVBoxLayout()
lyt.addWidget(self._view)
@ -46,25 +89,29 @@ class DetectionView(QWidget):
def setImage(self, image: QImage):
self._img = image
self._scene = QGraphicsScene()
self._scene = DetectionScene()
self._scene.signals.itemsSelected.connect(self.on_itemSelection)
self._pixmapitem = self._scene.addPixmap(QPixmap.fromImage(self._img))
self._view.setScene(self._scene)
self._view.fitInView(self._scene.sceneRect(), aspectRadioMode=Qt.AspectRatioMode.KeepAspectRatio)
# self._view.show()
def clearDetections(self):
for item in self._items:
self._scene.removeItem(item)
for it in self._scene.items():
if isinstance(it, Detection):
self._scene.removeItem(it)
del it
def addDetections(self, coordinates:np.array, ids:np.array, brush:QBrush):
def addDetections(self, coordinates:np.array, track_ids:np.array, detection_ids:np.array, brush:QBrush):
logging.info("DetectionView: add %i detections with color %s", coordinates.shape[0], brush.color.__str__())
image_rect = self._pixmapitem.boundingRect()
for i in range(coordinates.shape[0]):
x = coordinates[i, 0]
y = coordinates[i, 1]
item = Detection(image_rect.left() + x, image_rect.top() + y, 10, 10, brush=brush)
item.setData(0, ids[i])
item = Detection(image_rect.left() + x, image_rect.top() + y, 20, 20, brush=brush)
item.setData(0, track_ids[i])
item.setData(1, detection_ids[i])
item = self._scene.addItem(item)
self._items.append(item)
logging.info("View contains %i items", len(self._scene.items()))
def fit_image_to_view(self):
"""Scale the image to fit the QGraphicsView."""
@ -75,8 +122,15 @@ class DetectionView(QWidget):
"""Handle window resizing to fit the image."""
super().resizeEvent(event)
self.fit_image_to_view()
def on_itemSelection(self, selected_items):
self.signals.itemsSelected.emit(selected_items)
def main():
def items_selected(items):
print(items)
import pickle
import numpy as np
from IPython import embed
@ -93,13 +147,16 @@ def main():
background_brush = QBrush(QColor.fromString("white"))
bg_coords = np.stack(df.keypoints[(df.track != 1) & (df.track != 2)].values,).astype(np.float32)[:,0,:]
bg_ids = df.track[(df.track != 1) & (df.track != 2)].values
bg_tracks = df.track[(df.track != 1) & (df.track != 2)].values
bg_ids = df.track[(df.track != 1) & (df.track != 2)].index.values
scnd_coords = np.stack(df.keypoints[(df.track == 2)].values,).astype(np.float32)[:,0,:]
scnd_ids = df.track[df.track == 2].values
scnd_tracks = df.track[df.track == 2].values
scnd_ids = df.track[(df.track == 2)].index.values
focus_coords = np.stack(df.keypoints[df.track == 1].values,).astype(np.float32)[:,0,:]
focus_ids = df.track[df.track == 1].values
focus_tracks = df.track[df.track == 1].values
focus_ids = df.track[(df.track == 2)].index.values
app = QApplication([])
window = QWidget()
@ -107,13 +164,14 @@ def main():
layout = QVBoxLayout()
view = DetectionView()
view.signals.itemsSelected.connect(items_selected)
layout.addWidget(view)
view.setImage(img)
view.addDetections(bg_coords, bg_ids, background_brush)
view.addDetections(focus_coords, focus_ids, focus_brush)
view.addDetections(scnd_coords, scnd_ids, second_brush)
view.addDetections(bg_coords, bg_tracks, bg_ids, background_brush)
view.addDetections(focus_coords, focus_tracks, focus_ids, focus_brush)
view.addDetections(scnd_coords, scnd_tracks, scnd_ids, second_brush)
window.setLayout(layout)
# window.show()
window.show()
app.exec()
if __name__ == "__main__":