227 lines
8.6 KiB
Python
227 lines
8.6 KiB
Python
import logging
|
|
import numpy as np
|
|
|
|
from PySide6.QtWidgets import QWidget, QVBoxLayout, QSizePolicy, QGraphicsView, QGraphicsScene, QGraphicsEllipseItem, QGraphicsRectItem
|
|
from PySide6.QtCore import Qt, QPointF, QRectF, QPointF
|
|
from PySide6.QtGui import QPixmap, QBrush, QColor, QImage, QPen
|
|
|
|
from fixtracks.info import PACKAGE_ROOT
|
|
from fixtracks.utils.signals import DetectionSignals, DetectionViewSignals, DetectionSceneSignals
|
|
from fixtracks.utils.enums import DetectionData, Tracks
|
|
from fixtracks.utils.trackingdata import TrackingData
|
|
|
|
|
|
class Detection(QGraphicsEllipseItem):
|
|
signals = DetectionSignals()
|
|
|
|
def __init__(self, x, y, width, height, brush):
|
|
super().__init__(x, y, width, height)
|
|
self.setBrush(brush)
|
|
self.setAcceptHoverEvents(True) # Enable hover events if needed
|
|
self.setFlags(QGraphicsRectItem.ItemIsSelectable)
|
|
|
|
def mousePressEvent(self, event):
|
|
self.signals.clicked.emit(self.data(0), QPointF(event.scenePos().x(), event.scenePos().y()))
|
|
|
|
def hoverEnterEvent(self, event):
|
|
self.signals.hover.emit(self.data(0), QPointF(event.scenePos().x(), event.scenePos().y()))
|
|
super().hoverEnterEvent(event)
|
|
|
|
|
|
class DetectionScene(QGraphicsScene):
|
|
signals = DetectionSceneSignals()
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.start_point = QPointF()
|
|
self.selection_rect = None
|
|
|
|
def mousePressEvent(self, event):
|
|
if event.button() == Qt.LeftButton:
|
|
# Record the start point for the selection rectangle
|
|
self.start_point = event.scenePos()
|
|
|
|
# Create a temporary rectangle to visually show the selection
|
|
self.selection_rect = QGraphicsRectItem()
|
|
self.selection_rect.setPen(Qt.DashLine) # Dashed outline
|
|
self.selection_rect.setBrush(Qt.transparent)
|
|
self.addItem(self.selection_rect)
|
|
super().mousePressEvent(event)
|
|
|
|
def mouseMoveEvent(self, event):
|
|
if self.selection_rect is not None:
|
|
rect = QRectF(self.start_point, event.scenePos()).normalized()
|
|
self.selection_rect.setRect(rect)
|
|
super().mouseMoveEvent(event)
|
|
|
|
def mouseReleaseEvent(self, event):
|
|
if event.button() == Qt.LeftButton and self.selection_rect is not None:
|
|
rect = self.selection_rect.rect()
|
|
self.removeItem(self.selection_rect)
|
|
self.selection_rect = None
|
|
if rect.width() > 0.0:
|
|
selected_items = self.items(rect, Qt.IntersectsItemShape)
|
|
for item in selected_items:
|
|
if not isinstance(item, Detection):
|
|
selected_items.remove(item)
|
|
item.setSelected(True) # Mark the item as selected
|
|
self.signals.itemsSelected.emit(selected_items)
|
|
else:
|
|
item = self.itemAt(event.scenePos(), self.views()[0].transform())
|
|
if item:
|
|
item.setSelected(True)
|
|
self.signals.itemsSelected.emit(self.selectedItems())
|
|
super().mouseReleaseEvent(event)
|
|
|
|
|
|
class DetectionView(QWidget):
|
|
signals = DetectionViewSignals()
|
|
|
|
def __init__(self, parent=None):
|
|
super().__init__(parent)
|
|
self._img = None
|
|
self._data = None
|
|
self._pixmapitem = None
|
|
self._scene = DetectionScene()
|
|
# self.setRenderHint(QGraphicsView.RenderFlag.Ren Antialiasing)
|
|
self._view = QGraphicsView()
|
|
self._view.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding)
|
|
self._view.setMouseTracking(True)
|
|
self._mouseEnabled = True
|
|
self._zoomFactor = 1.15
|
|
self._minZoom = 0.1
|
|
self._maxZoom = 10
|
|
self._currentZoom = 1.0
|
|
lyt = QVBoxLayout()
|
|
lyt.addWidget(self._view)
|
|
self.setLayout(lyt)
|
|
|
|
def wheelEvent(self, event):
|
|
if not self._mouseEnabled:
|
|
super().wheelEvent(event)
|
|
return
|
|
modifiers = event.modifiers()
|
|
if modifiers == Qt.ControlModifier:
|
|
delta = event.angleDelta().x()
|
|
if delta == 0:
|
|
delta = event.angleDelta().y()
|
|
sc = 1.001 ** delta
|
|
self._view.scale(sc, sc)
|
|
else:
|
|
super().wheelEvent(event)
|
|
|
|
def setImage(self, image: QImage):
|
|
self._img = image
|
|
self._scene.signals.itemsSelected.connect(self.on_itemSelection)
|
|
self._pixmapitem = self._scene.addPixmap(QPixmap.fromImage(self._img))
|
|
self._view.setScene(self._scene)
|
|
self._view.fitInView(self._scene.sceneRect(), aspectRadioMode=Qt.AspectRatioMode.KeepAspectRatio)
|
|
|
|
def setData(self, data:TrackingData):
|
|
self._data = data
|
|
|
|
def clearDetections(self):
|
|
items = self._scene.items()
|
|
if items is not None:
|
|
for it in self._scene.items():
|
|
if isinstance(it, Detection):
|
|
self._scene.removeItem(it)
|
|
del it
|
|
|
|
def updateDetections(self, keypoint=-1):
|
|
logging.info("DetectionView.updateDetections!")
|
|
self.clearDetections()
|
|
if self._data is None:
|
|
return
|
|
frames = self._data.selectedData("frame")
|
|
tracks = self._data.selectedData("track")
|
|
ids = self._data.selectedData("index")
|
|
coordinates = self._data.coordinates(selection=True)
|
|
centercoordinates = self._data.centerOfGravity(selection=True)
|
|
userlabeled = self._data.selectedData("userlabeled")
|
|
scores = self._data.selectedData("confidence")
|
|
|
|
image_rect = self._pixmapitem.boundingRect() if self._pixmapitem is not None else QRectF(0,0,0,0)
|
|
|
|
for i, (id, f, t, l, s) in enumerate(zip(ids, frames, tracks, userlabeled, scores)):
|
|
c = Tracks.fromValue(t).toColor()
|
|
if keypoint >= 0:
|
|
x = coordinates[i, keypoint, 0]
|
|
y = coordinates[i, keypoint, 1]
|
|
else:
|
|
x = centercoordinates[i, 0]
|
|
y = centercoordinates[i, 1]
|
|
|
|
item = Detection(image_rect.left() + x, image_rect.top() + y, 20, 20, brush=QBrush(c))
|
|
item.setData(DetectionData.TRACK_ID.value, t)
|
|
item.setData(DetectionData.ID.value, id)
|
|
item.setData(DetectionData.COORDINATES.value, coordinates[i, :, :])
|
|
item.setData(DetectionData.FRAME.value, f)
|
|
item.setData(DetectionData.USERLABELED.value, l)
|
|
item.setData(DetectionData.SCORE.value, s)
|
|
item = self._scene.addItem(item)
|
|
|
|
def fit_image_to_view(self):
|
|
"""Scale the image to fit the QGraphicsView."""
|
|
if self._pixmapitem is not None:
|
|
self._view.fitInView(self._pixmapitem, Qt.KeepAspectRatio)
|
|
|
|
def resizeEvent(self, event):
|
|
"""Handle window resizing to fit the image."""
|
|
super().resizeEvent(event)
|
|
self.fit_image_to_view()
|
|
|
|
def on_itemSelection(self, selected_items):
|
|
self.signals.itemsSelected.emit(selected_items)
|
|
|
|
|
|
def main():
|
|
def items_selected(items):
|
|
print("items selected")
|
|
# FIXME The following code will no longer work...
|
|
import pickle
|
|
import numpy as np
|
|
from IPython import embed
|
|
from PySide6.QtWidgets import QApplication
|
|
|
|
datafile = PACKAGE_ROOT / "data/merged_small.pkl"
|
|
imgfile = PACKAGE_ROOT / "data/merged2.png"
|
|
print(datafile)
|
|
with open(datafile, "rb") as f:
|
|
df = pickle.load(f)
|
|
img = QImage(imgfile)
|
|
focus_brush = QBrush(QColor.fromString("red"))
|
|
second_brush = QBrush(QColor.fromString("blue"))
|
|
background_brush = QBrush(QColor.fromString("white"))
|
|
|
|
bg_coords = np.stack(df.keypoints[(df.track != 1) & (df.track != 2)].values,).astype(np.float32)[:,0,:]
|
|
bg_tracks = df.track[(df.track != 1) & (df.track != 2)].values
|
|
bg_ids = df.track[(df.track != 1) & (df.track != 2)].index.values
|
|
|
|
scnd_coords = np.stack(df.keypoints[(df.track == 2)].values,).astype(np.float32)[:,0,:]
|
|
scnd_tracks = df.track[df.track == 2].values
|
|
scnd_ids = df.track[(df.track == 2)].index.values
|
|
|
|
focus_coords = np.stack(df.keypoints[df.track == 1].values,).astype(np.float32)[:,0,:]
|
|
focus_tracks = df.track[df.track == 1].values
|
|
focus_ids = df.track[(df.track == 2)].index.values
|
|
|
|
app = QApplication([])
|
|
window = QWidget()
|
|
window.setMinimumSize(200, 200)
|
|
layout = QVBoxLayout()
|
|
|
|
view = DetectionView()
|
|
view.signals.itemsSelected.connect(items_selected)
|
|
layout.addWidget(view)
|
|
view.setImage(img)
|
|
view.addDetections(bg_coords, bg_tracks, bg_ids, background_brush)
|
|
view.addDetections(focus_coords, focus_tracks, focus_ids, focus_brush)
|
|
view.addDetections(scnd_coords, scnd_tracks, scnd_ids, second_brush)
|
|
window.setLayout(layout)
|
|
window.show()
|
|
app.exec()
|
|
|
|
if __name__ == "__main__":
|
|
main()
|