diff --git a/correct_bboxes.py b/correct_bboxes.py index ccce71b..7b7f2e5 100644 --- a/correct_bboxes.py +++ b/correct_bboxes.py @@ -6,6 +6,7 @@ import pyqtgraph as pg from PIL import Image, ImageOps import numpy as np +import pandas as pd import torch import torchvision.transforms.functional as F @@ -13,8 +14,11 @@ from IPython import embed import pathlib class ImageWithBbox(pg.GraphicsLayoutWidget): +# class ImageWithBbox(pg.ImageView): def __init__(self, img_path, parent=None): super(ImageWithBbox, self).__init__(parent) + self.img_path = img_path + self.label_path = (pathlib.Path(img_path).parent.parent / 'labels' / pathlib.Path(img_path).name).with_suffix('.txt') # image setup self.imgItem = pg.ImageItem(ColorMap='viridis') @@ -35,13 +39,20 @@ class ImageWithBbox(pg.GraphicsLayoutWidget): self.plot_widget.setYRange(0, self.img_array.shape[0], padding=0) self.plot_widget.setXRange(0, self.img_array.shape[1], padding=0) - label_path = (pathlib.Path(img_path).parent.parent / 'labels' / pathlib.Path(img_path).name).with_suffix('.txt') - self.labels = np.loadtxt(label_path, delimiter=' ') + # label_path = (pathlib.Path(img_path).parent.parent / 'labels' / pathlib.Path(img_path).name).with_suffix('.txt') + self.labels = np.loadtxt(self.label_path, delimiter=' ') + if len(self.labels) > 0 and len(self.labels.shape) == 1: + self.labels = np.expand_dims(self.labels, 0) # add ROIS self.ROIs = [] for enu, l in enumerate(self.labels): - x_center, y_center, width, height = l[1:] * self.img_array.shape[1] + # x_center, y_center, width, height = l[1:] * self.img_array.shape[1] + x_center = l[1] * self.img_array.shape[0] + y_center = l[2] * self.img_array.shape[1] + width = l[3] * self.img_array.shape[0] + height = l[4] * self.img_array.shape[1] + x0, y0, = x_center-width/2, y_center-height/2 ROI = pg.RectROI((x0, y0), size=(width, height), removable=True, sideScalers=True) ROI.sigRemoveRequested.connect(self.removeROI) @@ -68,9 +79,15 @@ class ImageWithBbox(pg.GraphicsLayoutWidget): self.plot_widget.addItem(new_ROI) class Bbox_correct_UI(QMainWindow): - def __init__(self, img_path, parent=None): + def __init__(self, data_path, parent=None): super(Bbox_correct_UI, self).__init__(parent) + self.data_path = data_path + files = sorted(list(pathlib.Path(self.data_path).absolute().rglob('*images/*.png'))) + + # embed() + # quit() + rec = QApplication.desktop().screenGeometry() height = rec.height() width = rec.width() @@ -80,31 +97,134 @@ class Bbox_correct_UI(QMainWindow): # widget and layout self.central_widget = QWidget(self) self.central_gridLayout = QGridLayout() - self.central_widget.setLayout(self.central_gridLayout) + self.central_Layout = QHBoxLayout() + self.central_widget.setLayout(self.central_Layout) self.setCentralWidget(self.central_widget) - self.current_img = ImageWithBbox(img_path, parent=self) - - self.central_gridLayout.addWidget(self.current_img, 0, 0) + # image widget + self.current_img = ImageWithBbox(files[0], parent=self) + self.central_Layout.addWidget(self.current_img, 4) + # image select widget + self.scroll = QScrollArea() # Scroll Area which contains the widgets, set as the centralWidget + self.widget = QWidget() # Widget that contains the collection of Vertical Box + self.vbox = QVBoxLayout() + + self.highlighted_label = None + self.all_labels = [] + # embed() + # quit() + self.file_dict = pd.DataFrame( + {'name': [f.name for f in files], + 'files': files, + 'text_color': ['red' for i in range(len(files))]} # change this to locked + ) + + for i in range(len(self.file_dict)): + label = QLabel(f'{self.file_dict["name"][i]}') + # label.setFrameShape(QLabel.Panel) + # label.setFrameShadow(QLabel.Sunken) + label.setAlignment(Qt.AlignRight) + label.mousePressEvent = lambda event, label=label: self.label_clicked(label) + # label.mousePressEvent = lambda event, label: self.label_clicked(label) + + if i == 0: + label.setStyleSheet("border: 2px solid black; " + "color : %s;" % (self.file_dict['text_color'][i])) + self.highlighted_label = label + else: + label.setStyleSheet("border: 1px solid gray; " + "color : %s;" % (self.file_dict['text_color'][i])) + self.vbox.addWidget(label) + self.all_labels.append(label) + + self.widget.setLayout(self.vbox) + + self.scroll.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOn) + self.scroll.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff) + self.scroll.setWidgetResizable(True) + self.scroll.setWidget(self.widget) + + self.central_Layout.addWidget(self.scroll, 1) self.add_actions() - def add_actions(self): - self.readout_rois = QAction('read', self) - self.readout_rois.triggered.connect(self.readout_rois_fn) - self.readout_rois.setShortcut(Qt.Key_Return) + def label_clicked(self, clicked_label): + if self.highlighted_label: + hl_mask = self.file_dict['name'] == self.highlighted_label.text() + hl_label_name, _, hl_label_text_color = self.file_dict[hl_mask].values[0] + self.highlighted_label.setStyleSheet("border: 1px solid gray; " + "color: %s;" % (hl_label_text_color)) - self.addAction(self.readout_rois) + mask = self.file_dict['name'] == clicked_label.text() + new_name, new_file, new_text_color = self.file_dict[mask].values[0] + clicked_label.setStyleSheet("border: 2px solid black;" + "color: %s;" % (new_text_color)) + self.highlighted_label = clicked_label - def readout_rois_fn(self): + self.switch_to_new_file(new_file, new_name) + + def switch_to_new_file(self, new_file, new_name): + self.readout_rois() + + self.setWindowTitle(f'efishSignalTracker | {new_name}') + + self.central_Layout.removeWidget(self.current_img) + self.current_img = ImageWithBbox(new_file, parent=self) + self.central_Layout.insertWidget(0, self.current_img, 4) + + def readout_rois(self): + new_labels = [] for roi in self.current_img.ROIs: - print(roi.pos(), roi.size()) + x0, y0 = roi.pos() + x0 /= self.current_img.img_array.shape[0] + y0 /= self.current_img.img_array.shape[1] + + w, h = roi.size() + w /= self.current_img.img_array.shape[0] + h /= self.current_img.img_array.shape[1] + + x_center = x0 + w/2 + y_center = y0 + h/2 + new_labels.append([1, x_center, y_center, w, h]) + new_labels = np.array(new_labels) + np.savetxt(self.current_img.label_path, new_labels) + + def lock_file(self): + + hl_mask = self.file_dict['name'] == self.highlighted_label.text() + hl_label_name, _, hl_label_text_color = self.file_dict[hl_mask].values[0] + + # ToDo: do everything with the index instead of mask + df_idx = self.file_dict.loc[self.file_dict['name'] == self.highlighted_label.text()].index.values[0] + self.file_dict.iloc[df_idx]['text_color'] = 'black' + + self.highlighted_label.setStyleSheet("border: 1px solid gray; " + "color: 'black';") + + new_idx = df_idx + 1 if df_idx < len(self.file_dict)-1 else 0 + new_name, new_file, new_color = self.file_dict.iloc[new_idx].values + + self.all_labels[new_idx].setStyleSheet("border: 2px solid black;" + "color: %s;" % (new_color)) + self.highlighted_label = self.all_labels[new_idx] + + self.switch_to_new_file(new_file, new_name) + # embed() + # quit() + # go to next + pass + + def add_actions(self): + self.lock_file_act = QAction('loc', self) + self.lock_file_act.triggered.connect(self.lock_file) + self.lock_file_act.setShortcut(Qt.Key_Space) + self.addAction(self.lock_file_act) def main_UI(): app = QApplication(sys.argv) # create application - img_path = sys.argv[1] - w = Bbox_correct_UI(img_path) # create window + data_path = sys.argv[1] + w = Bbox_correct_UI(data_path) # create window w.show() sys.exit(app.exec_())