From f28615660742987d7de805ebb11b1940b680de8e Mon Sep 17 00:00:00 2001 From: Till Raab Date: Mon, 13 Nov 2023 15:10:37 +0100 Subject: [PATCH] work im progress ... --- correct_bboxes.py | 110 +++++++++++++++++++++++++++++----------------- 1 file changed, 70 insertions(+), 40 deletions(-) diff --git a/correct_bboxes.py b/correct_bboxes.py index f660cde..ccce71b 100644 --- a/correct_bboxes.py +++ b/correct_bboxes.py @@ -9,12 +9,67 @@ import numpy as np import torch import torchvision.transforms.functional as F +from IPython import embed import pathlib +class ImageWithBbox(pg.GraphicsLayoutWidget): + def __init__(self, img_path, parent=None): + super(ImageWithBbox, self).__init__(parent) + + # image setup + self.imgItem = pg.ImageItem(ColorMap='viridis') + self.plot_widget = self.addPlot(title="") + self.plot_widget.addItem(self.imgItem, colorMap='viridis') + self.plot_widget.scene().sigMouseClicked.connect(self.addROI) + + # self.imgItem.mouseClickEvent() + + self.imgItem.getViewBox().invertY(True) + self.imgItem.getViewBox().setMouseEnabled(x=False, y=False) + + # get image and labels + img = Image.open(img_path) + img_gray = ImageOps.grayscale(img) + self.img_array = np.array(img_gray).T + self.imgItem.setImage(np.array(self.img_array)) + self.plot_widget.setYRange(0, self.img_array.shape[0], padding=0) + self.plot_widget.setXRange(0, self.img_array.shape[1], padding=0) + + label_path = (pathlib.Path(img_path).parent.parent / 'labels' / pathlib.Path(img_path).name).with_suffix('.txt') + self.labels = np.loadtxt(label_path, delimiter=' ') + + # add ROIS + self.ROIs = [] + for enu, l in enumerate(self.labels): + x_center, y_center, width, height = l[1:] * self.img_array.shape[1] + x0, y0, = x_center-width/2, y_center-height/2 + ROI = pg.RectROI((x0, y0), size=(width, height), removable=True, sideScalers=True) + ROI.sigRemoveRequested.connect(self.removeROI) + + self.ROIs.append(ROI) + self.plot_widget.addItem(ROI) + + def removeROI(self, roi): + if roi in self.ROIs: + self.ROIs.remove(roi) + self.plot_widget.removeItem(roi) + + def addROI(self, event): + # Check if the event is a double-click event + if event.double(): + pos = event.pos() + # Transform the mouse position to data coordinates + pos_data = self.plot_widget.getViewBox().mapToView(pos) + x, y = pos_data.x(), pos_data.y() + # Create a new ROI at the double-clicked position + new_ROI = pg.RectROI(pos=(x, y), size=(self.img_array.shape[0]*0.05, self.img_array.shape[1]*0.05), removable=True) + new_ROI.sigRemoveRequested.connect(self.removeROI) + self.ROIs.append(new_ROI) + self.plot_widget.addItem(new_ROI) + class Bbox_correct_UI(QMainWindow): def __init__(self, img_path, parent=None): super(Bbox_correct_UI, self).__init__(parent) - label_path = (pathlib.Path(img_path).parent.parent / 'labels' / pathlib.Path(img_path).name).with_suffix('.txt') rec = QApplication.desktop().screenGeometry() height = rec.height() @@ -22,54 +77,29 @@ class Bbox_correct_UI(QMainWindow): self.resize(int(.8 * width), int(.8 * height)) self.setWindowTitle('efishSignalTracker') # set window title + # widget and layout self.central_widget = QWidget(self) - self.gridLayout = QGridLayout() - self.central_widget.setLayout(self.gridLayout) + self.central_gridLayout = QGridLayout() + self.central_widget.setLayout(self.central_gridLayout) self.setCentralWidget(self.central_widget) - self.plot_handels = [] - self.plot_widgets = [] - self.win = pg.GraphicsLayoutWidget() - - - self.plot_handels.append(pg.ImageItem(ColorMap='viridis')) - self.plot_widgets.append(self.win.addPlot(title="")) - # xxx.setMouseMode(xxx.RectMode) + self.current_img = ImageWithBbox(img_path, parent=self) - self.plot_widgets[0].addItem(self.plot_handels[0], colorMap='viridis') - # self.plot_widgets[0].setLabel('left', 'frequency [Hz]') - # self.plot_widgets[0].setLabel('bottom', 'time [s]') - - self.gridLayout.addWidget(self.win, 0, 0) - - self.plot_handels[0].getViewBox().invertY(True) - self.plot_handels[0].getViewBox().setMouseEnabled(x=False, y=False) - img = Image.open(img_path) - img_gray = ImageOps.grayscale(img) - img_array = np.array(img_gray).T + self.central_gridLayout.addWidget(self.current_img, 0, 0) + self.add_actions() - self.labels = np.loadtxt(label_path, delimiter=' ') - - self.plot_handels[0].setImage(np.array(img_array)) - - self.ROIs = [] - for enu, l in enumerate(self.labels): - x_center, y_center, width, height = l[1:] * img_array.shape[1] - x0, y0, = x_center-width/2, y_center-height/2 - ROI = pg.RectROI((x0, y0), size=(width, height), removable=True) - # ROI.sigRemoveRequested.connect(lambda: ) - self.ROIs.append(ROI) - self.plot_widgets[0].addItem(ROI) + def add_actions(self): + self.readout_rois = QAction('read', self) + self.readout_rois.triggered.connect(self.readout_rois_fn) + self.readout_rois.setShortcut(Qt.Key_Return) + self.addAction(self.readout_rois) - self.plot_widgets[0].setYRange(0, img_array.shape[0], padding=0) - self.plot_widgets[0].setXRange(0, img_array.shape[1], padding=0) + def readout_rois_fn(self): + for roi in self.current_img.ROIs: + print(roi.pos(), roi.size()) - # def kill_me(self, ROI): - # print('yay') - # print(ROI) - # self.win.removeItem(ROI) def main_UI(): app = QApplication(sys.argv) # create application