working on bbox correct gui

This commit is contained in:
Till Raab 2023-11-14 12:22:48 +01:00
parent f286156607
commit 33f6bb8683

View File

@ -6,6 +6,7 @@ import pyqtgraph as pg
from PIL import Image, ImageOps from PIL import Image, ImageOps
import numpy as np import numpy as np
import pandas as pd
import torch import torch
import torchvision.transforms.functional as F import torchvision.transforms.functional as F
@ -13,8 +14,11 @@ from IPython import embed
import pathlib import pathlib
class ImageWithBbox(pg.GraphicsLayoutWidget): class ImageWithBbox(pg.GraphicsLayoutWidget):
# class ImageWithBbox(pg.ImageView):
def __init__(self, img_path, parent=None): def __init__(self, img_path, parent=None):
super(ImageWithBbox, self).__init__(parent) super(ImageWithBbox, self).__init__(parent)
self.img_path = img_path
self.label_path = (pathlib.Path(img_path).parent.parent / 'labels' / pathlib.Path(img_path).name).with_suffix('.txt')
# image setup # image setup
self.imgItem = pg.ImageItem(ColorMap='viridis') self.imgItem = pg.ImageItem(ColorMap='viridis')
@ -35,13 +39,20 @@ class ImageWithBbox(pg.GraphicsLayoutWidget):
self.plot_widget.setYRange(0, self.img_array.shape[0], padding=0) self.plot_widget.setYRange(0, self.img_array.shape[0], padding=0)
self.plot_widget.setXRange(0, self.img_array.shape[1], padding=0) self.plot_widget.setXRange(0, self.img_array.shape[1], padding=0)
label_path = (pathlib.Path(img_path).parent.parent / 'labels' / pathlib.Path(img_path).name).with_suffix('.txt') # label_path = (pathlib.Path(img_path).parent.parent / 'labels' / pathlib.Path(img_path).name).with_suffix('.txt')
self.labels = np.loadtxt(label_path, delimiter=' ') self.labels = np.loadtxt(self.label_path, delimiter=' ')
if len(self.labels) > 0 and len(self.labels.shape) == 1:
self.labels = np.expand_dims(self.labels, 0)
# add ROIS # add ROIS
self.ROIs = [] self.ROIs = []
for enu, l in enumerate(self.labels): for enu, l in enumerate(self.labels):
x_center, y_center, width, height = l[1:] * self.img_array.shape[1] # x_center, y_center, width, height = l[1:] * self.img_array.shape[1]
x_center = l[1] * self.img_array.shape[0]
y_center = l[2] * self.img_array.shape[1]
width = l[3] * self.img_array.shape[0]
height = l[4] * self.img_array.shape[1]
x0, y0, = x_center-width/2, y_center-height/2 x0, y0, = x_center-width/2, y_center-height/2
ROI = pg.RectROI((x0, y0), size=(width, height), removable=True, sideScalers=True) ROI = pg.RectROI((x0, y0), size=(width, height), removable=True, sideScalers=True)
ROI.sigRemoveRequested.connect(self.removeROI) ROI.sigRemoveRequested.connect(self.removeROI)
@ -68,9 +79,15 @@ class ImageWithBbox(pg.GraphicsLayoutWidget):
self.plot_widget.addItem(new_ROI) self.plot_widget.addItem(new_ROI)
class Bbox_correct_UI(QMainWindow): class Bbox_correct_UI(QMainWindow):
def __init__(self, img_path, parent=None): def __init__(self, data_path, parent=None):
super(Bbox_correct_UI, self).__init__(parent) super(Bbox_correct_UI, self).__init__(parent)
self.data_path = data_path
files = sorted(list(pathlib.Path(self.data_path).absolute().rglob('*images/*.png')))
# embed()
# quit()
rec = QApplication.desktop().screenGeometry() rec = QApplication.desktop().screenGeometry()
height = rec.height() height = rec.height()
width = rec.width() width = rec.width()
@ -80,31 +97,134 @@ class Bbox_correct_UI(QMainWindow):
# widget and layout # widget and layout
self.central_widget = QWidget(self) self.central_widget = QWidget(self)
self.central_gridLayout = QGridLayout() self.central_gridLayout = QGridLayout()
self.central_widget.setLayout(self.central_gridLayout) self.central_Layout = QHBoxLayout()
self.central_widget.setLayout(self.central_Layout)
self.setCentralWidget(self.central_widget) self.setCentralWidget(self.central_widget)
self.current_img = ImageWithBbox(img_path, parent=self) # image widget
self.current_img = ImageWithBbox(files[0], parent=self)
self.central_gridLayout.addWidget(self.current_img, 0, 0) self.central_Layout.addWidget(self.current_img, 4)
# image select widget
self.scroll = QScrollArea() # Scroll Area which contains the widgets, set as the centralWidget
self.widget = QWidget() # Widget that contains the collection of Vertical Box
self.vbox = QVBoxLayout()
self.highlighted_label = None
self.all_labels = []
# embed()
# quit()
self.file_dict = pd.DataFrame(
{'name': [f.name for f in files],
'files': files,
'text_color': ['red' for i in range(len(files))]} # change this to locked
)
for i in range(len(self.file_dict)):
label = QLabel(f'{self.file_dict["name"][i]}')
# label.setFrameShape(QLabel.Panel)
# label.setFrameShadow(QLabel.Sunken)
label.setAlignment(Qt.AlignRight)
label.mousePressEvent = lambda event, label=label: self.label_clicked(label)
# label.mousePressEvent = lambda event, label: self.label_clicked(label)
if i == 0:
label.setStyleSheet("border: 2px solid black; "
"color : %s;" % (self.file_dict['text_color'][i]))
self.highlighted_label = label
else:
label.setStyleSheet("border: 1px solid gray; "
"color : %s;" % (self.file_dict['text_color'][i]))
self.vbox.addWidget(label)
self.all_labels.append(label)
self.widget.setLayout(self.vbox)
self.scroll.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOn)
self.scroll.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
self.scroll.setWidgetResizable(True)
self.scroll.setWidget(self.widget)
self.central_Layout.addWidget(self.scroll, 1)
self.add_actions() self.add_actions()
def add_actions(self): def label_clicked(self, clicked_label):
self.readout_rois = QAction('read', self) if self.highlighted_label:
self.readout_rois.triggered.connect(self.readout_rois_fn) hl_mask = self.file_dict['name'] == self.highlighted_label.text()
self.readout_rois.setShortcut(Qt.Key_Return) hl_label_name, _, hl_label_text_color = self.file_dict[hl_mask].values[0]
self.highlighted_label.setStyleSheet("border: 1px solid gray; "
"color: %s;" % (hl_label_text_color))
self.addAction(self.readout_rois) mask = self.file_dict['name'] == clicked_label.text()
new_name, new_file, new_text_color = self.file_dict[mask].values[0]
clicked_label.setStyleSheet("border: 2px solid black;"
"color: %s;" % (new_text_color))
self.highlighted_label = clicked_label
def readout_rois_fn(self): self.switch_to_new_file(new_file, new_name)
def switch_to_new_file(self, new_file, new_name):
self.readout_rois()
self.setWindowTitle(f'efishSignalTracker | {new_name}')
self.central_Layout.removeWidget(self.current_img)
self.current_img = ImageWithBbox(new_file, parent=self)
self.central_Layout.insertWidget(0, self.current_img, 4)
def readout_rois(self):
new_labels = []
for roi in self.current_img.ROIs: for roi in self.current_img.ROIs:
print(roi.pos(), roi.size()) x0, y0 = roi.pos()
x0 /= self.current_img.img_array.shape[0]
y0 /= self.current_img.img_array.shape[1]
w, h = roi.size()
w /= self.current_img.img_array.shape[0]
h /= self.current_img.img_array.shape[1]
x_center = x0 + w/2
y_center = y0 + h/2
new_labels.append([1, x_center, y_center, w, h])
new_labels = np.array(new_labels)
np.savetxt(self.current_img.label_path, new_labels)
def lock_file(self):
hl_mask = self.file_dict['name'] == self.highlighted_label.text()
hl_label_name, _, hl_label_text_color = self.file_dict[hl_mask].values[0]
# ToDo: do everything with the index instead of mask
df_idx = self.file_dict.loc[self.file_dict['name'] == self.highlighted_label.text()].index.values[0]
self.file_dict.iloc[df_idx]['text_color'] = 'black'
self.highlighted_label.setStyleSheet("border: 1px solid gray; "
"color: 'black';")
new_idx = df_idx + 1 if df_idx < len(self.file_dict)-1 else 0
new_name, new_file, new_color = self.file_dict.iloc[new_idx].values
self.all_labels[new_idx].setStyleSheet("border: 2px solid black;"
"color: %s;" % (new_color))
self.highlighted_label = self.all_labels[new_idx]
self.switch_to_new_file(new_file, new_name)
# embed()
# quit()
# go to next
pass
def add_actions(self):
self.lock_file_act = QAction('loc', self)
self.lock_file_act.triggered.connect(self.lock_file)
self.lock_file_act.setShortcut(Qt.Key_Space)
self.addAction(self.lock_file_act)
def main_UI(): def main_UI():
app = QApplication(sys.argv) # create application app = QApplication(sys.argv) # create application
img_path = sys.argv[1] data_path = sys.argv[1]
w = Bbox_correct_UI(img_path) # create window w = Bbox_correct_UI(data_path) # create window
w.show() w.show()
sys.exit(app.exec_()) sys.exit(app.exec_())