model and datasets adapted to new yolo format. reinforced learning is now implemented. correct_bboxes.py now capable of exporting data to train file

This commit is contained in:
Till Raab 2023-11-16 12:19:57 +01:00
parent 3ed6f1102c
commit 324c690841
6 changed files with 80 additions and 152 deletions

View File

@ -26,8 +26,8 @@ DELTA_TIME = 60*10
TIME_OVERLAP = 60*1
# output parameters
DATA_DIR = 'data/images'
LABEL_DIR = 'data/labels'
DATA_DIR = 'data/rise_training/images'
LABEL_DIR = 'data/rise_training/labels'
OUTDIR = 'model_outputs'
INFERENCE_OUTDIR = 'inference_outputs'
for required_folders in [DATA_DIR, OUTDIR, INFERENCE_OUTDIR, LABEL_DIR]:

View File

@ -87,11 +87,13 @@ class Bbox_correct_UI(QMainWindow):
self.data_path = data_path
self.files = sorted(list(pathlib.Path(self.data_path).absolute().rglob('*images/*.png')))
self.close_without_saving_rois = False
rec = QApplication.desktop().screenGeometry()
height = rec.height()
width = rec.width()
self.resize(int(.8 * width), int(.8 * height))
self.setWindowTitle('efishSignalTracker') # set window title
# widget and layout
self.central_widget = QWidget(self)
@ -100,19 +102,23 @@ class Bbox_correct_UI(QMainWindow):
self.central_widget.setLayout(self.central_Layout)
self.setCentralWidget(self.central_widget)
###########
self.highlighted_label = None
self.all_labels = []
self.load_or_create_file_dict()
new_name, new_file, new_checked = self.file_dict.iloc[0].values
self.setWindowTitle(f'efishSignalTracker | {new_name} | {self.file_dict["checked"].sum()}/{len(self.file_dict)}')
# image widget
self.current_img = ImageWithBbox(self.files[0], parent=self)
self.current_img = ImageWithBbox(new_file, parent=self)
self.central_Layout.addWidget(self.current_img, 4)
# image select widget
self.scroll = QScrollArea() # Scroll Area which contains the widgets, set as the centralWidget
self.widget = QWidget() # Widget that contains the collection of Vertical Box
self.vbox = QVBoxLayout()
self.highlighted_label = None
self.all_labels = []
self.load_or_create_file_dict()
for i in range(len(self.file_dict)):
label = QLabel(f'{self.file_dict["name"][i]}')
# label.setFrameShape(QLabel.Panel)
@ -198,13 +204,15 @@ class Bbox_correct_UI(QMainWindow):
def switch_to_new_file(self, new_file, new_name):
self.readout_rois()
self.setWindowTitle(f'efishSignalTracker | {new_name}')
self.setWindowTitle(f'efishSignalTracker | {new_name} | {self.file_dict["checked"].sum()}/{len(self.file_dict)}')
self.central_Layout.removeWidget(self.current_img)
self.current_img = ImageWithBbox(new_file, parent=self)
self.central_Layout.insertWidget(0, self.current_img, 4)
def readout_rois(self):
if self.close_without_saving_rois:
return
new_labels = []
for roi in self.current_img.ROIs:
x0, y0 = roi.pos()
@ -225,29 +233,33 @@ class Bbox_correct_UI(QMainWindow):
fd = QFileDialog()
export_path = fd.getExistingDirectory(self, 'Select Directory')
if export_path:
embed()
quit()
# ToDo: copy the validated files and delete from csv
# ToDo: copy entries to new/exten csv
# only files that are not in the training dataset so far
print(export_path)
export_idxs = list(self.file_dict['files'][self.file_dict['checked'] == 1].index)
keep_idxs = []
for export_file_path, export_idx in zip(list(self.file_dict['files'][self.file_dict['checked'] == 1]), export_idxs):
export_image_path = pathlib.Path(export_file_path)
export_label_path = export_image_path.parent.parent / 'labels' / pathlib.Path(export_image_path.name).with_suffix('.txt')
target_image_path = pathlib.Path(export_path) / 'images' / export_image_path.name
target_label_path = pathlib.Path(export_path) / 'labels' / export_label_path.name
if not target_image_path.exists():
os.rename(export_image_path, target_image_path)
os.rename(export_label_path, target_label_path)
else:
print('nope')
pass
keep_idxs.append(export_idx)
# self.file_dict.loc[self.file_dict['name'] == self.highlighted_label.text()].index.values[0]
drop_idxs = list(set(export_idxs) - set(keep_idxs))
self.file_dict = self.file_dict.drop(drop_idxs)
self.save_file_dict()
self.close_without_saving_rois = True
self.close()
def import_data(self):
fd = QFileDialog()
import_path = fd.getExistingDirectory(self, 'Select Directory')
if import_path:
# ToDo: copy the UN validated files and add them to csv
# only files that are not in the current folder
print(import_path)
else:
print('nope')
pass
def open(self):
# ToDo: to be implemented
pass
def init_MenuBar(self):

View File

@ -26,6 +26,8 @@ class InferenceDataset(Dataset):
def __init__(self, dir_path):
self.dir_path = dir_path
self.all_images = sorted(list(Path(self.dir_path).rglob(f'*.png')))
self.file_names = np.array([Path(x).with_suffix('') for x in self.all_images])
# self.images = np.array(sorted(os.listdir(DATA_DIR)))
def __len__(self):
return len(self.all_images)
@ -39,45 +41,6 @@ class InferenceDataset(Dataset):
class CustomDataset(Dataset):
def __init__(self, dir_path, bbox_df):
self.dir_path = dir_path
self.bbox_df = bbox_df
self.all_images = np.array(sorted(pd.unique(self.bbox_df['image'])), dtype=str)
self.image_paths = list(map(lambda x: Path(self.dir_path)/x, self.all_images))
def __getitem__(self, idx):
image_name = self.all_images[idx]
image_path = os.path.join(self.dir_path, image_name)
img = Image.open(image_path)
img_tensor = F.to_tensor(img.convert('RGB'))
Cbbox = self.bbox_df[self.bbox_df['image'] == image_name]
labels = np.ones(len(Cbbox), dtype=int)
boxes = torch.as_tensor(Cbbox.loc[:, ['x0', 'y0', 'x1', 'y1']].values, dtype=torch.float32)
area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
# no crowd instances
iscrowd = torch.zeros((boxes.shape[0],), dtype=torch.int64)
target = {}
target["boxes"] = boxes
target["labels"] = torch.as_tensor(labels, dtype=torch.int64)
target["area"] = area
target["iscrowd"] = iscrowd
image_id = torch.tensor([idx])
target["image_id"] = image_id
target["image_name"] = image_name #ToDo: implement this as 3rd return...
return img_tensor, target
def __len__(self):
return len(self.all_images)
class CustomDataset_v2(Dataset):
def __init__(self, limited_idxs=None):
self.images = np.array(sorted(os.listdir(DATA_DIR)))
self.labels = np.array(sorted(os.listdir(LABEL_DIR)))
@ -99,7 +62,6 @@ class CustomDataset_v2(Dataset):
boxes, labels, area, iscrowd = self.extract_bboxes(annotations)
target = {}
target["boxes"] = boxes
target["labels"] = torch.as_tensor(labels, dtype=torch.int64)
@ -147,27 +109,11 @@ def custom_train_test_split():
train_idxs = np.sort(data_idxs[int(0.2 * len(data_idxs)):])
test_idxs = np.sort(data_idxs[:int(0.2 * len(data_idxs))])
train_data = CustomDataset_v2(limited_idxs=train_idxs)
test_data = CustomDataset_v2(limited_idxs=test_idxs)
train_data = CustomDataset(limited_idxs=train_idxs)
test_data = CustomDataset(limited_idxs=test_idxs)
return train_data, test_data
def create_train_or_test_dataset(path, train=True):
if train == True:
pfx='train'
print('Generate train dataset !')
else:
print('Generate test dataset !')
pfx='test'
csv_candidates = list(Path(path).rglob(f'*{pfx}*.csv'))
if len(csv_candidates) == 0:
print(f'no .csv files for *{pfx}* found in {Path(path)}')
quit()
else:
bboxes = pd.read_csv(csv_candidates[0], sep=',', index_col=0)
return CustomDataset(path, bboxes)
def create_train_loader(train_dataset, num_workers=0):
train_loader = DataLoader(
@ -191,17 +137,6 @@ def create_valid_loader(valid_dataset, num_workers=0):
)
return valid_loader
def create_inference_loader(inference_dataset, num_workers=0):
inference_loader = DataLoader(
inference_dataset,
batch_size=BATCH_SIZE,
shuffle=False,
num_workers=num_workers,
collate_fn=collate_fn
)
return inference_loader
if __name__ == '__main__':
train_data, test_data = custom_train_test_split()

View File

@ -97,7 +97,7 @@ def save_spec_pic(folder, s_trans, times, freq, t_idx0, t_idx1, f_idx0, f_idx1,
return fig_title
def bboxes_from_file(times_v, fish_freq, rise_idx, rise_size, fish_baseline_freq_time, fish_baseline_freq, pic_save_str,t0, t1, f0, f1):
def bboxes_from_file(times_v, fish_freq, rise_idx, rise_size, fish_baseline_freq_time, fish_baseline_freq, pic_save_str,t0, t1, f0, f1, label_save_folder):
times_v_idx0, times_v_idx1 = np.argmin(np.abs(times_v - t0)), np.argmin(np.abs(times_v - t1))
@ -185,13 +185,16 @@ def bboxes_from_file(times_v, fish_freq, rise_idx, rise_size, fish_baseline_freq
all_height
]).T
np.savetxt(LABEL_DIR/ Path(pic_save_str).with_suffix('.txt'), bbox_yolo_style)
np.savetxt(label_save_folder / Path(pic_save_str).with_suffix('.txt'), bbox_yolo_style)
return bbox_yolo_style
def main(args):
folders = list(f.parent for f in Path(args.folder).rglob('fill_times.npy'))
pic_save_folder = DATA_DIR if not args.inference else (Path('data') / Path(args.folder).name)
pic_save_folder = Path('data') / Path(args.folder).name / 'images'
label_save_folder = Path('data') / Path(args.folder).name / 'labels'
# embed()
# quit()
if len(folders) == 0:
print('no datasets containing fill_times.npy found')
@ -202,8 +205,10 @@ def main(args):
else:
print('generate inference dataset ... only image output')
if not (Path('data') / Path(args.folder).name).exists():
(Path('data') / Path(args.folder).name).mkdir(parents=True, exist_ok=True)
if not pic_save_folder.exists():
pic_save_folder.mkdir(parents=True, exist_ok=True)
if not label_save_folder.exists():
label_save_folder.mkdir(parents=True, exist_ok=True)
for enu, folder in enumerate(folders):
print(f'DataSet generation from {folder} | {enu+1}/{len(folders)}')
@ -253,52 +258,7 @@ def main(args):
if not args.inference:
bbox_yolo_style = bboxes_from_file(times_v, fish_freq, rise_idx, rise_size,
fish_baseline_freq_time, fish_baseline_freq,
pic_save_str,t0, t1, f0, f1)
#######################################################################
# if False:
# if bbox_yolo_style.shape[0] >= 1:
# f_res, t_res = freq[1] - freq[0], times[1] - times[0]
#
# fig_title = (
# f'{Path(folder).name}__{times[t_idx0]:5.0f}s-{times[t_idx1]:5.0f}s__{freq[f_idx0]:4.0f}-{freq[f_idx1]:4.0f}Hz.png').replace(
# ' ', '0')
# fig = plt.figure(figsize=IMG_SIZE, num=fig_title)
# gs = gridspec.GridSpec(1, 1, bottom=0.1, left=0.1, right=0.95, top=0.95) #
# ax = fig.add_subplot(gs[0, 0])
# ax.imshow(s_trans.squeeze(), cmap='gray', aspect='auto', origin='lower',
# extent=(times[t_idx0] / 3600, (times[t_idx1] + t_res) / 3600, freq[f_idx0], freq[f_idx1] + f_res))
# # ax.invert_yaxis()
# # ax.axis(False)
#
# for i in range(len(bbox_df)):
# # Cbbox = np.array(bbox_df.loc[i, ['x0', 'y0', 'x1', 'y1']].values, dtype=np.float32)
# Cbbox = bbox_df.loc[i, ['t0', 'f0', 't1', 'f1']]
# ax.add_patch(
# Rectangle((float(Cbbox['t0']) / 3600, float(Cbbox['f0'])),
# float(Cbbox['t1']) / 3600 - float(Cbbox['t0']) / 3600,
# float(Cbbox['f1']) - float(Cbbox['f0']),
# fill=False, color="white", linestyle='-', linewidth=2, zorder=10)
# )
#
# # print(bbox_yolo_style.T)
#
# for bbox in bbox_yolo_style:
# x0 = bbox[1] - bbox[3]/2 # x_center - width/2
# y0 = 1 - (bbox[2] + bbox[4]/2) # x_center - width/2
# w = bbox[3]
# h = bbox[4]
# ax.add_patch(
# Rectangle((x0, y0), w, h,
# fill=False, color="k", linestyle='--', linewidth=2, zorder=10,
# transform=ax.transAxes)
# )
# plt.show()
#######################################################################
# if not args.inference:
# print('save bboxes')
# bbox_df.to_csv(os.path.join(args.dataset_folder, 'bbox_dataset.csv'), columns=cols, sep=',')
pic_save_str,t0, t1, f0, f1, label_save_folder)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Evaluated electrode array recordings with multiple fish.')
@ -306,4 +266,6 @@ if __name__ == '__main__':
parser.add_argument('-i', "--inference", action="store_true")
args = parser.parse_args()
# ToDo: put "images" in image folder
main(args)

View File

@ -54,11 +54,35 @@ def infere_model(inference_loader, model, dataset_name, detection_th=0.8):
with torch.inference_mode():
outputs = model(images)
# ToDo: save outputs in label folder !
for image, img_name, output in zip(images, img_names, outputs):
# x0, y0, x1, y1
yolo_labels = []
# for (x0, y0, x1, y1) in output['boxes'].cpu().numpy():
for Cbbox, score in zip(output['boxes'].cpu().numpy(), output['scores'].cpu().numpy()):
if score < detection_th:
continue
rel_x0 = Cbbox[0] / image.shape[-2]
rel_y0 = Cbbox[1] / image.shape[-2]
rel_x1 = Cbbox[2] / image.shape[-2]
rel_y1 = Cbbox[3] / image.shape[-2]
rel_x_center = rel_x1 - (rel_x1 - rel_x0) / 2
rel_y_center = rel_y1 - (rel_y1 - rel_y0) / 2
rel_width = rel_x1 - rel_x0
rel_height = rel_y1 - rel_y0
yolo_labels.append([1, rel_x_center, rel_y_center, rel_width, rel_height])
label_path = Path('data') / dataset_name / 'labels' / Path(img_name).with_suffix('.txt')
np.savetxt(label_path, yolo_labels)
plot_inference(image, img_name, output, detection_th, dataset_name)
def main(args):
model = create_model(num_classes=NUM_CLASSES)
checkpoint = torch.load(f'{OUTDIR}/best_model.pth', map_location=DEVICE)

View File

@ -121,11 +121,6 @@ def plot_validation(img_tensor, img_name, output, target, detection_threshold):
# plt.show()
if __name__ == '__main__':
# train_data = create_train_or_test_dataset(DATA_DIR)
# test_data = create_train_or_test_dataset(DATA_DIR, train=False)
#
# train_loader = create_train_loader(train_data)
# test_loader = create_valid_loader(test_data)
train_data, test_data = custom_train_test_split()