fixtracks/fixtracks/widgets/detectionview.py

227 lines
8.6 KiB
Python

import logging
import numpy as np
from PySide6.QtWidgets import QWidget, QVBoxLayout, QSizePolicy, QGraphicsView, QGraphicsScene, QGraphicsEllipseItem, QGraphicsRectItem
from PySide6.QtCore import Qt, QPointF, QRectF, QPointF
from PySide6.QtGui import QPixmap, QBrush, QColor, QImage, QPen
from fixtracks.info import PACKAGE_ROOT
from fixtracks.utils.signals import DetectionSignals, DetectionViewSignals, DetectionSceneSignals
from fixtracks.utils.enums import DetectionData, Tracks
from fixtracks.utils.trackingdata import TrackingData
class Detection(QGraphicsEllipseItem):
signals = DetectionSignals()
def __init__(self, x, y, width, height, brush):
super().__init__(x, y, width, height)
self.setBrush(brush)
self.setAcceptHoverEvents(True) # Enable hover events if needed
self.setFlags(QGraphicsRectItem.ItemIsSelectable)
def mousePressEvent(self, event):
self.signals.clicked.emit(self.data(0), QPointF(event.scenePos().x(), event.scenePos().y()))
def hoverEnterEvent(self, event):
self.signals.hover.emit(self.data(0), QPointF(event.scenePos().x(), event.scenePos().y()))
super().hoverEnterEvent(event)
class DetectionScene(QGraphicsScene):
signals = DetectionSceneSignals()
def __init__(self):
super().__init__()
self.start_point = QPointF()
self.selection_rect = None
def mousePressEvent(self, event):
if event.button() == Qt.LeftButton:
# Record the start point for the selection rectangle
self.start_point = event.scenePos()
# Create a temporary rectangle to visually show the selection
self.selection_rect = QGraphicsRectItem()
self.selection_rect.setPen(Qt.DashLine) # Dashed outline
self.selection_rect.setBrush(Qt.transparent)
self.addItem(self.selection_rect)
super().mousePressEvent(event)
def mouseMoveEvent(self, event):
if self.selection_rect is not None:
rect = QRectF(self.start_point, event.scenePos()).normalized()
self.selection_rect.setRect(rect)
super().mouseMoveEvent(event)
def mouseReleaseEvent(self, event):
if event.button() == Qt.LeftButton and self.selection_rect is not None:
rect = self.selection_rect.rect()
self.removeItem(self.selection_rect)
self.selection_rect = None
if rect.width() > 0.0:
selected_items = self.items(rect, Qt.IntersectsItemShape)
for item in selected_items:
if not isinstance(item, Detection):
selected_items.remove(item)
item.setSelected(True) # Mark the item as selected
self.signals.itemsSelected.emit(selected_items)
else:
item = self.itemAt(event.scenePos(), self.views()[0].transform())
if item:
item.setSelected(True)
self.signals.itemsSelected.emit(self.selectedItems())
super().mouseReleaseEvent(event)
class DetectionView(QWidget):
signals = DetectionViewSignals()
def __init__(self, parent=None):
super().__init__(parent)
self._img = None
self._data = None
self._pixmapitem = None
self._scene = DetectionScene()
# self.setRenderHint(QGraphicsView.RenderFlag.Ren Antialiasing)
self._view = QGraphicsView()
self._view.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding)
self._view.setMouseTracking(True)
self._mouseEnabled = True
self._zoomFactor = 1.15
self._minZoom = 0.1
self._maxZoom = 10
self._currentZoom = 1.0
lyt = QVBoxLayout()
lyt.addWidget(self._view)
self.setLayout(lyt)
def wheelEvent(self, event):
if not self._mouseEnabled:
super().wheelEvent(event)
return
modifiers = event.modifiers()
if modifiers == Qt.ControlModifier:
delta = event.angleDelta().x()
if delta == 0:
delta = event.angleDelta().y()
sc = 1.001 ** delta
self._view.scale(sc, sc)
else:
super().wheelEvent(event)
def setImage(self, image: QImage):
self._img = image
self._scene.signals.itemsSelected.connect(self.on_itemSelection)
self._pixmapitem = self._scene.addPixmap(QPixmap.fromImage(self._img))
self._view.setScene(self._scene)
self._view.fitInView(self._scene.sceneRect(), aspectRadioMode=Qt.AspectRatioMode.KeepAspectRatio)
def setData(self, data:TrackingData):
self._data = data
def clearDetections(self):
items = self._scene.items()
if items is not None:
for it in self._scene.items():
if isinstance(it, Detection):
self._scene.removeItem(it)
del it
def updateDetections(self, keypoint=-1):
logging.info("DetectionView.updateDetections!")
self.clearDetections()
if self._data is None:
return
frames = self._data.selectedData("frame")
tracks = self._data.selectedData("track")
ids = self._data.selectedData("index")
coordinates = self._data.coordinates(selection=True)
centercoordinates = self._data.centerOfGravity(selection=True)
userlabeled = self._data.selectedData("userlabeled")
scores = self._data.selectedData("confidence")
image_rect = self._pixmapitem.boundingRect() if self._pixmapitem is not None else QRectF(0,0,0,0)
for i, (id, f, t, l, s) in enumerate(zip(ids, frames, tracks, userlabeled, scores)):
c = Tracks.fromValue(t).toColor()
if keypoint >= 0:
x = coordinates[i, keypoint, 0]
y = coordinates[i, keypoint, 1]
else:
x = centercoordinates[i, 0]
y = centercoordinates[i, 1]
item = Detection(image_rect.left() + x, image_rect.top() + y, 20, 20, brush=QBrush(c))
item.setData(DetectionData.TRACK_ID.value, t)
item.setData(DetectionData.ID.value, id)
item.setData(DetectionData.COORDINATES.value, coordinates[i, :, :])
item.setData(DetectionData.FRAME.value, f)
item.setData(DetectionData.USERLABELED.value, l)
item.setData(DetectionData.SCORE.value, s)
item = self._scene.addItem(item)
def fit_image_to_view(self):
"""Scale the image to fit the QGraphicsView."""
if self._pixmapitem is not None:
self._view.fitInView(self._pixmapitem, Qt.KeepAspectRatio)
def resizeEvent(self, event):
"""Handle window resizing to fit the image."""
super().resizeEvent(event)
self.fit_image_to_view()
def on_itemSelection(self, selected_items):
self.signals.itemsSelected.emit(selected_items)
def main():
def items_selected(items):
print("items selected")
# FIXME The following code will no longer work...
import pickle
import numpy as np
from IPython import embed
from PySide6.QtWidgets import QApplication
datafile = PACKAGE_ROOT / "data/merged_small.pkl"
imgfile = PACKAGE_ROOT / "data/merged2.png"
print(datafile)
with open(datafile, "rb") as f:
df = pickle.load(f)
img = QImage(imgfile)
focus_brush = QBrush(QColor.fromString("red"))
second_brush = QBrush(QColor.fromString("blue"))
background_brush = QBrush(QColor.fromString("white"))
bg_coords = np.stack(df.keypoints[(df.track != 1) & (df.track != 2)].values,).astype(np.float32)[:,0,:]
bg_tracks = df.track[(df.track != 1) & (df.track != 2)].values
bg_ids = df.track[(df.track != 1) & (df.track != 2)].index.values
scnd_coords = np.stack(df.keypoints[(df.track == 2)].values,).astype(np.float32)[:,0,:]
scnd_tracks = df.track[df.track == 2].values
scnd_ids = df.track[(df.track == 2)].index.values
focus_coords = np.stack(df.keypoints[df.track == 1].values,).astype(np.float32)[:,0,:]
focus_tracks = df.track[df.track == 1].values
focus_ids = df.track[(df.track == 2)].index.values
app = QApplication([])
window = QWidget()
window.setMinimumSize(200, 200)
layout = QVBoxLayout()
view = DetectionView()
view.signals.itemsSelected.connect(items_selected)
layout.addWidget(view)
view.setImage(img)
view.addDetections(bg_coords, bg_tracks, bg_ids, background_brush)
view.addDetections(focus_coords, focus_tracks, focus_ids, focus_brush)
view.addDetections(scnd_coords, scnd_tracks, scnd_ids, second_brush)
window.setLayout(layout)
window.show()
app.exec()
if __name__ == "__main__":
main()