[detectionview] rubberband selection handling
This commit is contained in:
parent
24c8584105
commit
a7f6d65e62
@ -1,15 +1,12 @@
|
|||||||
import logging
|
import logging
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
|
||||||
|
|
||||||
from PySide6.QtWidgets import QWidget, QVBoxLayout, QSizePolicy, QGraphicsView, QGraphicsScene, QGraphicsEllipseItem
|
from PySide6.QtWidgets import QWidget, QVBoxLayout, QSizePolicy, QGraphicsView, QGraphicsScene, QGraphicsEllipseItem, QGraphicsRectItem
|
||||||
from PySide6.QtCore import Qt, QPointF
|
from PySide6.QtCore import Qt, QPointF, QRectF, QPointF
|
||||||
from PySide6.QtGui import QPixmap, QBrush, QColor, QImage
|
from PySide6.QtGui import QPixmap, QBrush, QColor, QImage
|
||||||
|
|
||||||
|
|
||||||
from fixtracks.info import PACKAGE_ROOT
|
from fixtracks.info import PACKAGE_ROOT
|
||||||
from fixtracks.utils.reader import PickleLoader
|
from fixtracks.utils.signals import DetectionSignals, DetectionViewSignals, DetectionSceneSignals
|
||||||
from fixtracks.utils.signals import DetectionSignals
|
|
||||||
|
|
||||||
class Detection(QGraphicsEllipseItem):
|
class Detection(QGraphicsEllipseItem):
|
||||||
signals = DetectionSignals()
|
signals = DetectionSignals()
|
||||||
@ -28,7 +25,54 @@ class Detection(QGraphicsEllipseItem):
|
|||||||
super().hoverEnterEvent(event)
|
super().hoverEnterEvent(event)
|
||||||
|
|
||||||
|
|
||||||
|
class DetectionScene(QGraphicsScene):
|
||||||
|
signals = DetectionSceneSignals()
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.start_point = QPointF()
|
||||||
|
self.selection_rect = None
|
||||||
|
|
||||||
|
def mousePressEvent(self, event):
|
||||||
|
if event.button() == Qt.LeftButton:
|
||||||
|
# Record the start point for the selection rectangle
|
||||||
|
self.start_point = event.scenePos()
|
||||||
|
|
||||||
|
# Create a temporary rectangle to visually show the selection
|
||||||
|
self.selection_rect = QGraphicsRectItem()
|
||||||
|
self.selection_rect.setPen(Qt.DashLine) # Dashed outline
|
||||||
|
self.selection_rect.setBrush(Qt.transparent)
|
||||||
|
self.addItem(self.selection_rect)
|
||||||
|
super().mousePressEvent(event)
|
||||||
|
|
||||||
|
def mouseMoveEvent(self, event):
|
||||||
|
if self.selection_rect is not None:
|
||||||
|
# Update the selection rectangle as the mouse moves
|
||||||
|
rect = QRectF(self.start_point, event.scenePos()).normalized()
|
||||||
|
self.selection_rect.setRect(rect)
|
||||||
|
super().mouseMoveEvent(event)
|
||||||
|
|
||||||
|
def mouseReleaseEvent(self, event):
|
||||||
|
if event.button() == Qt.LeftButton and self.selection_rect is not None:
|
||||||
|
# Perform selection of items within the selection rectangle
|
||||||
|
rect = self.selection_rect.rect()
|
||||||
|
|
||||||
|
# Remove the temporary selection rectangle
|
||||||
|
self.removeItem(self.selection_rect)
|
||||||
|
self.selection_rect = None
|
||||||
|
|
||||||
|
# Find all items that intersect with the selection rectangle
|
||||||
|
selected_items = self.items(rect, Qt.IntersectsItemShape)
|
||||||
|
for item in selected_items:
|
||||||
|
if not isinstance(item, Detection):
|
||||||
|
selected_items.remove(item)
|
||||||
|
item.setSelected(True) # Mark the item as selected
|
||||||
|
self.signals.itemsSelected.emit(selected_items)
|
||||||
|
super().mouseReleaseEvent(event)
|
||||||
|
|
||||||
|
|
||||||
class DetectionView(QWidget):
|
class DetectionView(QWidget):
|
||||||
|
signals = DetectionViewSignals()
|
||||||
|
|
||||||
def __init__(self, parent=None):
|
def __init__(self, parent=None):
|
||||||
super().__init__(parent)
|
super().__init__(parent)
|
||||||
@ -38,7 +82,6 @@ class DetectionView(QWidget):
|
|||||||
self._view = QGraphicsView()
|
self._view = QGraphicsView()
|
||||||
self._view.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding)
|
self._view.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding)
|
||||||
self._view.setMouseTracking(True)
|
self._view.setMouseTracking(True)
|
||||||
self._items = []
|
|
||||||
|
|
||||||
lyt = QVBoxLayout()
|
lyt = QVBoxLayout()
|
||||||
lyt.addWidget(self._view)
|
lyt.addWidget(self._view)
|
||||||
@ -46,25 +89,29 @@ class DetectionView(QWidget):
|
|||||||
|
|
||||||
def setImage(self, image: QImage):
|
def setImage(self, image: QImage):
|
||||||
self._img = image
|
self._img = image
|
||||||
self._scene = QGraphicsScene()
|
self._scene = DetectionScene()
|
||||||
|
self._scene.signals.itemsSelected.connect(self.on_itemSelection)
|
||||||
self._pixmapitem = self._scene.addPixmap(QPixmap.fromImage(self._img))
|
self._pixmapitem = self._scene.addPixmap(QPixmap.fromImage(self._img))
|
||||||
self._view.setScene(self._scene)
|
self._view.setScene(self._scene)
|
||||||
self._view.fitInView(self._scene.sceneRect(), aspectRadioMode=Qt.AspectRatioMode.KeepAspectRatio)
|
self._view.fitInView(self._scene.sceneRect(), aspectRadioMode=Qt.AspectRatioMode.KeepAspectRatio)
|
||||||
# self._view.show()
|
|
||||||
|
|
||||||
def clearDetections(self):
|
def clearDetections(self):
|
||||||
for item in self._items:
|
for it in self._scene.items():
|
||||||
self._scene.removeItem(item)
|
if isinstance(it, Detection):
|
||||||
|
self._scene.removeItem(it)
|
||||||
|
del it
|
||||||
|
|
||||||
def addDetections(self, coordinates:np.array, ids:np.array, brush:QBrush):
|
def addDetections(self, coordinates:np.array, track_ids:np.array, detection_ids:np.array, brush:QBrush):
|
||||||
|
logging.info("DetectionView: add %i detections with color %s", coordinates.shape[0], brush.color.__str__())
|
||||||
image_rect = self._pixmapitem.boundingRect()
|
image_rect = self._pixmapitem.boundingRect()
|
||||||
for i in range(coordinates.shape[0]):
|
for i in range(coordinates.shape[0]):
|
||||||
x = coordinates[i, 0]
|
x = coordinates[i, 0]
|
||||||
y = coordinates[i, 1]
|
y = coordinates[i, 1]
|
||||||
item = Detection(image_rect.left() + x, image_rect.top() + y, 10, 10, brush=brush)
|
item = Detection(image_rect.left() + x, image_rect.top() + y, 20, 20, brush=brush)
|
||||||
item.setData(0, ids[i])
|
item.setData(0, track_ids[i])
|
||||||
|
item.setData(1, detection_ids[i])
|
||||||
item = self._scene.addItem(item)
|
item = self._scene.addItem(item)
|
||||||
self._items.append(item)
|
logging.info("View contains %i items", len(self._scene.items()))
|
||||||
|
|
||||||
def fit_image_to_view(self):
|
def fit_image_to_view(self):
|
||||||
"""Scale the image to fit the QGraphicsView."""
|
"""Scale the image to fit the QGraphicsView."""
|
||||||
@ -76,7 +123,14 @@ class DetectionView(QWidget):
|
|||||||
super().resizeEvent(event)
|
super().resizeEvent(event)
|
||||||
self.fit_image_to_view()
|
self.fit_image_to_view()
|
||||||
|
|
||||||
|
def on_itemSelection(self, selected_items):
|
||||||
|
self.signals.itemsSelected.emit(selected_items)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
def items_selected(items):
|
||||||
|
print(items)
|
||||||
|
|
||||||
import pickle
|
import pickle
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from IPython import embed
|
from IPython import embed
|
||||||
@ -93,13 +147,16 @@ def main():
|
|||||||
background_brush = QBrush(QColor.fromString("white"))
|
background_brush = QBrush(QColor.fromString("white"))
|
||||||
|
|
||||||
bg_coords = np.stack(df.keypoints[(df.track != 1) & (df.track != 2)].values,).astype(np.float32)[:,0,:]
|
bg_coords = np.stack(df.keypoints[(df.track != 1) & (df.track != 2)].values,).astype(np.float32)[:,0,:]
|
||||||
bg_ids = df.track[(df.track != 1) & (df.track != 2)].values
|
bg_tracks = df.track[(df.track != 1) & (df.track != 2)].values
|
||||||
|
bg_ids = df.track[(df.track != 1) & (df.track != 2)].index.values
|
||||||
|
|
||||||
scnd_coords = np.stack(df.keypoints[(df.track == 2)].values,).astype(np.float32)[:,0,:]
|
scnd_coords = np.stack(df.keypoints[(df.track == 2)].values,).astype(np.float32)[:,0,:]
|
||||||
scnd_ids = df.track[df.track == 2].values
|
scnd_tracks = df.track[df.track == 2].values
|
||||||
|
scnd_ids = df.track[(df.track == 2)].index.values
|
||||||
|
|
||||||
focus_coords = np.stack(df.keypoints[df.track == 1].values,).astype(np.float32)[:,0,:]
|
focus_coords = np.stack(df.keypoints[df.track == 1].values,).astype(np.float32)[:,0,:]
|
||||||
focus_ids = df.track[df.track == 1].values
|
focus_tracks = df.track[df.track == 1].values
|
||||||
|
focus_ids = df.track[(df.track == 2)].index.values
|
||||||
|
|
||||||
app = QApplication([])
|
app = QApplication([])
|
||||||
window = QWidget()
|
window = QWidget()
|
||||||
@ -107,13 +164,14 @@ def main():
|
|||||||
layout = QVBoxLayout()
|
layout = QVBoxLayout()
|
||||||
|
|
||||||
view = DetectionView()
|
view = DetectionView()
|
||||||
|
view.signals.itemsSelected.connect(items_selected)
|
||||||
layout.addWidget(view)
|
layout.addWidget(view)
|
||||||
view.setImage(img)
|
view.setImage(img)
|
||||||
view.addDetections(bg_coords, bg_ids, background_brush)
|
view.addDetections(bg_coords, bg_tracks, bg_ids, background_brush)
|
||||||
view.addDetections(focus_coords, focus_ids, focus_brush)
|
view.addDetections(focus_coords, focus_tracks, focus_ids, focus_brush)
|
||||||
view.addDetections(scnd_coords, scnd_ids, second_brush)
|
view.addDetections(scnd_coords, scnd_tracks, scnd_ids, second_brush)
|
||||||
window.setLayout(layout)
|
window.setLayout(layout)
|
||||||
# window.show()
|
window.show()
|
||||||
app.exec()
|
app.exec()
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Loading…
Reference in New Issue
Block a user